import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
# UNet 모델 구조 정의
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder1 = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(inplace=True)
)
self.pool1 = nn.MaxPool2d(2)
self.encoder2 = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True)
)
self.pool2 = nn.MaxPool2d(2)
self.encoder3 = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(inplace=True)
)
self.pool3 = nn.MaxPool2d(2)
self.bottleneck = nn.Sequential(
nn.Conv2d(256, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.decoder3 = nn.Sequential(
nn.Conv2d(512, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.decoder2 = nn.Sequential(
nn.Conv2d(256, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.decoder1 = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(inplace=True)
)
self.final_conv = nn.Conv2d(64, out_channels, 1)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
bottleneck = self.bottleneck(self.pool3(enc3))
dec3 = self.decoder3(torch.cat((enc3, self.upconv3(bottleneck)), 1))
dec2 = self.decoder2(torch.cat((enc2, self.upconv2(dec3)), 1))
dec1 = self.decoder1(torch.cat((enc1, self.upconv1(dec2)), 1))
final_output = self.final_conv(dec1)
return final_output
<궁금증>
Q1. class 함수로 구현해야 하나요?
A1. 이렇게 구현해주면 UNet(in_channels, out_channels) 를 이용해서 UNet 모델을 사용할 수 있습니다.
Q2. in_channels 와 out_channels 값은 어떤 값으로 설정해줘야 하나요?
A2. 저는 임의의 어떤 Brain Tumor Image 를 입력받으면, 이 이미지에서 Tumor 인 지점을 예측하고 싶었습니다.
그렇기 때문에 in_channels = 1, out_channels = 2 로 설정해줘야 합니다.
(out_channels = 2 인 이유는 Tumor 이거나 아니거나 이기 때문입니다.)
Q3. 이 코드 파일을 이용하는 방법은 뭔가요?
A3. class 함수로 모델을 정의한 이유 입니다. 이 모델은 models_web.ipynb 파일에 구현되어 있는데
사진처럼 같은 위치에 다른 .ipynb 파일을 생성하고 새로운 커널에
import import_ipynb
import models_web
model = models_web.UNet(1, num_classes)
이렇게 코드를 작성해주면 models_web.ipynb 에 구현된 UNet 구조를 가지고 올 수 있습니다.
(전체 코드는 차후 소프트웨어 등록이 끝나면 주석과 함께 첨부할 예정입니다.)
'AI > 국제공동프로젝트' 카테고리의 다른 글
6. (코드 공유)Brain Tumor 예측 인공지능 - Flask 웹, HTML, UNet, pytoch, Jupyter Notebook (최종) (0) | 2023.08.03 |
---|---|
5. Brain Tumor 예측 인공지능 - 웹 구현하기(1), Flask 웹 개발 (0) | 2023.07.28 |
4. Brain Tumor 예측 인공지능 - 이미지 예측하기 (0) | 2023.07.28 |
3. Brain Tumor 예측 인공지능 - 모델 훈련시키기 (0) | 2023.07.28 |
1. Brain Tumor 예측 인공지능 - 개발 환경설정 (0) | 2023.07.28 |