모델의 학습이 끝났으면 이제 이 모델을 이용해 예측해봐야겠죠.
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
import import_ipynb
import models_web
class predictUnet:
def __init__(self, model_path):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = models_web.UNet(1, 2)
self.model.eval()
self.model.to(self.device)
self.load_model_weights(model_path)
def load_model_weights(self, model_path):
if os.path.exists(model_path):
self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model_name = os.path.splitext(os.path.basename(model_path))[0]
print(f"Loaded model weights successfully - {model_name}.")
else:
print("Model weights not found. Please make sure the path is correct.")
def predict(self, image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image_tensor = torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
output = self.model(image_tensor)
output = output.cpu().numpy().squeeze()
output = output[1, :, :]
output = cv2.resize(output, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
_, output = cv2.threshold(output, 0.4, 1, cv2.THRESH_BINARY)
output = output.astype(np.uint8) * 255
output = cv2.cvtColor(output, cv2.COLOR_GRAY2BGR)
merged = cv2.addWeighted(cv2.cvtColor(image, cv2.COLOR_GRAY2BGR), 0.5, output, 0.5, 0)
output_folder = './static/output_imgs'
os.makedirs(output_folder, exist_ok=True)
# 웹에서 모델이 예측한 이미지를 시각화 하기 위해 image 를 저장한다
# Flask 라이브러리는 static 폴더 내 image 를 불러올 수 있기 때문에 static 폴더 내에 저장한다
cv2.imwrite(os.path.join(output_folder, 'original_image.png'), image)
cv2.imwrite(os.path.join(output_folder, 'predicted_mask.png'), output)
cv2.imwrite(os.path.join(output_folder, 'merged_image.png'), merged)
이렇게 구현했습니다. class 함수는 정말 많이 쓰는 함수 중 한개입니다.(저만 그럴수도 있구요)
이렇게 해주면 새로운 ipynb 파일을 한개 만들고 그 파일에
import import_ipynb
from predict_Unet import predictUnet
model_path = '학습된 모델이 저장된 위치를 여기에 입력합니다'
predictor = predictUnet(model_path)
predictor.predict('예측하고 싶은 이미지가 저자된 위치를 여기에 입력합니다.')
이렇게 코드를 작성하면 예측이 됩니다.
주의해야할 점은 저장된(모델, 이미지) 파일을 불러올 땐 주소만 입력하면 안됩니다!
"주소+그 파일이름" 형태여야 내가 원하는 그 파일만 읽어옵니다.
(전체 코드는 차후 소프트웨어 등록이 끝나면 주석과 함께 첨부할 예정입니다.)
'AI > 국제공동프로젝트' 카테고리의 다른 글
6. (코드 공유)Brain Tumor 예측 인공지능 - Flask 웹, HTML, UNet, pytoch, Jupyter Notebook (최종) (0) | 2023.08.03 |
---|---|
5. Brain Tumor 예측 인공지능 - 웹 구현하기(1), Flask 웹 개발 (0) | 2023.07.28 |
3. Brain Tumor 예측 인공지능 - 모델 훈련시키기 (0) | 2023.07.28 |
2. Brain Tumor 예측 인공지능 - UNet 모델 구조 구현 (0) | 2023.07.28 |
1. Brain Tumor 예측 인공지능 - 개발 환경설정 (0) | 2023.07.28 |