AI/국제공동프로젝트

2. Brain Tumor 예측 인공지능 - UNet 모델 구조 구현

살랑춤춰요 2023. 7. 28. 18:10
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models

# UNet 모델 구조 정의
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        self.encoder1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.encoder3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.decoder3 = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.final_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))

        bottleneck = self.bottleneck(self.pool3(enc3))

        dec3 = self.decoder3(torch.cat((enc3, self.upconv3(bottleneck)), 1))
        dec2 = self.decoder2(torch.cat((enc2, self.upconv2(dec3)), 1))
        dec1 = self.decoder1(torch.cat((enc1, self.upconv1(dec2)), 1))

        final_output = self.final_conv(dec1)
        return final_output

 

<궁금증>

Q1. class 함수로 구현해야 하나요?

A1. 이렇게 구현해주면 UNet(in_channels, out_channels) 를 이용해서 UNet 모델을 사용할 수 있습니다.

 

Q2. in_channels 와 out_channels 값은 어떤 값으로 설정해줘야 하나요?

A2. 저는 임의의 어떤 Brain Tumor Image 를 입력받으면, 이 이미지에서 Tumor 인 지점을 예측하고 싶었습니다.

그렇기 때문에 in_channels = 1, out_channels = 2 로 설정해줘야 합니다.

 

(out_channels = 2 인 이유는 Tumor 이거나 아니거나 이기 때문입니다.)

 

Q3. 이 코드 파일을 이용하는 방법은 뭔가요?

A3. class 함수로 모델을 정의한 이유 입니다. 이 모델은 models_web.ipynb 파일에 구현되어 있는데

사진처럼 같은 위치에 다른 .ipynb 파일을 생성하고 새로운 커널에

import import_ipynb
import models_web
 
model = models_web.UNet(1, num_classes)

이렇게 코드를 작성해주면 models_web.ipynb 에 구현된 UNet 구조를 가지고 올 수 있습니다.

 

(전체 코드는 차후 소프트웨어 등록이 끝나면 주석과 함께 첨부할 예정입니다.)