You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1279 lines
73 KiB
Plaintext
1279 lines
73 KiB
Plaintext
7 months ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Load Dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"C:\\Users\\pinb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import os\n",
|
||
|
"from glob import glob\n",
|
||
|
"import numpy as np\n",
|
||
|
"import torch\n",
|
||
|
"from torch.utils.data import Dataset\n",
|
||
|
"from PIL import Image\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from torchvision import transforms, datasets\n",
|
||
|
"import random\n",
|
||
|
"import cv2\n",
|
||
|
"\n",
|
||
|
"class CustomDataset(Dataset):\n",
|
||
|
" def __init__(self, list_imgs, list_masks, transform=None):\n",
|
||
|
" self.list_imgs = list_imgs\n",
|
||
|
" self.list_masks = list_masks\n",
|
||
|
" self.transform = transform\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return len(self.list_imgs)\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, index):\n",
|
||
|
" img_path = self.list_imgs[index]\n",
|
||
|
" mask_path = self.list_masks[index]\n",
|
||
|
"\n",
|
||
|
" img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)\n",
|
||
|
" mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)\n",
|
||
|
"\n",
|
||
|
" # 이미지 크기를 512x512로 변경\n",
|
||
|
" img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)\n",
|
||
|
" mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)\n",
|
||
|
"\n",
|
||
|
" img = img.astype(np.float32) / 255.0\n",
|
||
|
" mask = mask.astype(np.float32) / 255.0\n",
|
||
|
"\n",
|
||
|
" if img.ndim == 2:\n",
|
||
|
" img = img[:, :, np.newaxis]\n",
|
||
|
" if mask.ndim == 2:\n",
|
||
|
" mask = mask[:, :, np.newaxis]\n",
|
||
|
"\n",
|
||
|
" data = {'input': img, 'label': mask}\n",
|
||
|
"\n",
|
||
|
" if self.transform:\n",
|
||
|
" data = self.transform(data)\n",
|
||
|
" \n",
|
||
|
" return data\n",
|
||
|
"\n",
|
||
|
"def create_datasets(img_dir, mask_dir, train_ratio=0.7, val_ratio=0.2, transform=None):\n",
|
||
|
" list_imgs = sorted(glob(os.path.join(img_dir, '**', '*.bmp'), recursive=True))\n",
|
||
|
" list_masks = sorted(glob(os.path.join(mask_dir, '**', '*.bmp'), recursive=True))\n",
|
||
|
"\n",
|
||
|
" # combined = list(zip(list_imgs, list_masks))\n",
|
||
|
" # random.shuffle(combined)\n",
|
||
|
" # list_imgs, list_masks = zip(*combined)\n",
|
||
|
"\n",
|
||
|
" num_imgs = len(list_imgs)\n",
|
||
|
" num_train = int(num_imgs * train_ratio)\n",
|
||
|
" num_val = int(num_imgs * val_ratio)\n",
|
||
|
"\n",
|
||
|
" # train_set = CustomDataset(list_imgs[:num_train], list_masks[:num_train], transform)\n",
|
||
|
" # val_set = CustomDataset(list_imgs[num_train:num_train + num_val], list_masks[num_train:num_train + num_val], transform)\n",
|
||
|
" test_set = CustomDataset(list_imgs[num_train + num_val:], list_masks[num_train + num_val:], transform)\n",
|
||
|
"\n",
|
||
|
" # return train_set, val_set, test_set\n",
|
||
|
" return test_set\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Argument"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 트렌스폼 구현하기\n",
|
||
|
"class ToTensor(object):\n",
|
||
|
" # def __call__(self, data):\n",
|
||
|
" # label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
" # label = label.transpose((2, 0, 1)).astype(np.float32)\n",
|
||
|
" # input = input.transpose((2, 0, 1)).astype(np.float32)\n",
|
||
|
"\n",
|
||
|
" # data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}\n",
|
||
|
"\n",
|
||
|
" # return data\n",
|
||
|
" def __call__(self, data):\n",
|
||
|
" label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
" # 이미지가 이미 그레이스케일이면 채널 차원 추가\n",
|
||
|
" if label.ndim == 2:\n",
|
||
|
" label = label[:, :, np.newaxis]\n",
|
||
|
" if input.ndim == 2:\n",
|
||
|
" input = input[:, :, np.newaxis]\n",
|
||
|
"\n",
|
||
|
" # 채널을 첫 번째 차원으로 이동\n",
|
||
|
" label = label.transpose((2, 0, 1)).astype(np.float32)\n",
|
||
|
" input = input.transpose((2, 0, 1)).astype(np.float32)\n",
|
||
|
"\n",
|
||
|
" data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}\n",
|
||
|
"\n",
|
||
|
" return data\n",
|
||
|
"\n",
|
||
|
"class Normalization(object):\n",
|
||
|
" def __init__(self, mean=0.5, std=0.5):\n",
|
||
|
" self.mean = mean\n",
|
||
|
" self.std = std\n",
|
||
|
"\n",
|
||
|
" def __call__(self, data):\n",
|
||
|
" label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
" input = (input - self.mean) / self.std\n",
|
||
|
"\n",
|
||
|
" data = {'label': label, 'input': input}\n",
|
||
|
"\n",
|
||
|
" return data\n",
|
||
|
"\n",
|
||
|
"class RandomFlip(object):\n",
|
||
|
" def __call__(self, data):\n",
|
||
|
" label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
" if np.random.rand() > 0.5:\n",
|
||
|
" label = np.fliplr(label)\n",
|
||
|
" input = np.fliplr(input)\n",
|
||
|
"\n",
|
||
|
" if np.random.rand() > 0.5:\n",
|
||
|
" label = np.flipud(label)\n",
|
||
|
" input = np.flipud(input)\n",
|
||
|
"\n",
|
||
|
" data = {'label': label, 'input': input}\n",
|
||
|
"\n",
|
||
|
" return data\n",
|
||
|
" \n",
|
||
|
"# class Resize(object):\n",
|
||
|
"# def __init__(self, output_size):\n",
|
||
|
"# assert isinstance(output_size, (int, tuple))\n",
|
||
|
"# self.output_size = output_size\n",
|
||
|
"\n",
|
||
|
"# def __call__(self, data):\n",
|
||
|
"# label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
"# h, w = input.shape[:2]\n",
|
||
|
"# if isinstance(self.output_size, int):\n",
|
||
|
"# if h > w:\n",
|
||
|
"# new_h, new_w = self.output_size * h / w, self.output_size\n",
|
||
|
"# else:\n",
|
||
|
"# new_h, new_w = self.output_size, self.output_size * w / h\n",
|
||
|
"# else:\n",
|
||
|
"# new_h, new_w = self.output_size\n",
|
||
|
"\n",
|
||
|
"# new_h, new_w = int(new_h), int(new_w)\n",
|
||
|
"\n",
|
||
|
"# input = cv2.resize(input, (new_w, new_h))\n",
|
||
|
"# label = cv2.resize(label, (new_w, new_h))\n",
|
||
|
"\n",
|
||
|
"# return {'label': label, 'input': input}\n",
|
||
|
"\n",
|
||
|
"class Rotate(object):\n",
|
||
|
" def __init__(self, angle_range):\n",
|
||
|
" assert isinstance(angle_range, (tuple, list)) and len(angle_range) == 2\n",
|
||
|
" self.angle_min, self.angle_max = angle_range\n",
|
||
|
"\n",
|
||
|
" def __call__(self, data):\n",
|
||
|
" label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
" # NumPy 배열로 변환 (필요한 경우)\n",
|
||
|
" if not isinstance(input, np.ndarray):\n",
|
||
|
" input = np.array(input)\n",
|
||
|
" if not isinstance(label, np.ndarray):\n",
|
||
|
" label = np.array(label)\n",
|
||
|
"\n",
|
||
|
" # (H, W, C) 형태를 (H, W)로 변경 (필요한 경우)\n",
|
||
|
" if input.ndim == 3 and input.shape[2] == 1:\n",
|
||
|
" input = input.squeeze(2)\n",
|
||
|
" if label.ndim == 3 and label.shape[2] == 1:\n",
|
||
|
" label = label.squeeze(2)\n",
|
||
|
"\n",
|
||
|
" # 랜덤 각도 선택 및 회전 적용\n",
|
||
|
" angle = np.random.uniform(self.angle_min, self.angle_max)\n",
|
||
|
" h, w = input.shape[:2]\n",
|
||
|
" center = (w / 2, h / 2)\n",
|
||
|
" rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)\n",
|
||
|
" input = cv2.warpAffine(input, rot_matrix, (w, h))\n",
|
||
|
" label = cv2.warpAffine(label, rot_matrix, (w, h))\n",
|
||
|
"\n",
|
||
|
" return {'label': label, 'input': input}\n",
|
||
|
" \n",
|
||
|
"# class Crop(object):\n",
|
||
|
"# def __init__(self, output_size):\n",
|
||
|
"# assert isinstance(output_size, (int, tuple))\n",
|
||
|
"# if isinstance(output_size, int):\n",
|
||
|
"# self.output_size = (output_size, output_size)\n",
|
||
|
"# else:\n",
|
||
|
"# assert len(output_size) == 2\n",
|
||
|
"# self.output_size = output_size\n",
|
||
|
"\n",
|
||
|
"# def __call__(self, data):\n",
|
||
|
"# label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
"# h, w = input.shape[:2]\n",
|
||
|
"# new_h, new_w = self.output_size\n",
|
||
|
"\n",
|
||
|
"# top = np.random.randint(0, h - new_h)\n",
|
||
|
"# left = np.random.randint(0, w - new_w)\n",
|
||
|
"\n",
|
||
|
"# input = input[top: top + new_h, left: left + new_w]\n",
|
||
|
"# label = label[top: top + new_h, left: left + new_w]\n",
|
||
|
"\n",
|
||
|
"# return {'label': label, 'input': input}\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# UNet Model (Origin)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"## 라이브러리 불러오기\n",
|
||
|
"import os\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"from torch.utils.data import DataLoader\n",
|
||
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
||
|
"\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"## 네트워크 구축하기\n",
|
||
|
"class UNet(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(UNet, self).__init__()\n",
|
||
|
"\n",
|
||
|
" # Convolution + BatchNormalization + Relu 정의하기\n",
|
||
|
" def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): \n",
|
||
|
" layers = []\n",
|
||
|
" layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,\n",
|
||
|
" kernel_size=kernel_size, stride=stride, padding=padding,\n",
|
||
|
" bias=bias)]\n",
|
||
|
" layers += [nn.BatchNorm2d(num_features=out_channels)]\n",
|
||
|
" layers += [nn.ReLU()]\n",
|
||
|
"\n",
|
||
|
" cbr = nn.Sequential(*layers)\n",
|
||
|
"\n",
|
||
|
" return cbr\n",
|
||
|
"\n",
|
||
|
" # 수축 경로(Contracting path)\n",
|
||
|
" self.enc1_1 = CBR2d(in_channels=1, out_channels=64)\n",
|
||
|
" self.enc1_2 = CBR2d(in_channels=64, out_channels=64)\n",
|
||
|
"\n",
|
||
|
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc2_1 = CBR2d(in_channels=64, out_channels=128)\n",
|
||
|
" self.enc2_2 = CBR2d(in_channels=128, out_channels=128)\n",
|
||
|
"\n",
|
||
|
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc3_1 = CBR2d(in_channels=128, out_channels=256)\n",
|
||
|
" self.enc3_2 = CBR2d(in_channels=256, out_channels=256)\n",
|
||
|
"\n",
|
||
|
" self.pool3 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc4_1 = CBR2d(in_channels=256, out_channels=512)\n",
|
||
|
" self.enc4_2 = CBR2d(in_channels=512, out_channels=512)\n",
|
||
|
"\n",
|
||
|
" self.pool4 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)\n",
|
||
|
"\n",
|
||
|
" # 확장 경로(Expansive path)\n",
|
||
|
" self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)\n",
|
||
|
"\n",
|
||
|
" self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,\n",
|
||
|
" kernel_size=2, stride=2, padding=0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512)\n",
|
||
|
" self.dec4_1 = CBR2d(in_channels=512, out_channels=256)\n",
|
||
|
"\n",
|
||
|
" self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,\n",
|
||
|
" kernel_size=2, stride=2, padding=0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256)\n",
|
||
|
" self.dec3_1 = CBR2d(in_channels=256, out_channels=128)\n",
|
||
|
"\n",
|
||
|
" self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,\n",
|
||
|
" kernel_size=2, stride=2, padding=0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128)\n",
|
||
|
" self.dec2_1 = CBR2d(in_channels=128, out_channels=64)\n",
|
||
|
"\n",
|
||
|
" self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,\n",
|
||
|
" kernel_size=2, stride=2, padding=0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64)\n",
|
||
|
" self.dec1_1 = CBR2d(in_channels=64, out_channels=64)\n",
|
||
|
"\n",
|
||
|
" self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)\n",
|
||
|
" \n",
|
||
|
" # forward 함수 정의하기\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" enc1_1 = self.enc1_1(x)\n",
|
||
|
" enc1_2 = self.enc1_2(enc1_1)\n",
|
||
|
" pool1 = self.pool1(enc1_2)\n",
|
||
|
"\n",
|
||
|
" enc2_1 = self.enc2_1(pool1)\n",
|
||
|
" enc2_2 = self.enc2_2(enc2_1)\n",
|
||
|
" pool2 = self.pool2(enc2_2)\n",
|
||
|
"\n",
|
||
|
" enc3_1 = self.enc3_1(pool2)\n",
|
||
|
" enc3_2 = self.enc3_2(enc3_1)\n",
|
||
|
" pool3 = self.pool3(enc3_2)\n",
|
||
|
"\n",
|
||
|
" enc4_1 = self.enc4_1(pool3)\n",
|
||
|
" enc4_2 = self.enc4_2(enc4_1)\n",
|
||
|
" pool4 = self.pool4(enc4_2)\n",
|
||
|
"\n",
|
||
|
" enc5_1 = self.enc5_1(pool4)\n",
|
||
|
"\n",
|
||
|
" dec5_1 = self.dec5_1(enc5_1)\n",
|
||
|
"\n",
|
||
|
" unpool4 = self.unpool4(dec5_1)\n",
|
||
|
" cat4 = torch.cat((unpool4, enc4_2), dim=1)\n",
|
||
|
" dec4_2 = self.dec4_2(cat4)\n",
|
||
|
" dec4_1 = self.dec4_1(dec4_2)\n",
|
||
|
"\n",
|
||
|
" unpool3 = self.unpool3(dec4_1)\n",
|
||
|
" cat3 = torch.cat((unpool3, enc3_2), dim=1)\n",
|
||
|
" dec3_2 = self.dec3_2(cat3)\n",
|
||
|
" dec3_1 = self.dec3_1(dec3_2)\n",
|
||
|
"\n",
|
||
|
" unpool2 = self.unpool2(dec3_1)\n",
|
||
|
" cat2 = torch.cat((unpool2, enc2_2), dim=1)\n",
|
||
|
" dec2_2 = self.dec2_2(cat2)\n",
|
||
|
" dec2_1 = self.dec2_1(dec2_2)\n",
|
||
|
"\n",
|
||
|
" unpool1 = self.unpool1(dec2_1)\n",
|
||
|
" cat1 = torch.cat((unpool1, enc1_2), dim=1)\n",
|
||
|
" dec1_2 = self.dec1_2(cat1)\n",
|
||
|
" dec1_1 = self.dec1_1(dec1_2)\n",
|
||
|
"\n",
|
||
|
" x = self.fc(dec1_1)\n",
|
||
|
"\n",
|
||
|
" return x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# UNet Model (Mini)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"## 라이브러리 불러오기\n",
|
||
|
"import os\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"from torch.utils.data import DataLoader\n",
|
||
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
||
|
"\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"## 네트워크 구축하기\n",
|
||
|
"class UNet(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(UNet, self).__init__()\n",
|
||
|
"\n",
|
||
|
" # Convolution + BatchNormalization + Relu 정의하기\n",
|
||
|
" def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): \n",
|
||
|
" layers = []\n",
|
||
|
" layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,\n",
|
||
|
" kernel_size=kernel_size, stride=stride, padding=padding,\n",
|
||
|
" bias=bias)]\n",
|
||
|
" layers += [nn.BatchNorm2d(num_features=out_channels)]\n",
|
||
|
" layers += [nn.ReLU()]\n",
|
||
|
"\n",
|
||
|
" cbr = nn.Sequential(*layers)\n",
|
||
|
"\n",
|
||
|
" return cbr\n",
|
||
|
"\n",
|
||
|
" # 수축 경로(Contracting path)\n",
|
||
|
" self.enc1_1 = CBR2d(in_channels=1, out_channels=64)\n",
|
||
|
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc2_1 = CBR2d(in_channels=64, out_channels=128)\n",
|
||
|
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc3_1 = CBR2d(in_channels=128, out_channels=256)\n",
|
||
|
" self.pool3 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc4_1 = CBR2d(in_channels=256, out_channels=512)\n",
|
||
|
" self.pool4 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)\n",
|
||
|
"\n",
|
||
|
" # 확장 경로(Expansive path)의 깊이 감소\n",
|
||
|
" self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)\n",
|
||
|
" self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2)\n",
|
||
|
"\n",
|
||
|
" self.dec4_1 = CBR2d(in_channels=512 + 512, out_channels=256)\n",
|
||
|
" self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)\n",
|
||
|
"\n",
|
||
|
" self.dec3_1 = CBR2d(in_channels=256 + 256, out_channels=128)\n",
|
||
|
" self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)\n",
|
||
|
"\n",
|
||
|
" self.dec2_1 = CBR2d(in_channels=128 + 128, out_channels=64)\n",
|
||
|
" self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)\n",
|
||
|
"\n",
|
||
|
" self.dec1_1 = CBR2d(in_channels=64 + 64, out_channels=64)\n",
|
||
|
" self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)\n",
|
||
|
" \n",
|
||
|
" # forward 함수 정의하기\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" enc1_1 = self.enc1_1(x)\n",
|
||
|
" pool1 = self.pool1(enc1_1)\n",
|
||
|
"\n",
|
||
|
" enc2_1 = self.enc2_1(pool1)\n",
|
||
|
" pool2 = self.pool2(enc2_1)\n",
|
||
|
"\n",
|
||
|
" enc3_1 = self.enc3_1(pool2)\n",
|
||
|
" pool3 = self.pool3(enc3_1)\n",
|
||
|
"\n",
|
||
|
" enc4_1 = self.enc4_1(pool3)\n",
|
||
|
" pool4 = self.pool4(enc4_1)\n",
|
||
|
"\n",
|
||
|
" enc5_1 = self.enc5_1(pool4)\n",
|
||
|
"\n",
|
||
|
" dec5_1 = self.dec5_1(enc5_1)\n",
|
||
|
"\n",
|
||
|
" unpool4 = self.unpool4(dec5_1)\n",
|
||
|
" cat4 = torch.cat((unpool4, enc4_1), dim=1)\n",
|
||
|
" dec4_1 = self.dec4_1(cat4)\n",
|
||
|
"\n",
|
||
|
" unpool3 = self.unpool3(dec4_1)\n",
|
||
|
" cat3 = torch.cat((unpool3, enc3_1), dim=1)\n",
|
||
|
" dec3_1 = self.dec3_1(cat3)\n",
|
||
|
"\n",
|
||
|
" unpool2 = self.unpool2(dec3_1)\n",
|
||
|
" cat2 = torch.cat((unpool2, enc2_1), dim=1)\n",
|
||
|
" dec2_1 = self.dec2_1(cat2)\n",
|
||
|
"\n",
|
||
|
" unpool1 = self.unpool1(dec2_1)\n",
|
||
|
" cat1 = torch.cat((unpool1, enc1_1), dim=1)\n",
|
||
|
" dec1_1 = self.dec1_1(cat1)\n",
|
||
|
"\n",
|
||
|
" x = self.fc(dec1_1)\n",
|
||
|
"\n",
|
||
|
" return x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Model - Load, Save"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"## 네트워크 저장하기\n",
|
||
|
"def save(ckpt_dir, net, optim, epoch):\n",
|
||
|
" if not os.path.exists(ckpt_dir):\n",
|
||
|
" os.makedirs(ckpt_dir)\n",
|
||
|
"\n",
|
||
|
" torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},\n",
|
||
|
" \"%s/model_epoch%d.pth\" % (ckpt_dir, epoch))\n",
|
||
|
"\n",
|
||
|
"## 네트워크 불러오기\n",
|
||
|
"def load(ckpt_dir, net, optim):\n",
|
||
|
" if not os.path.exists(ckpt_dir):\n",
|
||
|
" epoch = 0\n",
|
||
|
" return net, optim, epoch\n",
|
||
|
"\n",
|
||
|
" ckpt_lst = os.listdir(ckpt_dir)\n",
|
||
|
" ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))\n",
|
||
|
"\n",
|
||
|
" dict_model = torch.load('%s/%s' % (ckpt_dir, ckpt_lst[-1]))\n",
|
||
|
"\n",
|
||
|
" net.load_state_dict(dict_model['net'])\n",
|
||
|
" optim.load_state_dict(dict_model['optim'])\n",
|
||
|
" epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])\n",
|
||
|
"\n",
|
||
|
" return net, optim, epoch"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Hyper Parameters"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 훈련 파라미터 설정하기\n",
|
||
|
"lr = 1e-3\n",
|
||
|
"batch_size = 4\n",
|
||
|
"num_epoch = 10\n",
|
||
|
"\n",
|
||
|
"# base_dir = './2nd_Battery/unet'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-mini'\n",
|
||
|
"base_dir = './2nd_Battery/unet-dice-loss'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-focal-loss'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-sgd'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-rmsprop'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-l1'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-l2'\n",
|
||
|
"ckpt_dir = os.path.join(base_dir, \"checkpoint\")\n",
|
||
|
"log_dir = os.path.join(base_dir, \"log\")\n",
|
||
|
"\n",
|
||
|
"# 네트워크 생성하기\n",
|
||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||
|
"net = UNet().to(device)\n",
|
||
|
"\n",
|
||
|
"# 손실함수 정의하기\n",
|
||
|
"fn_loss = nn.BCEWithLogitsLoss().to(device)\n",
|
||
|
"\n",
|
||
|
"# Optimizer 설정하기\n",
|
||
|
"optim = torch.optim.Adam(net.parameters(), lr=lr)\n",
|
||
|
"\n",
|
||
|
"# 그 밖에 부수적인 functions 설정하기\n",
|
||
|
"fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1)\n",
|
||
|
"fn_denorm = lambda x, mean, std: (x * std) + mean\n",
|
||
|
"fn_class = lambda x: 1.0 * (x > 0.95)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# TC - Dice Loss"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class DiceLoss(nn.Module):\n",
|
||
|
" def __init__(self, smooth=1e-6):\n",
|
||
|
" super(DiceLoss, self).__init__()\n",
|
||
|
" self.smooth = smooth\n",
|
||
|
"\n",
|
||
|
" def forward(self, preds, targets):\n",
|
||
|
" preds = torch.sigmoid(preds)\n",
|
||
|
" intersection = (preds * targets).sum()\n",
|
||
|
" dice = (2. * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)\n",
|
||
|
" return 1 - dice\n",
|
||
|
"\n",
|
||
|
"fn_loss = DiceLoss().to(device)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# TC - Focal Loss"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class FocalLoss(nn.Module):\n",
|
||
|
" def __init__(self, alpha=0.8, gamma=2.0):\n",
|
||
|
" super(FocalLoss, self).__init__()\n",
|
||
|
" self.alpha = alpha\n",
|
||
|
" self.gamma = gamma\n",
|
||
|
"\n",
|
||
|
" def forward(self, preds, targets):\n",
|
||
|
" BCE = nn.functional.binary_cross_entropy_with_logits(preds, targets, reduction='none')\n",
|
||
|
" BCE_exp = torch.exp(-BCE)\n",
|
||
|
" focal_loss = self.alpha * (1 - BCE_exp) ** self.gamma * BCE\n",
|
||
|
" return focal_loss.mean()\n",
|
||
|
"\n",
|
||
|
"fn_loss = FocalLoss().to(device)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# TC - SGD"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# TC - RMSProp"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"optim = torch.optim.RMSprop(net.parameters(), lr=lr, alpha=0.9)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# TC - L1"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class L1Loss(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(L1Loss, self).__init__()\n",
|
||
|
"\n",
|
||
|
" def forward(self, preds, targets):\n",
|
||
|
" return torch.mean(torch.abs(preds - targets))\n",
|
||
|
" \n",
|
||
|
"fn_loss = L1Loss().to(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# TC - L2"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class L2Loss(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(L2Loss, self).__init__()\n",
|
||
|
"\n",
|
||
|
" def forward(self, preds, targets):\n",
|
||
|
" return torch.mean((preds - targets) ** 2)\n",
|
||
|
" \n",
|
||
|
"fn_loss = L2Loss().to(device)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Test"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"dir_testset = 'C:/Users/pinb/Desktop/testimages/testset'\n",
|
||
|
"dir_groundtruth = 'C:/Users/pinb/Desktop/testimages/maskset'\n",
|
||
|
"# transform = transforms.Compose([Normalization(mean=0.5, std=0.5), RandomFlip(), Rotate(angle_range=(-90, 90)), ToTensor()])\n",
|
||
|
"transform = transforms.Compose([Normalization(mean=0.5, std=0.5), ToTensor()])\n",
|
||
|
"test_set = create_datasets(dir_testset, dir_groundtruth, 0, 0, transform=transform)\n",
|
||
|
"\n",
|
||
|
"# data = test_set.__getitem__(0) # 이미지 불러오기\n",
|
||
|
"\n",
|
||
|
"# input_img = data['input']\n",
|
||
|
"# label = data['label']\n",
|
||
|
"\n",
|
||
|
"# # 이미지 시각화\n",
|
||
|
"# plt.subplot(121)\n",
|
||
|
"# plt.imshow(input_img.reshape(input_img.shape[0], input_img.shape[1]), cmap='gray')\n",
|
||
|
"# plt.title('Input Image')\n",
|
||
|
"\n",
|
||
|
"# plt.subplot(122)\n",
|
||
|
"# plt.imshow(label.reshape(label.shape[0], label.shape[1]), cmap='gray')\n",
|
||
|
"# plt.title('Label')\n",
|
||
|
"\n",
|
||
|
"# plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"loader_test = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)\n",
|
||
|
"\n",
|
||
|
"# 그밖에 부수적인 variables 설정하기\n",
|
||
|
"num_data_test = len(test_set)\n",
|
||
|
"num_batch_test = np.ceil(num_data_test / batch_size)\n",
|
||
|
"\n",
|
||
|
"# 결과 디렉토리 생성하기\n",
|
||
|
"result_dir = os.path.join(base_dir, 'result')\n",
|
||
|
"if not os.path.exists(result_dir):\n",
|
||
|
" os.makedirs(os.path.join(result_dir, 'gt'))\n",
|
||
|
" os.makedirs(os.path.join(result_dir, 'img'))\n",
|
||
|
" os.makedirs(os.path.join(result_dir, 'pr'))\n",
|
||
|
" os.makedirs(os.path.join(result_dir, 'numpy'))\n",
|
||
|
"\n",
|
||
|
"net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"TEST: BATCH 0001 / 0250 | LOSS 0.3965\n",
|
||
|
"TEST: BATCH 0002 / 0250 | LOSS 0.3255\n",
|
||
|
"TEST: BATCH 0003 / 0250 | LOSS 0.3926\n",
|
||
|
"TEST: BATCH 0004 / 0250 | LOSS 0.3913\n",
|
||
|
"TEST: BATCH 0005 / 0250 | LOSS 0.3963\n",
|
||
|
"TEST: BATCH 0006 / 0250 | LOSS 0.3929\n",
|
||
|
"TEST: BATCH 0007 / 0250 | LOSS 0.4026\n",
|
||
|
"TEST: BATCH 0008 / 0250 | LOSS 0.3988\n",
|
||
|
"TEST: BATCH 0009 / 0250 | LOSS 0.4022\n",
|
||
|
"TEST: BATCH 0010 / 0250 | LOSS 0.3956\n",
|
||
|
"TEST: BATCH 0011 / 0250 | LOSS 0.3933\n",
|
||
|
"TEST: BATCH 0012 / 0250 | LOSS 0.3834\n",
|
||
|
"TEST: BATCH 0013 / 0250 | LOSS 0.3889\n",
|
||
|
"TEST: BATCH 0014 / 0250 | LOSS 0.3885\n",
|
||
|
"TEST: BATCH 0015 / 0250 | LOSS 0.3923\n",
|
||
|
"TEST: BATCH 0016 / 0250 | LOSS 0.3851\n",
|
||
|
"TEST: BATCH 0017 / 0250 | LOSS 0.3819\n",
|
||
|
"TEST: BATCH 0018 / 0250 | LOSS 0.3872\n",
|
||
|
"TEST: BATCH 0019 / 0250 | LOSS 0.3840\n",
|
||
|
"TEST: BATCH 0020 / 0250 | LOSS 0.3840\n",
|
||
|
"TEST: BATCH 0021 / 0250 | LOSS 0.3858\n",
|
||
|
"TEST: BATCH 0022 / 0250 | LOSS 0.3819\n",
|
||
|
"TEST: BATCH 0023 / 0250 | LOSS 0.3796\n",
|
||
|
"TEST: BATCH 0024 / 0250 | LOSS 0.3749\n",
|
||
|
"TEST: BATCH 0025 / 0250 | LOSS 0.3713\n",
|
||
|
"TEST: BATCH 0026 / 0250 | LOSS 0.3668\n",
|
||
|
"TEST: BATCH 0027 / 0250 | LOSS 0.3637\n",
|
||
|
"TEST: BATCH 0028 / 0250 | LOSS 0.3670\n",
|
||
|
"TEST: BATCH 0029 / 0250 | LOSS 0.3629\n",
|
||
|
"TEST: BATCH 0030 / 0250 | LOSS 0.3630\n",
|
||
|
"TEST: BATCH 0031 / 0250 | LOSS 0.3604\n",
|
||
|
"TEST: BATCH 0032 / 0250 | LOSS 0.3624\n",
|
||
|
"TEST: BATCH 0033 / 0250 | LOSS 0.3675\n",
|
||
|
"TEST: BATCH 0034 / 0250 | LOSS 0.3665\n",
|
||
|
"TEST: BATCH 0035 / 0250 | LOSS 0.3683\n",
|
||
|
"TEST: BATCH 0036 / 0250 | LOSS 0.3713\n",
|
||
|
"TEST: BATCH 0037 / 0250 | LOSS 0.3750\n",
|
||
|
"TEST: BATCH 0038 / 0250 | LOSS 0.3744\n",
|
||
|
"TEST: BATCH 0039 / 0250 | LOSS 0.3734\n",
|
||
|
"TEST: BATCH 0040 / 0250 | LOSS 0.3742\n",
|
||
|
"TEST: BATCH 0041 / 0250 | LOSS 0.3724\n",
|
||
|
"TEST: BATCH 0042 / 0250 | LOSS 0.3735\n",
|
||
|
"TEST: BATCH 0043 / 0250 | LOSS 0.3712\n",
|
||
|
"TEST: BATCH 0044 / 0250 | LOSS 0.3719\n",
|
||
|
"TEST: BATCH 0045 / 0250 | LOSS 0.3730\n",
|
||
|
"TEST: BATCH 0046 / 0250 | LOSS 0.3756\n",
|
||
|
"TEST: BATCH 0047 / 0250 | LOSS 0.3745\n",
|
||
|
"TEST: BATCH 0048 / 0250 | LOSS 0.3750\n",
|
||
|
"TEST: BATCH 0049 / 0250 | LOSS 0.3743\n",
|
||
|
"TEST: BATCH 0050 / 0250 | LOSS 0.3746\n",
|
||
|
"TEST: BATCH 0051 / 0250 | LOSS 0.3741\n",
|
||
|
"TEST: BATCH 0052 / 0250 | LOSS 0.3739\n",
|
||
|
"TEST: BATCH 0053 / 0250 | LOSS 0.3728\n",
|
||
|
"TEST: BATCH 0054 / 0250 | LOSS 0.3740\n",
|
||
|
"TEST: BATCH 0055 / 0250 | LOSS 0.3737\n",
|
||
|
"TEST: BATCH 0056 / 0250 | LOSS 0.3734\n",
|
||
|
"TEST: BATCH 0057 / 0250 | LOSS 0.3737\n",
|
||
|
"TEST: BATCH 0058 / 0250 | LOSS 0.3753\n",
|
||
|
"TEST: BATCH 0059 / 0250 | LOSS 0.3751\n",
|
||
|
"TEST: BATCH 0060 / 0250 | LOSS 0.3742\n",
|
||
|
"TEST: BATCH 0061 / 0250 | LOSS 0.3749\n",
|
||
|
"TEST: BATCH 0062 / 0250 | LOSS 0.3773\n",
|
||
|
"TEST: BATCH 0063 / 0250 | LOSS 0.3777\n",
|
||
|
"TEST: BATCH 0064 / 0250 | LOSS 0.3785\n",
|
||
|
"TEST: BATCH 0065 / 0250 | LOSS 0.3801\n",
|
||
|
"TEST: BATCH 0066 / 0250 | LOSS 0.3787\n",
|
||
|
"TEST: BATCH 0067 / 0250 | LOSS 0.3781\n",
|
||
|
"TEST: BATCH 0068 / 0250 | LOSS 0.3788\n",
|
||
|
"TEST: BATCH 0069 / 0250 | LOSS 0.3800\n",
|
||
|
"TEST: BATCH 0070 / 0250 | LOSS 0.3795\n",
|
||
|
"TEST: BATCH 0071 / 0250 | LOSS 0.3797\n",
|
||
|
"TEST: BATCH 0072 / 0250 | LOSS 0.3793\n",
|
||
|
"TEST: BATCH 0073 / 0250 | LOSS 0.3797\n",
|
||
|
"TEST: BATCH 0074 / 0250 | LOSS 0.3800\n",
|
||
|
"TEST: BATCH 0075 / 0250 | LOSS 0.3795\n",
|
||
|
"TEST: BATCH 0076 / 0250 | LOSS 0.3802\n",
|
||
|
"TEST: BATCH 0077 / 0250 | LOSS 0.3791\n",
|
||
|
"TEST: BATCH 0078 / 0250 | LOSS 0.3789\n",
|
||
|
"TEST: BATCH 0079 / 0250 | LOSS 0.3807\n",
|
||
|
"TEST: BATCH 0080 / 0250 | LOSS 0.3817\n",
|
||
|
"TEST: BATCH 0081 / 0250 | LOSS 0.3813\n",
|
||
|
"TEST: BATCH 0082 / 0250 | LOSS 0.3818\n",
|
||
|
"TEST: BATCH 0083 / 0250 | LOSS 0.3824\n",
|
||
|
"TEST: BATCH 0084 / 0250 | LOSS 0.3840\n",
|
||
|
"TEST: BATCH 0085 / 0250 | LOSS 0.3825\n",
|
||
|
"TEST: BATCH 0086 / 0250 | LOSS 0.3827\n",
|
||
|
"TEST: BATCH 0087 / 0250 | LOSS 0.3823\n",
|
||
|
"TEST: BATCH 0088 / 0250 | LOSS 0.3820\n",
|
||
|
"TEST: BATCH 0089 / 0250 | LOSS 0.3821\n",
|
||
|
"TEST: BATCH 0090 / 0250 | LOSS 0.3822\n",
|
||
|
"TEST: BATCH 0091 / 0250 | LOSS 0.3821\n",
|
||
|
"TEST: BATCH 0092 / 0250 | LOSS 0.3827\n",
|
||
|
"TEST: BATCH 0093 / 0250 | LOSS 0.3826\n",
|
||
|
"TEST: BATCH 0094 / 0250 | LOSS 0.3826\n",
|
||
|
"TEST: BATCH 0095 / 0250 | LOSS 0.3841\n",
|
||
|
"TEST: BATCH 0096 / 0250 | LOSS 0.3844\n",
|
||
|
"TEST: BATCH 0097 / 0250 | LOSS 0.3822\n",
|
||
|
"TEST: BATCH 0098 / 0250 | LOSS 0.3822\n",
|
||
|
"TEST: BATCH 0099 / 0250 | LOSS 0.3838\n",
|
||
|
"TEST: BATCH 0100 / 0250 | LOSS 0.3844\n",
|
||
|
"TEST: BATCH 0101 / 0250 | LOSS 0.3837\n",
|
||
|
"TEST: BATCH 0102 / 0250 | LOSS 0.3838\n",
|
||
|
"TEST: BATCH 0103 / 0250 | LOSS 0.3841\n",
|
||
|
"TEST: BATCH 0104 / 0250 | LOSS 0.3848\n",
|
||
|
"TEST: BATCH 0105 / 0250 | LOSS 0.3858\n",
|
||
|
"TEST: BATCH 0106 / 0250 | LOSS 0.3860\n",
|
||
|
"TEST: BATCH 0107 / 0250 | LOSS 0.3857\n",
|
||
|
"TEST: BATCH 0108 / 0250 | LOSS 0.3867\n",
|
||
|
"TEST: BATCH 0109 / 0250 | LOSS 0.3870\n",
|
||
|
"TEST: BATCH 0110 / 0250 | LOSS 0.3871\n",
|
||
|
"TEST: BATCH 0111 / 0250 | LOSS 0.3871\n",
|
||
|
"TEST: BATCH 0112 / 0250 | LOSS 0.3877\n",
|
||
|
"TEST: BATCH 0113 / 0250 | LOSS 0.3886\n",
|
||
|
"TEST: BATCH 0114 / 0250 | LOSS 0.3885\n",
|
||
|
"TEST: BATCH 0115 / 0250 | LOSS 0.3892\n",
|
||
|
"TEST: BATCH 0116 / 0250 | LOSS 0.3893\n",
|
||
|
"TEST: BATCH 0117 / 0250 | LOSS 0.3906\n",
|
||
|
"TEST: BATCH 0118 / 0250 | LOSS 0.3905\n",
|
||
|
"TEST: BATCH 0119 / 0250 | LOSS 0.3903\n",
|
||
|
"TEST: BATCH 0120 / 0250 | LOSS 0.3891\n",
|
||
|
"TEST: BATCH 0121 / 0250 | LOSS 0.3886\n",
|
||
|
"TEST: BATCH 0122 / 0250 | LOSS 0.3870\n",
|
||
|
"TEST: BATCH 0123 / 0250 | LOSS 0.3876\n",
|
||
|
"TEST: BATCH 0124 / 0250 | LOSS 0.3867\n",
|
||
|
"TEST: BATCH 0125 / 0250 | LOSS 0.3861\n",
|
||
|
"TEST: BATCH 0126 / 0250 | LOSS 0.3864\n",
|
||
|
"TEST: BATCH 0127 / 0250 | LOSS 0.3867\n",
|
||
|
"TEST: BATCH 0128 / 0250 | LOSS 0.3859\n",
|
||
|
"TEST: BATCH 0129 / 0250 | LOSS 0.3869\n",
|
||
|
"TEST: BATCH 0130 / 0250 | LOSS 0.3871\n",
|
||
|
"TEST: BATCH 0131 / 0250 | LOSS 0.3870\n",
|
||
|
"TEST: BATCH 0132 / 0250 | LOSS 0.3872\n",
|
||
|
"TEST: BATCH 0133 / 0250 | LOSS 0.3864\n",
|
||
|
"TEST: BATCH 0134 / 0250 | LOSS 0.3869\n",
|
||
|
"TEST: BATCH 0135 / 0250 | LOSS 0.3859\n",
|
||
|
"TEST: BATCH 0136 / 0250 | LOSS 0.3864\n",
|
||
|
"TEST: BATCH 0137 / 0250 | LOSS 0.3864\n",
|
||
|
"TEST: BATCH 0138 / 0250 | LOSS 0.3862\n",
|
||
|
"TEST: BATCH 0139 / 0250 | LOSS 0.3859\n",
|
||
|
"TEST: BATCH 0140 / 0250 | LOSS 0.3863\n",
|
||
|
"TEST: BATCH 0141 / 0250 | LOSS 0.3875\n",
|
||
|
"TEST: BATCH 0142 / 0250 | LOSS 0.3874\n",
|
||
|
"TEST: BATCH 0143 / 0250 | LOSS 0.3868\n",
|
||
|
"TEST: BATCH 0144 / 0250 | LOSS 0.3866\n",
|
||
|
"TEST: BATCH 0145 / 0250 | LOSS 0.3860\n",
|
||
|
"TEST: BATCH 0146 / 0250 | LOSS 0.3858\n",
|
||
|
"TEST: BATCH 0147 / 0250 | LOSS 0.3859\n",
|
||
|
"TEST: BATCH 0148 / 0250 | LOSS 0.3861\n",
|
||
|
"TEST: BATCH 0149 / 0250 | LOSS 0.3863\n",
|
||
|
"TEST: BATCH 0150 / 0250 | LOSS 0.3861\n",
|
||
|
"TEST: BATCH 0151 / 0250 | LOSS 0.3863\n",
|
||
|
"TEST: BATCH 0152 / 0250 | LOSS 0.3864\n",
|
||
|
"TEST: BATCH 0153 / 0250 | LOSS 0.3853\n",
|
||
|
"TEST: BATCH 0154 / 0250 | LOSS 0.3859\n",
|
||
|
"TEST: BATCH 0155 / 0250 | LOSS 0.3852\n",
|
||
|
"TEST: BATCH 0156 / 0250 | LOSS 0.3852\n",
|
||
|
"TEST: BATCH 0157 / 0250 | LOSS 0.3855\n",
|
||
|
"TEST: BATCH 0158 / 0250 | LOSS 0.3847\n",
|
||
|
"TEST: BATCH 0159 / 0250 | LOSS 0.3840\n",
|
||
|
"TEST: BATCH 0160 / 0250 | LOSS 0.3835\n",
|
||
|
"TEST: BATCH 0161 / 0250 | LOSS 0.3840\n",
|
||
|
"TEST: BATCH 0162 / 0250 | LOSS 0.3844\n",
|
||
|
"TEST: BATCH 0163 / 0250 | LOSS 0.3842\n",
|
||
|
"TEST: BATCH 0164 / 0250 | LOSS 0.3830\n",
|
||
|
"TEST: BATCH 0165 / 0250 | LOSS 0.3832\n",
|
||
|
"TEST: BATCH 0166 / 0250 | LOSS 0.3833\n",
|
||
|
"TEST: BATCH 0167 / 0250 | LOSS 0.3833\n",
|
||
|
"TEST: BATCH 0168 / 0250 | LOSS 0.3838\n",
|
||
|
"TEST: BATCH 0169 / 0250 | LOSS 0.3848\n",
|
||
|
"TEST: BATCH 0170 / 0250 | LOSS 0.3849\n",
|
||
|
"TEST: BATCH 0171 / 0250 | LOSS 0.3848\n",
|
||
|
"TEST: BATCH 0172 / 0250 | LOSS 0.3847\n",
|
||
|
"TEST: BATCH 0173 / 0250 | LOSS 0.3845\n",
|
||
|
"TEST: BATCH 0174 / 0250 | LOSS 0.3841\n",
|
||
|
"TEST: BATCH 0175 / 0250 | LOSS 0.3843\n",
|
||
|
"TEST: BATCH 0176 / 0250 | LOSS 0.3841\n",
|
||
|
"TEST: BATCH 0177 / 0250 | LOSS 0.3842\n",
|
||
|
"TEST: BATCH 0178 / 0250 | LOSS 0.3844\n",
|
||
|
"TEST: BATCH 0179 / 0250 | LOSS 0.3841\n",
|
||
|
"TEST: BATCH 0180 / 0250 | LOSS 0.3836\n",
|
||
|
"TEST: BATCH 0181 / 0250 | LOSS 0.3840\n",
|
||
|
"TEST: BATCH 0182 / 0250 | LOSS 0.3843\n",
|
||
|
"TEST: BATCH 0183 / 0250 | LOSS 0.3849\n",
|
||
|
"TEST: BATCH 0184 / 0250 | LOSS 0.3855\n",
|
||
|
"TEST: BATCH 0185 / 0250 | LOSS 0.3857\n",
|
||
|
"TEST: BATCH 0186 / 0250 | LOSS 0.3859\n",
|
||
|
"TEST: BATCH 0187 / 0250 | LOSS 0.3862\n",
|
||
|
"TEST: BATCH 0188 / 0250 | LOSS 0.3857\n",
|
||
|
"TEST: BATCH 0189 / 0250 | LOSS 0.3854\n",
|
||
|
"TEST: BATCH 0190 / 0250 | LOSS 0.3859\n",
|
||
|
"TEST: BATCH 0191 / 0250 | LOSS 0.3862\n",
|
||
|
"TEST: BATCH 0192 / 0250 | LOSS 0.3868\n",
|
||
|
"TEST: BATCH 0193 / 0250 | LOSS 0.3870\n",
|
||
|
"TEST: BATCH 0194 / 0250 | LOSS 0.3867\n",
|
||
|
"TEST: BATCH 0195 / 0250 | LOSS 0.3863\n",
|
||
|
"TEST: BATCH 0196 / 0250 | LOSS 0.3869\n",
|
||
|
"TEST: BATCH 0197 / 0250 | LOSS 0.3871\n",
|
||
|
"TEST: BATCH 0198 / 0250 | LOSS 0.3877\n",
|
||
|
"TEST: BATCH 0199 / 0250 | LOSS 0.3874\n",
|
||
|
"TEST: BATCH 0200 / 0250 | LOSS 0.3869\n",
|
||
|
"TEST: BATCH 0201 / 0250 | LOSS 0.3867\n",
|
||
|
"TEST: BATCH 0202 / 0250 | LOSS 0.3869\n",
|
||
|
"TEST: BATCH 0203 / 0250 | LOSS 0.3871\n",
|
||
|
"TEST: BATCH 0204 / 0250 | LOSS 0.3871\n",
|
||
|
"TEST: BATCH 0205 / 0250 | LOSS 0.3862\n",
|
||
|
"TEST: BATCH 0206 / 0250 | LOSS 0.3867\n",
|
||
|
"TEST: BATCH 0207 / 0250 | LOSS 0.3871\n",
|
||
|
"TEST: BATCH 0208 / 0250 | LOSS 0.3875\n",
|
||
|
"TEST: BATCH 0209 / 0250 | LOSS 0.3874\n",
|
||
|
"TEST: BATCH 0210 / 0250 | LOSS 0.3872\n",
|
||
|
"TEST: BATCH 0211 / 0250 | LOSS 0.3875\n",
|
||
|
"TEST: BATCH 0212 / 0250 | LOSS 0.3878\n",
|
||
|
"TEST: BATCH 0213 / 0250 | LOSS 0.3874\n",
|
||
|
"TEST: BATCH 0214 / 0250 | LOSS 0.3873\n",
|
||
|
"TEST: BATCH 0215 / 0250 | LOSS 0.3877\n",
|
||
|
"TEST: BATCH 0216 / 0250 | LOSS 0.3881\n",
|
||
|
"TEST: BATCH 0217 / 0250 | LOSS 0.3875\n",
|
||
|
"TEST: BATCH 0218 / 0250 | LOSS 0.3879\n",
|
||
|
"TEST: BATCH 0219 / 0250 | LOSS 0.3872\n",
|
||
|
"TEST: BATCH 0220 / 0250 | LOSS 0.3865\n",
|
||
|
"TEST: BATCH 0221 / 0250 | LOSS 0.3870\n",
|
||
|
"TEST: BATCH 0222 / 0250 | LOSS 0.3873\n",
|
||
|
"TEST: BATCH 0223 / 0250 | LOSS 0.3876\n",
|
||
|
"TEST: BATCH 0224 / 0250 | LOSS 0.3872\n",
|
||
|
"TEST: BATCH 0225 / 0250 | LOSS 0.3870\n",
|
||
|
"TEST: BATCH 0226 / 0250 | LOSS 0.3870\n",
|
||
|
"TEST: BATCH 0227 / 0250 | LOSS 0.3870\n",
|
||
|
"TEST: BATCH 0228 / 0250 | LOSS 0.3875\n",
|
||
|
"TEST: BATCH 0229 / 0250 | LOSS 0.3878\n",
|
||
|
"TEST: BATCH 0230 / 0250 | LOSS 0.3881\n",
|
||
|
"TEST: BATCH 0231 / 0250 | LOSS 0.3879\n",
|
||
|
"TEST: BATCH 0232 / 0250 | LOSS 0.3878\n",
|
||
|
"TEST: BATCH 0233 / 0250 | LOSS 0.3874\n",
|
||
|
"TEST: BATCH 0234 / 0250 | LOSS 0.3877\n",
|
||
|
"TEST: BATCH 0235 / 0250 | LOSS 0.3876\n",
|
||
|
"TEST: BATCH 0236 / 0250 | LOSS 0.3875\n",
|
||
|
"TEST: BATCH 0237 / 0250 | LOSS 0.3876\n",
|
||
|
"TEST: BATCH 0238 / 0250 | LOSS 0.3874\n",
|
||
|
"TEST: BATCH 0239 / 0250 | LOSS 0.3877\n",
|
||
|
"TEST: BATCH 0240 / 0250 | LOSS 0.3883\n",
|
||
|
"TEST: BATCH 0241 / 0250 | LOSS 0.3881\n",
|
||
|
"TEST: BATCH 0242 / 0250 | LOSS 0.3882\n",
|
||
|
"TEST: BATCH 0243 / 0250 | LOSS 0.3882\n",
|
||
|
"TEST: BATCH 0244 / 0250 | LOSS 0.3878\n",
|
||
|
"TEST: BATCH 0245 / 0250 | LOSS 0.3884\n",
|
||
|
"TEST: BATCH 0246 / 0250 | LOSS 0.3887\n",
|
||
|
"TEST: BATCH 0247 / 0250 | LOSS 0.3890\n",
|
||
|
"TEST: BATCH 0248 / 0250 | LOSS 0.3886\n",
|
||
|
"TEST: BATCH 0249 / 0250 | LOSS 0.3883\n",
|
||
|
"TEST: BATCH 0250 / 0250 | LOSS 0.3879\n",
|
||
|
"AVERAGE TEST: BATCH 0250 / 0250 | LOSS 0.3879\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"with torch.no_grad():\n",
|
||
|
" net.eval()\n",
|
||
|
" loss_arr = []\n",
|
||
|
"\n",
|
||
|
" for batch, data in enumerate(loader_test, 1):\n",
|
||
|
" # forward pass\n",
|
||
|
" label = data['label'].to(device)\n",
|
||
|
" input = data['input'].to(device)\n",
|
||
|
"\n",
|
||
|
" output = net(input)\n",
|
||
|
"\n",
|
||
|
" # 손실함수 계산하기\n",
|
||
|
" loss = fn_loss(output, label)\n",
|
||
|
"\n",
|
||
|
" loss_arr += [loss.item()]\n",
|
||
|
"\n",
|
||
|
" print(\"TEST: BATCH %04d / %04d | LOSS %.4f\" %\n",
|
||
|
" (batch, num_batch_test, np.mean(loss_arr)))\n",
|
||
|
"\n",
|
||
|
" # Tensorboard 저장하기\n",
|
||
|
" label = fn_tonumpy(label)\n",
|
||
|
" input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))\n",
|
||
|
" output = fn_tonumpy(fn_class(output))\n",
|
||
|
"\n",
|
||
|
" # 테스트 결과 저장하기\n",
|
||
|
" for j in range(label.shape[0]):\n",
|
||
|
" id = num_batch_test * (batch - 1) + j\n",
|
||
|
"\n",
|
||
|
" gt = label[j].squeeze()\n",
|
||
|
" img = input[j].squeeze()\n",
|
||
|
" pr = output[j].squeeze()\n",
|
||
|
"\n",
|
||
|
" plt.imsave(os.path.join(result_dir, 'gt', 'gt_%04d.png' % id), gt, cmap='gray')\n",
|
||
|
" plt.imsave(os.path.join(result_dir, 'img', 'img_%04d.png' % id), img, cmap='gray')\n",
|
||
|
" plt.imsave(os.path.join(result_dir, 'pr', 'pr_%04d.png' % id), pr, cmap='gray')\n",
|
||
|
" np.save(os.path.join(result_dir, 'numpy', 'gt_%04d.npy' % id), gt)\n",
|
||
|
" np.save(os.path.join(result_dir, 'numpy', 'img_%04d.npy' % id), img)\n",
|
||
|
" np.save(os.path.join(result_dir, 'numpy', 'pr_%04d.npy' % id), pr)\n",
|
||
|
"\n",
|
||
|
"print(\"AVERAGE TEST: BATCH %04d / %04d | LOSS %.4f\" %\n",
|
||
|
" (batch, num_batch_test, np.mean(loss_arr)))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Visualize"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"C:\\Users\\pinb\\AppData\\Local\\Temp\\ipykernel_19912\\3510449017.py:45: RuntimeWarning: invalid value encountered in divide\n",
|
||
|
" precision = tp / (tp + fp) # precision = TP / (TP + FP)\n",
|
||
|
"C:\\Users\\pinb\\AppData\\Local\\Temp\\ipykernel_19912\\3510449017.py:46: RuntimeWarning: invalid value encountered in divide\n",
|
||
|
" recall = tp / (tp + fn) # recall = TP / (TP + FN)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"precision: 0.7764027600652067\n",
|
||
|
"recall: 0.7843549615385272\n",
|
||
|
"accuracy: 0.9770164763057941\n",
|
||
|
"f1: 0.7741721124958945\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq8AAACaCAYAAACHSaGqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABP6UlEQVR4nO2deXgUVdbG396XdKcT0iELJAHCEjEsAqIBFNSwiaDsuLHIiEJQtmEchm9YRIn7KMigoyOgMoAsMi4IssoAQZA9YQtICBCykaSz9n6/P5iq6e50J91JdzqdnN/z1JP0rVtVp+7pqnr71rnnChhjDARBEARBEAQRAAj9bQBBEARBEARBuAuJV4IgCIIgCCJgIPFKEARBEARBBAwkXgmCIAiCIIiAgcQrQRAEQRAEETCQeCUIgiAIgiACBhKvBEEQBEEQRMBA4pUgCIIgCIIIGEi8EgRBEARBEAEDiVcPWbt2LQQCAbKysvxtCkEQRKNHIBBgyZIlHm0zefJktGnTxif2EN7F2TNxwIABGDBggN9sIpo+JF4JogaWL1+O7du3+9sMgiAIgiD+i9jfBgQazz//PCZMmACZTOZvU4gGYPny5RgzZgyeeuopf5tCEAFJVVUVxGLPHjWfffYZrFarjywiCCLQoZ5XDxGJRJDL5RAIBP42hSAIF1itVuj1en+bEVD4qs3kcrnH4lUikVAHgYdUVFT42wSiEdBcvgckXj3EMb6nTZs2eOKJJ3DgwAH06tULCoUCXbp0wYEDBwAA27ZtQ5cuXSCXy9GzZ0+cOnWq2j43b96Mzp07Qy6XIzExEd9++y3FfDUAnM/kcjni4+Px6aefYsmSJfwPE4FAgIqKCqxbtw4CgQACgQCTJ0/2r9HNDM4fFy9exLhx4xAcHIywsDDMmjXLTmgJBALMnDkT69evx7333guZTIadO3f60XL/4Y02u3XrFl544QVERERAJpPh3nvvxRdffFHtWHq9HkuWLEHHjh0hl8sRFRWFUaNG4erVq3bHsY15LSsrw+zZs9GmTRvIZDK0bNkSAwcOxMmTJ/k6zu5/FRUVmDdvHmJiYiCTydCpUye89957YIzZ1ePOa/v27UhMTOTtb0rfB87H58+fxzPPPIPQ0FD069cPAPD111+jZ8+eUCgUaNGiBSZMmIAbN25U28evv/6Kxx9/HKGhoQgKCkLXrl3x0Ucf8evPnj2LyZMno127dpDL5YiMjMQLL7yAO3fuNNh5NgauX7+OGTNmoFOnTlAoFAgLC8PYsWOdjnspKSnBnDlz+O9269atMXHiRBQWFvJ1artmDhw4AIFAwGsIjqysLAgEAqxdu5Yvmzx5MlQqFa5evYrHH38carUazz77LADgP//5D8aOHYvY2FjIZDLExMRgzpw5qKqqqmY3d68IDw+HQqFAp06dsHDhQgDA/v37IRAI8O2331bb7l//+hcEAgHS0tI8bdZ6Q2EDXuDKlSt45pln8NJLL+G5557De++9h+HDh+OTTz7BX/7yF8yYMQMAkJqainHjxuHSpUsQCu/+bvjxxx8xfvx4dOnSBampqSguLsbUqVPRqlUrf55Sk+fUqVMYMmQIoqKisHTpUlgsFrz++usIDw/n63z11Vf4wx/+gN69e2PatGkAgPj4eH+Z3KwZN24c2rRpg9TUVBw9ehQrVqxAcXExvvzyS77Ovn378M0332DmzJnQarXN/sdfXdssLy8PDz74IC8Cw8PD8dNPP2Hq1KkoLS3F7NmzAQAWiwVPPPEE9u7diwkTJmDWrFkoKyvD7t27kZ6e7vJaefnll7FlyxbMnDkTnTt3xp07d3Do0CFcuHABPXr0cLoNYwwjRozA/v37MXXqVHTv3h27du3C/PnzcevWLfztb3+zq3/o0CFs27YNM2bMgFqtxooVKzB69GhkZ2cjLCzMOw3cCBg7diw6dOiA5cuXgzGGN998E3/9618xbtw4/OEPf0BBQQFWrlyJhx9+GKdOnUJISAgAYPfu3XjiiScQFRWFWbNmITIyEhcuXMAPP/yAWbNm8XV+//13TJkyBZGRkcjIyMA//vEPZGRk4OjRo83m7ePx48dx5MgRTJgwAa1bt0ZWVhZWr16NAQMG4Pz581AqlQCA8vJyPPTQQ7hw4QJeeOEF9OjRA4WFhfjuu+9w8+ZNaLXaOl8zNWE2mzF48GD069cP7733Hm/P5s2bUVlZienTpyMsLAzHjh3DypUrcfPmTWzevJnf/uzZs3jooYcgkUgwbdo0tGnTBlevXsX333+PN998EwMGDEBMTAzWr1+PkSNH2h17/fr1iI+PR1JSUj1auI4wwiPWrFnDALBr164xxhiLi4tjANiRI0f4Ort27WIAmEKhYNevX+fLP/30UwaA7d+/ny/r0qULa926NSsrK+PLDhw4wACwuLg4X59Os2X48OFMqVSyW7du8WWZmZlMLBYz28siKCiITZo0yQ8WEowxtnjxYgaAjRgxwq58xowZDAA7c+YMY4wxAEwoFLKMjAx/mNmoqG+bTZ06lUVFRbHCwkK78gkTJjCNRsMqKysZY4x98cUXDAD74IMPqtlgtVr5/wGwxYsX8581Gg1LSUmp8RwmTZpkd//bvn07A8DeeOMNu3pjxoxhAoGAXblyxe54UqnUruzMmTMMAFu5cmWNxw0UOB8//fTTfFlWVhYTiUTszTfftKt77tw5JhaL+XKz2czatm3L4uLiWHFxsV1dW79xfrZlw4YNDAA7ePAgX+b4TGSMsf79+7P+/fvX4wwbD87aIS0tjQFgX375JV+2aNEiBoBt27atWn2uXd25Zvbv319NJzDG2LVr1xgAtmbNGr5s0qRJDAD785//7JbdqampTCAQ2OmShx9+mKnVarsyW3sYY2zBggVMJpOxkpISviw/P5+JxWK7a7shobABL9C5c2e7Xx4PPPAAAODRRx9FbGxstfLff/8dAJCTk4Nz585h4sSJUKlUfL3+/fujS5cuDWF6s8RisWDPnj146qmnEB0dzZe3b98eQ4cO9aNlhCtSUlLsPr/yyisAgB07dvBl/fv3R+fOnRvUrsZMXdqMMYatW7di+PDhYIyhsLCQXwYPHgydTse/3t+6dSu0Wi2/X1tq6pULCQnBr7/+ipycHLfPZceOHRCJRHj11VftyufNmwfGGH766Se78uTkZLterK5duyI4OJi/9zYVXn75Zf7/bdu2wWq1Yty4cXZ+i4yMRIcOHbB//34Ad986Xbt2DbNnz+Z7Yjls/aZQKPj/9Xo9CgsL8eCDDwKAXYhHU8e2HUwmE+7cuYP27dsjJCTErh22bt2Kbt26VeudBP7XrnW9Zmpj+vTpNdpdUVGBwsJC9OnTB4wxPnyxoKAABw8exAsvvGCnVRztmThxIgwGA7Zs2cKXbdq0CWazGc8991yd7a4PJF69gKPTNRoNACAmJsZpeXFxMYC7sTTAXdHkiLMywjvk5+ejqqqK2j2A6NChg93n+Ph4CIVCu7iztm3bNrBVjZu6tFlBQQFKSkrwj3/8A+Hh4XbLlClTANy9fgDg6tWr6NSpk8eDsd555x2kp6cjJiYGvXv3xpIlS2oVldevX0d0dDTUarVd+T333MOvt8XxngwAoaGh/L23qWDrv8zMTDDG0KFDh2q+u3Dhgp3fACAxMbHGfRcVFWHWrFmIiIiAQqFAeHg4fzydTuejM2p8VFVVYdGiRXystVarRXh4OEpKSuza4erVq7W2aV2vmZoQi8Vo3bp1tfLs7GxMnjwZLVq0gEqlQnh4OPr37w/gf/7jrrva7E5ISMD999+P9evX82Xr16/Hgw8+6LdnJsW8egGRSORROXMYYEAQhGc466Ww7WkgquNOm3HpqZ577jlMmjTJ6X66du1aLzvGjRuHhx56CN9++y1+/vlnvPvuu3j77bexbds2r735aC73Xlv/Wa1WCAQC/PTTT07P3/btnjuMGzcOR44cwfz589G9e3eoVCpYrVYMGTKkWaUxe+WVV7BmzRrMnj0bSUlJ0Gg0EAgEmDBhgk/awVUPrMVicVouk8n4MTS2dQcOHIiioiK89tprSEhIQFBQEG7
|
||
|
"text/plain": [
|
||
|
"<Figure size 800x600 with 6 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# base_dir = './2nd_Battery/unet'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-mini'\n",
|
||
|
"base_dir = './2nd_Battery/unet-dice-loss'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-focal-loss'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-sgd'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-rmsprop'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-l1'\n",
|
||
|
"# base_dir = './2nd_Battery/unet-l2'\n",
|
||
|
"result_dir = os.path.join(base_dir, 'result')\n",
|
||
|
"\n",
|
||
|
"##\n",
|
||
|
"lst_data = os.listdir(os.path.join(result_dir, 'numpy'))\n",
|
||
|
"\n",
|
||
|
"lst_img = [f for f in lst_data if f.startswith('img')]\n",
|
||
|
"lst_gt = [f for f in lst_data if f.startswith('gt')]\n",
|
||
|
"lst_pr = [f for f in lst_data if f.startswith('pr')]\n",
|
||
|
"\n",
|
||
|
"lst_img.sort()\n",
|
||
|
"lst_gt.sort()\n",
|
||
|
"lst_pr.sort()\n",
|
||
|
"\n",
|
||
|
"avg_precision = 0\n",
|
||
|
"avg_recall = 0\n",
|
||
|
"avg_accuracy = 0\n",
|
||
|
"avg_f1 = 0\n",
|
||
|
"\n",
|
||
|
"##\n",
|
||
|
"id = 0\n",
|
||
|
"length = len(lst_img)\n",
|
||
|
"\n",
|
||
|
"for id in range(0, length):\n",
|
||
|
" img = np.load(os.path.join(result_dir,\"numpy\", lst_img[id]))\n",
|
||
|
" gt = np.load(os.path.join(result_dir,\"numpy\", lst_gt[id]))\n",
|
||
|
" pr = np.load(os.path.join(result_dir,\"numpy\", lst_pr[id]))\n",
|
||
|
"\n",
|
||
|
" img = np.uint8(img * 255)\n",
|
||
|
" gt = np.uint8(gt * 255)\n",
|
||
|
" pr = np.uint8(pr * 255)\n",
|
||
|
"\n",
|
||
|
" tp = gt & pr # True Positive: gt와 pr이 모두 1인 경우\n",
|
||
|
" fp = pr & ~gt # False Positive: pr은 1이지만 gt은 0인 경우\n",
|
||
|
" tn = ~gt & ~pr # True Negative: gt와 pr이 모두 0인 경우\n",
|
||
|
" fn = ~pr & gt # False Negative: pr은 0이지만 gt은 1인 경우\n",
|
||
|
"\n",
|
||
|
" precision = tp / (tp + fp) # precision = TP / (TP + FP)\n",
|
||
|
" recall = tp / (tp + fn) # recall = TP / (TP + FN)\n",
|
||
|
" accuracy = (tp + tn) / (tp + tn + fp + fn)\n",
|
||
|
" f1 = 2 * precision * recall / (precision + recall)\n",
|
||
|
"\n",
|
||
|
" min_value = np.min(gt)\n",
|
||
|
" max_value = np.max(gt)\n",
|
||
|
" normalized_f1 = ((f1 - min_value) / (max_value - min_value))\n",
|
||
|
"\n",
|
||
|
" s_tp = np.sum(tp) / len(tp.flatten())\n",
|
||
|
" s_fp = np.sum(fp) / len(fp.flatten())\n",
|
||
|
" s_tn = np.sum(tn) / len(tn.flatten())\n",
|
||
|
" s_fn = np.sum(fn) / len(fn.flatten())\n",
|
||
|
" s_precision = s_tp / (s_tp + s_fp)\n",
|
||
|
" s_recall = s_tp / (s_tp + s_fn)\n",
|
||
|
" s_accuracy = (s_tp + s_tn) / (s_tp + s_tn + s_fp + s_fn)\n",
|
||
|
" s_f1 = 2 * s_precision * s_recall / (s_precision + s_recall)\n",
|
||
|
"\n",
|
||
|
" avg_precision += s_precision\n",
|
||
|
" avg_recall += s_recall\n",
|
||
|
" avg_accuracy += s_accuracy\n",
|
||
|
" avg_f1 += s_f1\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"print(f\"precision: {avg_precision / length}\")\n",
|
||
|
"print(f\"recall: {avg_recall / length}\")\n",
|
||
|
"print(f\"accuracy: {avg_accuracy / length}\")\n",
|
||
|
"print(f\"f1: {avg_f1 / length}\")\n",
|
||
|
"\n",
|
||
|
"## 플롯 그리기\n",
|
||
|
"plt.figure(figsize=(8,6))\n",
|
||
|
"plt.subplot(161)\n",
|
||
|
"plt.imshow(img, cmap='gray')\n",
|
||
|
"plt.title('img')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(162)\n",
|
||
|
"plt.imshow(gt, cmap='gray')\n",
|
||
|
"plt.title('gt')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(163)\n",
|
||
|
"plt.imshow(pr, cmap='gray')\n",
|
||
|
"plt.title('pr')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(164)\n",
|
||
|
"plt.imshow(precision, cmap='gray')\n",
|
||
|
"plt.title('precision')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(165)\n",
|
||
|
"plt.imshow(recall, cmap='gray')\n",
|
||
|
"plt.title('recall')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(166)\n",
|
||
|
"plt.imshow(accuracy, cmap='gray')\n",
|
||
|
"plt.title('accuracy')\n",
|
||
|
"\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# UNet\n",
|
||
|
"LOSS 0.2072\n",
|
||
|
"\n",
|
||
|
"# UNet - Mini\n",
|
||
|
"LOSS 0.1324\n",
|
||
|
"\n",
|
||
|
"# UNet - Dice Loss\n",
|
||
|
"LOSS 0.3879\n",
|
||
|
"\n",
|
||
|
"# UNet - Focal Loss\n",
|
||
|
"LOSS 0.0112\n",
|
||
|
"\n",
|
||
|
"# UNet - SGD Opt\n",
|
||
|
"LOSS 0.1787\n",
|
||
|
"\n",
|
||
|
"# UNEt - RMSProp Opt\n",
|
||
|
"LOSS 0.1666\n",
|
||
|
"\n",
|
||
|
"# UNet - L1 Loss\n",
|
||
|
"LOSS 0.0357\n",
|
||
|
"\n",
|
||
|
"# UNet - L2 Loss\n",
|
||
|
"LOSS 0.0241\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# UNet - L1 + L2 Loss\n",
|
||
|
"LOSS 0.0550\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.11"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|