DenseNet 이해하기



DenseNet이란?

DenseNet은 저번 글에서 설명한 ResNet과 유사하게 입력값과 출력값을 더하는 형태의 모델이지만, 명백하게 다릅니다.

1*_Y7-f9GpV7F93siM1js0cg

위 사진은 DenseNet모델 구조이며, 이미지가 들어가고 이후 Dense Block이라는 구조를 지나 convolution layer 와 pooling layer를 거쳐가는 것을 알 수 있습니다.

여기서 convolution layer는 1x1 convolution을 사용합니다. ( 1x1 convolution은 적은 연산량으로 채널을 압축시키거나, 증가시킬 수 있는 장점이 있습니다.)

image

먼저 DenseNet과 ResNet의 차이를 설명하겠습니다.

ResNet은 H(x) = F(x) + x의 꼴로 입력값 + 출력값이며, 입력에서 뽑은 특성과 출력에서 뽑은 더 복잡한 특성을 모두 사용한다는 장점이 있습니다.
또한 더하기 연산한 역전파 계산시 기울기가 1이기에, 기울기 소실이 발생하지 않게 할수있는 장점이 모델이었습니다.

DenseNet도 마찬가지로 H(x) = F(x) + x의 꼴로 입력값 + 출력값 이지만, DenseNet에서의 “+”는 ResNet은과 다르게 ADD가 아닌 CONCAT입니다.
간단한 파이썬코드로 Add와 Concatation 차이를 알아보겠습니다.

[Add]

X = [[1,2],[3,4]]
Y = [[5,6],[7,8]] 
print(X+Y)

>>> [[6, 8], [10, 12]]

[Concatation]


X = torch.Tensor([[1,2],[3,4]])
Y = torch.Tensor([[5,6],[7,8]]) 
print(torch.cat([X, Y]), 0)
print(torch.cat([X, Y]).shape)

>>> tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]])
>>> torch.Size([4, 2])

보시는 것과 같이 Add는 값과 값을 더하는 것과 같다면, Concatation은 차원과 차원을 붙인다는 개념에 가깝습니다. 즉, ResNet과 DeenseNet은 명백하게 다르다는 것 입니다.


Dense Block

image

이번엔 Dense Block을 확대해 이해해보겠습니다. 여기선 입력값이 “BN / RELU / CONV” (BatchNomalization - ReLU - Convolution) layer을 거친 후의 결과 값과 거치기 전의 입력 값을 Concatation합니다. (사진에선 이 과정을 4번 반복합니다.)

여기서 굳이 add가 아닌 concat을 하는 이유가 뭘까요? 이전 feature map 정보를 재사용을 통해 네트워크의 잠재력을 활용함으로써, 학습하기 쉽게 채널을 압축 시킵니다. 더불어 원래 층을 더 쌓는다면 그만큼 연산량이 증가해야하지만, 기존의 연산을 재활용하기에 연산량 또한 효율적입니다.

DenseNet Architectures

image

위는 DenseNet 기본적인 구조입니다. 121 / 169 / 201 / 264 등 깊이에 따라 Dense Block을 어떻게 사용하는지 다르며, 필터의 크기 층의 개수등 모두 다르기 때문에, 자신의 입력 데이터에 적합한 모델을 이용하시길 바랍니다.

Pytorch Code

class BottleNeck(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        inner_channels = 4 * growth_rate

        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, inner_channels, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(inner_channels),
            nn.ReLU(),
            nn.Conv2d(inner_channels, growth_rate, 3, stride=1, padding=1, bias=False)
        )

        self.shortcut = nn.Sequential()

    def forward(self, x):
        return torch.cat([self.shortcut(x), self.residual(x)], 1)
class Transition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.down_sample = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
            nn.AvgPool2d(2, stride=2)
        )

    def forward(self, x):
        return self.down_sample(x)


Reference
  • https://arxiv.org/pdf/1608.06993.pdf
  • https://deep-learning-study.tistory.com/545