You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1133 lines
713 KiB
Plaintext

1 year ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 의미론적 분할(Semantic Segmentation)이란?\n",
"\n",
"![sementic-segmentation](./images/unet1.png)\n",
"\n",
"U-Net은 컴퓨터 비전 영역에서 풀려고 하는 문제(task) 중 의미론적 분할(Semantic Segmentation)을 수행할 수 있는 모델이다. \n",
"\n",
"의미론적 분할이란 이미지 내에서 픽셀 단위로 분류하는 것이다.\n",
"\n",
"즉, 각 픽셀별로 어떤 클래스에 속하는지를 예측하는 문제를 말한다.\n",
"\n",
"이미지 내에 객체 존재 여부를 예측하는 문제(이미지 분류; Image Classification)에 비해서 객체 경계 정보를 보존해야하고,\n",
"\n",
"전체적인 이미지의 문맥을 파악해야 하는 등 조금 더 높은 수준의 이미지 이해를 요구한다는 점에서 까다로운 문제에 속한다."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# UNet이란?\n",
"\n",
"![unet](./images/unet2.png)\n",
"\n",
"U-Net은 이미지를 압축하는 수축 경로(contracting path)와 원본 이미지의 크기로 복원하는 확장경로(expansive path)로 구성된다.\n",
"\n",
"각 모듈을 인코더(Encoder), 디코더(Decoder)라고 부르고 모델의 구조가 U자 형태를 띄고 있다고 하여 U-Net으로 부른다. \n",
"\n",
"의미론적 분할을 수행하는 여러 모델들은 자율주행, 의생명공학 등 다양한 분야에 사용될 수 있다. \n",
"\n",
"U-Net은 MRI, CT 상에서 병변을 진단하거나 장기, 세포 조직 등을 분할하는 등 의료 영상(Biomedical) 분야에서 좋은 성능을 발휘하고 있고, \n",
"\n",
"U-Net 구조를 기반으로 한 모델들이 매년 다양한 문제를 더 잘 해결하는 모습을 보여주고 있다. \n",
"\n",
"paperwithcode에 따르면 U-Net이 해결하고 있는 문제의 10% 이상이 의료 분야와 관련되어 있고, \n",
"\n",
"의미론적 분할을 수행하는 모델 중 가장 많은 논문 수를 보유하고 있다.\n",
"\n",
"의미론적 분할은 기계가 수행하기 어려운 고난도의 문제임에도 불구하고 해당 영역의 최신 모델(SOTA; state of the art)들은 꽤 높은 수준에 이른 것으로 확인된다.\n",
"\n",
"paperwithcode에 따르면 의미론적 분할 영역에서 학습 및 테스트할 수 있는 대표적인 데이터셋인\n",
"\n",
"Cityscapes test/val와 PASCAL VOC 2012 test/val에 대해서 벤치마크에서 1위를 달성한 모델들이 각각 84.5/86.95%, 90.5/90.0%의 mIoU(mean Intersection of Union)를 달성하였다. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# overlap-tile 이란?\n",
"\n",
"U-Net의 논문 제목 \"U-Net: Convolutional Networks for Biomedical Image Segmentation\"에서도 알 수 있듯이\n",
"\n",
"U-Net은 의료 영상 이미지 분할을 위해 고안된 모델이다.\n",
" \n",
"CT에서 혹의 위치를 찾거나(nodule detection) 배경으로부터 망막 혈관을 분할하는(retinal vessel segmentation) 등 병변을 진단하는데 도움이 줄 수 있다. \n",
"\n",
"일반적인 이미지와 다르게 의료공학 분야에서는 고해상도 이미지가 대부분이기 때문에 많은 연산량을 필요로 한다. \n",
"\n",
"고용량의 의료 이미지를 효율적으로 처리하기 위한 방안으로 overlap-tile 전략을 고안해냈다. \n",
"\n",
"\n",
"![overlat-tile](./images/unet3.png)\n",
"\n",
"\n",
"노란색 영역, 즉 타일을 예측하게 되면 다음 타일로 넘어가는데 필요한 영역이 이전에 예측을 위해 사용했던 영역과 겹치게(overlap) 된다.\n",
"\n",
"따라서 이 방법을 overlap-tile이라고 부른다. \n",
"\n",
"논문에서는 overlap-tile 전략은 \"GPU 메모리가 한정되어 있을 때 큰 이미지들을 인공 신경망에 적용하는데 장점이 있다\"고 말하고 있다. \n",
"\n",
"\n",
"\n",
"# 데이터 증폭이란?\n",
"\n",
"\n",
" U-Net는 overlap-tile 이외에도 데이터 증폭이라는 방식을 사용하여 모델 학습을 수행합니다. \n",
"\n",
"의료 공학 분야에서는 훈련할 수 있는 이미지의 갯수가 적은 반면 조직의 변이나 변형이 매우 흔하기 때문에 확보한 데이터를 증폭하는 과정이 매우 중요하다고 합니다. \n",
"\n",
"데이터 증폭이란 확보한 이미지를 반전시키거나 회전, 뒤틀림, 이동시켜서 더 많은 양의 이미지를 확보하는 것을 의미합니다. \n",
"\n",
"U-Net 외에도 레이블링 비용 감소 등을 위해서 다른 모델에서도 데이터 증폭은 많이 사용되고 있습니다.\n",
"\n",
"\n",
"# 참고자료\n",
"\n",
" U-Net 논문, \"U-Net: Convolutional Networks for Biomedical Image Segmentation(2015)\" [Google Scholar]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 손실함수 재정의\n",
"\n",
"![weight-mat-loss](./images/unet4.png)\n",
"\n",
"의료 이미지 분야에서 많은 세포들이 모여있는 경우, 즉 동일한 클래스가 인접해 있는 경우 분할하는데 많은 어려움을 겪는다. \n",
"\n",
"일반적인 세포와 배경을 구분하는 것은 쉽지만 위의 예시처럼 세포가 인접해있는 경우 각각의 인스턴스(instance)를 구분하는 것은 쉽지 않다. \n",
"\n",
"그래서 이 논문에서는 각 인스턴스의 경계와 경계 사이에 반드시 배경이 존재하도록 처리한다.\n",
"\n",
"즉 2개의 세포가 붙어있는 경우라도 둘 사이에 반드시 배경이 인식되어야하는 틈을 만들겠다는 의미이다.\n",
"\n",
"이를 고려하여 손실함수를 재정의 한다. \n",
"\n",
"weight map loss를 의미하는 항을 추가했는데 가장 가까운 세포의 경계까지의 거리와 두번째로 가까운 세포의 경계까지의 거리의 합이 최대가 되도록 하는 손실함수다. \n",
"\n",
"이렇게 하면 모델은 낮은 손실함수를 갖는 방향으로 학습하기 때문에 두 세포 사이의 간격을 넓히는 식, 즉 두 인스턴스 사이의 배경을 넓히는 방향으로 학습하게 된다. \n",
"\n",
"이렇게 하면 세포나 조직이 뭉쳐있는 경우에도 정확하게 인스턴스별로 분할이 가능하다. \n",
"\n",
"이런 의미에서 세포 객체들 사이에 존재하는 배경/틈에 높은 가중치가 부여된 것을 확인할 수 있다. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 데이터셋\n",
"\n",
"![dataset](./images/unet5.png)\n",
"\n",
"ISBI 2012 EM Segmentation Challenge에 사용된 membrane 데이터셋\n",
"\n",
"왼쪽의 세포 이미지는 512x512(grayscale)이며, 오른쪽은 세포와 세포 사이의 벽(배경)을 분할한 모습이다.\n",
"\n",
"실제 레이블된 값은 세포는 255, 배경은 1로 지정되어 있다."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"dataset/\n",
" train-volumne.tif # 훈련 이미지\n",
" train-labels.tif # 훈련 이미지의 분할 레이블\n",
" test-volumne.tif # 테스트 이미지"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset/\n",
" train/\n",
" input_000.npy ~ input_023.npy\n",
" label_000.npy ~ label_023.npy\n",
" val/\n",
" input_000.npy ~ input_002.npy\n",
" label_000.npy ~ label_002.npy\n",
" test/\n",
" input_000.npy ~ input_002.npy\n",
" label_000.npy ~ label_002.npy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAEjCAYAAAAYIvrbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d3hc1bU+/E7vvWs00qgXN7kCxjZgDIQW4EJIgFASWkggCZ0EAoTk3hA6hN57v/RusLGxwVW2bElWbyNN773P94d+e30IQ2LANvZl3ufJEzQazZwzc3z22mu9hVMqlUooo4wyyiijjDLK2IfA/aEPoIwyyiijjDLKKOOrKBcoZZRRRhlllFHGPodygVJGGWWUUUYZZexzKBcoZZRRRhlllFHGPodygVJGGWWUUUYZZexzKBcoZZRRRhlllFHGPodygVJGGWWUUUYZZexzKBcoZZRRRhlllFHGPodygVJGGWWUUUYZZexzKBcoZewynnzySXA4HIyMjPzQh1JGGWXsZ/iu949DDz0U06dP363HYrfbcc455+zW1yxj96NcoJSxX6K7uxs33nhjuVgqo4wyyvg/inKBUsYu48wzz0QqlUJ1dfUPfSjo7u7GX//613KBUkYZZZTxfxT8H/oAyth/wOPxwOPxfujDKKOMMsoo40eAcgeljF3GV2fIdrsdxx13HNasWYMFCxZALBajtrYWTz/99Nf+3erVq3HhhRdCp9NBqVTirLPOQigUmvJcDoeDG2+8caf3/vLM+Mknn8TPfvYzAMBhhx0GDocDDoeDTz/9dHefchlllLGH8Oabb+LYY49FRUUFRCIR6urq8Le//Q2FQuFrn79582YsXLgQEokENTU1ePDBB3d6TiaTwQ033ID6+nqIRCLYbDZcddVVyGQye/p0ytgDKHdQyvheGBgYwCmnnIJzzz0XZ599Nh5//HGcc845mDt3LqZNmzbluRdffDHUajVuvPFG9Pb24oEHHsDo6Cg+/fRTcDicXX7PJUuW4Pe//z3uuece/PnPf0ZLSwsA0P+XUUYZ+z6efPJJyOVyXHbZZZDL5VixYgWuv/56RKNR3HrrrVOeGwqFcMwxx+DUU0/FaaedhpdffhkXXXQRhEIhfv3rXwMAisUifvrTn2LNmjW44IIL0NLSgu3bt+POO+9EX18f3njjjR/gLMv4XiiVUcYu4oknnigBKA0PD5dKpVKpurq6BKC0evVqeo7X6y2JRKLS5ZdfvtPfzZ07t5TNZunxW265pQSg9Oabb9JjAEo33HDDTu9dXV1dOvvss+nnV155pQSgtHLlyt12fmWUUcaew1fvH8lkcqfnXHjhhSWpVFpKp9P02CGHHFICULr99tvpsUwmU2praysZjUa6pzzzzDMlLpdb+uyzz6a85oMPPlgCUFq7di099tX7SRn7JsojnjK+F1pbW7F48WL62WAwoKmpCUNDQzs994ILLoBAIKCfL7roIvD5fLz33nt75VjLKKOMfQcSiYT+OxaLwe/3Y/HixUgmk+jp6ZnyXD6fjwsvvJB+FgqFuPDCC+H1erF582YAwCuvvIKWlhY0NzfD7/fT/5YuXQoAWLly5V44qzJ2J8ojnjK+F6qqqnZ6TKPR7MQtAYCGhoYpP8vlclgslrISp4wyfoTo6urCddddhxUrViAajU75XSQSmfJzRUUFZDLZlMcaGxsBACMjIzjwwAPR39+PHTt2wGAwfO37eb3e3Xj0ZewNlAuUMr4XvknVUyqVduv7fBNxrowyytj/EA6Hccghh0CpVOKmm25CXV0dxGIx2tvbcfXVV6NYLH7r1ywWi5gxYwbuuOOOr/29zWb7voddxl5GuUApY6+hv78fhx12GP0cj8fhcrlwzDHH0GMajQbhcHjK32WzWbhcrimPfRtSbRlllLFv4dNPP0UgEMBrr72GJUuW0OPDw8Nf+3yn04lEIjGli9LX1wdgUuEHAHV1dejo6MDhhx9evj/8H0GZg1LGXsPDDz+MXC5HPz/wwAPI5/M4+uij6bG6ujqsXr16p7/7ageF3ai+WsyUUUYZ+z5Y5/XLndZsNov777//a5+fz+fx0EMPTXnuQw89BIPBgLlz5wIATj31VExMTOCRRx7Z6e9TqRQSicTuPIUy9gLKHZQy9hqy2SwOP/xwnHrqqejt7cX999+PRYsW4ac//Sk957zzzsNvfvMbnHzyyTjiiCPQ0dGBDz/8EHq9fsprtbW1gcfj4Z///CcikQhEIhGWLl0Ko9G4t0+rjDLK+JZYuHAhNBoNzj77bPz+978Hh8PBM888842j4YqKCvzzn//EyMgIGhsb8dJLL2Hr1q14+OGHiXh/5pln4uWXX8ZvfvMbrFy5EgcffDAKhQJ6enrw8ssv48MPP8S8efP25mmW8T1R7qCUsddw7733oqWlBddffz2efPJJnHbaaXjzzTentGPPP/98XH311Vi9ejUuv/xyDA8PY/ny5TsR5MxmMx588EF4vV6ce+65OO2009Dd3b23T6mMMsr4DtDpdHjnnXdgsVhw3XXX4bbbbsMRRxyBW2655Wufr9Fo8N5772HTpk248sor4XA4cO+99+L888+n53C5XLzxxhu4+eabsX37dlxxxRX461//io0bN+IPf/gDkWrL2H/AKe1uNmMZZXwFTz75JH71q19h48aN5R1MGWWUUUYZu4RyB6WMMsooo4wyytjnUC5QyiijjDLKKKOMfQ7lAqWMMsooo4wyytjn8IMWKPfddx/sdjvEYjEOOOAAbNiw4Yc8nDL2EM455xyUSqUy/6SM3YLyfaOMMn4c+MEKlJdeegmXXXYZbrjhBrS3t2PWrFk46qijynbEZZRRxjeifN8oo4wfD34wFc8BBxyA+fPn49577wUwaVNss9lwySWX4JprrvkhDqmMMsrYx1G+b5RRxo8HP4hRWzabxebNm/GnP/2JHuNyuVi2bBm++OKLnZ6fyWSQyWTo52KxiGAwCJ1OV7Y0LqOMHwilUgmxWAwVFRXgcvd8M/bb3jeA8r2jjDL2NXyb+8YPUqD4/X4UCgWYTKYpj5tMpp1itgHgH//4B/7617/urcMro4wyvgUcDgcqKyv3+Pt82/sGUL53lFHGvopduW/sF1b3f/rTn3DZZZfRz5FIBFVVVTjllFPA5XKRzWZhMBjg8XggkUigUqlQKBQgFApRKBSQTqeh0WjgcrkgFothsVjQ3NyMbDaLnp4euFwuxONx6HQ6NDU1IRQKwWazoaqqCl6vF8lkEnq9Hnq9Hj09PchkMhCLxdi2bRt6e3uRSCRgs9kQCATQ2NiIbDaLQqGAfD4PrVYLPp+PWCyG2tpauN1uZDIZjIyMoLa2FgKBAFKpFNFoFGazGVqtFmNjYygUClCpVEgkElAoFFCpVNBqtbDb7XC5XJBIJEilUlAqlQiFQohEIrBYLDCbzfB6vYjFYuByuUilUmT3PDg4CJfLBY1GgyOOOAJWqxWxWAxisRgKhQJDQ0Pg8XiQSqWwWq1IJBIYGxuDXq+nzzEQCNDnr1KpEI1GwefzodFoIJfLkUwmEQqF4HQ6IRaLIZfLMTo6ir6+PsRiMahUKhx99NEYGBiAxWKBxWJBJpPB0NAQ7HY7+Hw+gsEgMpkMamtraVGaPn060uk0Vd/ZbBbZbBalUglqtRpisRgAkE6nkUgkUCwWoVKp0N/fj66uLuTzedhsNlgsFqhUKvD5fESjUeTzeZjNZmQyGXR3d6O6uhrA5E5bLBaDx+OR/bbRaCRHW4/HA5VKBYFAgEQiAaVSiWg0ikKhQLlBKpWK/i6fz2Pz5s10PYjFYshkMsRiMQCTO/1QKIRgMAiTyQS5XA4Oh4NcLge32w25XA6NRoN8Po9oNAqTyQQOh4OxsTE0NzcjlUohFAqBy+WiWCxCLpcjEAhg1apViEQimD17NtRqNZLJJHg8HsLhMAKBACwWC4RCIZxOJywWC6qrq+kaYdeY1+tFdXU1JBIJ3G43zGYzPB4PJiYm8D//8z9QKBR750bwHfBN9w6HwwGlUvmdX3doaAiXXHIJPv/8cxSLRVxyySW4/vrrIRQK/+PfnnH
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"## 라이브러리 불러오기\n",
"import os\n",
"import numpy as np\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"\n",
"## 데이터 불러오기\n",
"dir_data = './dataset' \n",
"\n",
"name_label = 'train-labels.tif'\n",
"name_input = 'train-volume.tif'\n",
"\n",
"img_label = Image.open(os.path.join(dir_data, name_label))\n",
"img_input = Image.open(os.path.join(dir_data, name_input))\n",
"\n",
"ny, nx = img_label.size\n",
"nframe = img_label.n_frames\n",
"\n",
"## train/test/val 폴더 생성\n",
"nframe_train = 24\n",
"nframe_val = 3\n",
"nframe_test = 3\n",
"\n",
"dir_save_train = os.path.join(dir_data, 'train')\n",
"dir_save_val = os.path.join(dir_data, 'val')\n",
"dir_save_test = os.path.join(dir_data, 'test')\n",
"\n",
"if not os.path.exists(dir_save_train):\n",
" os.makedirs(dir_save_train)\n",
"\n",
"if not os.path.exists(dir_save_val):\n",
" os.makedirs(dir_save_val)\n",
"\n",
"if not os.path.exists(dir_save_test):\n",
" os.makedirs(dir_save_test)\n",
"\n",
"## 전체 이미지 30개를 섞어줌\n",
"id_frame = np.arange(nframe)\n",
"np.random.shuffle(id_frame)\n",
"\n",
"## 선택된 train 이미지를 npy 파일로 저장\n",
"offset_nframe = 0\n",
"\n",
"for i in range(nframe_train):\n",
" img_label.seek(id_frame[i + offset_nframe])\n",
" img_input.seek(id_frame[i + offset_nframe])\n",
"\n",
" label_ = np.asarray(img_label)\n",
" input_ = np.asarray(img_input)\n",
"\n",
" np.save(os.path.join(dir_save_train, 'label_%03d.npy' % i), label_)\n",
" np.save(os.path.join(dir_save_train, 'input_%03d.npy' % i), input_)\n",
"\n",
"## 선택된 val 이미지를 npy 파일로 저장\n",
"offset_nframe = nframe_train\n",
"\n",
"for i in range(nframe_val):\n",
" img_label.seek(id_frame[i + offset_nframe])\n",
" img_input.seek(id_frame[i + offset_nframe])\n",
"\n",
" label_ = np.asarray(img_label)\n",
" input_ = np.asarray(img_input)\n",
"\n",
" np.save(os.path.join(dir_save_val, 'label_%03d.npy' % i), label_)\n",
" np.save(os.path.join(dir_save_val, 'input_%03d.npy' % i), input_)\n",
"\n",
"## 선택된 test 이미지를 npy 파일로 저장\n",
"offset_nframe = nframe_train + nframe_val\n",
"\n",
"for i in range(nframe_test):\n",
" img_label.seek(id_frame[i + offset_nframe])\n",
" img_input.seek(id_frame[i + offset_nframe])\n",
"\n",
" label_ = np.asarray(img_label)\n",
" input_ = np.asarray(img_input)\n",
"\n",
" np.save(os.path.join(dir_save_test, 'label_%03d.npy' % i), label_)\n",
" np.save(os.path.join(dir_save_test, 'input_%03d.npy' % i), input_)\n",
"\n",
"## 이미지 시각화\n",
"plt.subplot(122)\n",
"plt.imshow(label_, cmap='gray')\n",
"plt.title('label')\n",
"\n",
"plt.subplot(121)\n",
"plt.imshow(input_, cmap='gray')\n",
"plt.title('input')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnUAAAHWCAYAAAARl3+JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABMRElEQVR4nO39e1xVdd7//z8BY+NpgyfYcBOV7KDmqbCIDo4ml1tlmpisSXOKjDT9QpNQHuhy8FCfDw6Ox9Hk6mqMZtJJ/dwmm9TBCA9UIibKeCi5zNHBLt1oKewkBZH1+6Mfa9yCCokii8f9dlu3ca33a6/9Xmvab5+uo5dhGIYAAADQpHk3dgcAAABw7Qh1AAAAFkCoAwAAsABCHQAAgAUQ6gAAACyAUAcAAGABhDoAAAALINQBAABYAKEOAADAAgh1uKlkZGTIy8tLR44caeyuAMAN8VPHvUGDBql3794N2pdu3brpueeea9B14sYh1AGX8eWXX2rmzJkETABAk0Cow03lmWee0dmzZ9W1a9fG7oq+/PJLzZo1i1AHAGgSWjR2B4CL+fj4yMfHp7G7AQBAk8OROtxULr22pFu3bvr5z3+uzz77TPfdd5/8/Px066236k9/+lOtn8vJydGLL76oDh06yG6369lnn9Xp06c9ar28vDRz5swa333xtSQZGRl68sknJUmDBw+Wl5eXvLy8tGXLlobeZADw8OGHHyo6OlohISGy2Wzq3r27Xn/9dV24cKHW+vz8fD3wwANq2bKlwsLClJ6eXqOmvLxcM2bM0G233SabzabQ0FBNmTJF5eXl13tzcANxpA43va+//lpPPPGE4uLiFBsbq+XLl+u5555TeHi47rrrLo/ahIQEBQQEaObMmSosLNSyZcv0r3/9S1u2bJGXl1edv3PgwIH6zW9+o8WLF+u1115Tz549Jcn8XwC4XjIyMtSmTRslJSWpTZs22rRpk1JSUuR2uzV37lyP2tOnT2vEiBH61a9+pdGjR2v16tWaOHGifH199fzzz0uSqqqq9Itf/EKfffaZxo8fr549e2rv3r1asGCB/ud//kdr165thK3EdWEAN5F33nnHkGQcPnzYMAzD6Nq1qyHJyMnJMWtOnDhh2Gw245VXXqnxufDwcKOiosJcnpaWZkgyPvzwQ3OZJGPGjBk1vrtr165GbGysOb9mzRpDkrF58+YG2z4AuNSl494PP/xQo+bFF180WrVqZZw7d85c9rOf/cyQZMybN89cVl5ebvTv398IDAw0x8I///nPhre3t/Hpp596rDM9Pd2QZHz++efmskvHQTQtnH7FTa9Xr156+OGHzflOnTrpzjvv1D//+c8atePHj9ctt9xizk+cOFEtWrTQhg0bbkhfAeBatWzZ0vzz999/r2+//VYPP/ywfvjhBx04cMCjtkWLFnrxxRfNeV9fX7344os6ceKE8vPzJUlr1qxRz5491aNHD3377bfm9Mgjj0iSNm/efAO2CjcCp19x0+vSpUuNZe3atatxrZwk3X777R7zbdq0UXBwMHewAmgy9u/fr+nTp2vTpk1yu90ebaWlpR7zISEhat26tceyO+64Q5J05MgR3X///Tp48KC++uorderUqdbvO3HiRAP2Ho2JUIeb3uXuhjUMo0G/53IXIQPAjVJSUqKf/exnstvtmj17trp37y4/Pz/t2rVLU6dOVVVVVb3XWVVVpT59+mj+/Pm1toeGhl5rt3GTINTBUg4ePKjBgweb82fOnNHx48c1YsQIc1m7du1UUlLi8bmKigodP37cY1l9bqwAgIawZcsWfffdd/rrX/+qgQMHmssPHz5ca/2xY8dUVlbmcbTuf/7nfyT9eEe/JHXv3l3/+Mc/NGTIEMY1i+OaOljKW2+9pfPnz5vzy5YtU2VlpYYPH24u6969u3Jycmp87tIjddWD5KUBEACul+ozExefiaioqNCbb75Za31lZaX+67/+y6P2v/7rv9SpUyeFh4dLkn71q1/pf//3f/Xf//3fNT5/9uxZlZWVNeQmoBFxpA6WUlFRoSFDhuhXv/qVCgsL9eabb+qhhx7SL37xC7PmhRde0IQJEzRy5Ej9x3/8h/7xj39o48aN6tixo8e6+vfvLx8fH/3ud79TaWmpbDabHnnkEQUGBt7ozQLQTDzwwANq166dYmNj9Zvf/EZeXl7685//fNnLTUJCQvS73/1OR44c0R133KFVq1apoKBAb731lnnT2DPPPKPVq1drwoQJ2rx5sx588EFduHBBBw4c0OrVq7Vx40YNGDDgRm4mrhOO1MFSlixZop49eyolJUUZGRkaPXq0PvzwQ49TDuPGjdPUqVOVk5OjV155RYcPH1ZWVlaNi40dDofS09N14sQJxcXFafTo0fryyy9v9CYBaEY6dOigdevWKTg4WNOnT9fvf/97/cd//IfS0tJqrW/Xrp02bNignTt3avLkyTp69KiWLFmicePGmTXe3t5au3at5syZo7179+rVV1/VrFmz9MUXX+jll182b6xA0+dlNPTV5kAjyMjI0NixY/XFF1/wL04AQLPEkToAAAALINQBAABYAKEOAADAArimDgAAwAI4UgcAAGABhDoAAAALaNYPH66qqtKxY8fUtm1bXp0CNCGGYej7779XSEiIvL35t2ldMN4BTVddx7xmHeqOHTvGi4yBJuzo0aPq3LlzY3ejSWC8A5q+q415zTrUtW3bVtKPO8lutzdybwDUldvtVmhoqPkbxtUx3gFNV13HvGYd6qpPQdjtdgY5oAniNGLdMd4BTd/VxjwuRgEAALAAQh0AAIAFEOoAAAAsgFAHAABgAYQ6AAAACyDUAQAAWAChDgAAwAIIdQAAABZAqAMAALAAQh0AAIAFEOoAAAAsgFAHAABgAYQ6AAAACyDUAQAAWAChDgAAwAIIdQAAABZAqAMAALCAFo3dAcDKuk1bX+faI3Oir2NPAAA3QmOO+xypAwAAsABCHQAAgAUQ6gAAACyAa+qAeqrP9RIAANwoHKkDAACwAEIdAACABRDqAAAALIBQBwAAYAGEOgAAAAsg1AEAAFgAoQ4AAMACCHUAAAAWQKgDAACwAEIdgGYpNTVV9957r9q2bavAwEDFxMSosLDQo+bcuXOKj49Xhw4d1KZNG40cOVLFxcUeNUVFRYqOjlarVq0UGBioyZMnq7Ky0qNmy5Ytuueee2Sz2XTbbbcpIyOjRn+WLl2qbt26yc/PTxEREdqxY0e9+wKgeSPUAWiWtm7dqvj4eG3fvl1ZWVk6f/68hg4dqrKyMrMmMTFRH330kdasWaOtW7fq2LFjevzxx832CxcuKDo6WhUVFdq2bZveffddZWRkKCUlxaw5fPiwoqOjNXjwYBUUFGjSpEl64YUXtHHjRrNm1apVSkpK0owZM7Rr1y7169dPTqdTJ06cqHNfAMDLMAyjsTvRWNxut/z9/VVaWiq73d7Y3UETcb3e/XpkTvR1Wa8VXY/f7smTJxUYGKitW7dq4MCBKi0tVadOnbRy5Uo98cQTkqQDBw6oZ8+eys3N1f3336+///3v+vnPf65jx44pKChIkpSenq6pU6fq5MmT8vX11dSpU7V+/Xrt27fP/K5Ro0appKREmZmZkqSIiAjde++9WrJkiSSpqqpKoaGheumllzRt2rQ69aUx9hmAmurzd0Rdx/26/n7rdaSuLqcrBg0aJC8vL49pwoQJHjU36nQFANRVaWmpJKl9+/aSpPz8fJ0/f15RUVFmTY8ePdSlSxfl5uZKknJzc9WnTx8z0EmS0+mU2+3W/v37zZqL11FdU72OiooK5efne9R4e3srKirKrKlLXy5VXl4ut9vtMQGwtnqFurqcrpCkcePG6fjx4+aUlpZmtt3I0xUAUBdVVVWaNGmSHnzwQfXu3VuS5HK55Ovrq4CAAI/aoKAguVwus+biQFfdXt12pRq3262zZ8/q22+/1YULF2qtuXgdV+vLpVJTU+Xv729OoaGhddwbAJqqeoW6zMxMPffcc7rrrrvUr18/ZWRkqKioSPn5+R51rVq1ksPhMKeLDxV+/PHH+vLLL/Xee++pf//+Gj5
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"## 한 이미지의 분포\n",
"plt.subplot(122)\n",
"plt.hist(label_.flatten(), bins=20)\n",
"plt.title('label')\n",
"\n",
"plt.subplot(121)\n",
"plt.hist(input_.flatten(), bins=20)\n",
"plt.title('input')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# UNet Network"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\pinb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"## 라이브러리 불러오기\n",
"import os\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from torchvision import transforms, datasets\n",
"\n",
"## 네트워크 구축하기\n",
"class UNet(nn.Module):\n",
" def __init__(self):\n",
" super(UNet, self).__init__()\n",
"\n",
" # Convolution + BatchNormalization + Relu 정의하기\n",
" def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): \n",
" layers = []\n",
" layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,\n",
" kernel_size=kernel_size, stride=stride, padding=padding,\n",
" bias=bias)]\n",
" layers += [nn.BatchNorm2d(num_features=out_channels)]\n",
" layers += [nn.ReLU()]\n",
"\n",
" cbr = nn.Sequential(*layers)\n",
"\n",
" return cbr\n",
"\n",
" # 수축 경로(Contracting path)\n",
" self.enc1_1 = CBR2d(in_channels=1, out_channels=64)\n",
" self.enc1_2 = CBR2d(in_channels=64, out_channels=64)\n",
"\n",
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc2_1 = CBR2d(in_channels=64, out_channels=128)\n",
" self.enc2_2 = CBR2d(in_channels=128, out_channels=128)\n",
"\n",
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc3_1 = CBR2d(in_channels=128, out_channels=256)\n",
" self.enc3_2 = CBR2d(in_channels=256, out_channels=256)\n",
"\n",
" self.pool3 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc4_1 = CBR2d(in_channels=256, out_channels=512)\n",
" self.enc4_2 = CBR2d(in_channels=512, out_channels=512)\n",
"\n",
" self.pool4 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)\n",
"\n",
" # 확장 경로(Expansive path)\n",
" self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)\n",
"\n",
" self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,\n",
" kernel_size=2, stride=2, padding=0, bias=True)\n",
"\n",
" self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512)\n",
" self.dec4_1 = CBR2d(in_channels=512, out_channels=256)\n",
"\n",
" self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,\n",
" kernel_size=2, stride=2, padding=0, bias=True)\n",
"\n",
" self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256)\n",
" self.dec3_1 = CBR2d(in_channels=256, out_channels=128)\n",
"\n",
" self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,\n",
" kernel_size=2, stride=2, padding=0, bias=True)\n",
"\n",
" self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128)\n",
" self.dec2_1 = CBR2d(in_channels=128, out_channels=64)\n",
"\n",
" self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,\n",
" kernel_size=2, stride=2, padding=0, bias=True)\n",
"\n",
" self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64)\n",
" self.dec1_1 = CBR2d(in_channels=64, out_channels=64)\n",
"\n",
" self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)\n",
" \n",
" # forward 함수 정의하기\n",
" def forward(self, x):\n",
" enc1_1 = self.enc1_1(x)\n",
" enc1_2 = self.enc1_2(enc1_1)\n",
" pool1 = self.pool1(enc1_2)\n",
"\n",
" enc2_1 = self.enc2_1(pool1)\n",
" enc2_2 = self.enc2_2(enc2_1)\n",
" pool2 = self.pool2(enc2_2)\n",
"\n",
" enc3_1 = self.enc3_1(pool2)\n",
" enc3_2 = self.enc3_2(enc3_1)\n",
" pool3 = self.pool3(enc3_2)\n",
"\n",
" enc4_1 = self.enc4_1(pool3)\n",
" enc4_2 = self.enc4_2(enc4_1)\n",
" pool4 = self.pool4(enc4_2)\n",
"\n",
" enc5_1 = self.enc5_1(pool4)\n",
"\n",
" dec5_1 = self.dec5_1(enc5_1)\n",
"\n",
" unpool4 = self.unpool4(dec5_1)\n",
" cat4 = torch.cat((unpool4, enc4_2), dim=1)\n",
" dec4_2 = self.dec4_2(cat4)\n",
" dec4_1 = self.dec4_1(dec4_2)\n",
"\n",
" unpool3 = self.unpool3(dec4_1)\n",
" cat3 = torch.cat((unpool3, enc3_2), dim=1)\n",
" dec3_2 = self.dec3_2(cat3)\n",
" dec3_1 = self.dec3_1(dec3_2)\n",
"\n",
" unpool2 = self.unpool2(dec3_1)\n",
" cat2 = torch.cat((unpool2, enc2_2), dim=1)\n",
" dec2_2 = self.dec2_2(cat2)\n",
" dec2_1 = self.dec2_1(dec2_2)\n",
"\n",
" unpool1 = self.unpool1(dec2_1)\n",
" cat1 = torch.cat((unpool1, enc1_2), dim=1)\n",
" dec1_2 = self.dec1_2(cat1)\n",
" dec1_1 = self.dec1_1(dec1_2)\n",
"\n",
" x = self.fc(dec1_1)\n",
"\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Loader"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# 데이터 로더를 구현하기\n",
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, data_dir, transform=None):\n",
" self.data_dir = data_dir\n",
" self.transform = transform\n",
"\n",
" lst_data = os.listdir(self.data_dir)\n",
"\n",
" lst_label = [f for f in lst_data if f.startswith('label')]\n",
" lst_input = [f for f in lst_data if f.startswith('input')]\n",
"\n",
" lst_label.sort()\n",
" lst_input.sort()\n",
"\n",
" self.lst_label = lst_label\n",
" self.lst_input = lst_input\n",
"\n",
" def __len__(self):\n",
" return len(self.lst_label)\n",
"\n",
" def __getitem__(self, index):\n",
" label = np.load(os.path.join(self.data_dir, self.lst_label[index]))\n",
" input = np.load(os.path.join(self.data_dir, self.lst_input[index]))\n",
"\n",
" # 정규화\n",
" label = label/255.0\n",
" input = input/255.0\n",
"\n",
" # 이미지와 레이블의 차원 = 2일 경우(채널이 없을 경우, 흑백 이미지), 새로운 채널(축) 생성\n",
" if label.ndim == 2:\n",
" label = label[:, :, np.newaxis]\n",
" if input.ndim == 2:\n",
" input = input[:, :, np.newaxis]\n",
"\n",
" data = {'input': input, 'label': label}\n",
"\n",
" # transform이 정의되어 있다면 transform을 거친 데이터를 불러옴\n",
" if self.transform:\n",
" data = self.transform(data)\n",
"\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAEjCAYAAAAYIvrbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOx9d5SU5fn2Nb33nbq9L0tZioIIIqLGHkussYtGTQwao4nGKNYoJhos2GKiEX/G3nsBpUiRzu6yfWd3Zqf33uf7Y89zf6ygAiJKMtc5HtnZ2Xfeae9zP/d9FU6xWCyihBJKKKGEEkoo4ScE7o99AiWUUEIJJZRQQglfR6lAKaGEEkoooYQSfnIoFSgllFBCCSWUUMJPDqUCpYQSSiihhBJK+MmhVKCUUEIJJZRQQgk/OZQKlBJKKKGEEkoo4SeHUoFSQgkllFBCCSX85FAqUEoooYQSSiihhJ8cSgVKCSWUUEIJJZTwk0OpQClhj/Hss8+Cw+HAarX+2KdSQgklHGTY1+vH3LlzMWHChP16LjU1Nbjkkkv26zFL2P8oFSglHJTo7OzE7bffXiqWSiihhBL+S1EqUErYY1x44YVIJpOorq7+sU8FnZ2duOOOO0oFSgkllFDCfyn4P/YJlHDwgMfjgcfj/dinUUIJJZRQwv8ASh2UEvYYX58h19TU4OSTT8aqVaswffp0iMVi1NXV4bnnntvt361YsQJXXnkldDodlEolLrroIgSDwTH35XA4uP3223d57J1nxs8++yzOOussAMBRRx0FDocDDoeDzz//fH8/5RJKKOEHwltvvYWTTjoJFosFIpEI9fX1uOuuu5DP53d7/40bN+Lwww+HRCJBbW0tnnjiiV3uk06nsXDhQjQ0NEAkEqGyshJ/+MMfkE6nf+inU8IPgFIHpYTvhb6+Ppx55pmYP38+Lr74YvzrX//CJZdcgmnTpmH8+PFj7nvNNddArVbj9ttvR3d3Nx5//HEMDQ3h888/B4fD2ePHnDNnDhYsWICHH34Yf/rTnzBu3DgAoP+XUEIJP308++yzkMvluP766yGXy7Fs2TLcdtttiEQi+Otf/zrmvsFgECeeeCLOPvtsnHfeeXj55Zdx9dVXQygU4rLLLgMAFAoF/PznP8eqVavwq1/9CuPGjcP27dvx97//HT09PXjzzTd/hGdZwvdCsYQS9hDPPPNMEUBxcHCwWCwWi9XV1UUAxRUrVtB9PB5PUSQSFX//+9/v8nfTpk0rZjIZuv3+++8vAii+9dZbdBuA4sKFC3d57Orq6uLFF19MP7/yyitFAMXly5fvt+dXQgkl/HD4+vUjkUjscp8rr7yyKJVKi6lUim478sgjiwCKDzzwAN2WTqeLkydPLhoMBrqmLF26tMjlcosrV64cc8wnnniiCKC4evVquu3r15MSfpoojXhK+F5obW3FEUccQT/r9Xo0NzdjYGBgl/v+6le/gkAgoJ+vvvpq8Pl8vP/++wfkXEsooYSfDiQSCf07Go3C5/PhiCOOQCKRQFdX15j78vl8XHnllfSzUCjElVdeCY/Hg40bNwIAXnnlFYwbNw4tLS3w+Xz037x58wAAy5cvPwDPqoT9idKIp4Tvhaqqql1u02g0u3BLAKCxsXHMz3K5HGazuaTEKaGE/0F0dHTgz3/+M5YtW4ZIJDLmd+FweMzPFosFMplszG1NTU0AAKvVisMOOwy9vb3YsWMH9Hr9bh/P4/Hsx7Mv4UCgVKCU8L3wTaqeYrG4Xx/nm4hzJZRQwsGHUCiEI488EkqlEnfeeSfq6+shFouxadMm/PGPf0ShUNjrYxYKBUycOBEPPvjgbn9fWVn5fU+7hAOMUoFSwgFDb28vjjrqKPo5FovB6XTixBNPpNs0Gg1CodCYv8tkMnA6nWNu2xtSbQkllPDTwueffw6/34/XX38dc+bModsHBwd3e3+Hw4F4PD6mi9LT0wNgVOEHAPX19di6dSuOPvro0vXhvwQlDkoJBwxPPfUUstks/fz4448jl8vhhBNOoNvq6+uxYsWKXf7u6x0UdqH6ejFTQgkl/PTBOq87d1ozmQwee+yx3d4/l8vhySefHHPfJ598Enq9HtOmTQMAnH322RgZGcE//vGPXf4+mUwiHo/vz6dQwgFAqYNSwgFDJpPB0UcfjbPPPhvd3d147LHHMHv2bPz85z+n+1x++eW46qqr8Itf/ALHHnsstm7dio8++ghlZWVjjjV58mTweDwsWrQI4XAYIpEI8+bNg8FgONBPq4QSSthLHH744dBoNLj44ouxYMECcDgcLF269BtHwxaLBYsWLYLVakVTUxNeeuklbNmyBU899RQR7y+88EK8/PLLuOqqq7B8+XLMmjUL+XweXV1dePnll/HRRx/hkEMOOZBPs4TviVIHpYQDhkcffRTjxo3DbbfdhmeffRbnnXce3nrrrTHt2CuuuAJ//OMfsWLFCvz+97/H4OAgPvnkk10IciaTCU888QQ8Hg/mz5+P8847D52dnQf6KZVQQgn7AJ1Oh3fffRdmsxl//vOf8be//Q3HHnss7r///t3eX6PR4P3338eGDRtw4403wmaz4dFHH8UVV1xB9+FyuXjzzTdx3333Yfv27bjhhhtwxx134KuvvsK1115LpNoSDh5wivubzVhCCV/Ds88+i0svvRRfffVVaQdTQgkllFDCHqHUQSmhhBJKKKGEEn5yKBUoJZRQQgkllFDCTw6lAqWEEkoooYQSSvjJ4UctUJYsWYKamhqIxWLMmDED69ev/zFPp4QfCJdccgmKxWKJf1LCfkHpulFCCf8b+NEKlJdeegnXX389Fi5ciE2bNqGtrQ3HHXdcyY64hBJK+EaUrhsllPC/gx9NxTNjxgwceuihePTRRwGM2hRXVlbit7/9LW666aYf45RKKKGEnzhK140SSvjfwY9i1JbJZLBx40bcfPPNdBuXy8UxxxyDNWvW7HL/dDqNdDpNPxcKBQQCAeh0upKlcQkl/EgoFouIRqOwWCzgcn/4ZuzeXjeA0rWjhBJ+atib68aPUqD4fD7k83kYjcYxtxuNxl1itgHg3nvvxR133HGgTq+EEkrYC9hsNlRUVPzgj7O31w2gdO0ooYSfKvbkunFQWN3ffPPNuP766+nncDiMqqoqvPTSS0in0ygUClAqlUgkEpDL5RCJRAgEAuBwOIhGo+ByuRCJROByueDxeODz+QiHw+Dz+eDxeFCr1Uin08jn80ilUvB6vRAKhcjlcrDb7ejs7IRUKoXL5UI+n0cikYDZbEYymUR5eTlaW1tRLBZRW1uLYDCIQCAAh8OB4eFhJBIJqFQqzJo1CzqdDlqtFul0GoFAANFolP6vVquRTCbpXMPhMAYGBuD1eiGRSJBOp9HY2Ij29nZIpVJEo1HMmDEDU6ZMgUgkglqtxvDwMGKxGHK5HGpra5FIJBAIBKBSqWAwGCAWiyGVShEIBOByuVAoFKBQKOD1epFIJDBjxgyIxWIIhUI4HA4Eg0HE43FEIhGIxWIEAgGsWbMGyWQSc+fOhcFgwPbt28HlcqFSqTB+/Hh4vV6EQiEolUpUV1cjGAxiYGAAEokEMpkM+XwePB4Pg4ODMBgMqKqqwsaNG9HX14d0Og2TyYSpU6eCw+GgUCigUCjgiy++AI/Hg9lshkwmg0qlgsfjQTAYRCqVgkQigdFoRCQSQSAQgFgsRjabRaFQQHl5Obq7uxEOhzF16lT4fD74/X7k83nkcjlMmTIFCoUCTqcTkUgEPB4P0WgUSqUSxWIRCoUCRqMRAoEAJpMJGo0GfD4f8XgcYrEYarUauVwOcrkcQ0ND0Gq1iEajyOfzFPseDAbR29sLqVQKn88HsVgMhUKByspKyGQyWCwWhEIhJJNJJBIJZLNZiMVieL1ecLlclJWVIR6Pw+PxgMPhoLy8HOl0GhwOB++++y69/vX19WhqaqL78ng8iEQibNiwAYFAAHq9HhMmTEA0GkUikYBer4dIJEJ/fz9yuRx0Oh3Kysogk8mQTCbh9/tRKBRgMBig1+vR1dWFzs5OcDgcaLVaKJVK/P3vf4dCofixLg3fiW+6dth
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 데이터로더 잘 구현되었는지 확인\n",
"dataset_train = Dataset(data_dir=dir_save_train)\n",
"data = dataset_train.__getitem__(0) # 한 이미지 불러오기\n",
"input = data['input']\n",
"label = data['label']\n",
"\n",
"# 불러온 이미지 시각화\n",
"plt.subplot(122)\n",
"plt.imshow(label.reshape(512,512), cmap='gray')\n",
"plt.title('label')\n",
"\n",
"plt.subplot(121)\n",
"plt.imshow(input.reshape(512,512), cmap='gray')\n",
"plt.title('input')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Transform"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# 트렌스폼 구현하기\n",
"class ToTensor(object):\n",
" def __call__(self, data):\n",
" label, input = data['label'], data['input']\n",
"\n",
" label = label.transpose((2, 0, 1)).astype(np.float32)\n",
" input = input.transpose((2, 0, 1)).astype(np.float32)\n",
"\n",
" data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}\n",
"\n",
" return data\n",
"\n",
"class Normalization(object):\n",
" def __init__(self, mean=0.5, std=0.5):\n",
" self.mean = mean\n",
" self.std = std\n",
"\n",
" def __call__(self, data):\n",
" label, input = data['label'], data['input']\n",
"\n",
" input = (input - self.mean) / self.std\n",
"\n",
" data = {'label': label, 'input': input}\n",
"\n",
" return data\n",
"\n",
"class RandomFlip(object):\n",
" def __call__(self, data):\n",
" label, input = data['label'], data['input']\n",
"\n",
" if np.random.rand() > 0.5:\n",
" label = np.fliplr(label)\n",
" input = np.fliplr(input)\n",
"\n",
" if np.random.rand() > 0.5:\n",
" label = np.flipud(label)\n",
" input = np.flipud(input)\n",
"\n",
" data = {'label': label, 'input': input}\n",
"\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnUAAAHWCAYAAAARl3+JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABXpklEQVR4nO3de1QTd94/8HdAE/CSICIEjoh4qYgiKFaMVasrS0C2la3deluLFrV6oK3QeqEPi6h9flisF1pRtmsVu9X1sqfFLrgookgtEStKVRQetVjso0GrQhQVBOb3Rx9mHUEkCiLD+3XOnDrz/WTy/U6c8d3JzEQhCIIAIiIiImrVLFq6A0RERET09BjqiIiIiGSAoY6IiIhIBhjqiIiIiGSAoY6IiIhIBhjqiIiIiGSAoY6IiIhIBhjqiIiIiGSAoY6IiIhIBhjq6LmQlJQEhUKBixcvtnRXiIia1ZMe78aMGYOBAwc2aV969uyJGTNmNOk6qeUw1BE95MyZM4iJiWHAJCKiVoWhjp4L06dPx927d+Hi4tLSXcGZM2ewdOlShjoiImpV2rV0B4gAwNLSEpaWli3dDSIiolaLZ+roufDwNSY9e/bEH/7wBxw+fBjDhg2DlZUVevXqhS+//LLe12VlZeHtt99G165doVar8eabb+LmzZuSWoVCgZiYmDrv/eA1JUlJSfjTn/4EABg7diwUCgUUCgUyMzObeshERACA3bt3IzAwEE5OTlCpVOjduzeWL1+O6urqeutzc3MxYsQIWFtbw9XVFYmJiXVqKioqsGTJEvTp0wcqlQrOzs5YuHAhKioqmns41IJ4po6eW+fPn8frr7+OkJAQBAcHY9OmTZgxYwa8vb0xYMAASW1YWBhsbGwQExODwsJCbNiwAT///DMyMzOhUCga/Z6jR4/Gu+++i08//RQffvgh+vfvDwDif4mImlpSUhI6deqEiIgIdOrUCQcOHEB0dDRMJhNWrlwpqb158ybGjx+PN954A1OmTMHOnTsxb948KJVKvPXWWwCAmpoavPrqqzh8+DDmzJmD/v3749SpU1izZg3+53/+B8nJyS0wSnomBKLnwObNmwUAQlFRkSAIguDi4iIAELKyssSaq1evCiqVSnj//ffrvM7b21uorKwUl8fFxQkAhN27d4vLAAhLliyp894uLi5CcHCwOL9r1y4BgHDw4MEmGx8RUa2Hj3d37typU/P2228LHTp0EO7duycue/nllwUAwqpVq8RlFRUVgpeXl2Bvby8eA//+978LFhYWwnfffSdZZ2JiogBA+P7778VlDx//qHXj16/03HJ3d8eoUaPE+W7duqFfv3746aef6tTOmTMH7du3F+fnzZuHdu3aYc+ePc+kr0RET8ra2lr8861bt/Drr79i1KhRuHPnDgoKCiS17dq1w9tvvy3OK5VKvP3227h69Spyc3MBALt27UL//v3h5uaGX3/9VZx+97vfAQAOHjz4DEZFLYFfv9Jzq0ePHnWWdenSpc61cgDQt29fyXynTp3g6OjIO1iJ6LmXn5+PqKgoHDhwACaTSdJWVlYmmXdyckLHjh0ly1544QUAwMWLFzF8+HCcO3cOZ8+eRbdu3ep9v6tXrzZh7+l5wlBHz61H3Q0rCEKTvs+jLkYmImpupaWlePnll6FWq7Fs2TL07t0bVlZWOH78OBYtWoSamhqz11lTUwMPDw+sXr263nZnZ+en7TY9pxjqSBbOnTuHsWPHivO3b9/GlStXMH78eHFZly5dUFpaKnldZWUlrly5Illmzo0VRERPIzMzE9evX8fXX3+N0aNHi8uLiorqrb98+TLKy8slZ+v+53/+B8Bvd/IDQO/evfHjjz9i3LhxPJ61MbymjmTh888/x/3798X5DRs2oKqqCgEBAeKy3r17Iysrq87rHj5TV3uwfDgAEhE1tdpvJB78BqKyshLr16+vt76qqgp//etfJbV//etf0a1bN3h7ewMA3njjDfzv//4v/va3v9V5/d27d1FeXt6UQ6DnCM/UkSxUVlZi3LhxeOONN1BYWIj169dj5MiRePXVV8WaWbNmYe7cuZg4cSJ+//vf48cff8TevXthZ2cnWZeXlxcsLS3x8ccfo6ysDCqVCr/73e9gb2//rIdFRDI3YsQIdOnSBcHBwXj33XehUCjw97///ZGXmTg5OeHjjz/GxYsX8cILL2DHjh3Iy8vD559/Lt4sNn36dOzcuRNz587FwYMH8dJLL6G6uhoFBQXYuXMn9u7di6FDhz7LYdIzwjN1JAvr1q1D//79ER0djaSkJEyZMgW7d++WfPUwe/ZsLFq0CFlZWXj//fdRVFSE9PT0Ohcda7VaJCYm4urVqwgJCcGUKVNw5syZZz0kImoDunbtipSUFDg6OiIqKgqffPIJfv/73yMuLq7e+i5dumDPnj04duwYFixYgEuXLmHdunWYPXu2WGNhYYHk5GSsWLECp06dwgcffIClS5fihx9+wHvvvSfeWEHyoxCa+qpzomcoKSkJM2fOxA8//MD/8yQiojaNZ+qIiIiIZIChjoiIiEgGGOqIiIiIZIDX1BERERHJAM/UEREREcmAWaFuw4YNGDRoENRqNdRqNXQ6Hf7973+L7ffu3UNoaCi6du2KTp06YeLEiSgpKZGso7i4GIGBgejQoQPs7e2xYMECVFVVSWoyMzMxZMgQqFQq9OnTB0lJSXX6kpCQgJ49e8LKygo+Pj44evSoOUMhIiIikhWzHj7cvXt3rFixAn379oUgCNiyZQsmTJiAEydOYMCAAQgPD0dqaip27doFjUaDsLAwvPbaa/j+++8B/PYbm4GBgdBqtcjOzsaVK1fw5ptvon379vh//+//Afjtp1ECAwMxd+5cbN26FRkZGZg1axYcHR2h1+sBADt27EBERAQSExPh4+ODtWvXQq/Xo7Cw0KwHxNbU1ODy5cvo3Lkzf0qF6DknCAJu3boFJycnWFjwS4bG4nGOqHV5qmOd8JS6dOkibNy4USgtLRXat28v7Nq1S2w7e/asAEAwGAyCIAjCnj17BAsLC8FoNIo1GzZsENRqtVBRUSEIgiAsXLhQGDBggOQ9Jk2aJOj1enF+2LBhQmhoqDhfXV0tODk5CbGxsWb1/dKlSwIATpw4taLp0qVLZu3nbR2Pc5w4tc7pSY51T/wzYdXV1di1axfKy8uh0+mQm5uL+/fvw9fXV6xxc3NDjx49YDAYMHz4cBgMBnh4eMDBwUGs0ev1mDdvHvLz8zF48GAYDAbJOmpr5s+fD+C3n4PKzc1FZGSk2G5hYQFfX18YDAazxtC5c2cAwKVLl6BWq83dBET0DJlMJjg7O4v7LTUOj3NErcvTHOvMDnWnTp2CTqfDvXv30KlTJ3zzzTdwd3dHXl4elEolbGxsJPUODg4wGo0AAKPRKAl0te21bQ3VmEwm3L17Fzdv3kR1dXW9NQUFBQ32vaKiAhUVFeL8rVu3AEC8RpCInn/8CtE8tduLxzmi1uVJjnVmX5jSr18/5OXlIScnB/PmzUNwcHCr+V3M2NhYaDQacXJ2dm7pLhERERE1CbNDnVKpRJ8+feDt7Y3Y2Fh4enoiPj4eWq0WlZWVKC0tldSXlJRAq9UC+O2H0h++G7Z2/nE1arUa1tbWsLOzg6WlZb01tet4lMjISJSVlYnTpUuXzB0+ERER0XPpqW8hq6mpQUVFBby9vdG+fXtkZGSIbYWFhSguLoZOpwMA6HQ6nDp1ClevXhVr0tPToVar4e7uLtY8uI7amtp1KJVKeHt7S2pqamqQkZEh1jyKSqUSv4LgVxFEREQkJ2ZdUxcZGYmAgAD06NEDt27dwrZt25CZmYm9e/dCo9EgJCQEERERsLW1hVqtxjvvvAOdTofhw4cDAPz8/ODu7o7p06cjLi4ORqMRUVFRCA0NhUqlAgDMnTsX69atw8KFC/HWW2/hwIED2LlzJ1JTU8V+REREIDg4GEOHDsWwYcOwdu1alJeXY+bMmU24aYiIiIhaD7NC3dWrV/Hmm2/iypUr0Gg0GDRoEPbu3Yvf//73AIA1a9bAwsICEydOREVFBfR6PdavXy+
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 트랜스폼 잘 구현되었는지 확인\n",
"transform = transforms.Compose([Normalization(mean=0.5, std=0.5), RandomFlip(), ToTensor()])\n",
"dataset_train = Dataset(data_dir=dir_save_train, transform=transform)\n",
"data = dataset_train.__getitem__(0) # 한 이미지 불러오기\n",
"input = data['input']\n",
"label = data['label']\n",
"\n",
"# 불러온 이미지 시각화\n",
"plt.subplot(122)\n",
"plt.hist(label.flatten(), bins=20)\n",
"plt.title('label')\n",
"\n",
"plt.subplot(121)\n",
"plt.hist(input.flatten(), bins=20)\n",
"plt.title('input')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Save / Load"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"## 네트워크 저장하기\n",
"def save(ckpt_dir, net, optim, epoch):\n",
" if not os.path.exists(ckpt_dir):\n",
" os.makedirs(ckpt_dir)\n",
"\n",
" torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},\n",
" \"%s/model_epoch%d.pth\" % (ckpt_dir, epoch))\n",
"\n",
"## 네트워크 불러오기\n",
"def load(ckpt_dir, net, optim):\n",
" if not os.path.exists(ckpt_dir):\n",
" epoch = 0\n",
" return net, optim, epoch\n",
"\n",
" ckpt_lst = os.listdir(ckpt_dir)\n",
" ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))\n",
"\n",
" dict_model = torch.load('%s/%s' % (ckpt_dir, ckpt_lst[-1]))\n",
"\n",
" net.load_state_dict(dict_model['net'])\n",
" optim.load_state_dict(dict_model['optim'])\n",
" epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])\n",
"\n",
" return net, optim, epoch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TRAIN: EPOCH 0001 / 0020 | BATCH 0001 / 0006 | LOSS 0.6337\n",
"TRAIN: EPOCH 0001 / 0020 | BATCH 0002 / 0006 | LOSS 0.5923\n",
"TRAIN: EPOCH 0001 / 0020 | BATCH 0003 / 0006 | LOSS 0.5694\n",
"TRAIN: EPOCH 0001 / 0020 | BATCH 0004 / 0006 | LOSS 0.5456\n",
"TRAIN: EPOCH 0001 / 0020 | BATCH 0005 / 0006 | LOSS 0.5214\n",
"TRAIN: EPOCH 0001 / 0020 | BATCH 0006 / 0006 | LOSS 0.5036\n",
"VALID: EPOCH 0001 / 0020 | BATCH 0001 / 0001 | LOSS 0.6128\n",
"TRAIN: EPOCH 0002 / 0020 | BATCH 0001 / 0006 | LOSS 0.4027\n",
"TRAIN: EPOCH 0002 / 0020 | BATCH 0002 / 0006 | LOSS 0.3897\n",
"TRAIN: EPOCH 0002 / 0020 | BATCH 0003 / 0006 | LOSS 0.3905\n",
"TRAIN: EPOCH 0002 / 0020 | BATCH 0004 / 0006 | LOSS 0.3856\n",
"TRAIN: EPOCH 0002 / 0020 | BATCH 0005 / 0006 | LOSS 0.3807\n",
"TRAIN: EPOCH 0002 / 0020 | BATCH 0006 / 0006 | LOSS 0.3745\n",
"VALID: EPOCH 0002 / 0020 | BATCH 0001 / 0001 | LOSS 0.5134\n",
"TRAIN: EPOCH 0003 / 0020 | BATCH 0001 / 0006 | LOSS 0.3422\n",
"TRAIN: EPOCH 0003 / 0020 | BATCH 0002 / 0006 | LOSS 0.3350\n",
"TRAIN: EPOCH 0003 / 0020 | BATCH 0003 / 0006 | LOSS 0.3372\n",
"TRAIN: EPOCH 0003 / 0020 | BATCH 0004 / 0006 | LOSS 0.3337\n",
"TRAIN: EPOCH 0003 / 0020 | BATCH 0005 / 0006 | LOSS 0.3293\n",
"TRAIN: EPOCH 0003 / 0020 | BATCH 0006 / 0006 | LOSS 0.3286\n",
"VALID: EPOCH 0003 / 0020 | BATCH 0001 / 0001 | LOSS 0.4308\n",
"TRAIN: EPOCH 0004 / 0020 | BATCH 0001 / 0006 | LOSS 0.3094\n",
"TRAIN: EPOCH 0004 / 0020 | BATCH 0002 / 0006 | LOSS 0.3079\n",
"TRAIN: EPOCH 0004 / 0020 | BATCH 0003 / 0006 | LOSS 0.3090\n",
"TRAIN: EPOCH 0004 / 0020 | BATCH 0004 / 0006 | LOSS 0.3078\n",
"TRAIN: EPOCH 0004 / 0020 | BATCH 0005 / 0006 | LOSS 0.3049\n",
"TRAIN: EPOCH 0004 / 0020 | BATCH 0006 / 0006 | LOSS 0.3003\n",
"VALID: EPOCH 0004 / 0020 | BATCH 0001 / 0001 | LOSS 0.3922\n",
"TRAIN: EPOCH 0005 / 0020 | BATCH 0001 / 0006 | LOSS 0.2909\n",
"TRAIN: EPOCH 0005 / 0020 | BATCH 0002 / 0006 | LOSS 0.2870\n",
"TRAIN: EPOCH 0005 / 0020 | BATCH 0003 / 0006 | LOSS 0.2874\n",
"TRAIN: EPOCH 0005 / 0020 | BATCH 0004 / 0006 | LOSS 0.2851\n",
"TRAIN: EPOCH 0005 / 0020 | BATCH 0005 / 0006 | LOSS 0.2821\n",
"TRAIN: EPOCH 0005 / 0020 | BATCH 0006 / 0006 | LOSS 0.2812\n",
"VALID: EPOCH 0005 / 0020 | BATCH 0001 / 0001 | LOSS 0.3173\n",
"TRAIN: EPOCH 0006 / 0020 | BATCH 0001 / 0006 | LOSS 0.2671\n",
"TRAIN: EPOCH 0006 / 0020 | BATCH 0002 / 0006 | LOSS 0.2769\n",
"TRAIN: EPOCH 0006 / 0020 | BATCH 0003 / 0006 | LOSS 0.2720\n",
"TRAIN: EPOCH 0006 / 0020 | BATCH 0004 / 0006 | LOSS 0.2758\n",
"TRAIN: EPOCH 0006 / 0020 | BATCH 0005 / 0006 | LOSS 0.2745\n",
"TRAIN: EPOCH 0006 / 0020 | BATCH 0006 / 0006 | LOSS 0.2750\n",
"VALID: EPOCH 0006 / 0020 | BATCH 0001 / 0001 | LOSS 0.2644\n",
"TRAIN: EPOCH 0007 / 0020 | BATCH 0001 / 0006 | LOSS 0.2504\n",
"TRAIN: EPOCH 0007 / 0020 | BATCH 0002 / 0006 | LOSS 0.2509\n",
"TRAIN: EPOCH 0007 / 0020 | BATCH 0003 / 0006 | LOSS 0.2562\n",
"TRAIN: EPOCH 0007 / 0020 | BATCH 0004 / 0006 | LOSS 0.2540\n",
"TRAIN: EPOCH 0007 / 0020 | BATCH 0005 / 0006 | LOSS 0.2571\n",
"TRAIN: EPOCH 0007 / 0020 | BATCH 0006 / 0006 | LOSS 0.2563\n",
"VALID: EPOCH 0007 / 0020 | BATCH 0001 / 0001 | LOSS 0.3108\n",
"TRAIN: EPOCH 0008 / 0020 | BATCH 0001 / 0006 | LOSS 0.2595\n",
"TRAIN: EPOCH 0008 / 0020 | BATCH 0002 / 0006 | LOSS 0.2523\n",
"TRAIN: EPOCH 0008 / 0020 | BATCH 0003 / 0006 | LOSS 0.2517\n",
"TRAIN: EPOCH 0008 / 0020 | BATCH 0004 / 0006 | LOSS 0.2495\n",
"TRAIN: EPOCH 0008 / 0020 | BATCH 0005 / 0006 | LOSS 0.2467\n",
"TRAIN: EPOCH 0008 / 0020 | BATCH 0006 / 0006 | LOSS 0.2460\n",
"VALID: EPOCH 0008 / 0020 | BATCH 0001 / 0001 | LOSS 0.2661\n",
"TRAIN: EPOCH 0009 / 0020 | BATCH 0001 / 0006 | LOSS 0.2405\n",
"TRAIN: EPOCH 0009 / 0020 | BATCH 0002 / 0006 | LOSS 0.2372\n",
"TRAIN: EPOCH 0009 / 0020 | BATCH 0003 / 0006 | LOSS 0.2444\n",
"TRAIN: EPOCH 0009 / 0020 | BATCH 0004 / 0006 | LOSS 0.2422\n",
"TRAIN: EPOCH 0009 / 0020 | BATCH 0005 / 0006 | LOSS 0.2392\n",
"TRAIN: EPOCH 0009 / 0020 | BATCH 0006 / 0006 | LOSS 0.2389\n",
"VALID: EPOCH 0009 / 0020 | BATCH 0001 / 0001 | LOSS 0.2370\n",
"TRAIN: EPOCH 0010 / 0020 | BATCH 0001 / 0006 | LOSS 0.2330\n",
"TRAIN: EPOCH 0010 / 0020 | BATCH 0002 / 0006 | LOSS 0.2343\n",
"TRAIN: EPOCH 0010 / 0020 | BATCH 0003 / 0006 | LOSS 0.2297\n",
"TRAIN: EPOCH 0010 / 0020 | BATCH 0004 / 0006 | LOSS 0.2306\n",
"TRAIN: EPOCH 0010 / 0020 | BATCH 0005 / 0006 | LOSS 0.2329\n",
"TRAIN: EPOCH 0010 / 0020 | BATCH 0006 / 0006 | LOSS 0.2315\n",
"VALID: EPOCH 0010 / 0020 | BATCH 0001 / 0001 | LOSS 0.2316\n",
"TRAIN: EPOCH 0011 / 0020 | BATCH 0001 / 0006 | LOSS 0.2298\n",
"TRAIN: EPOCH 0011 / 0020 | BATCH 0002 / 0006 | LOSS 0.2244\n",
"TRAIN: EPOCH 0011 / 0020 | BATCH 0003 / 0006 | LOSS 0.2226\n",
"TRAIN: EPOCH 0011 / 0020 | BATCH 0004 / 0006 | LOSS 0.2295\n",
"TRAIN: EPOCH 0011 / 0020 | BATCH 0005 / 0006 | LOSS 0.2272\n",
"TRAIN: EPOCH 0011 / 0020 | BATCH 0006 / 0006 | LOSS 0.2301\n",
"VALID: EPOCH 0011 / 0020 | BATCH 0001 / 0001 | LOSS 0.2179\n",
"TRAIN: EPOCH 0012 / 0020 | BATCH 0001 / 0006 | LOSS 0.2201\n",
"TRAIN: EPOCH 0012 / 0020 | BATCH 0002 / 0006 | LOSS 0.2226\n",
"TRAIN: EPOCH 0012 / 0020 | BATCH 0003 / 0006 | LOSS 0.2231\n",
"TRAIN: EPOCH 0012 / 0020 | BATCH 0004 / 0006 | LOSS 0.2208\n",
"TRAIN: EPOCH 0012 / 0020 | BATCH 0005 / 0006 | LOSS 0.2222\n",
"TRAIN: EPOCH 0012 / 0020 | BATCH 0006 / 0006 | LOSS 0.2237\n",
"VALID: EPOCH 0012 / 0020 | BATCH 0001 / 0001 | LOSS 0.2273\n",
"TRAIN: EPOCH 0013 / 0020 | BATCH 0001 / 0006 | LOSS 0.2027\n",
"TRAIN: EPOCH 0013 / 0020 | BATCH 0002 / 0006 | LOSS 0.2203\n",
"TRAIN: EPOCH 0013 / 0020 | BATCH 0003 / 0006 | LOSS 0.2186\n",
"TRAIN: EPOCH 0013 / 0020 | BATCH 0004 / 0006 | LOSS 0.2188\n",
"TRAIN: EPOCH 0013 / 0020 | BATCH 0005 / 0006 | LOSS 0.2191\n",
"TRAIN: EPOCH 0013 / 0020 | BATCH 0006 / 0006 | LOSS 0.2185\n",
"VALID: EPOCH 0013 / 0020 | BATCH 0001 / 0001 | LOSS 0.2283\n",
"TRAIN: EPOCH 0014 / 0020 | BATCH 0001 / 0006 | LOSS 0.2275\n",
"TRAIN: EPOCH 0014 / 0020 | BATCH 0002 / 0006 | LOSS 0.2271\n",
"TRAIN: EPOCH 0014 / 0020 | BATCH 0003 / 0006 | LOSS 0.2190\n",
"TRAIN: EPOCH 0014 / 0020 | BATCH 0004 / 0006 | LOSS 0.2165\n",
"TRAIN: EPOCH 0014 / 0020 | BATCH 0005 / 0006 | LOSS 0.2160\n",
"TRAIN: EPOCH 0014 / 0020 | BATCH 0006 / 0006 | LOSS 0.2165\n",
"VALID: EPOCH 0014 / 0020 | BATCH 0001 / 0001 | LOSS 0.2171\n",
"TRAIN: EPOCH 0015 / 0020 | BATCH 0001 / 0006 | LOSS 0.2097\n",
"TRAIN: EPOCH 0015 / 0020 | BATCH 0002 / 0006 | LOSS 0.2162\n",
"TRAIN: EPOCH 0015 / 0020 | BATCH 0003 / 0006 | LOSS 0.2149\n",
"TRAIN: EPOCH 0015 / 0020 | BATCH 0004 / 0006 | LOSS 0.2129\n",
"TRAIN: EPOCH 0015 / 0020 | BATCH 0005 / 0006 | LOSS 0.2120\n",
"TRAIN: EPOCH 0015 / 0020 | BATCH 0006 / 0006 | LOSS 0.2126\n",
"VALID: EPOCH 0015 / 0020 | BATCH 0001 / 0001 | LOSS 0.2184\n",
"TRAIN: EPOCH 0016 / 0020 | BATCH 0001 / 0006 | LOSS 0.2167\n",
"TRAIN: EPOCH 0016 / 0020 | BATCH 0002 / 0006 | LOSS 0.2071\n",
"TRAIN: EPOCH 0016 / 0020 | BATCH 0003 / 0006 | LOSS 0.2096\n",
"TRAIN: EPOCH 0016 / 0020 | BATCH 0004 / 0006 | LOSS 0.2100\n",
"TRAIN: EPOCH 0016 / 0020 | BATCH 0005 / 0006 | LOSS 0.2084\n",
"TRAIN: EPOCH 0016 / 0020 | BATCH 0006 / 0006 | LOSS 0.2077\n",
"VALID: EPOCH 0016 / 0020 | BATCH 0001 / 0001 | LOSS 0.2243\n",
"TRAIN: EPOCH 0017 / 0020 | BATCH 0001 / 0006 | LOSS 0.2097\n",
"TRAIN: EPOCH 0017 / 0020 | BATCH 0002 / 0006 | LOSS 0.2075\n",
"TRAIN: EPOCH 0017 / 0020 | BATCH 0003 / 0006 | LOSS 0.2112\n",
"TRAIN: EPOCH 0017 / 0020 | BATCH 0004 / 0006 | LOSS 0.2114\n",
"TRAIN: EPOCH 0017 / 0020 | BATCH 0005 / 0006 | LOSS 0.2097\n",
"TRAIN: EPOCH 0017 / 0020 | BATCH 0006 / 0006 | LOSS 0.2091\n",
"VALID: EPOCH 0017 / 0020 | BATCH 0001 / 0001 | LOSS 0.2012\n",
"TRAIN: EPOCH 0018 / 0020 | BATCH 0001 / 0006 | LOSS 0.2049\n",
"TRAIN: EPOCH 0018 / 0020 | BATCH 0002 / 0006 | LOSS 0.1998\n",
"TRAIN: EPOCH 0018 / 0020 | BATCH 0003 / 0006 | LOSS 0.2004\n",
"TRAIN: EPOCH 0018 / 0020 | BATCH 0004 / 0006 | LOSS 0.2084\n",
"TRAIN: EPOCH 0018 / 0020 | BATCH 0005 / 0006 | LOSS 0.2070\n",
"TRAIN: EPOCH 0018 / 0020 | BATCH 0006 / 0006 | LOSS 0.2046\n",
"VALID: EPOCH 0018 / 0020 | BATCH 0001 / 0001 | LOSS 0.1956\n",
"TRAIN: EPOCH 0019 / 0020 | BATCH 0001 / 0006 | LOSS 0.1971\n",
"TRAIN: EPOCH 0019 / 0020 | BATCH 0002 / 0006 | LOSS 0.1906\n",
"TRAIN: EPOCH 0019 / 0020 | BATCH 0003 / 0006 | LOSS 0.1992\n",
"TRAIN: EPOCH 0019 / 0020 | BATCH 0004 / 0006 | LOSS 0.2029\n",
"TRAIN: EPOCH 0019 / 0020 | BATCH 0005 / 0006 | LOSS 0.2021\n",
"TRAIN: EPOCH 0019 / 0020 | BATCH 0006 / 0006 | LOSS 0.2061\n",
"VALID: EPOCH 0019 / 0020 | BATCH 0001 / 0001 | LOSS 0.2161\n",
"TRAIN: EPOCH 0020 / 0020 | BATCH 0001 / 0006 | LOSS 0.2169\n",
"TRAIN: EPOCH 0020 / 0020 | BATCH 0002 / 0006 | LOSS 0.2125\n",
"TRAIN: EPOCH 0020 / 0020 | BATCH 0003 / 0006 | LOSS 0.2093\n",
"TRAIN: EPOCH 0020 / 0020 | BATCH 0004 / 0006 | LOSS 0.2127\n",
"TRAIN: EPOCH 0020 / 0020 | BATCH 0005 / 0006 | LOSS 0.2115\n",
"TRAIN: EPOCH 0020 / 0020 | BATCH 0006 / 0006 | LOSS 0.2120\n",
"VALID: EPOCH 0020 / 0020 | BATCH 0001 / 0001 | LOSS 0.2045\n"
]
}
],
"source": [
"# 훈련 파라미터 설정하기\n",
"lr = 1e-3\n",
"batch_size = 4\n",
"num_epoch = 20\n",
"\n",
"base_dir = './drive/MyDrive/DACrew/unet'\n",
"data_dir = dir_data\n",
"ckpt_dir = os.path.join(base_dir, \"checkpoint\")\n",
"log_dir = os.path.join(base_dir, \"log\")\n",
"\n",
"\n",
"# 훈련을 위한 Transform과 DataLoader\n",
"transform = transforms.Compose([Normalization(mean=0.5, std=0.5), RandomFlip(), ToTensor()])\n",
"\n",
"dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train'), transform=transform)\n",
"loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=0)\n",
"\n",
"dataset_val = Dataset(data_dir=os.path.join(data_dir, 'val'), transform=transform)\n",
"loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=0)\n",
"\n",
"# 네트워크 생성하기\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"net = UNet().to(device)\n",
"\n",
"# 손실함수 정의하기\n",
"fn_loss = nn.BCEWithLogitsLoss().to(device)\n",
"\n",
"# Optimizer 설정하기\n",
"optim = torch.optim.Adam(net.parameters(), lr=lr)\n",
"\n",
"# 그밖에 부수적인 variables 설정하기\n",
"num_data_train = len(dataset_train)\n",
"num_data_val = len(dataset_val)\n",
"\n",
"num_batch_train = np.ceil(num_data_train / batch_size)\n",
"num_batch_val = np.ceil(num_data_val / batch_size)\n",
"\n",
"# 그 밖에 부수적인 functions 설정하기\n",
"fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1)\n",
"fn_denorm = lambda x, mean, std: (x * std) + mean\n",
"fn_class = lambda x: 1.0 * (x > 0.5)\n",
"\n",
"# Tensorboard 를 사용하기 위한 SummaryWriter 설정\n",
"writer_train = SummaryWriter(log_dir=os.path.join(log_dir, 'train'))\n",
"writer_val = SummaryWriter(log_dir=os.path.join(log_dir, 'val'))\n",
"\n",
"# 네트워크 학습시키기\n",
"st_epoch = 0\n",
"# 학습한 모델이 있을 경우 모델 로드하기\n",
"net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim) \n",
"\n",
"for epoch in range(st_epoch + 1, num_epoch + 1):\n",
" net.train()\n",
" loss_arr = []\n",
"\n",
" for batch, data in enumerate(loader_train, 1):\n",
" # forward pass\n",
" label = data['label'].to(device)\n",
" input = data['input'].to(device)\n",
"\n",
" output = net(input)\n",
"\n",
" # backward pass\n",
" optim.zero_grad()\n",
"\n",
" loss = fn_loss(output, label)\n",
" loss.backward()\n",
"\n",
" optim.step()\n",
"\n",
" # 손실함수 계산\n",
" loss_arr += [loss.item()]\n",
"\n",
" print(\"TRAIN: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f\" %\n",
" (epoch, num_epoch, batch, num_batch_train, np.mean(loss_arr)))\n",
"\n",
" # Tensorboard 저장하기\n",
" label = fn_tonumpy(label)\n",
" input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))\n",
" output = fn_tonumpy(fn_class(output))\n",
"\n",
" writer_train.add_image('label', label, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')\n",
" writer_train.add_image('input', input, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')\n",
" writer_train.add_image('output', output, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')\n",
"\n",
" writer_train.add_scalar('loss', np.mean(loss_arr), epoch)\n",
"\n",
" with torch.no_grad():\n",
" net.eval()\n",
" loss_arr = []\n",
"\n",
" for batch, data in enumerate(loader_val, 1):\n",
" # forward pass\n",
" label = data['label'].to(device)\n",
" input = data['input'].to(device)\n",
"\n",
" output = net(input)\n",
"\n",
" # 손실함수 계산하기\n",
" loss = fn_loss(output, label)\n",
"\n",
" loss_arr += [loss.item()]\n",
"\n",
" print(\"VALID: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f\" %\n",
" (epoch, num_epoch, batch, num_batch_val, np.mean(loss_arr)))\n",
"\n",
" # Tensorboard 저장하기\n",
" label = fn_tonumpy(label)\n",
" input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))\n",
" output = fn_tonumpy(fn_class(output))\n",
"\n",
" writer_val.add_image('label', label, num_batch_val * (epoch - 1) + batch, dataformats='NHWC')\n",
" writer_val.add_image('input', input, num_batch_val * (epoch - 1) + batch, dataformats='NHWC')\n",
" writer_val.add_image('output', output, num_batch_val * (epoch - 1) + batch, dataformats='NHWC')\n",
"\n",
" writer_val.add_scalar('loss', np.mean(loss_arr), epoch)\n",
"\n",
" # epoch 50마다 모델 저장하기\n",
" if epoch % 50 == 0:\n",
" save(ckpt_dir=ckpt_dir, net=net, optim=optim, epoch=epoch)\n",
"\n",
" writer_train.close()\n",
" writer_val.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TEST: BATCH 0001 / 0001 | LOSS 0.2526\n",
"AVERAGE TEST: BATCH 0001 / 0001 | LOSS 0.2526\n"
]
}
],
"source": [
"transform = transforms.Compose([Normalization(mean=0.5, std=0.5), ToTensor()])\n",
"\n",
"dataset_test = Dataset(data_dir=os.path.join(data_dir, 'test'), transform=transform)\n",
"loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=0)\n",
"\n",
"# 그밖에 부수적인 variables 설정하기\n",
"num_data_test = len(dataset_test)\n",
"num_batch_test = np.ceil(num_data_test / batch_size)\n",
"\n",
"# 결과 디렉토리 생성하기\n",
"result_dir = os.path.join(base_dir, 'result')\n",
"if not os.path.exists(result_dir):\n",
" os.makedirs(os.path.join(result_dir, 'png'))\n",
" os.makedirs(os.path.join(result_dir, 'numpy'))\n",
"\n",
"\n",
"net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)\n",
"\n",
"with torch.no_grad():\n",
" net.eval()\n",
" loss_arr = []\n",
"\n",
" for batch, data in enumerate(loader_test, 1):\n",
" # forward pass\n",
" label = data['label'].to(device)\n",
" input = data['input'].to(device)\n",
"\n",
" output = net(input)\n",
"\n",
" # 손실함수 계산하기\n",
" loss = fn_loss(output, label)\n",
"\n",
" loss_arr += [loss.item()]\n",
"\n",
" print(\"TEST: BATCH %04d / %04d | LOSS %.4f\" %\n",
" (batch, num_batch_test, np.mean(loss_arr)))\n",
"\n",
" # Tensorboard 저장하기\n",
" label = fn_tonumpy(label)\n",
" input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))\n",
" output = fn_tonumpy(fn_class(output))\n",
"\n",
" # 테스트 결과 저장하기\n",
" for j in range(label.shape[0]):\n",
" id = num_batch_test * (batch - 1) + j\n",
"\n",
" plt.imsave(os.path.join(result_dir, 'png', 'label_%04d.png' % id), label[j].squeeze(), cmap='gray')\n",
" plt.imsave(os.path.join(result_dir, 'png', 'input_%04d.png' % id), input[j].squeeze(), cmap='gray')\n",
" plt.imsave(os.path.join(result_dir, 'png', 'output_%04d.png' % id), output[j].squeeze(), cmap='gray')\n",
"\n",
" np.save(os.path.join(result_dir, 'numpy', 'label_%04d.npy' % id), label[j].squeeze())\n",
" np.save(os.path.join(result_dir, 'numpy', 'input_%04d.npy' % id), input[j].squeeze())\n",
" np.save(os.path.join(result_dir, 'numpy', 'output_%04d.npy' % id), output[j].squeeze())\n",
"\n",
"print(\"AVERAGE TEST: BATCH %04d / %04d | LOSS %.4f\" %\n",
" (batch, num_batch_test, np.mean(loss_arr)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Result (Visualize)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqQAAAD4CAYAAAA6lfQMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d3hUVff+/ZmeMpn0XkkgQOhNilQBFUREpSgWBBXEB1BQKRaKUgQVVJQqSFEUQVBEkV6kiIJ0SCAkIb1OMslMps+8f/Ce/SOG6oPt+c59XVyamTPnzDlz9tlrr3Wv+5a53W43HnjggQceeOCBBx548DdB/nd/AQ888MADDzzwwAMP/m/DE5B64IEHHnjggQceePC3whOQeuCBBx544IEHHnjwt8ITkHrggQceeOCBBx548LfCE5B64IEHHnjggQceePC3whOQeuCBBx544IEHHnjwt8ITkHrggQceeOCBBx548LfCE5B64IEHHnjggQceePC3whOQeuCBBx544IEHHnjwt8ITkP6PYsWKFchkMrKysv7ur+KBB//z+KPjrWvXrjRu3Pi2fpeEhASeeuqp27pPDzzwwIM/G56A1IM/HWfPnmXq1Kme4NgDDzzwwIO/FAsWLGDFihV/ybE8c91/B09A+j+KJ554ArPZTHx8/N/9VTh79izTpk3zDFIPPPDAAw/+UvzVAalnrvvjUP7dX8CDPwcKhQKFQvF3fw0PPPDAAw888MCDG8KTIf0fxe85bQkJCfTp04f9+/dzxx134OXlRWJiIqtWrbrq5/bt28eIESMIDg5Gp9Px5JNPUl5eXmNbmUzG1KlTax37Sg7bihUrGDBgAADdunVDJpMhk8nYs2fP7T5lDzz4x+Dbb7/lvvvuIyoqCo1GQ1JSEm+99RZOp/Oq2x89epQOHTrg7e1NnTp1WLRoUa1trFYrU6ZMoW7dumg0GmJjYxk/fjxWq/XPPh0PPPhLcezYMXr16oVOp0Or1dK9e3d+/vln8f7UqVORyWS1Pne1ee/MmTPs3btXzD1du3atsa1nrvvnwJMh/T+E9PR0+vfvz9NPP82QIUNYvnw5Tz31FK1ataJRo0Y1th01ahQBAQFMnTqVtLQ0Fi5cyKVLl9izZ89VHwTXQufOnRkzZgwffvghr776Kg0bNgQQ//XAg/9FrFixAq1Wy7hx49BqtezatYvJkydTWVnJO++8U2Pb8vJyevfuzcCBA3n00Uf56quvGDlyJGq1mmHDhgHgcrno27cv+/fvZ/jw4TRs2JBTp04xb948zp8/zzfffPM3nKUHHtx+nDlzhk6dOqHT6Rg/fjwqlYrFixfTtWtX9u7dS9u2bW96X++//z6jR49Gq9Xy2muvARAeHl5jG89c9w+C24P/SXz66aduwJ2Zmel2u93u+Ph4N+Det2+f2Ka4uNit0WjcL730Uq3PtWrVym2z2cTrc+bMcQPub7/9VrwGuKdMmVLr2PHx8e4hQ4aIv9etW+cG3Lt3775t5+eBB/8k/H68VVdX19pmxIgRbh8fH7fFYhGvdenSxQ2433vvPfGa1Wp1N2/e3B0WFibG4OrVq91yudz9008/1djnokWL3ID7wIED4rXfjz8PPPg3oV+/fm61Wu2+ePGieC0/P9/t5+fn7ty5s9vtdrunTJnivlr48vtx6Ha73Y0aNXJ36dLlmtt65rp/Djwl+/9DSElJoVOnTuLv0NBQ6tevT0ZGRq1thw8fjkqlEn+PHDkSpVLJDz/88Jd8Vw88+DfD29tb/H9VVRWlpaV06tSJ6upqUlNTa2yrVCoZMWKE+FutVjNixAiKi4s5evQoAOvWraNhw4Y0aNCA0tJS8e+uu+4CYPfu3X/BWXngwZ8Lp9PJtm3b6NevH4mJieL1yMhIBg8ezP79+6msrLytx/TMdf8ceEr2/4cQFxdX67XAwMBafBmAevXq1fhbq9USGRnp6R70wIObwJkzZ3j99dfZtWtXrQnUYDDU+DsqKgpfX98aryUnJwOQlZVFu3btuHDhAufOnSM0NPSqxysuLr6N394DD/4elJSUUF1dTf369Wu917BhQ1wuFzk5Obf1mJ657p8DT0D6fwjX6rp3u9239TjXatzwwIP/C6ioqKBLly7odDrefPNNkpKS8PLy4rfffmPChAm4XK5b3qfL5aJJkybMnTv3qu/Hxsb+t1/bAw/+NbgWt/Ovnns8c93thScg9eCquHDhAt26dRN/G41GCgoK6N27t3gtMDCQioqKGp+z2WwUFBTUeO1WiOEeePBvx549eygrK2PDhg107txZvJ6ZmXnV7fPz8zGZTDWypOfPnwcud/ECJCUlceLECbp37+4ZTx78zyI0NBQfHx/S0tJqvZeamopcLic2NpbAwEDg8uIvICBAbHPp0qVan7vRePHMdf8ceDikHlwVS5YswW63i78XLlyIw+GgV69e4rWkpCT27dtX63O/XzVKE+3vB7QHHvwvQqpEXFl5sNlsLFiw4KrbOxwOFi9eXGPbxYsXExoaSqtWrQAYOHAgeXl5LF26tNbnzWYzJpPpdp6CBx78LVAoFNx99918++23NUrmRUVFrFmzho4dO6LT6UhKSgKoMf+YTCZWrlxZa5++vr7XnXs8c90/B54MqQdXhc1mo3v37gwcOJC0tDQWLFhAx44d6du3r9jmmWee4bnnnuPhhx+mZ8+enDhxgq1btxISElJjX82bN0ehUDB79mwMBgMajYa77rqLsLCwv/q0PPDgT0eHDh0IDAxkyJAhjBkzBplMxurVq69JjYmKimL27NlkZWWRnJzM2rVrOX78OEuWLBHNFk888QRfffUVzz33HLt37+bOO+/E6XSSmprKV199xdatW2nduvVfeZoeePCnYPr06Wzfvp2OHTvy/PPPo1QqWbx4MVarlTlz5gBw9913ExcXx9NPP80rr7yCQqFg+fLlhIaGkp2dXWN/rVq1YuHChUyfPp26desSFhYmmgHBM9f9o/A3d/l78CfharJP9913X63tunTpUkMSQ/rc3r173cOHD3cHBga6tVqt+7HHHnOXlZXV+KzT6XRPmDDBHRIS4vbx8XHfc8897vT09KvKzixdutSdmJjoVigUHlkMD/7n8PvxduDAAXe7du3c3t7e7qioKPf48ePdW7durXXvd+nSxd2oUSP3kSNH3O3bt3d7eXm54+Pj3R999FGtY9hsNvfs2bPdjRo1cms0GndgYKC7VatW7mnTprkNBoPYziP75MG/Hb/99pv7nnvucWu1WrePj4+7W7du7oMHD9bY5ujRo+62bdu61Wq1Oy4uzj137tyryj4VFha677vvPrefn58bEPOdZ67750Hmdt/mjhYP/tVYsWIFQ4cO5ddff/VkXDzwwAMPPPifhGeu++fBwyH1wAMPPPDAAw888OBvhScg9cADDzzwwAMPPPDgb4UnIPXAAw888MADDzzw4G/F3xaQfvzxxyQkJODl5UXbtm355Zdf/q6v4sEVeOqpp3C73R5Ozf8QPGPNAw/+GnjG2r8Hnrnun4e/JSBdu3Yt48aNY8qUKfz22280a9aMe+65x2N/54EHtxmeseaBB38NPGPNAw/+O/wtXfZt27alTZs2fPTRR8BlW7zY2FhGjx7NxIkT/+qv44EH/7PwjDUPPPhr4BlrHnjw3+EvF8a32WwcPXqUSZMmidfkcjk9evTg0KFDV/2M1WrFarWKv10uF3q9nuDgYI9Vlwf/M3C73VRVVREVFYVc/t8XLzxjzQMPrg7PWPPAg78GtzLW/vKAtLS0FKfTSXh4eI3Xw8PDSU1NvepnZs2axbRp0/6Kr+eBB387cnJyiImJ+a/34xlrHnhwfXjGmgce/DW4mbH2r7AOnTRpEuPGjRN/GwwG4uLiWLhwIXa7ncrKSgwGAydOnKC6uhqHw4FWq8ViseDj40NgYCBKpRKn04nRaMTX1xeZTIbZbOb8+fOEhITQvXt3dDodFy5cICoqivj4eCoqKkhNTSU3N5eioiL8/PwICgrCy8uL+vXrU6dOHUwmExaLBY1Gg91ux8vLi5KSElQqFdHR0VRVVeF0OikrK8N
"text/plain": [
"<Figure size 800x600 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"##\n",
"lst_data = os.listdir(os.path.join(result_dir, 'numpy'))\n",
"\n",
"lst_label = [f for f in lst_data if f.startswith('label')]\n",
"lst_input = [f for f in lst_data if f.startswith('input')]\n",
"lst_output = [f for f in lst_data if f.startswith('output')]\n",
"\n",
"lst_label.sort()\n",
"lst_input.sort()\n",
"lst_output.sort()\n",
"\n",
"##\n",
"id = 0\n",
"\n",
"label = np.load(os.path.join(result_dir,\"numpy\", lst_label[id]))\n",
"input = np.load(os.path.join(result_dir,\"numpy\", lst_input[id]))\n",
"output = np.load(os.path.join(result_dir,\"numpy\", lst_output[id]))\n",
"\n",
"## 플롯 그리기\n",
"plt.figure(figsize=(8,6))\n",
"plt.subplot(131)\n",
"plt.imshow(input, cmap='gray')\n",
"plt.title('input')\n",
"\n",
"plt.subplot(132)\n",
"plt.imshow(label, cmap='gray')\n",
"plt.title('label')\n",
"\n",
"plt.subplot(133)\n",
"plt.imshow(output, cmap='gray')\n",
"plt.title('output')\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}