You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1133 lines
713 KiB
Plaintext
1133 lines
713 KiB
Plaintext
1 year ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# 의미론적 분할(Semantic Segmentation)이란?\n",
|
||
|
"\n",
|
||
|
"![sementic-segmentation](./images/unet1.png)\n",
|
||
|
"\n",
|
||
|
"U-Net은 컴퓨터 비전 영역에서 풀려고 하는 문제(task) 중 의미론적 분할(Semantic Segmentation)을 수행할 수 있는 모델이다. \n",
|
||
|
"\n",
|
||
|
"의미론적 분할이란 이미지 내에서 픽셀 단위로 분류하는 것이다.\n",
|
||
|
"\n",
|
||
|
"즉, 각 픽셀별로 어떤 클래스에 속하는지를 예측하는 문제를 말한다.\n",
|
||
|
"\n",
|
||
|
"이미지 내에 객체 존재 여부를 예측하는 문제(이미지 분류; Image Classification)에 비해서 객체 경계 정보를 보존해야하고,\n",
|
||
|
"\n",
|
||
|
"전체적인 이미지의 문맥을 파악해야 하는 등 조금 더 높은 수준의 이미지 이해를 요구한다는 점에서 까다로운 문제에 속한다."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# UNet이란?\n",
|
||
|
"\n",
|
||
|
"![unet](./images/unet2.png)\n",
|
||
|
"\n",
|
||
|
"U-Net은 이미지를 압축하는 수축 경로(contracting path)와 원본 이미지의 크기로 복원하는 확장경로(expansive path)로 구성된다.\n",
|
||
|
"\n",
|
||
|
"각 모듈을 인코더(Encoder), 디코더(Decoder)라고 부르고 모델의 구조가 U자 형태를 띄고 있다고 하여 U-Net으로 부른다. \n",
|
||
|
"\n",
|
||
|
"의미론적 분할을 수행하는 여러 모델들은 자율주행, 의생명공학 등 다양한 분야에 사용될 수 있다. \n",
|
||
|
"\n",
|
||
|
"U-Net은 MRI, CT 상에서 병변을 진단하거나 장기, 세포 조직 등을 분할하는 등 의료 영상(Biomedical) 분야에서 좋은 성능을 발휘하고 있고, \n",
|
||
|
"\n",
|
||
|
"U-Net 구조를 기반으로 한 모델들이 매년 다양한 문제를 더 잘 해결하는 모습을 보여주고 있다. \n",
|
||
|
"\n",
|
||
|
"paperwithcode에 따르면 U-Net이 해결하고 있는 문제의 10% 이상이 의료 분야와 관련되어 있고, \n",
|
||
|
"\n",
|
||
|
"의미론적 분할을 수행하는 모델 중 가장 많은 논문 수를 보유하고 있다.\n",
|
||
|
"\n",
|
||
|
"의미론적 분할은 기계가 수행하기 어려운 고난도의 문제임에도 불구하고 해당 영역의 최신 모델(SOTA; state of the art)들은 꽤 높은 수준에 이른 것으로 확인된다.\n",
|
||
|
"\n",
|
||
|
"paperwithcode에 따르면 의미론적 분할 영역에서 학습 및 테스트할 수 있는 대표적인 데이터셋인\n",
|
||
|
"\n",
|
||
|
"Cityscapes test/val와 PASCAL VOC 2012 test/val에 대해서 벤치마크에서 1위를 달성한 모델들이 각각 84.5/86.95%, 90.5/90.0%의 mIoU(mean Intersection of Union)를 달성하였다. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# overlap-tile 이란?\n",
|
||
|
"\n",
|
||
|
"U-Net의 논문 제목 \"U-Net: Convolutional Networks for Biomedical Image Segmentation\"에서도 알 수 있듯이\n",
|
||
|
"\n",
|
||
|
"U-Net은 의료 영상 이미지 분할을 위해 고안된 모델이다.\n",
|
||
|
" \n",
|
||
|
"CT에서 혹의 위치를 찾거나(nodule detection) 배경으로부터 망막 혈관을 분할하는(retinal vessel segmentation) 등 병변을 진단하는데 도움이 줄 수 있다. \n",
|
||
|
"\n",
|
||
|
"일반적인 이미지와 다르게 의료공학 분야에서는 고해상도 이미지가 대부분이기 때문에 많은 연산량을 필요로 한다. \n",
|
||
|
"\n",
|
||
|
"고용량의 의료 이미지를 효율적으로 처리하기 위한 방안으로 overlap-tile 전략을 고안해냈다. \n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"![overlat-tile](./images/unet3.png)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"노란색 영역, 즉 타일을 예측하게 되면 다음 타일로 넘어가는데 필요한 영역이 이전에 예측을 위해 사용했던 영역과 겹치게(overlap) 된다.\n",
|
||
|
"\n",
|
||
|
"따라서 이 방법을 overlap-tile이라고 부른다. \n",
|
||
|
"\n",
|
||
|
"논문에서는 overlap-tile 전략은 \"GPU 메모리가 한정되어 있을 때 큰 이미지들을 인공 신경망에 적용하는데 장점이 있다\"고 말하고 있다. \n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# 데이터 증폭이란?\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" U-Net는 overlap-tile 이외에도 데이터 증폭이라는 방식을 사용하여 모델 학습을 수행합니다. \n",
|
||
|
"\n",
|
||
|
"의료 공학 분야에서는 훈련할 수 있는 이미지의 갯수가 적은 반면 조직의 변이나 변형이 매우 흔하기 때문에 확보한 데이터를 증폭하는 과정이 매우 중요하다고 합니다. \n",
|
||
|
"\n",
|
||
|
"데이터 증폭이란 확보한 이미지를 반전시키거나 회전, 뒤틀림, 이동시켜서 더 많은 양의 이미지를 확보하는 것을 의미합니다. \n",
|
||
|
"\n",
|
||
|
"U-Net 외에도 레이블링 비용 감소 등을 위해서 다른 모델에서도 데이터 증폭은 많이 사용되고 있습니다.\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# 참고자료\n",
|
||
|
"\n",
|
||
|
" U-Net 논문, \"U-Net: Convolutional Networks for Biomedical Image Segmentation(2015)\" [Google Scholar]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# 손실함수 재정의\n",
|
||
|
"\n",
|
||
|
"![weight-mat-loss](./images/unet4.png)\n",
|
||
|
"\n",
|
||
|
"의료 이미지 분야에서 많은 세포들이 모여있는 경우, 즉 동일한 클래스가 인접해 있는 경우 분할하는데 많은 어려움을 겪는다. \n",
|
||
|
"\n",
|
||
|
"일반적인 세포와 배경을 구분하는 것은 쉽지만 위의 예시처럼 세포가 인접해있는 경우 각각의 인스턴스(instance)를 구분하는 것은 쉽지 않다. \n",
|
||
|
"\n",
|
||
|
"그래서 이 논문에서는 각 인스턴스의 경계와 경계 사이에 반드시 배경이 존재하도록 처리한다.\n",
|
||
|
"\n",
|
||
|
"즉 2개의 세포가 붙어있는 경우라도 둘 사이에 반드시 배경이 인식되어야하는 틈을 만들겠다는 의미이다.\n",
|
||
|
"\n",
|
||
|
"이를 고려하여 손실함수를 재정의 한다. \n",
|
||
|
"\n",
|
||
|
"weight map loss를 의미하는 항을 추가했는데 가장 가까운 세포의 경계까지의 거리와 두번째로 가까운 세포의 경계까지의 거리의 합이 최대가 되도록 하는 손실함수다. \n",
|
||
|
"\n",
|
||
|
"이렇게 하면 모델은 낮은 손실함수를 갖는 방향으로 학습하기 때문에 두 세포 사이의 간격을 넓히는 식, 즉 두 인스턴스 사이의 배경을 넓히는 방향으로 학습하게 된다. \n",
|
||
|
"\n",
|
||
|
"이렇게 하면 세포나 조직이 뭉쳐있는 경우에도 정확하게 인스턴스별로 분할이 가능하다. \n",
|
||
|
"\n",
|
||
|
"이런 의미에서 세포 객체들 사이에 존재하는 배경/틈에 높은 가중치가 부여된 것을 확인할 수 있다. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# 데이터셋\n",
|
||
|
"\n",
|
||
|
"![dataset](./images/unet5.png)\n",
|
||
|
"\n",
|
||
|
"ISBI 2012 EM Segmentation Challenge에 사용된 membrane 데이터셋\n",
|
||
|
"\n",
|
||
|
"왼쪽의 세포 이미지는 512x512(grayscale)이며, 오른쪽은 세포와 세포 사이의 벽(배경)을 분할한 모습이다.\n",
|
||
|
"\n",
|
||
|
"실제 레이블된 값은 세포는 255, 배경은 1로 지정되어 있다."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"dataset/\n",
|
||
|
" train-volumne.tif # 훈련 이미지\n",
|
||
|
" train-labels.tif # 훈련 이미지의 분할 레이블\n",
|
||
|
" test-volumne.tif # 테스트 이미지"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"dataset/\n",
|
||
|
" train/\n",
|
||
|
" input_000.npy ~ input_023.npy\n",
|
||
|
" label_000.npy ~ label_023.npy\n",
|
||
|
" val/\n",
|
||
|
" input_000.npy ~ input_002.npy\n",
|
||
|
" label_000.npy ~ label_002.npy\n",
|
||
|
" test/\n",
|
||
|
" input_000.npy ~ input_002.npy\n",
|
||
|
" label_000.npy ~ label_002.npy"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAEjCAYAAAAYIvrbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d3hc1bU+/E7vvWs00qgXN7kCxjZgDIQW4EJIgFASWkggCZ0EAoTk3hA6hN57v/RusLGxwVW2bElWbyNN773P94d+e30IQ2LANvZl3ufJEzQazZwzc3z22mu9hVMqlUooo4wyyiijjDLK2IfA/aEPoIwyyiijjDLKKOOrKBcoZZRRRhlllFHGPodygVJGGWWUUUYZZexzKBcoZZRRRhlllFHGPodygVJGGWWUUUYZZexzKBcoZZRRRhlllFHGPodygVJGGWWUUUYZZexzKBcoZZRRRhlllFHGPodygVJGGWWUUUYZZexzKBcoZewynnzySXA4HIyMjPzQh1JGGWXsZ/iu949DDz0U06dP363HYrfbcc455+zW1yxj96NcoJSxX6K7uxs33nhjuVgqo4wyyvg/inKBUsYu48wzz0QqlUJ1dfUPfSjo7u7GX//613KBUkYZZZTxfxT8H/oAyth/wOPxwOPxfujDKKOMMsoo40eAcgeljF3GV2fIdrsdxx13HNasWYMFCxZALBajtrYWTz/99Nf+3erVq3HhhRdCp9NBqVTirLPOQigUmvJcDoeDG2+8caf3/vLM+Mknn8TPfvYzAMBhhx0GDocDDoeDTz/9dHefchlllLGH8Oabb+LYY49FRUUFRCIR6urq8Le//Q2FQuFrn79582YsXLgQEokENTU1ePDBB3d6TiaTwQ033ID6+nqIRCLYbDZcddVVyGQye/p0ytgDKHdQyvheGBgYwCmnnIJzzz0XZ599Nh5//HGcc845mDt3LqZNmzbluRdffDHUajVuvPFG9Pb24oEHHsDo6Cg+/fRTcDicXX7PJUuW4Pe//z3uuece/PnPf0ZLSwsA0P+XUUYZ+z6efPJJyOVyXHbZZZDL5VixYgWuv/56RKNR3HrrrVOeGwqFcMwxx+DUU0/FaaedhpdffhkXXXQRhEIhfv3rXwMAisUifvrTn2LNmjW44IIL0NLSgu3bt+POO+9EX18f3njjjR/gLMv4XiiVUcYu4oknnigBKA0PD5dKpVKpurq6BKC0evVqeo7X6y2JRKLS5ZdfvtPfzZ07t5TNZunxW265pQSg9Oabb9JjAEo33HDDTu9dXV1dOvvss+nnV155pQSgtHLlyt12fmWUUcaew1fvH8lkcqfnXHjhhSWpVFpKp9P02CGHHFICULr99tvpsUwmU2praysZjUa6pzzzzDMlLpdb+uyzz6a85oMPPlgCUFq7di099tX7SRn7JsojnjK+F1pbW7F48WL62WAwoKmpCUNDQzs994ILLoBAIKCfL7roIvD5fLz33nt75VjLKKOMfQcSiYT+OxaLwe/3Y/HixUgmk+jp6ZnyXD6fjwsvvJB+FgqFuPDCC+H1erF582YAwCuvvIKWlhY0NzfD7/fT/5YuXQoAWLly5V44qzJ2J8ojnjK+F6qqqnZ6TKPR7MQtAYCGhoYpP8vlclgslrISp4wyfoTo6urCddddhxUrViAajU75XSQSmfJzRUUFZDLZlMcaGxsBACMjIzjwwAPR39+PHTt2wGAwfO37eb3e3Xj0ZewNlAuUMr4XvknVUyqVduv7fBNxrowyytj/EA6Hccghh0CpVOKmm25CXV0dxGIx2tvbcfXVV6NYLH7r1ywWi5gxYwbuuOOOr/29zWb7voddxl5GuUApY6+hv78fhx12GP0cj8fhcrlwzDHH0GMajQbhcHjK32WzWbhcrimPfRtSbRlllLFv4dNPP0UgEMBrr72GJUuW0OPDw8Nf+3yn04lEIjGli9LX1wdgUuEHAHV1dejo6MDhhx9evj/8H0GZg1LGXsPDDz+MXC5HPz/wwAPI5/M4+uij6bG6ujqsXr16p7/7ageF3ai+WsyUUUYZ+z5Y5/XLndZsNov777//a5+fz+fx0EMPTXnuQw89BIPBgLlz5wIATj31VExMTOCRRx7Z6e9TqRQSicTuPIUy9gLKHZQy9hqy2SwOP/xwnHrqqejt7cX999+PRYsW4ac//Sk957zzzsNvfvMbnHzyyTjiiCPQ0dGBDz/8EHq9fsprtbW1gcfj4Z///CcikQhEIhGWLl0Ko9G4t0+rjDLK+JZYuHAhNBoNzj77bPz+978Hh8PBM888842j4YqKCvzzn//EyMgIGhsb8dJLL2Hr1q14+OGHiXh/5pln4uWXX8ZvfvMbrFy5EgcffDAKhQJ6enrw8ssv48MPP8S8efP25mmW8T1R7qCUsddw7733oqWlBddffz2efPJJnHbaaXjzzTentGPPP/98XH311Vi9ejUuv/xyDA8PY/ny5TsR5MxmMx588EF4vV6ce+65OO2009Dd3b23T6mMMsr4DtDpdHjnnXdgsVhw3XXX4bbbbsMRRxyBW2655Wufr9Fo8N5772HTpk248sor4XA4cO+99+L888+n53C5XLzxxhu4+eabsX37dlxxxRX461//io0bN+IPf/gDkWrL2H/AKe1uNmMZZXwFTz75JH71q19h48aN5R1MGWWUUUYZu4RyB6WMMsooo4wyytjnUC5QyiijjDLKKKOMfQ7lAqWMMsooo4wyytjn8IMWKPfddx/sdjvEYjEOOOAAbNiw4Yc8nDL2EM455xyUSqUy/6SM3YLyfaOMMn4c+MEKlJdeegmXXXYZbrjhBrS3t2PWrFk46qijynbEZZRRxjeifN8oo4wfD34wFc8BBxyA+fPn49577wUwaVNss9lwySWX4JprrvkhDqmMMsrYx1G+b5RRxo8HP4hRWzabxebNm/GnP/2JHuNyuVi2bBm++OKLnZ6fyWSQyWTo52KxiGAwCJ1OV7Y0LqOMHwilUgmxWAwVFRXgcvd8M/bb3jeA8r2jjDL2NXyb+8YPUqD4/X4UCgWYTKYpj5tMpp1itgHgH//4B/7617/urcMro4wyvgUcDgcqKyv3+Pt82/sGUL53lFHGvopduW/sF1b3f/rTn3DZZZfRz5FIBFVVVTjllFPA5XKRzWZhMBjg8XggkUigUqlQKBQgFApRKBSQTqeh0WjgcrkgFothsVjQ3NyMbDaLnp4euFwuxONx6HQ6NDU1IRQKwWazoaqqCl6vF8lkEnq9Hnq9Hj09PchkMhCLxdi2bRt6e3uRSCRgs9kQCATQ2NiIbDaLQqGAfD4PrVYLPp+PWCyG2tpauN1uZDIZjIyMoLa2FgKBAFKpFNFoFGazGVqtFmNjYygUClCpVEgkElAoFFCpVNBqtbDb7XC5XJBIJEilUlAqlQiFQohEIrBYLDCbzfB6vYjFYuByuUilUmT3PDg4CJfLBY1GgyOOOAJWqxWxWAxisRgKhQJDQ0Pg8XiQSqWwWq1IJBIYGxuDXq+nzzEQCNDnr1KpEI1GwefzodFoIJfLkUwmEQqF4HQ6IRaLIZfLMTo6ir6+PsRiMahUKhx99NEYGBiAxWKBxWJBJpPB0NAQ7HY7+Hw+gsEgMpkMamtraVGaPn060uk0Vd/ZbBbZbBalUglqtRpisRgAkE6nkUgkUCwWoVKp0N/fj66uLuTzedhsNlgsFqhUKvD5fESjUeTzeZjNZmQyGXR3d6O6uhrA5E5bLBaDx+OR/bbRaCRHW4/HA5VKBYFAgEQiAaVSiWg0ikKhQLlBKpWK/i6fz2Pz5s10PYjFYshkMsRiMQCTO/1QKIRgMAiTyQS5XA4Oh4NcLge32w25XA6NRoN8Po9oNAqTyQQOh4OxsTE0NzcjlUohFAqBy+WiWCxCLpcjEAhg1apViEQimD17NtRqNZLJJHg8HsLhMAKBACwWC4RCIZxOJywWC6qrq+kaYdeY1+tFdXU1JBIJ3G43zGYzPB4PJiYm8D//8z9QKBR750bwHfBN9w6HwwGlUvmdX3doaAiXXHIJPv/8cxSLRVxyySW4/vrrIRQK/+PfnnH
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"## 라이브러리 불러오기\n",
|
||
|
"import os\n",
|
||
|
"import numpy as np\n",
|
||
|
"from PIL import Image\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"## 데이터 불러오기\n",
|
||
|
"dir_data = './dataset' \n",
|
||
|
"\n",
|
||
|
"name_label = 'train-labels.tif'\n",
|
||
|
"name_input = 'train-volume.tif'\n",
|
||
|
"\n",
|
||
|
"img_label = Image.open(os.path.join(dir_data, name_label))\n",
|
||
|
"img_input = Image.open(os.path.join(dir_data, name_input))\n",
|
||
|
"\n",
|
||
|
"ny, nx = img_label.size\n",
|
||
|
"nframe = img_label.n_frames\n",
|
||
|
"\n",
|
||
|
"## train/test/val 폴더 생성\n",
|
||
|
"nframe_train = 24\n",
|
||
|
"nframe_val = 3\n",
|
||
|
"nframe_test = 3\n",
|
||
|
"\n",
|
||
|
"dir_save_train = os.path.join(dir_data, 'train')\n",
|
||
|
"dir_save_val = os.path.join(dir_data, 'val')\n",
|
||
|
"dir_save_test = os.path.join(dir_data, 'test')\n",
|
||
|
"\n",
|
||
|
"if not os.path.exists(dir_save_train):\n",
|
||
|
" os.makedirs(dir_save_train)\n",
|
||
|
"\n",
|
||
|
"if not os.path.exists(dir_save_val):\n",
|
||
|
" os.makedirs(dir_save_val)\n",
|
||
|
"\n",
|
||
|
"if not os.path.exists(dir_save_test):\n",
|
||
|
" os.makedirs(dir_save_test)\n",
|
||
|
"\n",
|
||
|
"## 전체 이미지 30개를 섞어줌\n",
|
||
|
"id_frame = np.arange(nframe)\n",
|
||
|
"np.random.shuffle(id_frame)\n",
|
||
|
"\n",
|
||
|
"## 선택된 train 이미지를 npy 파일로 저장\n",
|
||
|
"offset_nframe = 0\n",
|
||
|
"\n",
|
||
|
"for i in range(nframe_train):\n",
|
||
|
" img_label.seek(id_frame[i + offset_nframe])\n",
|
||
|
" img_input.seek(id_frame[i + offset_nframe])\n",
|
||
|
"\n",
|
||
|
" label_ = np.asarray(img_label)\n",
|
||
|
" input_ = np.asarray(img_input)\n",
|
||
|
"\n",
|
||
|
" np.save(os.path.join(dir_save_train, 'label_%03d.npy' % i), label_)\n",
|
||
|
" np.save(os.path.join(dir_save_train, 'input_%03d.npy' % i), input_)\n",
|
||
|
"\n",
|
||
|
"## 선택된 val 이미지를 npy 파일로 저장\n",
|
||
|
"offset_nframe = nframe_train\n",
|
||
|
"\n",
|
||
|
"for i in range(nframe_val):\n",
|
||
|
" img_label.seek(id_frame[i + offset_nframe])\n",
|
||
|
" img_input.seek(id_frame[i + offset_nframe])\n",
|
||
|
"\n",
|
||
|
" label_ = np.asarray(img_label)\n",
|
||
|
" input_ = np.asarray(img_input)\n",
|
||
|
"\n",
|
||
|
" np.save(os.path.join(dir_save_val, 'label_%03d.npy' % i), label_)\n",
|
||
|
" np.save(os.path.join(dir_save_val, 'input_%03d.npy' % i), input_)\n",
|
||
|
"\n",
|
||
|
"## 선택된 test 이미지를 npy 파일로 저장\n",
|
||
|
"offset_nframe = nframe_train + nframe_val\n",
|
||
|
"\n",
|
||
|
"for i in range(nframe_test):\n",
|
||
|
" img_label.seek(id_frame[i + offset_nframe])\n",
|
||
|
" img_input.seek(id_frame[i + offset_nframe])\n",
|
||
|
"\n",
|
||
|
" label_ = np.asarray(img_label)\n",
|
||
|
" input_ = np.asarray(img_input)\n",
|
||
|
"\n",
|
||
|
" np.save(os.path.join(dir_save_test, 'label_%03d.npy' % i), label_)\n",
|
||
|
" np.save(os.path.join(dir_save_test, 'input_%03d.npy' % i), input_)\n",
|
||
|
"\n",
|
||
|
"## 이미지 시각화\n",
|
||
|
"plt.subplot(122)\n",
|
||
|
"plt.imshow(label_, cmap='gray')\n",
|
||
|
"plt.title('label')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(121)\n",
|
||
|
"plt.imshow(input_, cmap='gray')\n",
|
||
|
"plt.title('input')\n",
|
||
|
"\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnUAAAHWCAYAAAARl3+JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABMRElEQVR4nO39e1xVdd7//z8BY+NpgyfYcBOV7KDmqbCIDo4ml1tlmpisSXOKjDT9QpNQHuhy8FCfDw6Ox9Hk6mqMZtJJ/dwmm9TBCA9UIibKeCi5zNHBLt1oKewkBZH1+6Mfa9yCCokii8f9dlu3ca33a6/9Xmvab5+uo5dhGIYAAADQpHk3dgcAAABw7Qh1AAAAFkCoAwAAsABCHQAAgAUQ6gAAACyAUAcAAGABhDoAAAALINQBAABYAKEOAADAAgh1uKlkZGTIy8tLR44caeyuAMAN8VPHvUGDBql3794N2pdu3brpueeea9B14sYh1AGX8eWXX2rmzJkETABAk0Cow03lmWee0dmzZ9W1a9fG7oq+/PJLzZo1i1AHAGgSWjR2B4CL+fj4yMfHp7G7AQBAk8OROtxULr22pFu3bvr5z3+uzz77TPfdd5/8/Px066236k9/+lOtn8vJydGLL76oDh06yG6369lnn9Xp06c9ar28vDRz5swa333xtSQZGRl68sknJUmDBw+Wl5eXvLy8tGXLlobeZADw8OGHHyo6OlohISGy2Wzq3r27Xn/9dV24cKHW+vz8fD3wwANq2bKlwsLClJ6eXqOmvLxcM2bM0G233SabzabQ0FBNmTJF5eXl13tzcANxpA43va+//lpPPPGE4uLiFBsbq+XLl+u5555TeHi47rrrLo/ahIQEBQQEaObMmSosLNSyZcv0r3/9S1u2bJGXl1edv3PgwIH6zW9+o8WLF+u1115Tz549Jcn8XwC4XjIyMtSmTRslJSWpTZs22rRpk1JSUuR2uzV37lyP2tOnT2vEiBH61a9+pdGjR2v16tWaOHGifH199fzzz0uSqqqq9Itf/EKfffaZxo8fr549e2rv3r1asGCB/ud//kdr165thK3EdWEAN5F33nnHkGQcPnzYMAzD6Nq1qyHJyMnJMWtOnDhh2Gw245VXXqnxufDwcKOiosJcnpaWZkgyPvzwQ3OZJGPGjBk1vrtr165GbGysOb9mzRpDkrF58+YG2z4AuNSl494PP/xQo+bFF180WrVqZZw7d85c9rOf/cyQZMybN89cVl5ebvTv398IDAw0x8I///nPhre3t/Hpp596rDM9Pd2QZHz++efmskvHQTQtnH7FTa9Xr156+OGHzflOnTrpzjvv1D//+c8atePHj9ctt9xizk+cOFEtWrTQhg0bbkhfAeBatWzZ0vzz999/r2+//VYPP/ywfvjhBx04cMCjtkWLFnrxxRfNeV9fX7344os6ceKE8vPzJUlr1qxRz5491aNHD3377bfm9Mgjj0iSNm/efAO2CjcCp19x0+vSpUuNZe3atatxrZwk3X777R7zbdq0UXBwMHewAmgy9u/fr+nTp2vTpk1yu90ebaWlpR7zISEhat26tceyO+64Q5J05MgR3X///Tp48KC++uorderUqdbvO3HiRAP2Ho2JUIeb3uXuhjUMo0G/53IXIQPAjVJSUqKf/exnstvtmj17trp37y4/Pz/t2rVLU6dOVVVVVb3XWVVVpT59+mj+/Pm1toeGhl5rt3GTINTBUg4ePKjBgweb82fOnNHx48c1YsQIc1m7du1UUlLi8bmKigodP37cY1l9bqwAgIawZcsWfffdd/rrX/+qgQMHmssPHz5ca/2xY8dUVlbmcbTuf/7nfyT9eEe/JHXv3l3/+Mc/NGTIEMY1i+OaOljKW2+9pfPnz5vzy5YtU2VlpYYPH24u6969u3Jycmp87tIjddWD5KUBEACul+ozExefiaioqNCbb75Za31lZaX+67/+y6P2v/7rv9SpUyeFh4dLkn71q1/pf//3f/Xf//3fNT5/9uxZlZWVNeQmoBFxpA6WUlFRoSFDhuhXv/qVCgsL9eabb+qhhx7SL37xC7PmhRde0IQJEzRy5Ej9x3/8h/7xj39o48aN6tixo8e6+vfvLx8fH/3ud79TaWmpbDabHnnkEQUGBt7ozQLQTDzwwANq166dYmNj9Zvf/EZeXl7685//fNnLTUJCQvS73/1OR44c0R133KFVq1apoKBAb731lnnT2DPPPKPVq1drwoQJ2rx5sx588EFduHBBBw4c0OrVq7Vx40YNGDDgRm4mrhOO1MFSlixZop49eyolJUUZGRkaPXq0PvzwQ49TDuPGjdPUqVOVk5OjV155RYcPH1ZWVlaNi40dDofS09N14sQJxcXFafTo0fryyy9v9CYBaEY6dOigdevWKTg4WNOnT9fvf/97/cd//IfS0tJqrW/Xrp02bNignTt3avLkyTp69KiWLFmicePGmTXe3t5au3at5syZo7179+rVV1/VrFmz9MUXX+jll182b6xA0+dlNPTV5kAjyMjI0NixY/XFF1/wL04AQLPEkToAAAALINQBAABYAKEOAADAArimDgAAwAI4UgcAAGABhDoAAAALaNYPH66qqtKxY8fUtm1bXp0CNCGGYej7779XSEiIvL35t2ldMN4BTVddx7xmHeqOHTvGi4yBJuzo0aPq3LlzY3ejSWC8A5q+q415zTrUtW3bVtKPO8lutzdybwDUldvtVmhoqPkbxtUx3gFNV13HvGYd6qpPQdjtdgY5oAniNGLdMd4BTd/VxjwuRgEAALAAQh0AAIAFEOoAAAAsgFAHAABgAYQ6AAAACyDUAQAAWAChDgAAwAIIdQAAABZAqAMAALAAQh0AAIAFEOoAAAAsgFAHAABgAYQ6AAAACyDUAQAAWAChDgAAwAIIdQAAABZAqAMAALCAFo3dAcDKuk1bX+faI3Oir2NPAAA3QmOO+xypAwAAsABCHQAAgAUQ6gAAACyAa+qAeqrP9RIAANwoHKkDAACwAEIdAACABRDqAAAALIBQBwAAYAGEOgAAAAsg1AEAAFgAoQ4AAMACCHUAAAAWQKgDAACwAEIdgGYpNTVV9957r9q2bavAwEDFxMSosLDQo+bcuXOKj49Xhw4d1KZNG40cOVLFxcUeNUVFRYqOjlarVq0UGBioyZMnq7Ky0qNmy5Ytuueee2Sz2XTbbbcpIyOjRn+WLl2qbt26yc/PTxEREdqxY0e9+wKgeSPUAWiWtm7dqvj4eG3fvl1ZWVk6f/68hg4dqrKyMrMmMTFRH330kdasWaOtW7fq2LFjevzxx832CxcuKDo6WhUVFdq2bZveffddZWRkKCUlxaw5fPiwoqOjNXjwYBUUFGjSpEl64YUXtHHjRrNm1apVSkpK0owZM7Rr1y7169dPTqdTJ06cqHNfAMDLMAyjsTvRWNxut/z9/VVaWiq73d7Y3UETcb3e/XpkTvR1Wa8VXY/f7smTJxUYGKitW7dq4MCBKi0tVadOnbRy5Uo98cQTkqQDBw6oZ8+eys3N1f3336+///3v+vnPf65jx44pKChIkpSenq6pU6fq5MmT8vX11dSpU7V+/Xrt27fP/K5Ro0appKREmZmZkqSIiAjde++9WrJkiSSpqqpKoaGheumllzRt2rQ69aUx9hmAmurzd0Rdx/26/n7rdaSuLqcrBg0aJC8vL49pwoQJHjU36nQFANRVaWmpJKl9+/aSpPz8fJ0/f15RUVFmTY8ePdSlSxfl5uZKknJzc9WnTx8z0EmS0+mU2+3W/v37zZqL11FdU72OiooK5efne9R4e3srKirKrKlLXy5VXl4ut9vtMQGwtnqFurqcrpCkcePG6fjx4+aUlpZmtt3I0xUAUBdVVVWaNGmSHnzwQfXu3VuS5HK55Ovrq4CAAI/aoKAguVwus+biQFfdXt12pRq3262zZ8/q22+/1YULF2qtuXgdV+vLpVJTU+Xv729OoaGhddwbAJqqeoW6zMxMPffcc7rrrrvUr18/ZWRkqKioSPn5+R51rVq1ksPhMKeLDxV+/PHH+vLLL/Xee++pf//+Gj5
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"## 한 이미지의 분포\n",
|
||
|
"plt.subplot(122)\n",
|
||
|
"plt.hist(label_.flatten(), bins=20)\n",
|
||
|
"plt.title('label')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(121)\n",
|
||
|
"plt.hist(input_.flatten(), bins=20)\n",
|
||
|
"plt.title('input')\n",
|
||
|
"\n",
|
||
|
"plt.tight_layout()\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# UNet Network"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"C:\\Users\\pinb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"## 라이브러리 불러오기\n",
|
||
|
"import os\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"from torch.utils.data import DataLoader\n",
|
||
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
||
|
"\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"from torchvision import transforms, datasets\n",
|
||
|
"\n",
|
||
|
"## 네트워크 구축하기\n",
|
||
|
"class UNet(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(UNet, self).__init__()\n",
|
||
|
"\n",
|
||
|
" # Convolution + BatchNormalization + Relu 정의하기\n",
|
||
|
" def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): \n",
|
||
|
" layers = []\n",
|
||
|
" layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,\n",
|
||
|
" kernel_size=kernel_size, stride=stride, padding=padding,\n",
|
||
|
" bias=bias)]\n",
|
||
|
" layers += [nn.BatchNorm2d(num_features=out_channels)]\n",
|
||
|
" layers += [nn.ReLU()]\n",
|
||
|
"\n",
|
||
|
" cbr = nn.Sequential(*layers)\n",
|
||
|
"\n",
|
||
|
" return cbr\n",
|
||
|
"\n",
|
||
|
" # 수축 경로(Contracting path)\n",
|
||
|
" self.enc1_1 = CBR2d(in_channels=1, out_channels=64)\n",
|
||
|
" self.enc1_2 = CBR2d(in_channels=64, out_channels=64)\n",
|
||
|
"\n",
|
||
|
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc2_1 = CBR2d(in_channels=64, out_channels=128)\n",
|
||
|
" self.enc2_2 = CBR2d(in_channels=128, out_channels=128)\n",
|
||
|
"\n",
|
||
|
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc3_1 = CBR2d(in_channels=128, out_channels=256)\n",
|
||
|
" self.enc3_2 = CBR2d(in_channels=256, out_channels=256)\n",
|
||
|
"\n",
|
||
|
" self.pool3 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc4_1 = CBR2d(in_channels=256, out_channels=512)\n",
|
||
|
" self.enc4_2 = CBR2d(in_channels=512, out_channels=512)\n",
|
||
|
"\n",
|
||
|
" self.pool4 = nn.MaxPool2d(kernel_size=2)\n",
|
||
|
"\n",
|
||
|
" self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)\n",
|
||
|
"\n",
|
||
|
" # 확장 경로(Expansive path)\n",
|
||
|
" self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)\n",
|
||
|
"\n",
|
||
|
" self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,\n",
|
||
|
" kernel_size=2, stride=2, padding=0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512)\n",
|
||
|
" self.dec4_1 = CBR2d(in_channels=512, out_channels=256)\n",
|
||
|
"\n",
|
||
|
" self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,\n",
|
||
|
" kernel_size=2, stride=2, padding=0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256)\n",
|
||
|
" self.dec3_1 = CBR2d(in_channels=256, out_channels=128)\n",
|
||
|
"\n",
|
||
|
" self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,\n",
|
||
|
" kernel_size=2, stride=2, padding=0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128)\n",
|
||
|
" self.dec2_1 = CBR2d(in_channels=128, out_channels=64)\n",
|
||
|
"\n",
|
||
|
" self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,\n",
|
||
|
" kernel_size=2, stride=2, padding=0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64)\n",
|
||
|
" self.dec1_1 = CBR2d(in_channels=64, out_channels=64)\n",
|
||
|
"\n",
|
||
|
" self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)\n",
|
||
|
" \n",
|
||
|
" # forward 함수 정의하기\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" enc1_1 = self.enc1_1(x)\n",
|
||
|
" enc1_2 = self.enc1_2(enc1_1)\n",
|
||
|
" pool1 = self.pool1(enc1_2)\n",
|
||
|
"\n",
|
||
|
" enc2_1 = self.enc2_1(pool1)\n",
|
||
|
" enc2_2 = self.enc2_2(enc2_1)\n",
|
||
|
" pool2 = self.pool2(enc2_2)\n",
|
||
|
"\n",
|
||
|
" enc3_1 = self.enc3_1(pool2)\n",
|
||
|
" enc3_2 = self.enc3_2(enc3_1)\n",
|
||
|
" pool3 = self.pool3(enc3_2)\n",
|
||
|
"\n",
|
||
|
" enc4_1 = self.enc4_1(pool3)\n",
|
||
|
" enc4_2 = self.enc4_2(enc4_1)\n",
|
||
|
" pool4 = self.pool4(enc4_2)\n",
|
||
|
"\n",
|
||
|
" enc5_1 = self.enc5_1(pool4)\n",
|
||
|
"\n",
|
||
|
" dec5_1 = self.dec5_1(enc5_1)\n",
|
||
|
"\n",
|
||
|
" unpool4 = self.unpool4(dec5_1)\n",
|
||
|
" cat4 = torch.cat((unpool4, enc4_2), dim=1)\n",
|
||
|
" dec4_2 = self.dec4_2(cat4)\n",
|
||
|
" dec4_1 = self.dec4_1(dec4_2)\n",
|
||
|
"\n",
|
||
|
" unpool3 = self.unpool3(dec4_1)\n",
|
||
|
" cat3 = torch.cat((unpool3, enc3_2), dim=1)\n",
|
||
|
" dec3_2 = self.dec3_2(cat3)\n",
|
||
|
" dec3_1 = self.dec3_1(dec3_2)\n",
|
||
|
"\n",
|
||
|
" unpool2 = self.unpool2(dec3_1)\n",
|
||
|
" cat2 = torch.cat((unpool2, enc2_2), dim=1)\n",
|
||
|
" dec2_2 = self.dec2_2(cat2)\n",
|
||
|
" dec2_1 = self.dec2_1(dec2_2)\n",
|
||
|
"\n",
|
||
|
" unpool1 = self.unpool1(dec2_1)\n",
|
||
|
" cat1 = torch.cat((unpool1, enc1_2), dim=1)\n",
|
||
|
" dec1_2 = self.dec1_2(cat1)\n",
|
||
|
" dec1_1 = self.dec1_1(dec1_2)\n",
|
||
|
"\n",
|
||
|
" x = self.fc(dec1_1)\n",
|
||
|
"\n",
|
||
|
" return x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Data Loader"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 데이터 로더를 구현하기\n",
|
||
|
"class Dataset(torch.utils.data.Dataset):\n",
|
||
|
" def __init__(self, data_dir, transform=None):\n",
|
||
|
" self.data_dir = data_dir\n",
|
||
|
" self.transform = transform\n",
|
||
|
"\n",
|
||
|
" lst_data = os.listdir(self.data_dir)\n",
|
||
|
"\n",
|
||
|
" lst_label = [f for f in lst_data if f.startswith('label')]\n",
|
||
|
" lst_input = [f for f in lst_data if f.startswith('input')]\n",
|
||
|
"\n",
|
||
|
" lst_label.sort()\n",
|
||
|
" lst_input.sort()\n",
|
||
|
"\n",
|
||
|
" self.lst_label = lst_label\n",
|
||
|
" self.lst_input = lst_input\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return len(self.lst_label)\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, index):\n",
|
||
|
" label = np.load(os.path.join(self.data_dir, self.lst_label[index]))\n",
|
||
|
" input = np.load(os.path.join(self.data_dir, self.lst_input[index]))\n",
|
||
|
"\n",
|
||
|
" # 정규화\n",
|
||
|
" label = label/255.0\n",
|
||
|
" input = input/255.0\n",
|
||
|
"\n",
|
||
|
" # 이미지와 레이블의 차원 = 2일 경우(채널이 없을 경우, 흑백 이미지), 새로운 채널(축) 생성\n",
|
||
|
" if label.ndim == 2:\n",
|
||
|
" label = label[:, :, np.newaxis]\n",
|
||
|
" if input.ndim == 2:\n",
|
||
|
" input = input[:, :, np.newaxis]\n",
|
||
|
"\n",
|
||
|
" data = {'input': input, 'label': label}\n",
|
||
|
"\n",
|
||
|
" # transform이 정의되어 있다면 transform을 거친 데이터를 불러옴\n",
|
||
|
" if self.transform:\n",
|
||
|
" data = self.transform(data)\n",
|
||
|
"\n",
|
||
|
" return data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAEjCAYAAAAYIvrbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOx9d5SU5fn2Nb33nbq9L0tZioIIIqLGHkussYtGTQwao4nGKNYoJhos2GKiEX/G3nsBpUiRzu6yfWd3Zqf33uf7Y89zf6ygAiJKMtc5HtnZ2Xfeae9zP/d9FU6xWCyihBJKKKGEEkoo4ScE7o99AiWUUEIJJZRQQglfR6lAKaGEEkoooYQSfnIoFSgllFBCCSWUUMJPDqUCpYQSSiihhBJK+MmhVKCUUEIJJZRQQgk/OZQKlBJKKKGEEkoo4SeHUoFSQgkllFBCCSX85FAqUEoooYQSSiihhJ8cSgVKCSWUUEIJJZTwk0OpQClhj/Hss8+Cw+HAarX+2KdSQgklHGTY1+vH3LlzMWHChP16LjU1Nbjkkkv26zFL2P8oFSglHJTo7OzE7bffXiqWSiihhBL+S1EqUErYY1x44YVIJpOorq7+sU8FnZ2duOOOO0oFSgkllFDCfyn4P/YJlHDwgMfjgcfj/dinUUIJJZRQwv8ASh2UEvYYX58h19TU4OSTT8aqVaswffp0iMVi1NXV4bnnntvt361YsQJXXnkldDodlEolLrroIgSDwTH35XA4uP3223d57J1nxs8++yzOOussAMBRRx0FDocDDoeDzz//fH8/5RJKKOEHwltvvYWTTjoJFosFIpEI9fX1uOuuu5DP53d7/40bN+Lwww+HRCJBbW0tnnjiiV3uk06nsXDhQjQ0NEAkEqGyshJ/+MMfkE6nf+inU8IPgFIHpYTvhb6+Ppx55pmYP38+Lr74YvzrX//CJZdcgmnTpmH8+PFj7nvNNddArVbj9ttvR3d3Nx5//HEMDQ3h888/B4fD2ePHnDNnDhYsWICHH34Yf/rTnzBu3DgAoP+XUEIJP308++yzkMvluP766yGXy7Fs2TLcdtttiEQi+Otf/zrmvsFgECeeeCLOPvtsnHfeeXj55Zdx9dVXQygU4rLLLgMAFAoF/PznP8eqVavwq1/9CuPGjcP27dvx97//HT09PXjzzTd/hGdZwvdCsYQS9hDPPPNMEUBxcHCwWCwWi9XV1UUAxRUrVtB9PB5PUSQSFX//+9/v8nfTpk0rZjIZuv3+++8vAii+9dZbdBuA4sKFC3d57Orq6uLFF19MP7/yyitFAMXly5fvt+dXQgkl/HD4+vUjkUjscp8rr7yyKJVKi6lUim478sgjiwCKDzzwAN2WTqeLkydPLhoMBrqmLF26tMjlcosrV64cc8wnnniiCKC4evVquu3r15MSfpoojXhK+F5obW3FEUccQT/r9Xo0NzdjYGBgl/v+6le/gkAgoJ+vvvpq8Pl8vP/++wfkXEsooYSfDiQSCf07Go3C5/PhiCOOQCKRQFdX15j78vl8XHnllfSzUCjElVdeCY/Hg40bNwIAXnnlFYwbNw4tLS3w+Xz037x58wAAy5cvPwDPqoT9idKIp4Tvhaqqql1u02g0u3BLAKCxsXHMz3K5HGazuaTEKaGE/0F0dHTgz3/+M5YtW4ZIJDLmd+FweMzPFosFMplszG1NTU0AAKvVisMOOwy9vb3YsWMH9Hr9bh/P4/Hsx7Mv4UCgVKCU8L3wTaqeYrG4Xx/nm4hzJZRQwsGHUCiEI488EkqlEnfeeSfq6+shFouxadMm/PGPf0ShUNjrYxYKBUycOBEPPvjgbn9fWVn5fU+7hAOMUoFSwgFDb28vjjrqKPo5FovB6XTixBNPpNs0Gg1CodCYv8tkMnA6nWNu2xtSbQkllPDTwueffw6/34/XX38dc+bModsHBwd3e3+Hw4F4PD6mi9LT0wNgVOEHAPX19di6dSuOPvro0vXhvwQlDkoJBwxPPfUUstks/fz4448jl8vhhBNOoNvq6+uxYsWKXf7u6x0UdqH6ejFTQgkl/PTBOq87d1ozmQwee+yx3d4/l8vhySefHHPfJ598Enq9HtOmTQMAnH322RgZGcE//vGPXf4+mUwiHo/vz6dQwgFAqYNSwgFDJpPB0UcfjbPPPhvd3d147LHHMHv2bPz85z+n+1x++eW46qqr8Itf/ALHHnsstm7dio8++ghlZWVjjjV58mTweDwsWrQI4XAYIpEI8+bNg8FgONBPq4QSSthLHH744dBoNLj44ouxYMECcDgcLF269BtHwxaLBYsWLYLVakVTUxNeeuklbNmyBU899RQR7y+88EK8/PLLuOqqq7B8+XLMmjUL+XweXV1dePnll/HRRx/hkEMOOZBPs4TviVIHpYQDhkcffRTjxo3DbbfdhmeffRbnnXce3nrrrTHt2CuuuAJ//OMfsWLFCvz+97/H4OAgPvnkk10IciaTCU888QQ8Hg/mz5+P8847D52dnQf6KZVQQgn7AJ1Oh3fffRdmsxl//vOf8be//Q3HHnss7r///t3eX6PR4P3338eGDRtw4403wmaz4dFHH8UVV1xB9+FyuXjzzTdx3333Yfv27bjhhhtwxx134KuvvsK1115LpNoSDh5wivubzVhCCV/Ds88+i0svvRRfffVVaQdTQgkllFDCHqHUQSmhhBJKKKGEEn5yKBUoJZRQQgkllFDCTw6lAqWEEkoooYQSSvjJ4UctUJYsWYKamhqIxWLMmDED69ev/zFPp4QfCJdccgmKxWKJf1LCfkHpulFCCf8b+NEKlJdeegnXX389Fi5ciE2bNqGtrQ3HHXdcyY64hBJK+EaUrhsllPC/gx9NxTNjxgwceuihePTRRwGM2hRXVlbit7/9LW666aYf45RKKKGEnzhK140SSvjfwY9i1JbJZLBx40bcfPPNdBuXy8UxxxyDNWvW7HL/dDqNdDpNPxcKBQQCAeh0upKlcQkl/EgoFouIRqOwWCzgcn/4ZuzeXjeA0rWjhBJ+atib68aPUqD4fD7k83kYjcYxtxuNxl1itgHg3nvvxR133HGgTq+EEkrYC9hsNlRUVPzgj7O31w2gdO0ooYSfKvbkunFQWN3ffPPNuP766+nncDiMqqoqvPTSS0in0ygUClAqlUgkEpDL5RCJRAgEAuBwOIhGo+ByuRCJROByueDxeODz+QiHw+Dz+eDxeFCr1Uin08jn80ilUvB6vRAKhcjlcrDb7ejs7IRUKoXL5UI+n0cikYDZbEYymUR5eTlaW1tRLBZRW1uLYDCIQCAAh8OB4eFhJBIJqFQqzJo1CzqdDlqtFul0GoFAANFolP6vVquRTCbpXMPhMAYGBuD1eiGRSJBOp9HY2Ij29nZIpVJEo1HMmDEDU6ZMgUgkglqtxvDwMGKxGHK5HGpra5FIJBAIBKBSqWAwGCAWiyGVShEIBOByuVAoFKBQKOD1epFIJDBjxgyIxWIIhUI4HA4Eg0HE43FEIhGIxWIEAgGsWbMGyWQSc+fOhcFgwPbt28HlcqFSqTB+/Hh4vV6EQiEolUpUV1cjGAxiYGAAEokEMpkM+XwePB4Pg4ODMBgMqKqqwsaNG9HX14d0Og2TyYSpU6eCw+GgUCigUCjgiy++AI/Hg9lshkwmg0qlgsfjQTAYRCqVgkQigdFoRCQSQSAQgFgsRjabRaFQQHl5Obq7uxEOhzF16lT4fD74/X7k83nkcjlMmTIFCoUCTqcTkUgEPB4P0WgUSqUSxWIRCoUCRqMRAoEAJpMJGo0GfD4f8XgcYrEYarUauVwOcrkcQ0ND0Gq1iEajyOfzFPseDAbR29sLqVQKn88HsVgMhUKByspKyGQyWCwWhEIhJJNJJBIJZLNZiMVieL1ecLlclJWVIR6Pw+PxgMPhoLy8HOl0GhwOB++++y69/vX19WhqaqL78ng8iEQibNiwAYFAAHq9HhMmTEA0GkUikYBer4dIJEJ/fz9yuRx0Oh3Kysogk8mQTCbh9/tRKBRgMBig1+vR1dWFzs5OcDgcaLVaKJVK/P3vf4dCofixLg3fiW+6dth
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# 데이터로더 잘 구현되었는지 확인\n",
|
||
|
"dataset_train = Dataset(data_dir=dir_save_train)\n",
|
||
|
"data = dataset_train.__getitem__(0) # 한 이미지 불러오기\n",
|
||
|
"input = data['input']\n",
|
||
|
"label = data['label']\n",
|
||
|
"\n",
|
||
|
"# 불러온 이미지 시각화\n",
|
||
|
"plt.subplot(122)\n",
|
||
|
"plt.imshow(label.reshape(512,512), cmap='gray')\n",
|
||
|
"plt.title('label')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(121)\n",
|
||
|
"plt.imshow(input.reshape(512,512), cmap='gray')\n",
|
||
|
"plt.title('input')\n",
|
||
|
"\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Transform"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 트렌스폼 구현하기\n",
|
||
|
"class ToTensor(object):\n",
|
||
|
" def __call__(self, data):\n",
|
||
|
" label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
" label = label.transpose((2, 0, 1)).astype(np.float32)\n",
|
||
|
" input = input.transpose((2, 0, 1)).astype(np.float32)\n",
|
||
|
"\n",
|
||
|
" data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}\n",
|
||
|
"\n",
|
||
|
" return data\n",
|
||
|
"\n",
|
||
|
"class Normalization(object):\n",
|
||
|
" def __init__(self, mean=0.5, std=0.5):\n",
|
||
|
" self.mean = mean\n",
|
||
|
" self.std = std\n",
|
||
|
"\n",
|
||
|
" def __call__(self, data):\n",
|
||
|
" label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
" input = (input - self.mean) / self.std\n",
|
||
|
"\n",
|
||
|
" data = {'label': label, 'input': input}\n",
|
||
|
"\n",
|
||
|
" return data\n",
|
||
|
"\n",
|
||
|
"class RandomFlip(object):\n",
|
||
|
" def __call__(self, data):\n",
|
||
|
" label, input = data['label'], data['input']\n",
|
||
|
"\n",
|
||
|
" if np.random.rand() > 0.5:\n",
|
||
|
" label = np.fliplr(label)\n",
|
||
|
" input = np.fliplr(input)\n",
|
||
|
"\n",
|
||
|
" if np.random.rand() > 0.5:\n",
|
||
|
" label = np.flipud(label)\n",
|
||
|
" input = np.flipud(input)\n",
|
||
|
"\n",
|
||
|
" data = {'label': label, 'input': input}\n",
|
||
|
"\n",
|
||
|
" return data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnUAAAHWCAYAAAARl3+JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABXpklEQVR4nO3de1QTd94/8HdAE/CSICIEjoh4qYgiKFaMVasrS0C2la3deluLFrV6oK3QeqEPi6h9flisF1pRtmsVu9X1sqfFLrgookgtEStKVRQetVjso0GrQhQVBOb3Rx9mHUEkCiLD+3XOnDrz/WTy/U6c8d3JzEQhCIIAIiIiImrVLFq6A0RERET09BjqiIiIiGSAoY6IiIhIBhjqiIiIiGSAoY6IiIhIBhjqiIiIiGSAoY6IiIhIBhjqiIiIiGSAoY6IiIhIBhjq6LmQlJQEhUKBixcvtnRXiIia1ZMe78aMGYOBAwc2aV969uyJGTNmNOk6qeUw1BE95MyZM4iJiWHAJCKiVoWhjp4L06dPx927d+Hi4tLSXcGZM2ewdOlShjoiImpV2rV0B4gAwNLSEpaWli3dDSIiolaLZ+roufDwNSY9e/bEH/7wBxw+fBjDhg2DlZUVevXqhS+//LLe12VlZeHtt99G165doVar8eabb+LmzZuSWoVCgZiYmDrv/eA1JUlJSfjTn/4EABg7diwUCgUUCgUyMzObeshERACA3bt3IzAwEE5OTlCpVOjduzeWL1+O6urqeutzc3MxYsQIWFtbw9XVFYmJiXVqKioqsGTJEvTp0wcqlQrOzs5YuHAhKioqmns41IJ4po6eW+fPn8frr7+OkJAQBAcHY9OmTZgxYwa8vb0xYMAASW1YWBhsbGwQExODwsJCbNiwAT///DMyMzOhUCga/Z6jR4/Gu+++i08//RQffvgh+vfvDwDif4mImlpSUhI6deqEiIgIdOrUCQcOHEB0dDRMJhNWrlwpqb158ybGjx+PN954A1OmTMHOnTsxb948KJVKvPXWWwCAmpoavPrqqzh8+DDmzJmD/v3749SpU1izZg3+53/+B8nJyS0wSnomBKLnwObNmwUAQlFRkSAIguDi4iIAELKyssSaq1evCiqVSnj//ffrvM7b21uorKwUl8fFxQkAhN27d4vLAAhLliyp894uLi5CcHCwOL9r1y4BgHDw4MEmGx8RUa2Hj3d37typU/P2228LHTp0EO7duycue/nllwUAwqpVq8RlFRUVgpeXl2Bvby8eA//+978LFhYWwnfffSdZZ2JiogBA+P7778VlDx//qHXj16/03HJ3d8eoUaPE+W7duqFfv3746aef6tTOmTMH7du3F+fnzZuHdu3aYc+ePc+kr0RET8ra2lr8861bt/Drr79i1KhRuHPnDgoKCiS17dq1w9tvvy3OK5VKvP3227h69Spyc3MBALt27UL//v3h5uaGX3/9VZx+97vfAQAOHjz4DEZFLYFfv9Jzq0ePHnWWdenSpc61cgDQt29fyXynTp3g6OjIO1iJ6LmXn5+PqKgoHDhwACaTSdJWVlYmmXdyckLHjh0ly1544QUAwMWLFzF8+HCcO3cOZ8+eRbdu3ep9v6tXrzZh7+l5wlBHz61H3Q0rCEKTvs+jLkYmImpupaWlePnll6FWq7Fs2TL07t0bVlZWOH78OBYtWoSamhqz11lTUwMPDw+sXr263nZnZ+en7TY9pxjqSBbOnTuHsWPHivO3b9/GlStXMH78eHFZly5dUFpaKnldZWUlrly5Illmzo0VRERPIzMzE9evX8fXX3+N0aNHi8uLiorqrb98+TLKy8slZ+v+53/+B8Bvd/IDQO/evfHjjz9i3LhxPJ61MbymjmTh888/x/3798X5DRs2oKqqCgEBAeKy3r17Iysrq87rHj5TV3uwfDgAEhE1tdpvJB78BqKyshLr16+vt76qqgp//etfJbV//etf0a1bN3h7ewMA3njjDfzv//4v/va3v9V5/d27d1FeXt6UQ6DnCM/UkSxUVlZi3LhxeOONN1BYWIj169dj5MiRePXVV8WaWbNmYe7cuZg4cSJ+//vf48cff8TevXthZ2cnWZeXlxcsLS3x8ccfo6ysDCqVCr/73e9gb2//rIdFRDI3YsQIdOnSBcHBwXj33XehUCjw97///ZGXmTg5OeHjjz/GxYsX8cILL2DHjh3Iy8vD559/Lt4sNn36dOzcuRNz587FwYMH8dJLL6G6uhoFBQXYuXMn9u7di6FDhz7LYdIzwjN1JAvr1q1D//79ER0djaSkJEyZMgW7d++WfPUwe/ZsLFq0CFlZWXj//fdRVFSE9PT0Ohcda7VaJCYm4urVqwgJCcGUKVNw5syZZz0kImoDunbtipSUFDg6OiIqKgqffPIJfv/73yMuLq7e+i5dumDPnj04duwYFixYgEuXLmHdunWYPXu2WGNhYYHk5GSsWLECp06dwgcffIClS5fihx9+wHvvvSfeWEHyoxCa+qpzomcoKSkJM2fOxA8//MD/8yQiojaNZ+qIiIiIZIChjoiIiEgGGOqIiIiIZIDX1BERERHJAM/UEREREcmAWaFuw4YNGDRoENRqNdRqNXQ6Hf7973+L7ffu3UNoaCi6du2KTp06YeLEiSgpKZGso7i4GIGBgejQoQPs7e2xYMECVFVVSWoyMzMxZMgQqFQq9OnTB0lJSXX6kpCQgJ49e8LKygo+Pj44evSoOUMhIiIikhWzHj7cvXt3rFixAn379oUgCNiyZQsmTJiAEydOYMCAAQgPD0dqaip27doFjUaDsLAwvPbaa/j+++8B/PYbm4GBgdBqtcjOzsaVK1fw5ptvon379vh//+//Afjtp1ECAwMxd+5cbN26FRkZGZg1axYcHR2h1+sBADt27EBERAQSExPh4+ODtWvXQq/Xo7Cw0KwHxNbU1ODy5cvo3Lkzf0qF6DknCAJu3boFJycnWFjwS4bG4nGOqHV5qmOd8JS6dOkibNy4USgtLRXat28v7Nq1S2w7e/asAEAwGAyCIAjCnj17BAsLC8FoNIo1GzZsENRqtVBRUSEIgiAsXLhQGDBggOQ9Jk2aJOj1enF+2LBhQmhoqDhfXV0tODk5CbGxsWb1/dKlSwIATpw4taLp0qVLZu3nbR2Pc5w4tc7pSY51T/wzYdXV1di1axfKy8uh0+mQm5uL+/fvw9fXV6xxc3NDjx49YDAYMHz4cBgMBnh4eMDBwUGs0ev1mDdvHvLz8zF48GAYDAbJOmpr5s+fD+C3n4PKzc1FZGSk2G5hYQFfX18YDAazxtC5c2cAwKVLl6BWq83dBET0DJlMJjg7O4v7LTUOj3NErcvTHOvMDnWnTp2CTqfDvXv30KlTJ3zzzTdwd3dHXl4elEolbGxsJPUODg4wGo0AAKPRKAl0te21bQ3VmEwm3L17Fzdv3kR1dXW9NQUFBQ32vaKiAhUVFeL8rVu3AEC8RpCInn/8CtE8tduLxzmi1uVJjnVmX5jSr18/5OXlIScnB/PmzUNwcHCr+V3M2NhYaDQacXJ2dm7pLhERERE1CbNDnVKpRJ8+feDt7Y3Y2Fh4enoiPj4eWq0WlZWVKC0tldSXlJRAq9UC+O2H0h++G7Z2/nE1arUa1tbWsLOzg6WlZb01tet4lMjISJSVlYnTpUuXzB0+ERER0XPpqW8hq6mpQUVFBby9vdG+fXtkZGSIbYWFhSguLoZOpwMA6HQ6nDp1ClevXhVr0tPToVar4e7uLtY8uI7amtp1KJVKeHt7S2pqamqQkZEh1jyKSqUSv4LgVxFEREQkJ2ZdUxcZGYmAgAD06NEDt27dwrZt25CZmYm9e/dCo9EgJCQEERERsLW1hVqtxjvvvAOdTofhw4cDAPz8/ODu7o7p06cjLi4ORqMRUVFRCA0NhUqlAgDMnTsX69atw8KFC/HWW2/hwIED2LlzJ1JTU8V+REREIDg4GEOHDsWwYcOwdu1alJeXY+bMmU24aYiIiIhaD7NC3dWrV/Hmm2/iypUr0Gg0GDRoEPbu3Yvf//73AIA1a9bAwsICEydOREVFBfR6PdavXy+
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# 트랜스폼 잘 구현되었는지 확인\n",
|
||
|
"transform = transforms.Compose([Normalization(mean=0.5, std=0.5), RandomFlip(), ToTensor()])\n",
|
||
|
"dataset_train = Dataset(data_dir=dir_save_train, transform=transform)\n",
|
||
|
"data = dataset_train.__getitem__(0) # 한 이미지 불러오기\n",
|
||
|
"input = data['input']\n",
|
||
|
"label = data['label']\n",
|
||
|
"\n",
|
||
|
"# 불러온 이미지 시각화\n",
|
||
|
"plt.subplot(122)\n",
|
||
|
"plt.hist(label.flatten(), bins=20)\n",
|
||
|
"plt.title('label')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(121)\n",
|
||
|
"plt.hist(input.flatten(), bins=20)\n",
|
||
|
"plt.title('input')\n",
|
||
|
"\n",
|
||
|
"plt.tight_layout()\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Model Save / Load"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"## 네트워크 저장하기\n",
|
||
|
"def save(ckpt_dir, net, optim, epoch):\n",
|
||
|
" if not os.path.exists(ckpt_dir):\n",
|
||
|
" os.makedirs(ckpt_dir)\n",
|
||
|
"\n",
|
||
|
" torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},\n",
|
||
|
" \"%s/model_epoch%d.pth\" % (ckpt_dir, epoch))\n",
|
||
|
"\n",
|
||
|
"## 네트워크 불러오기\n",
|
||
|
"def load(ckpt_dir, net, optim):\n",
|
||
|
" if not os.path.exists(ckpt_dir):\n",
|
||
|
" epoch = 0\n",
|
||
|
" return net, optim, epoch\n",
|
||
|
"\n",
|
||
|
" ckpt_lst = os.listdir(ckpt_dir)\n",
|
||
|
" ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))\n",
|
||
|
"\n",
|
||
|
" dict_model = torch.load('%s/%s' % (ckpt_dir, ckpt_lst[-1]))\n",
|
||
|
"\n",
|
||
|
" net.load_state_dict(dict_model['net'])\n",
|
||
|
" optim.load_state_dict(dict_model['optim'])\n",
|
||
|
" epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])\n",
|
||
|
"\n",
|
||
|
" return net, optim, epoch"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Train"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"TRAIN: EPOCH 0001 / 0020 | BATCH 0001 / 0006 | LOSS 0.6337\n",
|
||
|
"TRAIN: EPOCH 0001 / 0020 | BATCH 0002 / 0006 | LOSS 0.5923\n",
|
||
|
"TRAIN: EPOCH 0001 / 0020 | BATCH 0003 / 0006 | LOSS 0.5694\n",
|
||
|
"TRAIN: EPOCH 0001 / 0020 | BATCH 0004 / 0006 | LOSS 0.5456\n",
|
||
|
"TRAIN: EPOCH 0001 / 0020 | BATCH 0005 / 0006 | LOSS 0.5214\n",
|
||
|
"TRAIN: EPOCH 0001 / 0020 | BATCH 0006 / 0006 | LOSS 0.5036\n",
|
||
|
"VALID: EPOCH 0001 / 0020 | BATCH 0001 / 0001 | LOSS 0.6128\n",
|
||
|
"TRAIN: EPOCH 0002 / 0020 | BATCH 0001 / 0006 | LOSS 0.4027\n",
|
||
|
"TRAIN: EPOCH 0002 / 0020 | BATCH 0002 / 0006 | LOSS 0.3897\n",
|
||
|
"TRAIN: EPOCH 0002 / 0020 | BATCH 0003 / 0006 | LOSS 0.3905\n",
|
||
|
"TRAIN: EPOCH 0002 / 0020 | BATCH 0004 / 0006 | LOSS 0.3856\n",
|
||
|
"TRAIN: EPOCH 0002 / 0020 | BATCH 0005 / 0006 | LOSS 0.3807\n",
|
||
|
"TRAIN: EPOCH 0002 / 0020 | BATCH 0006 / 0006 | LOSS 0.3745\n",
|
||
|
"VALID: EPOCH 0002 / 0020 | BATCH 0001 / 0001 | LOSS 0.5134\n",
|
||
|
"TRAIN: EPOCH 0003 / 0020 | BATCH 0001 / 0006 | LOSS 0.3422\n",
|
||
|
"TRAIN: EPOCH 0003 / 0020 | BATCH 0002 / 0006 | LOSS 0.3350\n",
|
||
|
"TRAIN: EPOCH 0003 / 0020 | BATCH 0003 / 0006 | LOSS 0.3372\n",
|
||
|
"TRAIN: EPOCH 0003 / 0020 | BATCH 0004 / 0006 | LOSS 0.3337\n",
|
||
|
"TRAIN: EPOCH 0003 / 0020 | BATCH 0005 / 0006 | LOSS 0.3293\n",
|
||
|
"TRAIN: EPOCH 0003 / 0020 | BATCH 0006 / 0006 | LOSS 0.3286\n",
|
||
|
"VALID: EPOCH 0003 / 0020 | BATCH 0001 / 0001 | LOSS 0.4308\n",
|
||
|
"TRAIN: EPOCH 0004 / 0020 | BATCH 0001 / 0006 | LOSS 0.3094\n",
|
||
|
"TRAIN: EPOCH 0004 / 0020 | BATCH 0002 / 0006 | LOSS 0.3079\n",
|
||
|
"TRAIN: EPOCH 0004 / 0020 | BATCH 0003 / 0006 | LOSS 0.3090\n",
|
||
|
"TRAIN: EPOCH 0004 / 0020 | BATCH 0004 / 0006 | LOSS 0.3078\n",
|
||
|
"TRAIN: EPOCH 0004 / 0020 | BATCH 0005 / 0006 | LOSS 0.3049\n",
|
||
|
"TRAIN: EPOCH 0004 / 0020 | BATCH 0006 / 0006 | LOSS 0.3003\n",
|
||
|
"VALID: EPOCH 0004 / 0020 | BATCH 0001 / 0001 | LOSS 0.3922\n",
|
||
|
"TRAIN: EPOCH 0005 / 0020 | BATCH 0001 / 0006 | LOSS 0.2909\n",
|
||
|
"TRAIN: EPOCH 0005 / 0020 | BATCH 0002 / 0006 | LOSS 0.2870\n",
|
||
|
"TRAIN: EPOCH 0005 / 0020 | BATCH 0003 / 0006 | LOSS 0.2874\n",
|
||
|
"TRAIN: EPOCH 0005 / 0020 | BATCH 0004 / 0006 | LOSS 0.2851\n",
|
||
|
"TRAIN: EPOCH 0005 / 0020 | BATCH 0005 / 0006 | LOSS 0.2821\n",
|
||
|
"TRAIN: EPOCH 0005 / 0020 | BATCH 0006 / 0006 | LOSS 0.2812\n",
|
||
|
"VALID: EPOCH 0005 / 0020 | BATCH 0001 / 0001 | LOSS 0.3173\n",
|
||
|
"TRAIN: EPOCH 0006 / 0020 | BATCH 0001 / 0006 | LOSS 0.2671\n",
|
||
|
"TRAIN: EPOCH 0006 / 0020 | BATCH 0002 / 0006 | LOSS 0.2769\n",
|
||
|
"TRAIN: EPOCH 0006 / 0020 | BATCH 0003 / 0006 | LOSS 0.2720\n",
|
||
|
"TRAIN: EPOCH 0006 / 0020 | BATCH 0004 / 0006 | LOSS 0.2758\n",
|
||
|
"TRAIN: EPOCH 0006 / 0020 | BATCH 0005 / 0006 | LOSS 0.2745\n",
|
||
|
"TRAIN: EPOCH 0006 / 0020 | BATCH 0006 / 0006 | LOSS 0.2750\n",
|
||
|
"VALID: EPOCH 0006 / 0020 | BATCH 0001 / 0001 | LOSS 0.2644\n",
|
||
|
"TRAIN: EPOCH 0007 / 0020 | BATCH 0001 / 0006 | LOSS 0.2504\n",
|
||
|
"TRAIN: EPOCH 0007 / 0020 | BATCH 0002 / 0006 | LOSS 0.2509\n",
|
||
|
"TRAIN: EPOCH 0007 / 0020 | BATCH 0003 / 0006 | LOSS 0.2562\n",
|
||
|
"TRAIN: EPOCH 0007 / 0020 | BATCH 0004 / 0006 | LOSS 0.2540\n",
|
||
|
"TRAIN: EPOCH 0007 / 0020 | BATCH 0005 / 0006 | LOSS 0.2571\n",
|
||
|
"TRAIN: EPOCH 0007 / 0020 | BATCH 0006 / 0006 | LOSS 0.2563\n",
|
||
|
"VALID: EPOCH 0007 / 0020 | BATCH 0001 / 0001 | LOSS 0.3108\n",
|
||
|
"TRAIN: EPOCH 0008 / 0020 | BATCH 0001 / 0006 | LOSS 0.2595\n",
|
||
|
"TRAIN: EPOCH 0008 / 0020 | BATCH 0002 / 0006 | LOSS 0.2523\n",
|
||
|
"TRAIN: EPOCH 0008 / 0020 | BATCH 0003 / 0006 | LOSS 0.2517\n",
|
||
|
"TRAIN: EPOCH 0008 / 0020 | BATCH 0004 / 0006 | LOSS 0.2495\n",
|
||
|
"TRAIN: EPOCH 0008 / 0020 | BATCH 0005 / 0006 | LOSS 0.2467\n",
|
||
|
"TRAIN: EPOCH 0008 / 0020 | BATCH 0006 / 0006 | LOSS 0.2460\n",
|
||
|
"VALID: EPOCH 0008 / 0020 | BATCH 0001 / 0001 | LOSS 0.2661\n",
|
||
|
"TRAIN: EPOCH 0009 / 0020 | BATCH 0001 / 0006 | LOSS 0.2405\n",
|
||
|
"TRAIN: EPOCH 0009 / 0020 | BATCH 0002 / 0006 | LOSS 0.2372\n",
|
||
|
"TRAIN: EPOCH 0009 / 0020 | BATCH 0003 / 0006 | LOSS 0.2444\n",
|
||
|
"TRAIN: EPOCH 0009 / 0020 | BATCH 0004 / 0006 | LOSS 0.2422\n",
|
||
|
"TRAIN: EPOCH 0009 / 0020 | BATCH 0005 / 0006 | LOSS 0.2392\n",
|
||
|
"TRAIN: EPOCH 0009 / 0020 | BATCH 0006 / 0006 | LOSS 0.2389\n",
|
||
|
"VALID: EPOCH 0009 / 0020 | BATCH 0001 / 0001 | LOSS 0.2370\n",
|
||
|
"TRAIN: EPOCH 0010 / 0020 | BATCH 0001 / 0006 | LOSS 0.2330\n",
|
||
|
"TRAIN: EPOCH 0010 / 0020 | BATCH 0002 / 0006 | LOSS 0.2343\n",
|
||
|
"TRAIN: EPOCH 0010 / 0020 | BATCH 0003 / 0006 | LOSS 0.2297\n",
|
||
|
"TRAIN: EPOCH 0010 / 0020 | BATCH 0004 / 0006 | LOSS 0.2306\n",
|
||
|
"TRAIN: EPOCH 0010 / 0020 | BATCH 0005 / 0006 | LOSS 0.2329\n",
|
||
|
"TRAIN: EPOCH 0010 / 0020 | BATCH 0006 / 0006 | LOSS 0.2315\n",
|
||
|
"VALID: EPOCH 0010 / 0020 | BATCH 0001 / 0001 | LOSS 0.2316\n",
|
||
|
"TRAIN: EPOCH 0011 / 0020 | BATCH 0001 / 0006 | LOSS 0.2298\n",
|
||
|
"TRAIN: EPOCH 0011 / 0020 | BATCH 0002 / 0006 | LOSS 0.2244\n",
|
||
|
"TRAIN: EPOCH 0011 / 0020 | BATCH 0003 / 0006 | LOSS 0.2226\n",
|
||
|
"TRAIN: EPOCH 0011 / 0020 | BATCH 0004 / 0006 | LOSS 0.2295\n",
|
||
|
"TRAIN: EPOCH 0011 / 0020 | BATCH 0005 / 0006 | LOSS 0.2272\n",
|
||
|
"TRAIN: EPOCH 0011 / 0020 | BATCH 0006 / 0006 | LOSS 0.2301\n",
|
||
|
"VALID: EPOCH 0011 / 0020 | BATCH 0001 / 0001 | LOSS 0.2179\n",
|
||
|
"TRAIN: EPOCH 0012 / 0020 | BATCH 0001 / 0006 | LOSS 0.2201\n",
|
||
|
"TRAIN: EPOCH 0012 / 0020 | BATCH 0002 / 0006 | LOSS 0.2226\n",
|
||
|
"TRAIN: EPOCH 0012 / 0020 | BATCH 0003 / 0006 | LOSS 0.2231\n",
|
||
|
"TRAIN: EPOCH 0012 / 0020 | BATCH 0004 / 0006 | LOSS 0.2208\n",
|
||
|
"TRAIN: EPOCH 0012 / 0020 | BATCH 0005 / 0006 | LOSS 0.2222\n",
|
||
|
"TRAIN: EPOCH 0012 / 0020 | BATCH 0006 / 0006 | LOSS 0.2237\n",
|
||
|
"VALID: EPOCH 0012 / 0020 | BATCH 0001 / 0001 | LOSS 0.2273\n",
|
||
|
"TRAIN: EPOCH 0013 / 0020 | BATCH 0001 / 0006 | LOSS 0.2027\n",
|
||
|
"TRAIN: EPOCH 0013 / 0020 | BATCH 0002 / 0006 | LOSS 0.2203\n",
|
||
|
"TRAIN: EPOCH 0013 / 0020 | BATCH 0003 / 0006 | LOSS 0.2186\n",
|
||
|
"TRAIN: EPOCH 0013 / 0020 | BATCH 0004 / 0006 | LOSS 0.2188\n",
|
||
|
"TRAIN: EPOCH 0013 / 0020 | BATCH 0005 / 0006 | LOSS 0.2191\n",
|
||
|
"TRAIN: EPOCH 0013 / 0020 | BATCH 0006 / 0006 | LOSS 0.2185\n",
|
||
|
"VALID: EPOCH 0013 / 0020 | BATCH 0001 / 0001 | LOSS 0.2283\n",
|
||
|
"TRAIN: EPOCH 0014 / 0020 | BATCH 0001 / 0006 | LOSS 0.2275\n",
|
||
|
"TRAIN: EPOCH 0014 / 0020 | BATCH 0002 / 0006 | LOSS 0.2271\n",
|
||
|
"TRAIN: EPOCH 0014 / 0020 | BATCH 0003 / 0006 | LOSS 0.2190\n",
|
||
|
"TRAIN: EPOCH 0014 / 0020 | BATCH 0004 / 0006 | LOSS 0.2165\n",
|
||
|
"TRAIN: EPOCH 0014 / 0020 | BATCH 0005 / 0006 | LOSS 0.2160\n",
|
||
|
"TRAIN: EPOCH 0014 / 0020 | BATCH 0006 / 0006 | LOSS 0.2165\n",
|
||
|
"VALID: EPOCH 0014 / 0020 | BATCH 0001 / 0001 | LOSS 0.2171\n",
|
||
|
"TRAIN: EPOCH 0015 / 0020 | BATCH 0001 / 0006 | LOSS 0.2097\n",
|
||
|
"TRAIN: EPOCH 0015 / 0020 | BATCH 0002 / 0006 | LOSS 0.2162\n",
|
||
|
"TRAIN: EPOCH 0015 / 0020 | BATCH 0003 / 0006 | LOSS 0.2149\n",
|
||
|
"TRAIN: EPOCH 0015 / 0020 | BATCH 0004 / 0006 | LOSS 0.2129\n",
|
||
|
"TRAIN: EPOCH 0015 / 0020 | BATCH 0005 / 0006 | LOSS 0.2120\n",
|
||
|
"TRAIN: EPOCH 0015 / 0020 | BATCH 0006 / 0006 | LOSS 0.2126\n",
|
||
|
"VALID: EPOCH 0015 / 0020 | BATCH 0001 / 0001 | LOSS 0.2184\n",
|
||
|
"TRAIN: EPOCH 0016 / 0020 | BATCH 0001 / 0006 | LOSS 0.2167\n",
|
||
|
"TRAIN: EPOCH 0016 / 0020 | BATCH 0002 / 0006 | LOSS 0.2071\n",
|
||
|
"TRAIN: EPOCH 0016 / 0020 | BATCH 0003 / 0006 | LOSS 0.2096\n",
|
||
|
"TRAIN: EPOCH 0016 / 0020 | BATCH 0004 / 0006 | LOSS 0.2100\n",
|
||
|
"TRAIN: EPOCH 0016 / 0020 | BATCH 0005 / 0006 | LOSS 0.2084\n",
|
||
|
"TRAIN: EPOCH 0016 / 0020 | BATCH 0006 / 0006 | LOSS 0.2077\n",
|
||
|
"VALID: EPOCH 0016 / 0020 | BATCH 0001 / 0001 | LOSS 0.2243\n",
|
||
|
"TRAIN: EPOCH 0017 / 0020 | BATCH 0001 / 0006 | LOSS 0.2097\n",
|
||
|
"TRAIN: EPOCH 0017 / 0020 | BATCH 0002 / 0006 | LOSS 0.2075\n",
|
||
|
"TRAIN: EPOCH 0017 / 0020 | BATCH 0003 / 0006 | LOSS 0.2112\n",
|
||
|
"TRAIN: EPOCH 0017 / 0020 | BATCH 0004 / 0006 | LOSS 0.2114\n",
|
||
|
"TRAIN: EPOCH 0017 / 0020 | BATCH 0005 / 0006 | LOSS 0.2097\n",
|
||
|
"TRAIN: EPOCH 0017 / 0020 | BATCH 0006 / 0006 | LOSS 0.2091\n",
|
||
|
"VALID: EPOCH 0017 / 0020 | BATCH 0001 / 0001 | LOSS 0.2012\n",
|
||
|
"TRAIN: EPOCH 0018 / 0020 | BATCH 0001 / 0006 | LOSS 0.2049\n",
|
||
|
"TRAIN: EPOCH 0018 / 0020 | BATCH 0002 / 0006 | LOSS 0.1998\n",
|
||
|
"TRAIN: EPOCH 0018 / 0020 | BATCH 0003 / 0006 | LOSS 0.2004\n",
|
||
|
"TRAIN: EPOCH 0018 / 0020 | BATCH 0004 / 0006 | LOSS 0.2084\n",
|
||
|
"TRAIN: EPOCH 0018 / 0020 | BATCH 0005 / 0006 | LOSS 0.2070\n",
|
||
|
"TRAIN: EPOCH 0018 / 0020 | BATCH 0006 / 0006 | LOSS 0.2046\n",
|
||
|
"VALID: EPOCH 0018 / 0020 | BATCH 0001 / 0001 | LOSS 0.1956\n",
|
||
|
"TRAIN: EPOCH 0019 / 0020 | BATCH 0001 / 0006 | LOSS 0.1971\n",
|
||
|
"TRAIN: EPOCH 0019 / 0020 | BATCH 0002 / 0006 | LOSS 0.1906\n",
|
||
|
"TRAIN: EPOCH 0019 / 0020 | BATCH 0003 / 0006 | LOSS 0.1992\n",
|
||
|
"TRAIN: EPOCH 0019 / 0020 | BATCH 0004 / 0006 | LOSS 0.2029\n",
|
||
|
"TRAIN: EPOCH 0019 / 0020 | BATCH 0005 / 0006 | LOSS 0.2021\n",
|
||
|
"TRAIN: EPOCH 0019 / 0020 | BATCH 0006 / 0006 | LOSS 0.2061\n",
|
||
|
"VALID: EPOCH 0019 / 0020 | BATCH 0001 / 0001 | LOSS 0.2161\n",
|
||
|
"TRAIN: EPOCH 0020 / 0020 | BATCH 0001 / 0006 | LOSS 0.2169\n",
|
||
|
"TRAIN: EPOCH 0020 / 0020 | BATCH 0002 / 0006 | LOSS 0.2125\n",
|
||
|
"TRAIN: EPOCH 0020 / 0020 | BATCH 0003 / 0006 | LOSS 0.2093\n",
|
||
|
"TRAIN: EPOCH 0020 / 0020 | BATCH 0004 / 0006 | LOSS 0.2127\n",
|
||
|
"TRAIN: EPOCH 0020 / 0020 | BATCH 0005 / 0006 | LOSS 0.2115\n",
|
||
|
"TRAIN: EPOCH 0020 / 0020 | BATCH 0006 / 0006 | LOSS 0.2120\n",
|
||
|
"VALID: EPOCH 0020 / 0020 | BATCH 0001 / 0001 | LOSS 0.2045\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# 훈련 파라미터 설정하기\n",
|
||
|
"lr = 1e-3\n",
|
||
|
"batch_size = 4\n",
|
||
|
"num_epoch = 20\n",
|
||
|
"\n",
|
||
|
"base_dir = './drive/MyDrive/DACrew/unet'\n",
|
||
|
"data_dir = dir_data\n",
|
||
|
"ckpt_dir = os.path.join(base_dir, \"checkpoint\")\n",
|
||
|
"log_dir = os.path.join(base_dir, \"log\")\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# 훈련을 위한 Transform과 DataLoader\n",
|
||
|
"transform = transforms.Compose([Normalization(mean=0.5, std=0.5), RandomFlip(), ToTensor()])\n",
|
||
|
"\n",
|
||
|
"dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train'), transform=transform)\n",
|
||
|
"loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=0)\n",
|
||
|
"\n",
|
||
|
"dataset_val = Dataset(data_dir=os.path.join(data_dir, 'val'), transform=transform)\n",
|
||
|
"loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=0)\n",
|
||
|
"\n",
|
||
|
"# 네트워크 생성하기\n",
|
||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||
|
"net = UNet().to(device)\n",
|
||
|
"\n",
|
||
|
"# 손실함수 정의하기\n",
|
||
|
"fn_loss = nn.BCEWithLogitsLoss().to(device)\n",
|
||
|
"\n",
|
||
|
"# Optimizer 설정하기\n",
|
||
|
"optim = torch.optim.Adam(net.parameters(), lr=lr)\n",
|
||
|
"\n",
|
||
|
"# 그밖에 부수적인 variables 설정하기\n",
|
||
|
"num_data_train = len(dataset_train)\n",
|
||
|
"num_data_val = len(dataset_val)\n",
|
||
|
"\n",
|
||
|
"num_batch_train = np.ceil(num_data_train / batch_size)\n",
|
||
|
"num_batch_val = np.ceil(num_data_val / batch_size)\n",
|
||
|
"\n",
|
||
|
"# 그 밖에 부수적인 functions 설정하기\n",
|
||
|
"fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1)\n",
|
||
|
"fn_denorm = lambda x, mean, std: (x * std) + mean\n",
|
||
|
"fn_class = lambda x: 1.0 * (x > 0.5)\n",
|
||
|
"\n",
|
||
|
"# Tensorboard 를 사용하기 위한 SummaryWriter 설정\n",
|
||
|
"writer_train = SummaryWriter(log_dir=os.path.join(log_dir, 'train'))\n",
|
||
|
"writer_val = SummaryWriter(log_dir=os.path.join(log_dir, 'val'))\n",
|
||
|
"\n",
|
||
|
"# 네트워크 학습시키기\n",
|
||
|
"st_epoch = 0\n",
|
||
|
"# 학습한 모델이 있을 경우 모델 로드하기\n",
|
||
|
"net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim) \n",
|
||
|
"\n",
|
||
|
"for epoch in range(st_epoch + 1, num_epoch + 1):\n",
|
||
|
" net.train()\n",
|
||
|
" loss_arr = []\n",
|
||
|
"\n",
|
||
|
" for batch, data in enumerate(loader_train, 1):\n",
|
||
|
" # forward pass\n",
|
||
|
" label = data['label'].to(device)\n",
|
||
|
" input = data['input'].to(device)\n",
|
||
|
"\n",
|
||
|
" output = net(input)\n",
|
||
|
"\n",
|
||
|
" # backward pass\n",
|
||
|
" optim.zero_grad()\n",
|
||
|
"\n",
|
||
|
" loss = fn_loss(output, label)\n",
|
||
|
" loss.backward()\n",
|
||
|
"\n",
|
||
|
" optim.step()\n",
|
||
|
"\n",
|
||
|
" # 손실함수 계산\n",
|
||
|
" loss_arr += [loss.item()]\n",
|
||
|
"\n",
|
||
|
" print(\"TRAIN: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f\" %\n",
|
||
|
" (epoch, num_epoch, batch, num_batch_train, np.mean(loss_arr)))\n",
|
||
|
"\n",
|
||
|
" # Tensorboard 저장하기\n",
|
||
|
" label = fn_tonumpy(label)\n",
|
||
|
" input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))\n",
|
||
|
" output = fn_tonumpy(fn_class(output))\n",
|
||
|
"\n",
|
||
|
" writer_train.add_image('label', label, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')\n",
|
||
|
" writer_train.add_image('input', input, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')\n",
|
||
|
" writer_train.add_image('output', output, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')\n",
|
||
|
"\n",
|
||
|
" writer_train.add_scalar('loss', np.mean(loss_arr), epoch)\n",
|
||
|
"\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" net.eval()\n",
|
||
|
" loss_arr = []\n",
|
||
|
"\n",
|
||
|
" for batch, data in enumerate(loader_val, 1):\n",
|
||
|
" # forward pass\n",
|
||
|
" label = data['label'].to(device)\n",
|
||
|
" input = data['input'].to(device)\n",
|
||
|
"\n",
|
||
|
" output = net(input)\n",
|
||
|
"\n",
|
||
|
" # 손실함수 계산하기\n",
|
||
|
" loss = fn_loss(output, label)\n",
|
||
|
"\n",
|
||
|
" loss_arr += [loss.item()]\n",
|
||
|
"\n",
|
||
|
" print(\"VALID: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f\" %\n",
|
||
|
" (epoch, num_epoch, batch, num_batch_val, np.mean(loss_arr)))\n",
|
||
|
"\n",
|
||
|
" # Tensorboard 저장하기\n",
|
||
|
" label = fn_tonumpy(label)\n",
|
||
|
" input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))\n",
|
||
|
" output = fn_tonumpy(fn_class(output))\n",
|
||
|
"\n",
|
||
|
" writer_val.add_image('label', label, num_batch_val * (epoch - 1) + batch, dataformats='NHWC')\n",
|
||
|
" writer_val.add_image('input', input, num_batch_val * (epoch - 1) + batch, dataformats='NHWC')\n",
|
||
|
" writer_val.add_image('output', output, num_batch_val * (epoch - 1) + batch, dataformats='NHWC')\n",
|
||
|
"\n",
|
||
|
" writer_val.add_scalar('loss', np.mean(loss_arr), epoch)\n",
|
||
|
"\n",
|
||
|
" # epoch 50마다 모델 저장하기\n",
|
||
|
" if epoch % 50 == 0:\n",
|
||
|
" save(ckpt_dir=ckpt_dir, net=net, optim=optim, epoch=epoch)\n",
|
||
|
"\n",
|
||
|
" writer_train.close()\n",
|
||
|
" writer_val.close()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Test"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"TEST: BATCH 0001 / 0001 | LOSS 0.2526\n",
|
||
|
"AVERAGE TEST: BATCH 0001 / 0001 | LOSS 0.2526\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"transform = transforms.Compose([Normalization(mean=0.5, std=0.5), ToTensor()])\n",
|
||
|
"\n",
|
||
|
"dataset_test = Dataset(data_dir=os.path.join(data_dir, 'test'), transform=transform)\n",
|
||
|
"loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=0)\n",
|
||
|
"\n",
|
||
|
"# 그밖에 부수적인 variables 설정하기\n",
|
||
|
"num_data_test = len(dataset_test)\n",
|
||
|
"num_batch_test = np.ceil(num_data_test / batch_size)\n",
|
||
|
"\n",
|
||
|
"# 결과 디렉토리 생성하기\n",
|
||
|
"result_dir = os.path.join(base_dir, 'result')\n",
|
||
|
"if not os.path.exists(result_dir):\n",
|
||
|
" os.makedirs(os.path.join(result_dir, 'png'))\n",
|
||
|
" os.makedirs(os.path.join(result_dir, 'numpy'))\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)\n",
|
||
|
"\n",
|
||
|
"with torch.no_grad():\n",
|
||
|
" net.eval()\n",
|
||
|
" loss_arr = []\n",
|
||
|
"\n",
|
||
|
" for batch, data in enumerate(loader_test, 1):\n",
|
||
|
" # forward pass\n",
|
||
|
" label = data['label'].to(device)\n",
|
||
|
" input = data['input'].to(device)\n",
|
||
|
"\n",
|
||
|
" output = net(input)\n",
|
||
|
"\n",
|
||
|
" # 손실함수 계산하기\n",
|
||
|
" loss = fn_loss(output, label)\n",
|
||
|
"\n",
|
||
|
" loss_arr += [loss.item()]\n",
|
||
|
"\n",
|
||
|
" print(\"TEST: BATCH %04d / %04d | LOSS %.4f\" %\n",
|
||
|
" (batch, num_batch_test, np.mean(loss_arr)))\n",
|
||
|
"\n",
|
||
|
" # Tensorboard 저장하기\n",
|
||
|
" label = fn_tonumpy(label)\n",
|
||
|
" input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))\n",
|
||
|
" output = fn_tonumpy(fn_class(output))\n",
|
||
|
"\n",
|
||
|
" # 테스트 결과 저장하기\n",
|
||
|
" for j in range(label.shape[0]):\n",
|
||
|
" id = num_batch_test * (batch - 1) + j\n",
|
||
|
"\n",
|
||
|
" plt.imsave(os.path.join(result_dir, 'png', 'label_%04d.png' % id), label[j].squeeze(), cmap='gray')\n",
|
||
|
" plt.imsave(os.path.join(result_dir, 'png', 'input_%04d.png' % id), input[j].squeeze(), cmap='gray')\n",
|
||
|
" plt.imsave(os.path.join(result_dir, 'png', 'output_%04d.png' % id), output[j].squeeze(), cmap='gray')\n",
|
||
|
"\n",
|
||
|
" np.save(os.path.join(result_dir, 'numpy', 'label_%04d.npy' % id), label[j].squeeze())\n",
|
||
|
" np.save(os.path.join(result_dir, 'numpy', 'input_%04d.npy' % id), input[j].squeeze())\n",
|
||
|
" np.save(os.path.join(result_dir, 'numpy', 'output_%04d.npy' % id), output[j].squeeze())\n",
|
||
|
"\n",
|
||
|
"print(\"AVERAGE TEST: BATCH %04d / %04d | LOSS %.4f\" %\n",
|
||
|
" (batch, num_batch_test, np.mean(loss_arr)))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Result (Visualize)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqQAAAD4CAYAAAA6lfQMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d3hUVff+/ZmeMpn0XkkgQOhNilQBFUREpSgWBBXEB1BQKRaKUgQVVJQqSFEUQVBEkV6kiIJ0SCAkIb1OMslMps+8f/Ce/SOG6oPt+c59XVyamTPnzDlz9tlrr3Wv+5a53W43HnjggQceeOCBBx548DdB/nd/AQ888MADDzzwwAMP/m/DE5B64IEHHnjggQceePC3whOQeuCBBx544IEHHnjwt8ITkHrggQceeOCBBx548LfCE5B64IEHHnjggQceePC3whOQeuCBBx544IEHHnjwt8ITkHrggQceeOCBBx548LfCE5B64IEHHnjggQceePC3whOQeuCBBx544IEHHnjwt8ITkP6PYsWKFchkMrKysv7ur+KBB//z+KPjrWvXrjRu3Pi2fpeEhASeeuqp27pPDzzwwIM/G56A1IM/HWfPnmXq1Kme4NgDDzzwwIO/FAsWLGDFihV/ybE8c91/B09A+j+KJ554ArPZTHx8/N/9VTh79izTpk3zDFIPPPDAAw/+UvzVAalnrvvjUP7dX8CDPwcKhQKFQvF3fw0PPPDAAw888MCDG8KTIf0fxe85bQkJCfTp04f9+/dzxx134OXlRWJiIqtWrbrq5/bt28eIESMIDg5Gp9Px5JNPUl5eXmNbmUzG1KlTax37Sg7bihUrGDBgAADdunVDJpMhk8nYs2fP7T5lDzz4x+Dbb7/lvvvuIyoqCo1GQ1JSEm+99RZOp/Oq2x89epQOHTrg7e1NnTp1WLRoUa1trFYrU6ZMoW7dumg0GmJjYxk/fjxWq/XPPh0PPPhLcezYMXr16oVOp0Or1dK9e3d+/vln8f7UqVORyWS1Pne1ee/MmTPs3btXzD1du3atsa1nrvvnwJMh/T+E9PR0+vfvz9NPP82QIUNYvnw5Tz31FK1ataJRo0Y1th01ahQBAQFMnTqVtLQ0Fi5cyKVLl9izZ89VHwTXQufOnRkzZgwffvghr776Kg0bNgQQ//XAg/9FrFixAq1Wy7hx49BqtezatYvJkydTWVnJO++8U2Pb8vJyevfuzcCBA3n00Uf56quvGDlyJGq1mmHDhgHgcrno27cv+/fvZ/jw4TRs2JBTp04xb948zp8/zzfffPM3nKUHHtx+nDlzhk6dOqHT6Rg/fjwqlYrFixfTtWtX9u7dS9u2bW96X++//z6jR49Gq9Xy2muvARAeHl5jG89c9w+C24P/SXz66aduwJ2Zmel2u93u+Ph4N+Det2+f2Ka4uNit0WjcL730Uq3PtWrVym2z2cTrc+bMcQPub7/9VrwGuKdMmVLr2PHx8e4hQ4aIv9etW+cG3Lt3775t5+eBB/8k/H68VVdX19pmxIgRbh8fH7fFYhGvdenSxQ2433vvPfGa1Wp1N2/e3B0WFibG4OrVq91yudz9008/1djnokWL3ID7wIED4rXfjz8PPPg3oV+/fm61Wu2+ePGieC0/P9/t5+fn7ty5s9vtdrunTJnivlr48vtx6Ha73Y0aNXJ36dLlmtt65rp/Djwl+/9DSElJoVOnTuLv0NBQ6tevT0ZGRq1thw8fjkqlEn+PHDkSpVLJDz/88Jd8Vw88+DfD29tb/H9VVRWlpaV06tSJ6upqUlNTa2yrVCoZMWKE+FutVjNixAiKi4s5evQoAOvWraNhw4Y0aNCA0tJS8e+uu+4CYPfu3X/BWXngwZ8Lp9PJtm3b6NevH4mJieL1yMhIBg8ezP79+6msrLytx/TMdf8ceEr2/4cQFxdX67XAwMBafBmAevXq1fhbq9USGRnp6R70wIObwJkzZ3j99dfZtWtXrQnUYDDU+DsqKgpfX98aryUnJwOQlZVFu3btuHDhAufOnSM0NPSqxysuLr6N394DD/4elJSUUF1dTf369Wu917BhQ1wuFzk5Obf1mJ657p8DT0D6fwjX6rp3u9239TjXatzwwIP/C6ioqKBLly7odDrefPNNkpKS8PLy4rfffmPChAm4XK5b3qfL5aJJkybMnTv3qu/Hxsb+t1/bAw/+NbgWt/Ovnns8c93thScg9eCquHDhAt26dRN/G41GCgoK6N27t3gtMDCQioqKGp+z2WwUFBTUeO1WiOEeePBvx549eygrK2PDhg107txZvJ6ZmXnV7fPz8zGZTDWypOfPnwcud/ECJCUlceLECbp37+4ZTx78zyI0NBQfHx/S0tJqvZeamopcLic2NpbAwEDg8uIvICBAbHPp0qVan7vRePHMdf8ceDikHlwVS5YswW63i78XLlyIw+GgV69e4rWkpCT27dtX63O/XzVKE+3vB7QHHvwvQqpEXFl5sNlsLFiw4KrbOxwOFi9eXGPbxYsXExoaSqtWrQAYOHAgeXl5LF26tNbnzWYzJpPpdp6CBx78LVAoFNx99918++23NUrmRUVFrFmzho4dO6LT6UhKSgKoMf+YTCZWrlxZa5++vr7XnXs8c90/B54MqQdXhc1mo3v37gwcOJC0tDQWLFhAx44d6du3r9jmmWee4bnnnuPhhx+mZ8+enDhxgq1btxISElJjX82bN0ehUDB79mwMBgMajYa77rqLsLCwv/q0PPDgT0eHDh0IDAxkyJAhjBkzBplMxurVq69JjYmKimL27NlkZWWRnJzM2rVrOX78OEuWLBHNFk888QRfffUVzz33HLt37+bOO+/E6XSSmprKV199xdatW2nduvVfeZoeePCnYPr06Wzfvp2OHTvy/PPPo1QqWbx4MVarlTlz5gBw9913ExcXx9NPP80rr7yCQqFg+fLlhIaGkp2dXWN/rVq1YuHChUyfPp26desSFhYmmgHBM9f9o/A3d/l78CfharJP9913X63tunTpUkMSQ/rc3r173cOHD3cHBga6tVqt+7HHHnOXlZXV+KzT6XRPmDDBHRIS4vbx8XHfc8897vT09KvKzixdutSdmJjoVigUHlkMD/7n8PvxduDAAXe7du3c3t7e7qioKPf48ePdW7durXXvd+nSxd2oUSP3kSNH3O3bt3d7eXm54+Pj3R999FGtY9hsNvfs2bPdjRo1cms0GndgYKC7VatW7mnTprkNBoPYziP75MG/Hb/99pv7nnvucWu1WrePj4+7W7du7oMHD9bY5ujRo+62bdu61Wq1Oy4uzj137tyryj4VFha677vvPrefn58bEPOdZ67750Hmdt/mjhYP/tVYsWIFQ4cO5ddff/VkXDzwwAMPPPifhGeu++fBwyH1wAMPPPDAAw888OBvhScg9cADDzzwwAMPPPDgb4UnIPXAAw888MADDzzw4G/F3xaQfvzxxyQkJODl5UXbtm355Zdf/q6v4sEVeOqpp3C73R5Ozf8QPGPNAw/+GnjG2r8Hnrnun4e/JSBdu3Yt48aNY8qUKfz22280a9aMe+65x2N/54EHtxmeseaBB38NPGPNAw/+O/wtXfZt27alTZs2fPTRR8BlW7zY2FhGjx7NxIkT/+qv44EH/7PwjDUPPPhr4BlrHnjw3+EvF8a32WwcPXqUSZMmidfkcjk9evTg0KFDV/2M1WrFarWKv10uF3q9nuDgYI9Vlwf/M3C73VRVVREVFYVc/t8XLzxjzQMPrg7PWPPAg78GtzLW/vKAtLS0FKfTSXh4eI3Xw8PDSU1NvepnZs2axbRp0/6Kr+eBB387cnJyiImJ+a/34xlrHnhwfXjGmgce/DW4mbH2r7AOnTRpEuPGjRN/GwwG4uLiWLhwIXa7ncrKSgwGAydOnKC6uhqHw4FWq8ViseDj40NgYCBKpRKn04nRaMTX1xeZTIbZbOb8+fOEhITQvXt3dDodFy5cICoqivj4eCoqKkhNTSU3N5eioiL8/PwICgrCy8uL+vXrU6dOHUwmExaLBY1Gg91ux8vLi5KSElQqFdHR0VRVVeF0OikrK8N
|
||
|
"text/plain": [
|
||
|
"<Figure size 800x600 with 3 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"##\n",
|
||
|
"lst_data = os.listdir(os.path.join(result_dir, 'numpy'))\n",
|
||
|
"\n",
|
||
|
"lst_label = [f for f in lst_data if f.startswith('label')]\n",
|
||
|
"lst_input = [f for f in lst_data if f.startswith('input')]\n",
|
||
|
"lst_output = [f for f in lst_data if f.startswith('output')]\n",
|
||
|
"\n",
|
||
|
"lst_label.sort()\n",
|
||
|
"lst_input.sort()\n",
|
||
|
"lst_output.sort()\n",
|
||
|
"\n",
|
||
|
"##\n",
|
||
|
"id = 0\n",
|
||
|
"\n",
|
||
|
"label = np.load(os.path.join(result_dir,\"numpy\", lst_label[id]))\n",
|
||
|
"input = np.load(os.path.join(result_dir,\"numpy\", lst_input[id]))\n",
|
||
|
"output = np.load(os.path.join(result_dir,\"numpy\", lst_output[id]))\n",
|
||
|
"\n",
|
||
|
"## 플롯 그리기\n",
|
||
|
"plt.figure(figsize=(8,6))\n",
|
||
|
"plt.subplot(131)\n",
|
||
|
"plt.imshow(input, cmap='gray')\n",
|
||
|
"plt.title('input')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(132)\n",
|
||
|
"plt.imshow(label, cmap='gray')\n",
|
||
|
"plt.title('label')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(133)\n",
|
||
|
"plt.imshow(output, cmap='gray')\n",
|
||
|
"plt.title('output')\n",
|
||
|
"\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.11"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|