You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
20_Final_Project/unet_battery_test.ipynb

1279 lines
73 KiB
Plaintext

7 months ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load Dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\pinb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"from glob import glob\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import Dataset\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import transforms, datasets\n",
"import random\n",
"import cv2\n",
"\n",
"class CustomDataset(Dataset):\n",
" def __init__(self, list_imgs, list_masks, transform=None):\n",
" self.list_imgs = list_imgs\n",
" self.list_masks = list_masks\n",
" self.transform = transform\n",
"\n",
" def __len__(self):\n",
" return len(self.list_imgs)\n",
"\n",
" def __getitem__(self, index):\n",
" img_path = self.list_imgs[index]\n",
" mask_path = self.list_masks[index]\n",
"\n",
" img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)\n",
" mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)\n",
"\n",
" # 이미지 크기를 512x512로 변경\n",
" img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)\n",
" mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)\n",
"\n",
" img = img.astype(np.float32) / 255.0\n",
" mask = mask.astype(np.float32) / 255.0\n",
"\n",
" if img.ndim == 2:\n",
" img = img[:, :, np.newaxis]\n",
" if mask.ndim == 2:\n",
" mask = mask[:, :, np.newaxis]\n",
"\n",
" data = {'input': img, 'label': mask}\n",
"\n",
" if self.transform:\n",
" data = self.transform(data)\n",
" \n",
" return data\n",
"\n",
"def create_datasets(img_dir, mask_dir, train_ratio=0.7, val_ratio=0.2, transform=None):\n",
" list_imgs = sorted(glob(os.path.join(img_dir, '**', '*.bmp'), recursive=True))\n",
" list_masks = sorted(glob(os.path.join(mask_dir, '**', '*.bmp'), recursive=True))\n",
"\n",
" # combined = list(zip(list_imgs, list_masks))\n",
" # random.shuffle(combined)\n",
" # list_imgs, list_masks = zip(*combined)\n",
"\n",
" num_imgs = len(list_imgs)\n",
" num_train = int(num_imgs * train_ratio)\n",
" num_val = int(num_imgs * val_ratio)\n",
"\n",
" # train_set = CustomDataset(list_imgs[:num_train], list_masks[:num_train], transform)\n",
" # val_set = CustomDataset(list_imgs[num_train:num_train + num_val], list_masks[num_train:num_train + num_val], transform)\n",
" test_set = CustomDataset(list_imgs[num_train + num_val:], list_masks[num_train + num_val:], transform)\n",
"\n",
" # return train_set, val_set, test_set\n",
" return test_set\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Argument"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# 트렌스폼 구현하기\n",
"class ToTensor(object):\n",
" # def __call__(self, data):\n",
" # label, input = data['label'], data['input']\n",
"\n",
" # label = label.transpose((2, 0, 1)).astype(np.float32)\n",
" # input = input.transpose((2, 0, 1)).astype(np.float32)\n",
"\n",
" # data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}\n",
"\n",
" # return data\n",
" def __call__(self, data):\n",
" label, input = data['label'], data['input']\n",
"\n",
" # 이미지가 이미 그레이스케일이면 채널 차원 추가\n",
" if label.ndim == 2:\n",
" label = label[:, :, np.newaxis]\n",
" if input.ndim == 2:\n",
" input = input[:, :, np.newaxis]\n",
"\n",
" # 채널을 첫 번째 차원으로 이동\n",
" label = label.transpose((2, 0, 1)).astype(np.float32)\n",
" input = input.transpose((2, 0, 1)).astype(np.float32)\n",
"\n",
" data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}\n",
"\n",
" return data\n",
"\n",
"class Normalization(object):\n",
" def __init__(self, mean=0.5, std=0.5):\n",
" self.mean = mean\n",
" self.std = std\n",
"\n",
" def __call__(self, data):\n",
" label, input = data['label'], data['input']\n",
"\n",
" input = (input - self.mean) / self.std\n",
"\n",
" data = {'label': label, 'input': input}\n",
"\n",
" return data\n",
"\n",
"class RandomFlip(object):\n",
" def __call__(self, data):\n",
" label, input = data['label'], data['input']\n",
"\n",
" if np.random.rand() > 0.5:\n",
" label = np.fliplr(label)\n",
" input = np.fliplr(input)\n",
"\n",
" if np.random.rand() > 0.5:\n",
" label = np.flipud(label)\n",
" input = np.flipud(input)\n",
"\n",
" data = {'label': label, 'input': input}\n",
"\n",
" return data\n",
" \n",
"# class Resize(object):\n",
"# def __init__(self, output_size):\n",
"# assert isinstance(output_size, (int, tuple))\n",
"# self.output_size = output_size\n",
"\n",
"# def __call__(self, data):\n",
"# label, input = data['label'], data['input']\n",
"\n",
"# h, w = input.shape[:2]\n",
"# if isinstance(self.output_size, int):\n",
"# if h > w:\n",
"# new_h, new_w = self.output_size * h / w, self.output_size\n",
"# else:\n",
"# new_h, new_w = self.output_size, self.output_size * w / h\n",
"# else:\n",
"# new_h, new_w = self.output_size\n",
"\n",
"# new_h, new_w = int(new_h), int(new_w)\n",
"\n",
"# input = cv2.resize(input, (new_w, new_h))\n",
"# label = cv2.resize(label, (new_w, new_h))\n",
"\n",
"# return {'label': label, 'input': input}\n",
"\n",
"class Rotate(object):\n",
" def __init__(self, angle_range):\n",
" assert isinstance(angle_range, (tuple, list)) and len(angle_range) == 2\n",
" self.angle_min, self.angle_max = angle_range\n",
"\n",
" def __call__(self, data):\n",
" label, input = data['label'], data['input']\n",
"\n",
" # NumPy 배열로 변환 (필요한 경우)\n",
" if not isinstance(input, np.ndarray):\n",
" input = np.array(input)\n",
" if not isinstance(label, np.ndarray):\n",
" label = np.array(label)\n",
"\n",
" # (H, W, C) 형태를 (H, W)로 변경 (필요한 경우)\n",
" if input.ndim == 3 and input.shape[2] == 1:\n",
" input = input.squeeze(2)\n",
" if label.ndim == 3 and label.shape[2] == 1:\n",
" label = label.squeeze(2)\n",
"\n",
" # 랜덤 각도 선택 및 회전 적용\n",
" angle = np.random.uniform(self.angle_min, self.angle_max)\n",
" h, w = input.shape[:2]\n",
" center = (w / 2, h / 2)\n",
" rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)\n",
" input = cv2.warpAffine(input, rot_matrix, (w, h))\n",
" label = cv2.warpAffine(label, rot_matrix, (w, h))\n",
"\n",
" return {'label': label, 'input': input}\n",
" \n",
"# class Crop(object):\n",
"# def __init__(self, output_size):\n",
"# assert isinstance(output_size, (int, tuple))\n",
"# if isinstance(output_size, int):\n",
"# self.output_size = (output_size, output_size)\n",
"# else:\n",
"# assert len(output_size) == 2\n",
"# self.output_size = output_size\n",
"\n",
"# def __call__(self, data):\n",
"# label, input = data['label'], data['input']\n",
"\n",
"# h, w = input.shape[:2]\n",
"# new_h, new_w = self.output_size\n",
"\n",
"# top = np.random.randint(0, h - new_h)\n",
"# left = np.random.randint(0, w - new_w)\n",
"\n",
"# input = input[top: top + new_h, left: left + new_w]\n",
"# label = label[top: top + new_h, left: left + new_w]\n",
"\n",
"# return {'label': label, 'input': input}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# UNet Model (Origin)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"## 라이브러리 불러오기\n",
"import os\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"## 네트워크 구축하기\n",
"class UNet(nn.Module):\n",
" def __init__(self):\n",
" super(UNet, self).__init__()\n",
"\n",
" # Convolution + BatchNormalization + Relu 정의하기\n",
" def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): \n",
" layers = []\n",
" layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,\n",
" kernel_size=kernel_size, stride=stride, padding=padding,\n",
" bias=bias)]\n",
" layers += [nn.BatchNorm2d(num_features=out_channels)]\n",
" layers += [nn.ReLU()]\n",
"\n",
" cbr = nn.Sequential(*layers)\n",
"\n",
" return cbr\n",
"\n",
" # 수축 경로(Contracting path)\n",
" self.enc1_1 = CBR2d(in_channels=1, out_channels=64)\n",
" self.enc1_2 = CBR2d(in_channels=64, out_channels=64)\n",
"\n",
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc2_1 = CBR2d(in_channels=64, out_channels=128)\n",
" self.enc2_2 = CBR2d(in_channels=128, out_channels=128)\n",
"\n",
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc3_1 = CBR2d(in_channels=128, out_channels=256)\n",
" self.enc3_2 = CBR2d(in_channels=256, out_channels=256)\n",
"\n",
" self.pool3 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc4_1 = CBR2d(in_channels=256, out_channels=512)\n",
" self.enc4_2 = CBR2d(in_channels=512, out_channels=512)\n",
"\n",
" self.pool4 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)\n",
"\n",
" # 확장 경로(Expansive path)\n",
" self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)\n",
"\n",
" self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,\n",
" kernel_size=2, stride=2, padding=0, bias=True)\n",
"\n",
" self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512)\n",
" self.dec4_1 = CBR2d(in_channels=512, out_channels=256)\n",
"\n",
" self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,\n",
" kernel_size=2, stride=2, padding=0, bias=True)\n",
"\n",
" self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256)\n",
" self.dec3_1 = CBR2d(in_channels=256, out_channels=128)\n",
"\n",
" self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,\n",
" kernel_size=2, stride=2, padding=0, bias=True)\n",
"\n",
" self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128)\n",
" self.dec2_1 = CBR2d(in_channels=128, out_channels=64)\n",
"\n",
" self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,\n",
" kernel_size=2, stride=2, padding=0, bias=True)\n",
"\n",
" self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64)\n",
" self.dec1_1 = CBR2d(in_channels=64, out_channels=64)\n",
"\n",
" self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)\n",
" \n",
" # forward 함수 정의하기\n",
" def forward(self, x):\n",
" enc1_1 = self.enc1_1(x)\n",
" enc1_2 = self.enc1_2(enc1_1)\n",
" pool1 = self.pool1(enc1_2)\n",
"\n",
" enc2_1 = self.enc2_1(pool1)\n",
" enc2_2 = self.enc2_2(enc2_1)\n",
" pool2 = self.pool2(enc2_2)\n",
"\n",
" enc3_1 = self.enc3_1(pool2)\n",
" enc3_2 = self.enc3_2(enc3_1)\n",
" pool3 = self.pool3(enc3_2)\n",
"\n",
" enc4_1 = self.enc4_1(pool3)\n",
" enc4_2 = self.enc4_2(enc4_1)\n",
" pool4 = self.pool4(enc4_2)\n",
"\n",
" enc5_1 = self.enc5_1(pool4)\n",
"\n",
" dec5_1 = self.dec5_1(enc5_1)\n",
"\n",
" unpool4 = self.unpool4(dec5_1)\n",
" cat4 = torch.cat((unpool4, enc4_2), dim=1)\n",
" dec4_2 = self.dec4_2(cat4)\n",
" dec4_1 = self.dec4_1(dec4_2)\n",
"\n",
" unpool3 = self.unpool3(dec4_1)\n",
" cat3 = torch.cat((unpool3, enc3_2), dim=1)\n",
" dec3_2 = self.dec3_2(cat3)\n",
" dec3_1 = self.dec3_1(dec3_2)\n",
"\n",
" unpool2 = self.unpool2(dec3_1)\n",
" cat2 = torch.cat((unpool2, enc2_2), dim=1)\n",
" dec2_2 = self.dec2_2(cat2)\n",
" dec2_1 = self.dec2_1(dec2_2)\n",
"\n",
" unpool1 = self.unpool1(dec2_1)\n",
" cat1 = torch.cat((unpool1, enc1_2), dim=1)\n",
" dec1_2 = self.dec1_2(cat1)\n",
" dec1_1 = self.dec1_1(dec1_2)\n",
"\n",
" x = self.fc(dec1_1)\n",
"\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# UNet Model (Mini)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"## 라이브러리 불러오기\n",
"import os\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"## 네트워크 구축하기\n",
"class UNet(nn.Module):\n",
" def __init__(self):\n",
" super(UNet, self).__init__()\n",
"\n",
" # Convolution + BatchNormalization + Relu 정의하기\n",
" def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): \n",
" layers = []\n",
" layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,\n",
" kernel_size=kernel_size, stride=stride, padding=padding,\n",
" bias=bias)]\n",
" layers += [nn.BatchNorm2d(num_features=out_channels)]\n",
" layers += [nn.ReLU()]\n",
"\n",
" cbr = nn.Sequential(*layers)\n",
"\n",
" return cbr\n",
"\n",
" # 수축 경로(Contracting path)\n",
" self.enc1_1 = CBR2d(in_channels=1, out_channels=64)\n",
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc2_1 = CBR2d(in_channels=64, out_channels=128)\n",
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc3_1 = CBR2d(in_channels=128, out_channels=256)\n",
" self.pool3 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc4_1 = CBR2d(in_channels=256, out_channels=512)\n",
" self.pool4 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)\n",
"\n",
" # 확장 경로(Expansive path)의 깊이 감소\n",
" self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)\n",
" self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2)\n",
"\n",
" self.dec4_1 = CBR2d(in_channels=512 + 512, out_channels=256)\n",
" self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)\n",
"\n",
" self.dec3_1 = CBR2d(in_channels=256 + 256, out_channels=128)\n",
" self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)\n",
"\n",
" self.dec2_1 = CBR2d(in_channels=128 + 128, out_channels=64)\n",
" self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)\n",
"\n",
" self.dec1_1 = CBR2d(in_channels=64 + 64, out_channels=64)\n",
" self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)\n",
" \n",
" # forward 함수 정의하기\n",
" def forward(self, x):\n",
" enc1_1 = self.enc1_1(x)\n",
" pool1 = self.pool1(enc1_1)\n",
"\n",
" enc2_1 = self.enc2_1(pool1)\n",
" pool2 = self.pool2(enc2_1)\n",
"\n",
" enc3_1 = self.enc3_1(pool2)\n",
" pool3 = self.pool3(enc3_1)\n",
"\n",
" enc4_1 = self.enc4_1(pool3)\n",
" pool4 = self.pool4(enc4_1)\n",
"\n",
" enc5_1 = self.enc5_1(pool4)\n",
"\n",
" dec5_1 = self.dec5_1(enc5_1)\n",
"\n",
" unpool4 = self.unpool4(dec5_1)\n",
" cat4 = torch.cat((unpool4, enc4_1), dim=1)\n",
" dec4_1 = self.dec4_1(cat4)\n",
"\n",
" unpool3 = self.unpool3(dec4_1)\n",
" cat3 = torch.cat((unpool3, enc3_1), dim=1)\n",
" dec3_1 = self.dec3_1(cat3)\n",
"\n",
" unpool2 = self.unpool2(dec3_1)\n",
" cat2 = torch.cat((unpool2, enc2_1), dim=1)\n",
" dec2_1 = self.dec2_1(cat2)\n",
"\n",
" unpool1 = self.unpool1(dec2_1)\n",
" cat1 = torch.cat((unpool1, enc1_1), dim=1)\n",
" dec1_1 = self.dec1_1(cat1)\n",
"\n",
" x = self.fc(dec1_1)\n",
"\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model - Load, Save"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"## 네트워크 저장하기\n",
"def save(ckpt_dir, net, optim, epoch):\n",
" if not os.path.exists(ckpt_dir):\n",
" os.makedirs(ckpt_dir)\n",
"\n",
" torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},\n",
" \"%s/model_epoch%d.pth\" % (ckpt_dir, epoch))\n",
"\n",
"## 네트워크 불러오기\n",
"def load(ckpt_dir, net, optim):\n",
" if not os.path.exists(ckpt_dir):\n",
" epoch = 0\n",
" return net, optim, epoch\n",
"\n",
" ckpt_lst = os.listdir(ckpt_dir)\n",
" ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))\n",
"\n",
" dict_model = torch.load('%s/%s' % (ckpt_dir, ckpt_lst[-1]))\n",
"\n",
" net.load_state_dict(dict_model['net'])\n",
" optim.load_state_dict(dict_model['optim'])\n",
" epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])\n",
"\n",
" return net, optim, epoch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hyper Parameters"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# 훈련 파라미터 설정하기\n",
"lr = 1e-3\n",
"batch_size = 4\n",
"num_epoch = 10\n",
"\n",
"# base_dir = './2nd_Battery/unet'\n",
"# base_dir = './2nd_Battery/unet-mini'\n",
"base_dir = './2nd_Battery/unet-dice-loss'\n",
"# base_dir = './2nd_Battery/unet-focal-loss'\n",
"# base_dir = './2nd_Battery/unet-sgd'\n",
"# base_dir = './2nd_Battery/unet-rmsprop'\n",
"# base_dir = './2nd_Battery/unet-l1'\n",
"# base_dir = './2nd_Battery/unet-l2'\n",
"ckpt_dir = os.path.join(base_dir, \"checkpoint\")\n",
"log_dir = os.path.join(base_dir, \"log\")\n",
"\n",
"# 네트워크 생성하기\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"net = UNet().to(device)\n",
"\n",
"# 손실함수 정의하기\n",
"fn_loss = nn.BCEWithLogitsLoss().to(device)\n",
"\n",
"# Optimizer 설정하기\n",
"optim = torch.optim.Adam(net.parameters(), lr=lr)\n",
"\n",
"# 그 밖에 부수적인 functions 설정하기\n",
"fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1)\n",
"fn_denorm = lambda x, mean, std: (x * std) + mean\n",
"fn_class = lambda x: 1.0 * (x > 0.95)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TC - Dice Loss"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class DiceLoss(nn.Module):\n",
" def __init__(self, smooth=1e-6):\n",
" super(DiceLoss, self).__init__()\n",
" self.smooth = smooth\n",
"\n",
" def forward(self, preds, targets):\n",
" preds = torch.sigmoid(preds)\n",
" intersection = (preds * targets).sum()\n",
" dice = (2. * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)\n",
" return 1 - dice\n",
"\n",
"fn_loss = DiceLoss().to(device)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TC - Focal Loss"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class FocalLoss(nn.Module):\n",
" def __init__(self, alpha=0.8, gamma=2.0):\n",
" super(FocalLoss, self).__init__()\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
"\n",
" def forward(self, preds, targets):\n",
" BCE = nn.functional.binary_cross_entropy_with_logits(preds, targets, reduction='none')\n",
" BCE_exp = torch.exp(-BCE)\n",
" focal_loss = self.alpha * (1 - BCE_exp) ** self.gamma * BCE\n",
" return focal_loss.mean()\n",
"\n",
"fn_loss = FocalLoss().to(device)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TC - SGD"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TC - RMSProp"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"optim = torch.optim.RMSprop(net.parameters(), lr=lr, alpha=0.9)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TC - L1"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class L1Loss(nn.Module):\n",
" def __init__(self):\n",
" super(L1Loss, self).__init__()\n",
"\n",
" def forward(self, preds, targets):\n",
" return torch.mean(torch.abs(preds - targets))\n",
" \n",
"fn_loss = L1Loss().to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TC - L2"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"class L2Loss(nn.Module):\n",
" def __init__(self):\n",
" super(L2Loss, self).__init__()\n",
"\n",
" def forward(self, preds, targets):\n",
" return torch.mean((preds - targets) ** 2)\n",
" \n",
"fn_loss = L2Loss().to(device)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"dir_testset = 'C:/Users/pinb/Desktop/testimages/testset'\n",
"dir_groundtruth = 'C:/Users/pinb/Desktop/testimages/maskset'\n",
"# transform = transforms.Compose([Normalization(mean=0.5, std=0.5), RandomFlip(), Rotate(angle_range=(-90, 90)), ToTensor()])\n",
"transform = transforms.Compose([Normalization(mean=0.5, std=0.5), ToTensor()])\n",
"test_set = create_datasets(dir_testset, dir_groundtruth, 0, 0, transform=transform)\n",
"\n",
"# data = test_set.__getitem__(0) # 이미지 불러오기\n",
"\n",
"# input_img = data['input']\n",
"# label = data['label']\n",
"\n",
"# # 이미지 시각화\n",
"# plt.subplot(121)\n",
"# plt.imshow(input_img.reshape(input_img.shape[0], input_img.shape[1]), cmap='gray')\n",
"# plt.title('Input Image')\n",
"\n",
"# plt.subplot(122)\n",
"# plt.imshow(label.reshape(label.shape[0], label.shape[1]), cmap='gray')\n",
"# plt.title('Label')\n",
"\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"loader_test = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)\n",
"\n",
"# 그밖에 부수적인 variables 설정하기\n",
"num_data_test = len(test_set)\n",
"num_batch_test = np.ceil(num_data_test / batch_size)\n",
"\n",
"# 결과 디렉토리 생성하기\n",
"result_dir = os.path.join(base_dir, 'result')\n",
"if not os.path.exists(result_dir):\n",
" os.makedirs(os.path.join(result_dir, 'gt'))\n",
" os.makedirs(os.path.join(result_dir, 'img'))\n",
" os.makedirs(os.path.join(result_dir, 'pr'))\n",
" os.makedirs(os.path.join(result_dir, 'numpy'))\n",
"\n",
"net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TEST: BATCH 0001 / 0250 | LOSS 0.3965\n",
"TEST: BATCH 0002 / 0250 | LOSS 0.3255\n",
"TEST: BATCH 0003 / 0250 | LOSS 0.3926\n",
"TEST: BATCH 0004 / 0250 | LOSS 0.3913\n",
"TEST: BATCH 0005 / 0250 | LOSS 0.3963\n",
"TEST: BATCH 0006 / 0250 | LOSS 0.3929\n",
"TEST: BATCH 0007 / 0250 | LOSS 0.4026\n",
"TEST: BATCH 0008 / 0250 | LOSS 0.3988\n",
"TEST: BATCH 0009 / 0250 | LOSS 0.4022\n",
"TEST: BATCH 0010 / 0250 | LOSS 0.3956\n",
"TEST: BATCH 0011 / 0250 | LOSS 0.3933\n",
"TEST: BATCH 0012 / 0250 | LOSS 0.3834\n",
"TEST: BATCH 0013 / 0250 | LOSS 0.3889\n",
"TEST: BATCH 0014 / 0250 | LOSS 0.3885\n",
"TEST: BATCH 0015 / 0250 | LOSS 0.3923\n",
"TEST: BATCH 0016 / 0250 | LOSS 0.3851\n",
"TEST: BATCH 0017 / 0250 | LOSS 0.3819\n",
"TEST: BATCH 0018 / 0250 | LOSS 0.3872\n",
"TEST: BATCH 0019 / 0250 | LOSS 0.3840\n",
"TEST: BATCH 0020 / 0250 | LOSS 0.3840\n",
"TEST: BATCH 0021 / 0250 | LOSS 0.3858\n",
"TEST: BATCH 0022 / 0250 | LOSS 0.3819\n",
"TEST: BATCH 0023 / 0250 | LOSS 0.3796\n",
"TEST: BATCH 0024 / 0250 | LOSS 0.3749\n",
"TEST: BATCH 0025 / 0250 | LOSS 0.3713\n",
"TEST: BATCH 0026 / 0250 | LOSS 0.3668\n",
"TEST: BATCH 0027 / 0250 | LOSS 0.3637\n",
"TEST: BATCH 0028 / 0250 | LOSS 0.3670\n",
"TEST: BATCH 0029 / 0250 | LOSS 0.3629\n",
"TEST: BATCH 0030 / 0250 | LOSS 0.3630\n",
"TEST: BATCH 0031 / 0250 | LOSS 0.3604\n",
"TEST: BATCH 0032 / 0250 | LOSS 0.3624\n",
"TEST: BATCH 0033 / 0250 | LOSS 0.3675\n",
"TEST: BATCH 0034 / 0250 | LOSS 0.3665\n",
"TEST: BATCH 0035 / 0250 | LOSS 0.3683\n",
"TEST: BATCH 0036 / 0250 | LOSS 0.3713\n",
"TEST: BATCH 0037 / 0250 | LOSS 0.3750\n",
"TEST: BATCH 0038 / 0250 | LOSS 0.3744\n",
"TEST: BATCH 0039 / 0250 | LOSS 0.3734\n",
"TEST: BATCH 0040 / 0250 | LOSS 0.3742\n",
"TEST: BATCH 0041 / 0250 | LOSS 0.3724\n",
"TEST: BATCH 0042 / 0250 | LOSS 0.3735\n",
"TEST: BATCH 0043 / 0250 | LOSS 0.3712\n",
"TEST: BATCH 0044 / 0250 | LOSS 0.3719\n",
"TEST: BATCH 0045 / 0250 | LOSS 0.3730\n",
"TEST: BATCH 0046 / 0250 | LOSS 0.3756\n",
"TEST: BATCH 0047 / 0250 | LOSS 0.3745\n",
"TEST: BATCH 0048 / 0250 | LOSS 0.3750\n",
"TEST: BATCH 0049 / 0250 | LOSS 0.3743\n",
"TEST: BATCH 0050 / 0250 | LOSS 0.3746\n",
"TEST: BATCH 0051 / 0250 | LOSS 0.3741\n",
"TEST: BATCH 0052 / 0250 | LOSS 0.3739\n",
"TEST: BATCH 0053 / 0250 | LOSS 0.3728\n",
"TEST: BATCH 0054 / 0250 | LOSS 0.3740\n",
"TEST: BATCH 0055 / 0250 | LOSS 0.3737\n",
"TEST: BATCH 0056 / 0250 | LOSS 0.3734\n",
"TEST: BATCH 0057 / 0250 | LOSS 0.3737\n",
"TEST: BATCH 0058 / 0250 | LOSS 0.3753\n",
"TEST: BATCH 0059 / 0250 | LOSS 0.3751\n",
"TEST: BATCH 0060 / 0250 | LOSS 0.3742\n",
"TEST: BATCH 0061 / 0250 | LOSS 0.3749\n",
"TEST: BATCH 0062 / 0250 | LOSS 0.3773\n",
"TEST: BATCH 0063 / 0250 | LOSS 0.3777\n",
"TEST: BATCH 0064 / 0250 | LOSS 0.3785\n",
"TEST: BATCH 0065 / 0250 | LOSS 0.3801\n",
"TEST: BATCH 0066 / 0250 | LOSS 0.3787\n",
"TEST: BATCH 0067 / 0250 | LOSS 0.3781\n",
"TEST: BATCH 0068 / 0250 | LOSS 0.3788\n",
"TEST: BATCH 0069 / 0250 | LOSS 0.3800\n",
"TEST: BATCH 0070 / 0250 | LOSS 0.3795\n",
"TEST: BATCH 0071 / 0250 | LOSS 0.3797\n",
"TEST: BATCH 0072 / 0250 | LOSS 0.3793\n",
"TEST: BATCH 0073 / 0250 | LOSS 0.3797\n",
"TEST: BATCH 0074 / 0250 | LOSS 0.3800\n",
"TEST: BATCH 0075 / 0250 | LOSS 0.3795\n",
"TEST: BATCH 0076 / 0250 | LOSS 0.3802\n",
"TEST: BATCH 0077 / 0250 | LOSS 0.3791\n",
"TEST: BATCH 0078 / 0250 | LOSS 0.3789\n",
"TEST: BATCH 0079 / 0250 | LOSS 0.3807\n",
"TEST: BATCH 0080 / 0250 | LOSS 0.3817\n",
"TEST: BATCH 0081 / 0250 | LOSS 0.3813\n",
"TEST: BATCH 0082 / 0250 | LOSS 0.3818\n",
"TEST: BATCH 0083 / 0250 | LOSS 0.3824\n",
"TEST: BATCH 0084 / 0250 | LOSS 0.3840\n",
"TEST: BATCH 0085 / 0250 | LOSS 0.3825\n",
"TEST: BATCH 0086 / 0250 | LOSS 0.3827\n",
"TEST: BATCH 0087 / 0250 | LOSS 0.3823\n",
"TEST: BATCH 0088 / 0250 | LOSS 0.3820\n",
"TEST: BATCH 0089 / 0250 | LOSS 0.3821\n",
"TEST: BATCH 0090 / 0250 | LOSS 0.3822\n",
"TEST: BATCH 0091 / 0250 | LOSS 0.3821\n",
"TEST: BATCH 0092 / 0250 | LOSS 0.3827\n",
"TEST: BATCH 0093 / 0250 | LOSS 0.3826\n",
"TEST: BATCH 0094 / 0250 | LOSS 0.3826\n",
"TEST: BATCH 0095 / 0250 | LOSS 0.3841\n",
"TEST: BATCH 0096 / 0250 | LOSS 0.3844\n",
"TEST: BATCH 0097 / 0250 | LOSS 0.3822\n",
"TEST: BATCH 0098 / 0250 | LOSS 0.3822\n",
"TEST: BATCH 0099 / 0250 | LOSS 0.3838\n",
"TEST: BATCH 0100 / 0250 | LOSS 0.3844\n",
"TEST: BATCH 0101 / 0250 | LOSS 0.3837\n",
"TEST: BATCH 0102 / 0250 | LOSS 0.3838\n",
"TEST: BATCH 0103 / 0250 | LOSS 0.3841\n",
"TEST: BATCH 0104 / 0250 | LOSS 0.3848\n",
"TEST: BATCH 0105 / 0250 | LOSS 0.3858\n",
"TEST: BATCH 0106 / 0250 | LOSS 0.3860\n",
"TEST: BATCH 0107 / 0250 | LOSS 0.3857\n",
"TEST: BATCH 0108 / 0250 | LOSS 0.3867\n",
"TEST: BATCH 0109 / 0250 | LOSS 0.3870\n",
"TEST: BATCH 0110 / 0250 | LOSS 0.3871\n",
"TEST: BATCH 0111 / 0250 | LOSS 0.3871\n",
"TEST: BATCH 0112 / 0250 | LOSS 0.3877\n",
"TEST: BATCH 0113 / 0250 | LOSS 0.3886\n",
"TEST: BATCH 0114 / 0250 | LOSS 0.3885\n",
"TEST: BATCH 0115 / 0250 | LOSS 0.3892\n",
"TEST: BATCH 0116 / 0250 | LOSS 0.3893\n",
"TEST: BATCH 0117 / 0250 | LOSS 0.3906\n",
"TEST: BATCH 0118 / 0250 | LOSS 0.3905\n",
"TEST: BATCH 0119 / 0250 | LOSS 0.3903\n",
"TEST: BATCH 0120 / 0250 | LOSS 0.3891\n",
"TEST: BATCH 0121 / 0250 | LOSS 0.3886\n",
"TEST: BATCH 0122 / 0250 | LOSS 0.3870\n",
"TEST: BATCH 0123 / 0250 | LOSS 0.3876\n",
"TEST: BATCH 0124 / 0250 | LOSS 0.3867\n",
"TEST: BATCH 0125 / 0250 | LOSS 0.3861\n",
"TEST: BATCH 0126 / 0250 | LOSS 0.3864\n",
"TEST: BATCH 0127 / 0250 | LOSS 0.3867\n",
"TEST: BATCH 0128 / 0250 | LOSS 0.3859\n",
"TEST: BATCH 0129 / 0250 | LOSS 0.3869\n",
"TEST: BATCH 0130 / 0250 | LOSS 0.3871\n",
"TEST: BATCH 0131 / 0250 | LOSS 0.3870\n",
"TEST: BATCH 0132 / 0250 | LOSS 0.3872\n",
"TEST: BATCH 0133 / 0250 | LOSS 0.3864\n",
"TEST: BATCH 0134 / 0250 | LOSS 0.3869\n",
"TEST: BATCH 0135 / 0250 | LOSS 0.3859\n",
"TEST: BATCH 0136 / 0250 | LOSS 0.3864\n",
"TEST: BATCH 0137 / 0250 | LOSS 0.3864\n",
"TEST: BATCH 0138 / 0250 | LOSS 0.3862\n",
"TEST: BATCH 0139 / 0250 | LOSS 0.3859\n",
"TEST: BATCH 0140 / 0250 | LOSS 0.3863\n",
"TEST: BATCH 0141 / 0250 | LOSS 0.3875\n",
"TEST: BATCH 0142 / 0250 | LOSS 0.3874\n",
"TEST: BATCH 0143 / 0250 | LOSS 0.3868\n",
"TEST: BATCH 0144 / 0250 | LOSS 0.3866\n",
"TEST: BATCH 0145 / 0250 | LOSS 0.3860\n",
"TEST: BATCH 0146 / 0250 | LOSS 0.3858\n",
"TEST: BATCH 0147 / 0250 | LOSS 0.3859\n",
"TEST: BATCH 0148 / 0250 | LOSS 0.3861\n",
"TEST: BATCH 0149 / 0250 | LOSS 0.3863\n",
"TEST: BATCH 0150 / 0250 | LOSS 0.3861\n",
"TEST: BATCH 0151 / 0250 | LOSS 0.3863\n",
"TEST: BATCH 0152 / 0250 | LOSS 0.3864\n",
"TEST: BATCH 0153 / 0250 | LOSS 0.3853\n",
"TEST: BATCH 0154 / 0250 | LOSS 0.3859\n",
"TEST: BATCH 0155 / 0250 | LOSS 0.3852\n",
"TEST: BATCH 0156 / 0250 | LOSS 0.3852\n",
"TEST: BATCH 0157 / 0250 | LOSS 0.3855\n",
"TEST: BATCH 0158 / 0250 | LOSS 0.3847\n",
"TEST: BATCH 0159 / 0250 | LOSS 0.3840\n",
"TEST: BATCH 0160 / 0250 | LOSS 0.3835\n",
"TEST: BATCH 0161 / 0250 | LOSS 0.3840\n",
"TEST: BATCH 0162 / 0250 | LOSS 0.3844\n",
"TEST: BATCH 0163 / 0250 | LOSS 0.3842\n",
"TEST: BATCH 0164 / 0250 | LOSS 0.3830\n",
"TEST: BATCH 0165 / 0250 | LOSS 0.3832\n",
"TEST: BATCH 0166 / 0250 | LOSS 0.3833\n",
"TEST: BATCH 0167 / 0250 | LOSS 0.3833\n",
"TEST: BATCH 0168 / 0250 | LOSS 0.3838\n",
"TEST: BATCH 0169 / 0250 | LOSS 0.3848\n",
"TEST: BATCH 0170 / 0250 | LOSS 0.3849\n",
"TEST: BATCH 0171 / 0250 | LOSS 0.3848\n",
"TEST: BATCH 0172 / 0250 | LOSS 0.3847\n",
"TEST: BATCH 0173 / 0250 | LOSS 0.3845\n",
"TEST: BATCH 0174 / 0250 | LOSS 0.3841\n",
"TEST: BATCH 0175 / 0250 | LOSS 0.3843\n",
"TEST: BATCH 0176 / 0250 | LOSS 0.3841\n",
"TEST: BATCH 0177 / 0250 | LOSS 0.3842\n",
"TEST: BATCH 0178 / 0250 | LOSS 0.3844\n",
"TEST: BATCH 0179 / 0250 | LOSS 0.3841\n",
"TEST: BATCH 0180 / 0250 | LOSS 0.3836\n",
"TEST: BATCH 0181 / 0250 | LOSS 0.3840\n",
"TEST: BATCH 0182 / 0250 | LOSS 0.3843\n",
"TEST: BATCH 0183 / 0250 | LOSS 0.3849\n",
"TEST: BATCH 0184 / 0250 | LOSS 0.3855\n",
"TEST: BATCH 0185 / 0250 | LOSS 0.3857\n",
"TEST: BATCH 0186 / 0250 | LOSS 0.3859\n",
"TEST: BATCH 0187 / 0250 | LOSS 0.3862\n",
"TEST: BATCH 0188 / 0250 | LOSS 0.3857\n",
"TEST: BATCH 0189 / 0250 | LOSS 0.3854\n",
"TEST: BATCH 0190 / 0250 | LOSS 0.3859\n",
"TEST: BATCH 0191 / 0250 | LOSS 0.3862\n",
"TEST: BATCH 0192 / 0250 | LOSS 0.3868\n",
"TEST: BATCH 0193 / 0250 | LOSS 0.3870\n",
"TEST: BATCH 0194 / 0250 | LOSS 0.3867\n",
"TEST: BATCH 0195 / 0250 | LOSS 0.3863\n",
"TEST: BATCH 0196 / 0250 | LOSS 0.3869\n",
"TEST: BATCH 0197 / 0250 | LOSS 0.3871\n",
"TEST: BATCH 0198 / 0250 | LOSS 0.3877\n",
"TEST: BATCH 0199 / 0250 | LOSS 0.3874\n",
"TEST: BATCH 0200 / 0250 | LOSS 0.3869\n",
"TEST: BATCH 0201 / 0250 | LOSS 0.3867\n",
"TEST: BATCH 0202 / 0250 | LOSS 0.3869\n",
"TEST: BATCH 0203 / 0250 | LOSS 0.3871\n",
"TEST: BATCH 0204 / 0250 | LOSS 0.3871\n",
"TEST: BATCH 0205 / 0250 | LOSS 0.3862\n",
"TEST: BATCH 0206 / 0250 | LOSS 0.3867\n",
"TEST: BATCH 0207 / 0250 | LOSS 0.3871\n",
"TEST: BATCH 0208 / 0250 | LOSS 0.3875\n",
"TEST: BATCH 0209 / 0250 | LOSS 0.3874\n",
"TEST: BATCH 0210 / 0250 | LOSS 0.3872\n",
"TEST: BATCH 0211 / 0250 | LOSS 0.3875\n",
"TEST: BATCH 0212 / 0250 | LOSS 0.3878\n",
"TEST: BATCH 0213 / 0250 | LOSS 0.3874\n",
"TEST: BATCH 0214 / 0250 | LOSS 0.3873\n",
"TEST: BATCH 0215 / 0250 | LOSS 0.3877\n",
"TEST: BATCH 0216 / 0250 | LOSS 0.3881\n",
"TEST: BATCH 0217 / 0250 | LOSS 0.3875\n",
"TEST: BATCH 0218 / 0250 | LOSS 0.3879\n",
"TEST: BATCH 0219 / 0250 | LOSS 0.3872\n",
"TEST: BATCH 0220 / 0250 | LOSS 0.3865\n",
"TEST: BATCH 0221 / 0250 | LOSS 0.3870\n",
"TEST: BATCH 0222 / 0250 | LOSS 0.3873\n",
"TEST: BATCH 0223 / 0250 | LOSS 0.3876\n",
"TEST: BATCH 0224 / 0250 | LOSS 0.3872\n",
"TEST: BATCH 0225 / 0250 | LOSS 0.3870\n",
"TEST: BATCH 0226 / 0250 | LOSS 0.3870\n",
"TEST: BATCH 0227 / 0250 | LOSS 0.3870\n",
"TEST: BATCH 0228 / 0250 | LOSS 0.3875\n",
"TEST: BATCH 0229 / 0250 | LOSS 0.3878\n",
"TEST: BATCH 0230 / 0250 | LOSS 0.3881\n",
"TEST: BATCH 0231 / 0250 | LOSS 0.3879\n",
"TEST: BATCH 0232 / 0250 | LOSS 0.3878\n",
"TEST: BATCH 0233 / 0250 | LOSS 0.3874\n",
"TEST: BATCH 0234 / 0250 | LOSS 0.3877\n",
"TEST: BATCH 0235 / 0250 | LOSS 0.3876\n",
"TEST: BATCH 0236 / 0250 | LOSS 0.3875\n",
"TEST: BATCH 0237 / 0250 | LOSS 0.3876\n",
"TEST: BATCH 0238 / 0250 | LOSS 0.3874\n",
"TEST: BATCH 0239 / 0250 | LOSS 0.3877\n",
"TEST: BATCH 0240 / 0250 | LOSS 0.3883\n",
"TEST: BATCH 0241 / 0250 | LOSS 0.3881\n",
"TEST: BATCH 0242 / 0250 | LOSS 0.3882\n",
"TEST: BATCH 0243 / 0250 | LOSS 0.3882\n",
"TEST: BATCH 0244 / 0250 | LOSS 0.3878\n",
"TEST: BATCH 0245 / 0250 | LOSS 0.3884\n",
"TEST: BATCH 0246 / 0250 | LOSS 0.3887\n",
"TEST: BATCH 0247 / 0250 | LOSS 0.3890\n",
"TEST: BATCH 0248 / 0250 | LOSS 0.3886\n",
"TEST: BATCH 0249 / 0250 | LOSS 0.3883\n",
"TEST: BATCH 0250 / 0250 | LOSS 0.3879\n",
"AVERAGE TEST: BATCH 0250 / 0250 | LOSS 0.3879\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" net.eval()\n",
" loss_arr = []\n",
"\n",
" for batch, data in enumerate(loader_test, 1):\n",
" # forward pass\n",
" label = data['label'].to(device)\n",
" input = data['input'].to(device)\n",
"\n",
" output = net(input)\n",
"\n",
" # 손실함수 계산하기\n",
" loss = fn_loss(output, label)\n",
"\n",
" loss_arr += [loss.item()]\n",
"\n",
" print(\"TEST: BATCH %04d / %04d | LOSS %.4f\" %\n",
" (batch, num_batch_test, np.mean(loss_arr)))\n",
"\n",
" # Tensorboard 저장하기\n",
" label = fn_tonumpy(label)\n",
" input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))\n",
" output = fn_tonumpy(fn_class(output))\n",
"\n",
" # 테스트 결과 저장하기\n",
" for j in range(label.shape[0]):\n",
" id = num_batch_test * (batch - 1) + j\n",
"\n",
" gt = label[j].squeeze()\n",
" img = input[j].squeeze()\n",
" pr = output[j].squeeze()\n",
"\n",
" plt.imsave(os.path.join(result_dir, 'gt', 'gt_%04d.png' % id), gt, cmap='gray')\n",
" plt.imsave(os.path.join(result_dir, 'img', 'img_%04d.png' % id), img, cmap='gray')\n",
" plt.imsave(os.path.join(result_dir, 'pr', 'pr_%04d.png' % id), pr, cmap='gray')\n",
" np.save(os.path.join(result_dir, 'numpy', 'gt_%04d.npy' % id), gt)\n",
" np.save(os.path.join(result_dir, 'numpy', 'img_%04d.npy' % id), img)\n",
" np.save(os.path.join(result_dir, 'numpy', 'pr_%04d.npy' % id), pr)\n",
"\n",
"print(\"AVERAGE TEST: BATCH %04d / %04d | LOSS %.4f\" %\n",
" (batch, num_batch_test, np.mean(loss_arr)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualize"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\pinb\\AppData\\Local\\Temp\\ipykernel_19912\\3510449017.py:45: RuntimeWarning: invalid value encountered in divide\n",
" precision = tp / (tp + fp) # precision = TP / (TP + FP)\n",
"C:\\Users\\pinb\\AppData\\Local\\Temp\\ipykernel_19912\\3510449017.py:46: RuntimeWarning: invalid value encountered in divide\n",
" recall = tp / (tp + fn) # recall = TP / (TP + FN)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"precision: 0.7764027600652067\n",
"recall: 0.7843549615385272\n",
"accuracy: 0.9770164763057941\n",
"f1: 0.7741721124958945\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq8AAACaCAYAAACHSaGqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABP6UlEQVR4nO2deXgUVdbG396XdKcT0iELJAHCEjEsAqIBFNSwiaDsuLHIiEJQtmEchm9YRIn7KMigoyOgMoAsMi4IssoAQZA9YQtICBCykaSz9n6/P5iq6e50J91JdzqdnN/z1JP0rVtVp+7pqnr71rnnChhjDARBEARBEAQRAAj9bQBBEARBEARBuAuJV4IgCIIgCCJgIPFKEARBEARBBAwkXgmCIAiCIIiAgcQrQRAEQRAEETCQeCUIgiAIgiACBhKvBEEQBEEQRMBA4pUgCIIgCIIIGEi8EgRBEARBEAEDiVcPWbt2LQQCAbKysvxtCkEQRKNHIBBgyZIlHm0zefJktGnTxif2EN7F2TNxwIABGDBggN9sIpo+JF4JogaWL1+O7du3+9sMgiAIgiD+i9jfBgQazz//PCZMmACZTOZvU4gGYPny5RgzZgyeeuopf5tCEAFJVVUVxGLPHjWfffYZrFarjywiCCLQoZ5XDxGJRJDL5RAIBP42hSAIF1itVuj1en+bEVD4qs3kcrnH4lUikVAHgYdUVFT42wSiEdBcvgckXj3EMb6nTZs2eOKJJ3DgwAH06tULCoUCXbp0wYEDBwAA27ZtQ5cuXSCXy9GzZ0+cOnWq2j43b96Mzp07Qy6XIzExEd9++y3FfDUAnM/kcjni4+Px6aefYsmSJfwPE4FAgIqKCqxbtw4CgQACgQCTJ0/2r9HNDM4fFy9exLhx4xAcHIywsDDMmjXLTmgJBALMnDkT69evx7333guZTIadO3f60XL/4Y02u3XrFl544QVERERAJpPh3nvvxRdffFHtWHq9HkuWLEHHjh0hl8sRFRWFUaNG4erVq3bHsY15LSsrw+zZs9GmTRvIZDK0bNkSAwcOxMmTJ/k6zu5/FRUVmDdvHmJiYiCTydCpUye89957YIzZ1ePOa/v27UhMTOTtb0rfB87H58+fxzPPPIPQ0FD069cPAPD111+jZ8+eUCgUaNGiBSZMmIAbN25U28evv/6Kxx9/HKGhoQgKCkLXrl3x0Ucf8evPnj2LyZMno127dpDL5YiMjMQLL7yAO3fuNNh5NgauX7+OGTNmoFOnTlAoFAgLC8PYsWOdjnspKSnBnDlz+O9269atMXHiRBQWFvJ1artmDhw4AIFAwGsIjqysLAgEAqxdu5Yvmzx5MlQqFa5evYrHH38carUazz77LADgP//5D8aOHYvY2FjIZDLExMRgzpw5qKqqqmY3d68IDw+HQqFAp06dsHDhQgDA/v37IRAI8O2331bb7l//+hcEAgHS0tI8bdZ6Q2EDXuDKlSt45pln8NJLL+G5557De++9h+HDh+OTTz7BX/7yF8yYMQMAkJqainHjxuHSpUsQCu/+bvjxxx8xfvx4dOnSBampqSguLsbUqVPRqlUrf55Sk+fUqVMYMmQIoqKisHTpUlgsFrz++usIDw/n63z11Vf4wx/+gN69e2PatGkAgPj4eH+Z3KwZN24c2rRpg9TUVBw9ehQrVqxAcXExvvzyS77Ovn378M0332DmzJnQarXN/sdfXdssLy8PDz74IC8Cw8PD8dNPP2Hq1KkoLS3F7NmzAQAWiwVPPPEE9u7diwkTJmDWrFkoKyvD7t27kZ6e7vJaefnll7FlyxbMnDkTnTt3xp07d3Do0CFcuHABPXr0cLoNYwwjRozA/v37MXXqVHTv3h27du3C/PnzcevWLfztb3+zq3/o0CFs27YNM2bMgFqtxooVKzB69GhkZ2cjLCzMOw3cCBg7diw6dOiA5cuXgzGGN998E3/9618xbtw4/OEPf0BBQQFWrlyJhx9+GKdOnUJISAgAYPfu3XjiiScQFRWFWbNmITIyEhcuXMAPP/yAWbNm8XV+//13TJkyBZGRkcjIyMA//vEPZGRk4OjRo83m7ePx48dx5MgRTJgwAa1bt0ZWVhZWr16NAQMG4Pz581AqlQCA8vJyPPTQQ7hw4QJeeOEF9OjRA4WFhfjuu+9w8+ZNaLXaOl8zNWE2mzF48GD069cP7733Hm/P5s2bUVlZienTpyMsLAzHjh3DypUrcfPmTWzevJnf/uzZs3jooYcgkUgwbdo0tGnTBlevXsX333+PN998EwMGDEBMTAzWr1+PkSNH2h17/fr1iI+PR1JSUj1auI4wwiPWrFnDALBr164xxhiLi4tjANiRI0f4Ort27WIAmEKhYNevX+fLP/30UwaA7d+/ny/r0qULa926NSsrK+PLDhw4wACwuLg4X59Os2X48OFMqVSyW7du8WWZmZlMLBYz28siKCiITZo0yQ8WEowxtnjxYgaAjRgxwq58xowZDAA7c+YMY4wxAEwoFLKMjAx/mNmoqG+bTZ06lUVFRbHCwkK78gkTJjCNRsMqKysZY4x98cUXDAD74IMPqtlgtVr5/wGwxYsX8581Gg1LSUmp8RwmTZpkd//bvn07A8DeeOMNu3pjxoxhAoGAXblyxe54UqnUruzMmTMMAFu5cmWNxw0UOB8//fTTfFlWVhYTiUTszTfftKt77tw5JhaL+XKz2czatm3L4uLiWHFxsV1dW79xfrZlw4YNDAA7ePAgX+b4TGSMsf79+7P+/fvX4wwbD87aIS0tjQFgX375JV+2aNEiBoBt27atWn2uXd25Zvbv319NJzDG2LVr1xgAtmbNGr5s0qRJDAD785//7JbdqampTCAQ2OmShx9+mKnVarsyW3sYY2zBggVMJpOxkpISviw/P5+JxWK7a7shobABL9C5c2e7Xx4PPPAAAODRRx9FbGxstfLff/8dAJCTk4Nz585h4sSJUKlUfL3+/fujS5cuDWF6s8RisWDPnj146qmnEB0dzZe3b98eQ4cO9aNlhCtSUlLsPr/yyisAgB07dvBl/fv3R+fOnRvUrsZMXdqMMYatW7di+PDhYIyhsLCQXwYPHgydTse/3t+6dSu0Wi2/X1tq6pULCQnBr7/+ipycHLfPZceOHRCJRHj11VftyufNmwfGGH766Se78uTkZLterK5duyI4OJi/9zYVXn75Zf7/bdu2wWq1Yty4cXZ+i4yMRIcOHbB//34Ad986Xbt2DbNnz+Z7Yjls/aZQKPj/9Xo9CgsL8eCDDwKAXYhHU8e2HUwmE+7cuYP27dsjJCTErh22bt2Kbt26VeudBP7XrnW9Zmpj+vTpNdpdUVGBwsJC9OnTB4wxPnyxoKAABw8exAsvvGCnVRztmThxIgwGA7Zs2cKXbdq0CWazGc8991yd7a4PJF69gKPTNRoNACAmJsZpeXFxMYC7sTTAXdHkiLMywjvk5+ejqqqK2j2A6NChg93n+Ph4CIVCu7iztm3bNrBVjZu6tFlBQQFKSkrwj3/8A+Hh4XbLlClTANy9fgDg6tWr6NSpk8eDsd555x2kp6cjJiYGvXv3xpIlS2oVldevX0d0dDTUarVd+T333MOvt8XxngwAoaGh/L23qWDrv8zMTDDG0KFDh2q+u3Dhgp3fACAxMbHGfRcVFWHWrFmIiIiAQqFAeHg4fzydTuejM2p8VFVVYdGiRXystVarRXh4OEpKSuza4erVq7W2aV2vmZoQi8Vo3bp1tfLs7GxMnjwZLVq0gEqlQnh4OPr37w/gf/7jrrva7E5ISMD999+P9evX82Xr16/Hgw8+6LdnJsW8egGRSORROXMYYEAQhGc466Ww7WkgquNOm3HpqZ577jlMmjTJ6X66du1aLzvGjRuHhx56CN9++y1+/vlnvPvuu3j77bexbds2r735aC73Xlv/Wa1WCAQC/PTTT07P3/btnjuMGzcOR44cwfz589G9e3eoVCpYrVYMGTKkWaUxe+WVV7BmzRrMnj0bSUlJ0Gg0EAgEmDBhgk/awVUPrMVicVouk8n4MTS2dQcOHIiioiK89tprSEhIQFBQEG7
"text/plain": [
"<Figure size 800x600 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# base_dir = './2nd_Battery/unet'\n",
"# base_dir = './2nd_Battery/unet-mini'\n",
"base_dir = './2nd_Battery/unet-dice-loss'\n",
"# base_dir = './2nd_Battery/unet-focal-loss'\n",
"# base_dir = './2nd_Battery/unet-sgd'\n",
"# base_dir = './2nd_Battery/unet-rmsprop'\n",
"# base_dir = './2nd_Battery/unet-l1'\n",
"# base_dir = './2nd_Battery/unet-l2'\n",
"result_dir = os.path.join(base_dir, 'result')\n",
"\n",
"##\n",
"lst_data = os.listdir(os.path.join(result_dir, 'numpy'))\n",
"\n",
"lst_img = [f for f in lst_data if f.startswith('img')]\n",
"lst_gt = [f for f in lst_data if f.startswith('gt')]\n",
"lst_pr = [f for f in lst_data if f.startswith('pr')]\n",
"\n",
"lst_img.sort()\n",
"lst_gt.sort()\n",
"lst_pr.sort()\n",
"\n",
"avg_precision = 0\n",
"avg_recall = 0\n",
"avg_accuracy = 0\n",
"avg_f1 = 0\n",
"\n",
"##\n",
"id = 0\n",
"length = len(lst_img)\n",
"\n",
"for id in range(0, length):\n",
" img = np.load(os.path.join(result_dir,\"numpy\", lst_img[id]))\n",
" gt = np.load(os.path.join(result_dir,\"numpy\", lst_gt[id]))\n",
" pr = np.load(os.path.join(result_dir,\"numpy\", lst_pr[id]))\n",
"\n",
" img = np.uint8(img * 255)\n",
" gt = np.uint8(gt * 255)\n",
" pr = np.uint8(pr * 255)\n",
"\n",
" tp = gt & pr # True Positive: gt와 pr이 모두 1인 경우\n",
" fp = pr & ~gt # False Positive: pr은 1이지만 gt은 0인 경우\n",
" tn = ~gt & ~pr # True Negative: gt와 pr이 모두 0인 경우\n",
" fn = ~pr & gt # False Negative: pr은 0이지만 gt은 1인 경우\n",
"\n",
" precision = tp / (tp + fp) # precision = TP / (TP + FP)\n",
" recall = tp / (tp + fn) # recall = TP / (TP + FN)\n",
" accuracy = (tp + tn) / (tp + tn + fp + fn)\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
"\n",
" min_value = np.min(gt)\n",
" max_value = np.max(gt)\n",
" normalized_f1 = ((f1 - min_value) / (max_value - min_value))\n",
"\n",
" s_tp = np.sum(tp) / len(tp.flatten())\n",
" s_fp = np.sum(fp) / len(fp.flatten())\n",
" s_tn = np.sum(tn) / len(tn.flatten())\n",
" s_fn = np.sum(fn) / len(fn.flatten())\n",
" s_precision = s_tp / (s_tp + s_fp)\n",
" s_recall = s_tp / (s_tp + s_fn)\n",
" s_accuracy = (s_tp + s_tn) / (s_tp + s_tn + s_fp + s_fn)\n",
" s_f1 = 2 * s_precision * s_recall / (s_precision + s_recall)\n",
"\n",
" avg_precision += s_precision\n",
" avg_recall += s_recall\n",
" avg_accuracy += s_accuracy\n",
" avg_f1 += s_f1\n",
"\n",
"\n",
"print(f\"precision: {avg_precision / length}\")\n",
"print(f\"recall: {avg_recall / length}\")\n",
"print(f\"accuracy: {avg_accuracy / length}\")\n",
"print(f\"f1: {avg_f1 / length}\")\n",
"\n",
"## 플롯 그리기\n",
"plt.figure(figsize=(8,6))\n",
"plt.subplot(161)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.title('img')\n",
"\n",
"plt.subplot(162)\n",
"plt.imshow(gt, cmap='gray')\n",
"plt.title('gt')\n",
"\n",
"plt.subplot(163)\n",
"plt.imshow(pr, cmap='gray')\n",
"plt.title('pr')\n",
"\n",
"plt.subplot(164)\n",
"plt.imshow(precision, cmap='gray')\n",
"plt.title('precision')\n",
"\n",
"plt.subplot(165)\n",
"plt.imshow(recall, cmap='gray')\n",
"plt.title('recall')\n",
"\n",
"plt.subplot(166)\n",
"plt.imshow(accuracy, cmap='gray')\n",
"plt.title('accuracy')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# UNet\n",
"LOSS 0.2072\n",
"\n",
"# UNet - Mini\n",
"LOSS 0.1324\n",
"\n",
"# UNet - Dice Loss\n",
"LOSS 0.3879\n",
"\n",
"# UNet - Focal Loss\n",
"LOSS 0.0112\n",
"\n",
"# UNet - SGD Opt\n",
"LOSS 0.1787\n",
"\n",
"# UNEt - RMSProp Opt\n",
"LOSS 0.1666\n",
"\n",
"# UNet - L1 Loss\n",
"LOSS 0.0357\n",
"\n",
"# UNet - L2 Loss\n",
"LOSS 0.0241\n",
"\n",
"\n",
"# UNet - L1 + L2 Loss\n",
"LOSS 0.0550\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}