인공지능을 좋아하는 곧미남

BackBone Encoder Layer에서 Feature Map 추출 본문

code_study/pytorch

BackBone Encoder Layer에서 Feature Map 추출

곧미남 2022. 1. 17. 17:23

Image Feature를 학습하기 위해 BackBone으로 사용되는 많은 모델들이 있습니다. 그 중 ResNet50 구조에서 원하는 layer에서 feature map을 추출할 수 있는 코드를 pytorch를 이용해 구현해보겠습니다.

 

최종적인 목표는 Multi Layer의 Feature를 사용해서 Image Resolution(Scale)에 강건한 모델을 만들고 싶어 이런 방법을 생각해보았습니다. 

 

오늘의 내용은 아래의 목차와 같습니다.

 

< INDEX >

 

1. Multi-Scale를 사용하는 이유 및 장점

 

2. ResNet50 구조

 

3. ResNet50 구조에서 원하는 Layer의 Feature Map을 추출

 

4. 내용 고찰 정리


1. Multi-Scale를 사용하는 이유 및 장점

Feature pyramid network의 내용을 참고해보면, pyramid 형식으로 feature map의 해상도(size)별로 가져와 결합하여 Image Scale에 강건한 학습 모델을 만들어 냈습니다. 

 

좀 더 자세히 말씀드리면, Convolutional network에서 입력층에 보다 가까울수록 feature map은 높은 해상도를 가지고 저수준의 특징(가장자리, 곡선)을 대표합니다. 

 

그리고 입력층보다 멀리 떨어질수록 feature map은 낮은 해상도를 가지고 좀 더 상세한 이미지의 질감과 객체의 일부분 등 class를 추론할 수 있는 고수준의 특징을 대표합니다.

 

즉, 서로다른 해상도의 feature map을 결합한 구성으로 Scale에 강건한 학습 모델을 구성할 수 있습니다.


2. ResNet50 구조

총 4개의 Block으로 이루어져 있는 ResNet 구조에서 Block의 마지막 Layer의 Feature를 추출해보겠습니다.


3. ResNet50 구조에서 원하는 Layer의 Feature Map을 추출

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        """ 변경전
        # x = self.layer1(x)
        # x = self.layer2(x)
        # x = self.layer3(x)
        # x = self.layer4(x)

        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.fc(x)

        # return x
        """
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # layer1, 2, 3, 4의 output feature map으로 lstm conv layer 모델에 input하기 위해 추출.
        feature_1 = self.layer1(x)
        feature_2 = self.layer2(x)
        feature_3 = self.layer3(x)
        feature_4 = self.layer4(x)

        return feature_1, feature_2, feature_3, feature_4


    def forward(self, x):
        feature_1, feature_2, feature_3, feature_4 = self._forward_impl(x)

        return feature_1, feature_2, feature_3, feature_4

def _forward_impl(self, x): 메서드에서 주석에 적혀있는 변경전에서 현제 feature_1, 2, 3, 4 변수를 추가했고 각 block 끝에서 출력되는 Feature map의 값을 출력할 수 있습니다.


4. 내용 고찰 정리

Multi Scale에 강건하기 위한 모델을 만들기 위해 API 코드를 직접 커스텀하고 있음에 발전했다고 느끼고 있습니다.

 

이후에 ConvLSTM 관련 미래 예측 Model에 해당 내용이 첨가되는데 거기서는 내가 직접 커스텀한 코드를 설명하면 더욱 좋을 것 같습니다.

 

읽어주셔서 감사합니다.

'code_study > pytorch' 카테고리의 다른 글

[torchvision] ImageFolder  (0) 2022.05.16
pytorch의 def forward(self, x)  (0) 2022.01.18
torch.nn.Sequential  (0) 2022.01.14
pytorch DataLoader  (0) 2022.01.10
[pytorch] torch.nn.Module.parameters(), named_parameters()  (0) 2021.12.10
Comments