AI/국제공동프로젝트

4. Brain Tumor 예측 인공지능 - 이미지 예측하기

살랑춤춰요 2023. 7. 28. 18:29

모델의 학습이 끝났으면 이제 이 모델을 이용해 예측해봐야겠죠.

 

import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries

import import_ipynb
import models_web

class predictUnet:
    def __init__(self, model_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = models_web.UNet(1, 2)
        self.model.eval()
        self.model.to(self.device)
       
        self.load_model_weights(model_path)

    def load_model_weights(self, model_path):
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
            model_name = os.path.splitext(os.path.basename(model_path))[0]
            print(f"Loaded model weights successfully - {model_name}.")
        else:
            print("Model weights not found. Please make sure the path is correct.")

    def predict(self, image_path):
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        image_tensor = torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        image_tensor = image_tensor.to(self.device)

        with torch.no_grad():
            output = self.model(image_tensor)
            output = output.cpu().numpy().squeeze()
            output = output[1, :, :]
           
            output = cv2.resize(output, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
            _, output = cv2.threshold(output, 0.4, 1, cv2.THRESH_BINARY)
            output = output.astype(np.uint8) * 255
            output = cv2.cvtColor(output, cv2.COLOR_GRAY2BGR)
            merged = cv2.addWeighted(cv2.cvtColor(image, cv2.COLOR_GRAY2BGR), 0.5, output, 0.5, 0)

            output_folder = './static/output_imgs'
            os.makedirs(output_folder, exist_ok=True)

            # 웹에서 모델이 예측한 이미지를 시각화 하기 위해 image 를 저장한다
            # Flask 라이브러리는 static 폴더 내 image 를 불러올 수 있기 때문에 static 폴더 내에 저장한다
            cv2.imwrite(os.path.join(output_folder, 'original_image.png'), image)
            cv2.imwrite(os.path.join(output_folder, 'predicted_mask.png'), output)
            cv2.imwrite(os.path.join(output_folder, 'merged_image.png'), merged)

이렇게 구현했습니다. class 함수는 정말 많이 쓰는 함수 중 한개입니다.(저만 그럴수도 있구요)

이렇게 해주면 새로운 ipynb 파일을 한개 만들고 그 파일에

import import_ipynb
from predict_Unet import predictUnet
 
model_path = '학습된 모델이 저장된 위치를 여기에 입력합니다'
predictor = predictUnet(model_path)
predictor.predict('예측하고 싶은 이미지가 저자된 위치를 여기에 입력합니다.')

이렇게 코드를 작성하면 예측이 됩니다.

주의해야할 점은 저장된(모델, 이미지) 파일을 불러올 땐 주소만 입력하면 안됩니다!

"주소+그 파일이름" 형태여야 내가 원하는 그 파일만 읽어옵니다.

 

(전체 코드는 차후 소프트웨어 등록이 끝나면 주석과 함께 첨부할 예정입니다.)