개발/AI

[Pytorch] GoogLeNet : 구현 연습

jykim23 2023. 11. 29. 16:50

ConvBlock

import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1):
        super(ConvBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                                kernel_size=kernel_size, padding=padding, stride=stride),
            nn.ReLU()
        )
    def __call__(self, x):
        x = self.layers(x)
        return x

 

Inception

class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()
        self.branch1 = nn.Sequential(
            ConvBlock(in_channels=in_channels, out_channels=ch1x1, kernel_size=1, padding=0, stride=1)
        )
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels=in_channels, out_channels=ch3x3red, kernel_size=1, padding=0, stride=1),
            ConvBlock(in_channels=ch3x3red, out_channels=ch3x3, kernel_size=3, padding=1, stride=1)
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels=in_channels, out_channels=ch5x5red, kernel_size=1, padding=0, stride=1),
            ConvBlock(in_channels=ch5x5red, out_channels=ch5x5, kernel_size=5, padding=2, stride=1)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            ConvBlock(in_channels=in_channels, out_channels=pool_proj, kernel_size=1, padding=0, stride=1)
        )
    def __call__(self, x):
        out_branch1 = self.branch1(x)
        out_branch2 = self.branch2(x)
        out_branch3 = self.branch3(x)
        out_branch4 = self.branch4(x)
        x = torch.concat([out_branch1, out_branch2, out_branch3, out_branch4], dim=1)
        return x

 

GoogLeNet

class GoogLeNet(nn.Module):
    def __init__(self):
        super(GoogLeNet, self).__init__()
        self.conv1 = ConvBlock(in_channels=3, out_channels=64, kernel_size=7, padding=3, stride=2)
        self.max_pool1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
        
        self.conv2 = nn.Sequential(
            ConvBlock(in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1),
            ConvBlock(in_channels=64, out_channels=192, kernel_size=3, padding=1, stride=1)
        )
        self.max_pool2 = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)

        self.inception_3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception_3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.max_pool3 = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)

        self.inception_4a = Inception(480, 192, 96, 208, 16,48, 64)
        self.inception_4b = Inception(512, 160, 112, 224, 24,64, 64)
        self.inception_4c = Inception(512, 128, 128, 256, 24,64, 64)
        self.inception_4d = Inception(512, 112, 144, 288, 32,64, 64)
        self.inception_4e = Inception(528, 256, 260, 320, 32,128, 128)
        self.max_pool4 = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)

        self.inception_5a = Inception(832, 256,160, 320, 32, 128, 128)
        self.inception_5b = Inception(832, 384, 192, 384, 48, 128, 128)
        self.avg_pool = nn.AvgPool2d(kernel_size=7, padding=0, stride=1)

        self.fc = nn.Linear(in_features=1024, out_features=1000)

    def forward(self, x):
        x = self.conv1(x)
        print(f'conv1        : {x.shape}')
        x = self.max_pool1(x)
        print(f'max_pool1    : {x.shape}')
        x = self.conv2(x)
        print(f'conv2        : {x.shape}')
        x = self.max_pool2(x)
        print(f'max_pool2    : {x.shape}')
        x = self.inception_3a(x)
        print(f'inception_3a : {x.shape}')
        x = self.inception_3b(x)
        print(f'inception_3b : {x.shape}')
        x = self.max_pool3(x)
        print(f'max_pool3    : {x.shape}')
        x = self.inception_4a(x)
        print(f'inception_4a : {x.shape}')
        x = self.inception_4b(x)
        print(f'inception_4b : {x.shape}')
        x = self.inception_4c(x)
        print(f'inception_4c : {x.shape}')
        x = self.inception_4d(x)
        print(f'inception_4d : {x.shape}')
        x = self.inception_4e(x)
        print(f'inception_4e : {x.shape}')
        x = self.max_pool4(x)
        print(f'max_pool4    : {x.shape}')

        x = self.inception_5a(x)
        print(f'inception_5a : {x.shape}')
        x = self.inception_5b(x)
        print(f'inception_5b : {x.shape}')
        x = self.avg_pool(x)
        print(f'avg_pool     : {x.shape}')

        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

 

 

실행

input_tensor = torch.rand(size=(2, 3, 224, 224))
model = GoogLeNet()
output = model.forward(input_tensor)
print(f'total        : {output.shape}')

 

출력

conv1        : torch.Size([2, 64, 112, 112])
max_pool1    : torch.Size([2, 64, 56, 56])
conv2        : torch.Size([2, 192, 56, 56])
max_pool2    : torch.Size([2, 192, 28, 28])
inception_3a : torch.Size([2, 256, 28, 28])
inception_3b : torch.Size([2, 480, 28, 28])
max_pool3    : torch.Size([2, 480, 14, 14])
inception_4a : torch.Size([2, 512, 14, 14])
inception_4b : torch.Size([2, 512, 14, 14])
inception_4c : torch.Size([2, 512, 14, 14])
inception_4d : torch.Size([2, 528, 14, 14])
inception_4e : torch.Size([2, 832, 14, 14])
max_pool4    : torch.Size([2, 832, 7, 7])
inception_5a : torch.Size([2, 832, 7, 7])
inception_5b : torch.Size([2, 1024, 7, 7])
avg_pool     : torch.Size([2, 1024, 1, 1])
total        : torch.Size([2, 1000])