You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
20_Final_Project/jeju_thesis_model_test.ipynb

386 lines
433 KiB
Plaintext

6 months ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualize"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, precision_recall_curve\n",
"\n",
"base_dirs = ['./2nd_Battery/unet',\n",
" './2nd_Battery/unet-dice-loss',\n",
" './2nd_Battery/unet-focal-loss',\n",
" './2nd_Battery/unet-l1',\n",
" './2nd_Battery/unet-l2']\n",
"colors = ['red', 'orange', 'yellow', 'pink', 'green', 'gold', 'magenta', 'cyan', 'violet']\n",
"losses = [0.2072, 0.3879, 0.0112, 0.0357, 0.0241]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"unet - precision: 0.582, recall: 0.895, accuracy: 0.958, f1: 0.685, iou: 0.550\n",
"unet-dice-loss - precision: 0.509, recall: 0.934, accuracy: 0.946, f1: 0.642, iou: 0.492\n",
"unet-focal-loss - precision: 0.894, recall: 0.762, accuracy: 0.984, f1: 0.818, iou: 0.711\n",
"unet-l1 - precision: 0.670, recall: 0.885, accuracy: 0.971, f1: 0.752, iou: 0.618\n",
"unet-l2 - precision: 0.778, recall: 0.788, accuracy: 0.977, f1: 0.777, iou: 0.649\n"
]
}
],
"source": [
"lst_rst, lst_cdata = [], []\n",
"idx = 0\n",
"\n",
"for base_dir in base_dirs:\n",
" result_dir = os.path.join(base_dir, 'result')\n",
"\n",
" ##\n",
" lst_data = os.listdir(os.path.join(result_dir, 'numpy'))\n",
"\n",
" lst_img = [f for f in lst_data if f.startswith('img')]\n",
" lst_gt = [f for f in lst_data if f.startswith('gt')]\n",
" lst_pr = [f for f in lst_data if f.startswith('pr')]\n",
"\n",
" lst_img.sort()\n",
" lst_gt.sort()\n",
" lst_pr.sort()\n",
"\n",
" ##\n",
" # id = 0\n",
" length = 1000\n",
" tp_list, tn_list, fp_list, fn_list, rst_list = [], [], [], [], []\n",
"\n",
" for id in range(0, length):\n",
" img = np.load(os.path.join(result_dir,\"numpy\", lst_img[id]))\n",
" gt = np.load(os.path.join(result_dir,\"numpy\", lst_gt[id]))\n",
" pr = np.load(os.path.join(result_dir,\"numpy\", lst_pr[id]))\n",
"\n",
" tp = np.sum(np.logical_and(gt == 1, pr == 1))\n",
" tn = np.sum(np.logical_and(gt == 0, pr == 0))\n",
" fp = np.sum(np.logical_and(gt == 0, pr == 1))\n",
" fn = np.sum(np.logical_and(gt == 1, pr == 0))\n",
"\n",
" precision = tp / (tp + fp) # precision = TP / (TP + FP)\n",
" recall = tp / (tp + fn) # recall = TP / (TP + FN), SE(Sensitivity), hit rate\n",
" accuracy = (tp + tn) / (tp + tn + fp + fn)\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
" iou = tp / (tp + fn + fp) \n",
"\n",
" tp_list.append(tp)\n",
" tn_list.append(tn)\n",
" fp_list.append(fp)\n",
" fn_list.append(fn)\n",
"\n",
"\n",
" # output_binary = (pr > 0).astype(np.int32)\n",
" # label_flat = gt.flatten().astype(np.int32)\n",
" # output_flat = output_binary.flatten().astype(np.int32)\n",
"\n",
" # accuracy = accuracy_score(label_flat, output_flat)\n",
" # precision = precision_score(label_flat, output_flat)\n",
" # recall = recall_score(label_flat, output_flat)\n",
" # f1 = f1_score(label_flat, output_flat)\n",
"\n",
" rst_list.append((precision, recall, accuracy, f1, iou))\n",
" \n",
" # avg_precision = sum(item[0] for item in rst_list) / len(rst_list)\n",
" # avg_recall = sum(item[1] for item in rst_list) / len(rst_list)\n",
" # avg_accuracy = sum(item[2] for item in rst_list) / len(rst_list)\n",
" # avg_f1 = sum(item[3] for item in rst_list) / len(rst_list)\n",
" # avg_iou = sum(item[4] for item in rst_list) / len(rst_list)\n",
" avg_precision = np.mean([item[0] for item in rst_list])\n",
" avg_recall = np.mean([item[1] for item in rst_list])\n",
" avg_accuracy = np.mean([item[2] for item in rst_list])\n",
" avg_f1 = np.mean([item[3] for item in rst_list])\n",
" avg_iou = np.mean([item[4] for item in rst_list])\n",
" print(f'{os.path.basename(base_dir)} - precision: {avg_precision:.3f}, recall: {avg_recall:.3f}, accuracy: {avg_accuracy:.3f}, f1: {avg_f1:.3f}, iou: {avg_iou:.3f}')\n",
"\n",
" total_tp = np.sum(tp_list)\n",
" total_tn = np.sum(tn_list)\n",
" total_fp = np.sum(fp_list)\n",
" total_fn = np.sum(fn_list)\n",
" y_true = np.concatenate([np.ones(total_tp + total_fn), np.zeros(total_tn + total_fp)])\n",
" y_score = np.concatenate([np.ones(total_tp), np.zeros(total_fn), np.ones(total_fp), np.zeros(total_tn)])\n",
"\n",
" lst_cdata.append((y_true, y_score))\n",
" lst_rst.append((losses[idx], avg_precision, avg_recall, avg_accuracy, avg_f1, avg_iou))\n",
" idx += 1"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
"avg_loss_list = [item[0] for item in lst_rst]\n",
"avg_precision_list = [item[1] for item in lst_rst]\n",
"avg_recall_list = [item[2] for item in lst_rst]\n",
"avg_accuracy_list = [item[3] for item in lst_rst]\n",
"avg_f1_list = [item[4] for item in lst_rst]\n",
"avg_iou_list = [item[5] for item in lst_rst]\n",
"model_list = ['BCE', 'Dice', 'Focal', 'L1', 'L2']"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGzCAYAAAD9pBdvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACqcElEQVR4nOzdd3hURRfA4d+m9wQIhARCbwFpglRpAgLShdA7WBAVxIqFpoJ+SlEsqIQqKE06UqRIEenV0Hvv6T073x+TbLKkkECSTch5n2cfsvfevXd22WTPnpk5Y1BKKYQQQgghLMTK0g0QQgghRP4mwYgQQgghLEqCESGEEEJYlAQjQgghhLAoCUaEEEIIYVESjAghhBDCoiQYEUIIIYRFSTAihBBCCIuSYEQIIYQQFiXBiMgxX331FWXKlMHa2poaNWpYujkiGzVt2pSmTZvm2PXCwsIYMmQIRYsWxWAwMGLEiBy7dk4ZO3YsBoOBO3fuWLopQmQ5CUbysdmzZ2MwGEw3BwcHKlSowOuvv87Nmzez9FobNmzgvffeo2HDhsyaNYsJEyZk6flF9rh27Rpjx47l0KFDlm5KuiZMmMDs2bMZOnQo8+bNo2/fvpZuUq6zYMECpk6dmq3XyCvvF5H72Fi6AcLyxo8fT+nSpYmKimLHjh38+OOPrF27lmPHjuHk5JQl19i8eTNWVlYEBARgZ2eXJecU2e/atWuMGzeOUqVKZSqbtWHDhuxrVCo2b95MvXr1GDNmTI5eNy9ZsGABx44dy9as0aO+X4SQYETQpk0bateuDcCQIUMoVKgQkydPZsWKFfTs2fOxzh0REYGTkxO3bt3C0dExywIRpRRRUVE4OjpmyflE1kj8/87pgPPWrVtUrlw5y84XFxeH0WiUwFmIHCLdNCKF5557DoDz58+btv3666/UqlULR0dHChYsSI8ePbh8+bLZ45o2bcpTTz3F/v37ady4MU5OTnz44YcYDAZmzZpFeHi4qUto9uzZgP6j/+mnn1K2bFns7e0pVaoUH374IdHR0WbnLlWqFO3atWP9+vXUrl0bR0dHfvrpJ7Zu3YrBYGDRokWMGzeOYsWK4erqSteuXQkODiY6OpoRI0ZQpEgRXFxcGDhwYIpzz5o1i+eee44iRYpgb29P5cqV+fHHH1O8Lolt2LFjB3Xq1MHBwYEyZcowd+7cFMcGBQXx1ltvUapUKezt7SlevDj9+vUz6++Pjo5mzJgxlCtXDnt7e3x9fXnvvfdStC81ia/1kSNHaNKkCU5OTpQrV44lS5YA8Pfff1O3bl0cHR2pWLEif/31V4pzXL16lUGDBuHl5YW9vT1VqlRh5syZpv1bt27lmWeeAWDgwIEp/u/S+v9O3PfgmJGoqCjGjh1LhQoVcHBwwNvbmxdffJGzZ8+ajvn999+pVasWrq6uuLm5UbVqVb755ps0X4fE///z58+zZs0aUxsvXLgA6CBl8ODBeHl54eDgQPXq1ZkzZ47ZOS5cuIDBYODrr79m6tSppvdiYGBguv8HGfmd2L59O/7+/pQoUcL0f/zWW28RGRmZ4nwnTpygW7duFC5c2PT/9tFHH6U4LigoiAEDBuDh4YG7uzsDBw4kIiIi3bY2bdqUNWvWcPHiRdNrVKpUKdP+jL4XN27cyLPPPouHhwcuLi5UrFjR9H/+sPdLWq5evcrgwYPx8fHB3t6e0qVLM3ToUGJiYgC4d+8e77zzDlWrVsXFxQU3NzfatGnD4cOHU5xr2rRpVKlSBScnJwoUKEDt2rVZsGBBiuul974XliGZEZFC4odDoUKFAPj888/55JNP6NatG0OGDOH27dtMmzaNxo0bc/DgQTw8PEyPvXv3Lm3atKFHjx706dMHLy8vateuzc8//8yePXuYMWMGAA0aNAB0JmbOnDl07dqVt99+m927dzNx4kSOHz/OsmXLzNp18uRJevbsySuvvMJLL71ExYoVTfsmTpyIo6MjH3zwAWfOnGHatGnY2tpiZWXF/fv3GTt2LP/++y+zZ8+mdOnSjB492vTYH3/8kSpVqtChQwdsbGxYtWoVr732GkajkWHDhpm14cyZM3Tt2pXBgwfTv39/Zs6cyYABA6hVqxZVqlQB9GDKRo0acfz4cQYNGsTTTz/NnTt3WLlyJVeuXMHT0xOj0UiHDh3YsWMHL7/8Mn5+fhw9epQpU6Zw6tQpli9f/tD/p/v379OuXTt69OiBv78/P/74Iz169GD+/PmMGDGCV199lV69evHVV1/RtWtXLl++jKurKwA3b96kXr16GAwGXn/9dQoXLsyff/7J4MGDCQkJYcSIEfj5+TF+/HhGjx7Nyy+/TKNGjcz+79L6/05NfHw87dq1Y9OmTfTo0YPhw4cTGhrKxo0bOXbsGGXLlmXjxo307NmT5s2b8+WXXwJw/Phxdu7cyfDhw1M9r5+fH/PmzeOtt96iePHivP322wAULlyYyMhImjZtypkzZ3j99dcpXbo0ixcvZsCAAQQFBaU456xZs4iKiuLll1/G3t6eggULpvnaZ/R3YvHixURERDB06FAKFSrEnj17mDZtGleuXGHx4sWm8x05coRGjRpha2vLyy+/TKlSpTh79iyrVq3i888/N7t2t27dKF26NBMnTuTAgQPMmDGDIkWKmF6z1Hz00UcEBwdz5coVpkyZAoCLiwtAht+L//33H+3ataNatWqMHz8ee3t7zpw5w86dO03/Fw97vzzo2rVr1KlTh6CgIF5++WUqVarE1atXWbJkCREREdjZ2XHu3DmWL1+Ov78/pUuX5ubNm/z00080adKEwMBAfHx8APjll19488036dq1K8OHDycqKoojR46we/duevXqBWTsfS8sRIl8a9asWQpQf/31l7p9+7a6fPmy+v3331WhQoWUo6OjunLlirpw4YKytrZWn3/+udljjx49qmxsbMy2N2nSRAFq+vTpKa7Vv39/5ezsbLbt0KFDClBDhgwx2/7OO+8oQG3evNm0rWTJkgpQ69atMzt2y5YtClBPPfWUiomJMW3v2bOnMhgMqk2bNmbH169fX5UsWdJsW0RERIr2tmrVSpUpU8ZsW2Ibtm3bZtp269YtZW9vr95++23TttGjRytA/fHHHynOazQalVJKzZs3T1lZWant27eb7Z8+fboC1M6dO1M8NrnE13rBggWmbSdOnFCAsrKyUv/++69p+/r16xWgZs2aZdo2ePBg5e3tre7cuWN23h49eih3d3fTa7J3794Uj32wDan9fzdp0kQ1adLEdH/mzJkKUJMnT05xbOJrMnz4cOXm5qbi4uLSfe6pKVmypGrbtq3ZtqlTpypA/frrr6ZtMTExqn79+srFxUWFhIQopZQ6f/68ApSbm5u6devWQ6+Vmd+J1N5bEydOVAaDQV28eNG0rXHjxsrV1dVsm1JJr41SSo0ZM0YBatCgQWbHdO7cWRUqVOih7W7btm2K975SGX8vTpkyRQHq9u3baV4jvfdLavr166esrKzU3r17U+xLfO5RUVEqPj7ebN/58+eVvb29Gj9+vGlbx44dVZUqVdK9Xkbf9yLnSTeNoEWLFhQuXBhfX1969OiBi4sLy5Yto1ixYvzxxx8YjUa6devGnTt3TLeiRYtSvnx5tmzZYnYue3t7Bg4cmKHrrl27FoCRI0eabU/8drtmzRqz7aVLl6ZVq1apnqtfv37Y2tqa7tetWxelFIMGDTI7rm7duly+fJm4uDjTtuTjToKDg7lz5w5NmjTh3LlzBAcHmz2+cuXKpm98oL+BV6xYkXPnzpm2LV26lOrVq9O5c+cU7TQYDID+xuzn50elSpXMXtfELrIHX9fUuLi40KNHD9P9ihUr4uHhgZ+fH3Xr1jV7zoCpjUopli5dSvv27VFKmV2/VatWBAcHc+DAgYdeHzL+/7106VI8PT154403UuxLfE08PDwIDw9n48aNGbr2w6xdu5aiRYuajXuytbXlzTffJCwsjL///tvs+C5dulC4cOGHnjczvxPJ31vh4eHcuXOHBg0aoJTi4MGDANy+fZtt27YxaNAgSpQoYXatxNcmuVdffdXsfqNGjbh79y4
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(model_list, avg_loss_list, '-', color='blue', label=f'Avg Loss')\n",
"plt.plot(model_list, avg_precision_list, '-', color='green', label=f'Avg Precision')\n",
"plt.plot(model_list, avg_recall_list, '-', color='red', label=f'Avg Recall')\n",
"plt.plot(model_list, avg_accuracy_list, '-', color='gold', label=f'Avg Accuracy')\n",
"plt.plot(model_list, avg_f1_list, '-', color='orange', label=f'Avg F1-Score')\n",
"plt.plot(model_list, avg_iou_list, '-', color='pink', label=f'mIoU')\n",
"plt.title('Performance metrics for each test case')\n",
"plt.legend(loc='right', bbox_to_anchor=(0.99, 0.3))\n",
"# plt.legend(loc='lower right')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"unet PR AUC: 0.704, minDist: 0.497, (0.885424512257604, 0.51682025846915)\n",
"unet-dice-loss PR AUC: 0.692, minDist: 0.555, (0.931617865125619, 0.44876511258753)\n",
"unet-focal-loss PR AUC: 0.822, minDist: 0.277, (0.7487617009547043, 0.8840717896362793)\n",
"unet-l1 PR AUC: 0.754, minDist: 0.393, (0.8766366499467493, 0.626387811957221)\n",
"unet-l2 PR AUC: 0.760, minDist: 0.349, (0.7813983872652583, 0.7285129156027101)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAHHCAYAAABTMjf2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAADpcElEQVR4nOydd3jUVdqG75n0nhAIhCQk9N57L6ICgtIUKQJSbOiKoH7qumvZtS0WUBCUIl2aiGIBkRJ6772EFEqAQCrpM7/vj5PMZEghCZP+3tc1F+H82pmImSfnvO/z6DRN0xAEQRAEQaiA6Et6AoIgCIIgCCWFCCFBEARBECosIoQEQRAEQaiwiBASBEEQBKHCIkJIEARBEIQKiwghQRAEQRAqLCKEBEEQBEGosIgQEgRBEAShwiJCSBAEQRCECosIIUEQBEEQKiwihARBKDQLFy5Ep9OZXo6OjtSrV4+XX36ZGzdumM7btm2bxXk2Njb4+PgwdOhQzpw5U6BnXrp0ieeff55atWrh6OiIu7s7nTt3ZsaMGSQlJVn7LQqCUM6xLekJCIJQ9vnwww+pWbMmycnJ7Ny5k9mzZ/PHH39w8uRJnJ2dTef94x//oG3btqSlpXH8+HHmzJnDtm3bOHnyJNWqVbvvc37//XeefPJJHBwcGD16NE2aNCE1NZWdO3fyxhtvcOrUKb7//vuifKuCIJQzRAgJgvDA9O3blzZt2gAwYcIEvL29+fLLL/nll18YPny46byuXbsydOhQ09/r16/Piy++yOLFi3nzzTfzfMbly5d5+umnCQwMZMuWLfj6+pqOTZo0iYsXL/L7779b5f3cvXsXFxcXq9xLEITSjWyNCYJgdXr16gUo8ZIXXbt2BdR21/343//+R0JCAvPnz7cQQZnUqVOHV199FYDQ0FB0Oh0LFy7Mdp5Op+P99983/f39999Hp9Nx+vRpRowYgZeXF126dOHzzz9Hp9MRFhaW7R5vv/029vb2REdHm8b27dtHnz598PDwwNnZme7du7Nr1677vi9BEEoWEUKCIFidTGHj7e2d53mhoaEAeHl53fee69evp1atWnTq1OmB55cTTz75JImJiXz88cdMnDiRp556Cp1Ox6pVq7Kdu2rVKh555BHTvLds2UK3bt2Ii4vjvffe4+OPPyYmJoZevXqxf//+IpmvIAjWQbbGBEF4YGJjY4mKiiI5OZldu3bx4Ycf4uTkRP/+/S3Oi4+PJyoqylQjNHnyZHQ6HUOGDMnz/nFxcVy9epUnnniiyN5D8+bNWb58ucVYhw4dWLlyJW+88YZp7MCBA4SEhJhWlTRN44UXXqBnz578+eef6HQ6AJ5//nkaN27Mu+++y19//VVk8xYE4cEQISQIwgPTu3dvi78HBgaybNky/Pz8LMbHjRtn8fcqVaqwZMkS2rZtm+f94+LiAHBzc7PCbHPmhRdeyDY2bNgwJk+ezKVLl6hduzYAK1euxMHBwSTKjh49yoULF3j33Xe5ffu2xfUPPfQQS5YswWg0otfLArwglEZECAmC8MDMmjWLevXqYWtrS9WqValfv36OH/z//ve/6dq1KwkJCfz888+sWLEiXwLB3d0dUCtKRUXNmjWzjT355JNMmTKFlStX8s4776BpGqtXr6Zv376mOV24cAGAMWPG5Hrv2NjYfG3/CYJQ/IgQEgThgWnXrp2paywvmjZtalo9GjhwIImJiUycOJEuXboQEBCQ63Xu7u5Ur16dkydP5ms+mdtT92IwGHK9xsnJKdtY9erV6dq1K6tWreKdd95h7969hIeH89lnn5nOMRqNAEybNo0WLVrkeG9XV9d8zVsQhOJH1moFQSgxPv30U5KTk/noo4/ue27//v25dOkSe/bsue+5masvMTExFuM5dYDdj2HDhnHs2DHOnTvHypUrcXZ2ZsCAAabjmVtm7u7u9O7dO8eXnZ1dgZ8rCELxIEJIEIQSo3bt2gwZMoSFCxcSGRmZ57lvvvkmLi4uTJgwwcK1OpNLly4xY8YMQImSypUrs337dotzvv322wLPcciQIdjY2PDjjz+yevVq+vfvb+Ex1Lp1a2rXrs3nn39OQkJCtutv3bpV4GcKglB8yNaYIAglyhtvvMGqVauYPn06n376aa7n1a5dm+XLlzNs2DAaNmxo4Sy9e/duVq9ezdixY03nT5gwgU8//ZQJEybQpk0btm/fzvnz5ws8Px8fH3r27MmXX35JfHw8w4YNsziu1+uZN28effv2pXHjxjz77LP4+flx9epVtm7diru7O+vXry/wcwVBKB5kRUgQhBKlTZs29OjRg9mzZxMbG5vnuY8//jjHjx9n6NCh/PLLL0yaNIm33nqL0NBQvvjiC77++mvTuf/+978ZP348a9as4c0338RgMPDnn38Wao7Dhg0jPj4eNzc3+vXrl+14jx492LNnD23atGHmzJm88sorLFy4kGrVqvHaa68V6pmCIBQPOk3TtJKehCAIgiAIQkkgK0KCIAiCIFRYRAgJgiAIglBhESEkCIIgCEKFRYSQIAiCIAgVFhFCgiAIgiBUWEQICYIgCIJQYalwhopGo5Fr167h5uaWax6RIAiCIAilC03TiI+Pp3r16vkKa84vFU4IXbt2Lc9wR0EQBEEQSi8RERH4+/tb7X4VTgi5ubkB6hvp7u5ewrMRBEEQBCE/xMXFERAQYPoctxYVTghlboe5u7uLEBIEQRCEMoa1y1qkWFoQBEEQhAqLCCFBEARBECosIoQEQRAEQaiwiBASBEEQBKHCIkJIEARBEIQKiwghQRAEQRAqLCKEBEEQBEGosIgQEgRBEAShwiJCSBAEQRCECosIIUEQBEEQKiwlKoS2b9/OgAEDqF69OjqdjnXr1t33mm3bttGqVSscHByoU6cOCxcuLPJ5CoIgCIJQPilRIXT37l2aN2/OrFmz8nX+5cuXeeyxx+jZsydHjx5l8uTJTJgwgY0bNxbxTAVBEARBKI+UaOhq37596du3b77PnzNnDjVr1uSLL74AoGHDhuzcuZOvvvqKRx99tKimKQiCIAhCOaVM1Qjt2bOH3r17W4w9+uij7Nmzp8D3ioneDNwCNOtMThAEQRCEIuPGjYQiuW+JrggVlMjISKpWrWoxVrVqVeLi4khKSsLJySnbNSkpKaSkpJj+HhcXB0DgnMG0rOZHd38nutcz0q1OVSo51QT8gYCMPzNfPpQxzSgIgiAI5QajUePxx38sknuXKSFUGD755BM++OCDHI8dibnKkRiYfhIghKZeZ+kRZKR7nVi6BUIVl8wzbQE/zMLoXqHkD1QDbIryrQiCIAhChUSv1/H++z14+mnr37tMCaFq1apx48YNi7EbN27g7u6e42oQwNtvv82UKVNMf4+LiyMgIICz9SdxJO08wcZ0guNCOZN4mRPR0ZyIhm+OqHMbVXKle81Eugel0z0ojGquYXnMzgbwJXeh5J9x3K5wb14QBEEQKhCHD1/n5s279OlTB4C+fesWyXPKlBDq2LEjf/zxh8XYpk2b6NixY67XODg44ODgkG3cd/DH1Hd35+mzf8Cuadys0YDt6ZUJTrzLtthjnLx7idN3Ejh9B2YfUtfUq1SdHkE16B7kRfdAG/zc7wBXgKuAIePrK8DeXGajQ60c5SaU/IHqQPb5CoIgCEJFwGjU+Pzz3bz77hZcXe05fvxF/P3di+x5JSqEEhISuHjxounvly9f5ujRo1SqVIkaNWrw9ttvc/XqVRYvXgzACy+8wMyZM3nzzTcZN24cW7ZsYdWqVfz++++Fn0SDftCgHz7Jtxj697sMTf4NajUmyqY/O1LcCI67SHDsYY4lXOD8nWucv3ON7w+rS2t71aZ7YG+6B3Wle2BjAj01zGIop1cacD3jtT+PSVUld6GU+XIs/HsWBEEQhFJIREQsY8asY+vWUAB69AjCyalopYpO07QSa5vatm0bPXv2zDY+ZswYFi5cyNixYwkNDWXbtm0W17z22mucPn0af39//vWvfzF27Nh8PzMuLg4PDw9iY2Nxd89BYRoNcGEl7P0v2J0Fuwbg2I1oWrEzNYFtMYcIjjnCkYR
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import precision_recall_curve, auc\n",
"\n",
"idx = 0\n",
"for cdata in lst_cdata:\n",
" y_true, y_score = cdata\n",
"\n",
" # PR 커브 계산 및 그리기\n",
" precision, recall, _ = precision_recall_curve(y_true, y_score)\n",
"\n",
" # AUC 계산\n",
" pr_auc = auc(recall, precision)\n",
" plt.plot(recall, precision, '-', color=colors[idx], label=f'{os.path.basename(base_dirs[idx])}(AUC = {pr_auc:.3f})')\n",
"\n",
" # (1, 1)에 가장 가까운 점 찾기\n",
" min_distance = float('inf')\n",
" closest_point = None\n",
" for i in range(len(recall)):\n",
" distance = ((1 - recall[i])**2 + (1 - precision[i])**2)**0.5\n",
" if distance < min_distance:\n",
" min_distance = distance\n",
" closest_point = i\n",
" plt.scatter(recall[closest_point], precision[closest_point], color=colors[idx], marker='o')\n",
"\n",
" print(f'{os.path.basename(base_dirs[idx])} PR AUC: {pr_auc:.3f}, minDist: {min_distance:.3f}, {(recall[closest_point], precision[closest_point])}')\n",
" idx += 1\n",
"\n",
"plt.plot([0.0, 1.05], [0.0, 1.05], '--', color='navy', label='baseline')\n",
"plt.xlabel('Recall')\n",
"plt.ylabel('Precision')\n",
"plt.title('PR Curve')\n",
"plt.legend()\n",
"plt.xlim([0.0, 1.0])\n",
"plt.ylim([0.0, 1.05])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"unet ROC AUC: 0.760, minDist: 0.121, (0.03855843283619862, 0.885424512257604)\n",
"unet-dice-loss ROC AUC: 0.760, minDist: 0.087, (0.05330329573045489, 0.931617865125619)\n",
"unet-focal-loss ROC AUC: 0.760, minDist: 0.251, (0.004573449270014527, 0.7487617009547043)\n",
"unet-l1 ROC AUC: 0.760, minDist: 0.126, (0.024355441418179306, 0.8766366499467493)\n",
"unet-l2 ROC AUC: 0.760, minDist: 0.219, (0.013563852969138058, 0.7813983872652583)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAHHCAYAAABTMjf2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC3f0lEQVR4nOzdd3iTZffA8W+SNuletOxCK1BakI0gIIqALOUV0BdkyHSDCycuQH+CCiooCspUQREUEceLAwFREBRk771aoNBJd3L//kiTJk3SAS1pm/O5rlxNnjxPcqdKe3ruc99Ho5RSCCGEEEJ4IK27ByCEEEII4S4SCAkhhBDCY0kgJIQQQgiPJYGQEEIIITyWBEJCCCGE8FgSCAkhhBDCY0kgJIQQQgiPJYGQEEIIITyWBEJCCCGE8FgSCAkhhBDCY0kgJIQo0qJFi9BoNNabl5cXderUYeTIkZw5c8bpNUopPvvsM26++WZCQkLw8/OjWbNmvPrqq1y+fNnle33zzTf07t2b8PBw9Ho9tWvXZuDAgfz2228lGmtWVhbvvvsu7du3Jzg4GB8fH2JiYhg3bhwHDx68os8vhKjaNNJrTAhRlEWLFjFq1CheffVVoqOjycrK4q+//mLRokVERUWxe/dufHx8rOcbjUaGDBnCsmXL6Ny5MwMGDMDPz48NGzbw+eef06RJE3799Vdq1KhhvUYpxejRo1m0aBGtWrXi7rvvpmbNmsTHx/PNN9+wdetW/vzzTzp27OhynImJifTq1YutW7dyxx130L17dwICAjhw4ABLly4lISGBnJyccv1eCSEqISWEEEVYuHChAtTff/9td/y5555TgPryyy/tjk+ZMkUB6umnn3Z4rVWrVimtVqt69epld3zatGkKUE888YQymUwO13366adq8+bNRY7z9ttvV1qtVn311VcOz2VlZamnnnqqyOtLKjc3V2VnZ5fJawkh3E8CISFEkVwFQt9//70C1JQpU6zHMjIyVGhoqIqJiVG5ublOX2/UqFEKUJs2bbJeExYWpmJjY1VeXt4VjfGvv/5SgLr//vtLdP4tt9yibrnlFofjI0aMUPXr17c+PnbsmALUtGnT1Lvvvquuu+46pdVq1V9//aV0Op2aNGmSw2vs379fAer999+3HktKSlKPP/64qlu3rtLr9apBgwbqjTfeUEajsdSfVQhRtqRGSAhxRY4fPw5AaGio9dgff/xBUlISQ4YMwcvLy+l1w4cPB+D777+3XnPp0iWGDBmCTqe7orGsWrUKgHvvvfeKri/OwoULef/993nggQd4++23qVWrFrfccgvLli1zOPfLL79Ep9Px3//+F4CMjAxuueUWFi9ezPDhw3nvvffo1KkTEyZMYPz48eUyXiFEyTn/SSWEEIWkpKSQmJhIVlYWmzdvZvLkyRgMBu644w7rOXv37gWgRYsWLl/H8ty+ffvsvjZr1uyKx1YWr1GU06dPc/jwYSIiIqzHBg0axIMPPsju3bu5/vrrrce//PJLbrnlFmsN1DvvvMORI0f4999/adSoEQAPPvggtWvXZtq0aTz11FNERkaWy7iFEMWTjJAQokS6d+9OREQEkZGR3H333fj7+7Nq1Srq1q1rPSctLQ2AwMBAl69jeS41NdXua1HXFKcsXqMod911l10QBDBgwAC8vLz48ssvrcd2797N3r17GTRokPXY8uXL6dy5M6GhoSQmJlpv3bt3x2g08vvvv5fLmIUQJSMZISFEiXzwwQfExMSQkpLCggUL+P333zEYDHbnWAIRS0DkTOFgKSgoqNhrimP7GiEhIVf8Oq5ER0c7HAsPD6dbt24sW7aM1157DTBng7y8vBgwYID1vEOHDrFz506HQMri/PnzZT5eIUTJSSAkhCiRdu3a0bZtWwD69evHTTfdxJAhQzhw4AABAQEAxMXFAbBz50769evn9HV27twJQJMmTQCIjY0FYNeuXS6vKY7ta3Tu3LnY8zUaDcrJziFGo9Hp+b6+vk6P33PPPYwaNYrt27fTsmVLli1bRrdu3QgPD7eeYzKZuO2223j22WedvkZMTEyx4xVClB+ZGhNClJpOp2Pq1KmcPXuWWbNmWY/fdNNNhISE8Pnnn7sMKj799FMAa23RTTfdRGhoKF988YXLa4rTt29fABYvXlyi80NDQ0lOTnY4fuLEiVK9b79+/dDr9Xz55Zds376dgwcPcs8999id06BBA9LT0+nevbvTW7169Ur1nkKIsiWBkBDiinTp0oV27doxY8YMsrKyAPDz8+Ppp5/mwIEDvPjiiw7X/PDDDyxatIiePXty4403Wq957rnn2LdvH88995zTTM3ixYvZsmWLy7F06NCBXr16MW/ePFauXOnwfE5ODk8//bT1cYMGDdi/fz8XLlywHtuxYwd//vlniT8/QEhICD179mTZsmUsXboUvV7vkNUaOHAgmzZt4qeffnK4Pjk5mby8vFK9pxCibMnO0kKIIll2lv7777+tU2MWX331Ff/973+ZPXs2Dz30EGCeXho0aBBff/01N998M3fddRe+vr788ccfLF68mLi4ONasWWO3s7TJZGLkyJF89tlntG7d2rqzdEJCAitXrmTLli1s3LiRDh06uBznhQsX6NGjBzt27KBv375069YNf39/Dh06xNKlS4mPjyc7OxswrzK7/vrradGiBWPGjOH8+fPMmTOHGjVqkJqaat0a4Pjx40RHRzNt2jS7QMrWkiVLGDZsGIGBgXTp0sW6lN8iIyODzp07s3PnTkaOHEmbNm24fPkyu3bt4quvvuL48eN2U2lCiGvMvdsYCSEqOlcbKiqllNFoVA0aNFANGjSw2wzRaDSqhQsXqk6dOqmgoCDl4+OjmjZtqiZPnqzS09NdvtdXX32levToocLCwpSXl5eqVauWGjRokFq3bl2JxpqRkaGmT5+ubrjhBhUQEKD0er1q1KiRevTRR9Xhw4ftzl28eLG67rrrlF6vVy1btlQ//fRTkRsqupKamqp8fX0VoBYvXuz0nLS0NDVhwgTVsGFDpdfrVXh4uOrYsaOaPn26ysnJKdFnE0KUD8kICSGEEMJjSY2QEEIIITyWBEJCCCGE8FgSCAkhhBDCY0kgJIQQQgiPJYGQEEIIITyWBEJCCCGE8Fge12vMZDJx9uxZAgMD0Wg07h6OEEIIIUpAKUVaWhq1a9dGqy27PI7HBUJnz54lMjLS3cMQQgghxBU4deoUdevWLbPX87hAKDAwEDB/I4OCgtw8GiGEEEKURGpqKpGRkdbf42XF4wIhy3RYUFCQBEJCCCFEJVPWZS1SLC2EEEIIjyWBkBBCCCE8lgRCQgghhPBYEggJIYQQwmNJICSEEEIIjyWBkBBCCCE8lgRCQgghhPBYEggJIYQQwmNJICSEEEIIjyWBkBBCCCE8llsDod9//52+fftSu3ZtNBoNK1euLPaadevW0bp1awwGAw0bNmTRokXlPk4hhBBCVE1uDYQuX75MixYt+OCDD0p0/rFjx7j99tu59dZb2b59O0888QT33XcfP/30UzmPVAghhBBVkVubrvbu3ZvevXuX+Pw5c+YQHR3N22+/DUBcXBx//PEH7777Lj179iyvYQohhBCiiqpU3ec3bdpE9+7d7Y717NmTJ554wj0DEkIIIUTZM2ZDbhpkXoLzZ0g5d4hfD2wpl7eqVIFQQkICNWrUsDtWo0YNUlNTyczMxNfX1+Ga7OxssrOzrY9TU1PLfZxCCCGER1EmyLsMuanmACYvzfnXnBRIS4TLGZCdB0YNKG/Q+KI0/pxWsN+Yxf6cdPZlXGR/RgL7M04Qn3URPhhTLkOvVIHQlZg6dSqTJ0929zCEEEKIisWSdbELVgoFMsU9n5cGOWlgzARtKOjCQFst/2uh+7pI0LYgGz8Om86yP/cE+y4fY3/GcfZnnGB/xnEumzKdj1ULIbf9TfLSsv82VKpAqGbNmpw7d87u2Llz5wgKCnKaDQKYMGEC48ePtz5OTU0lMjKyXMcphBBClDllgrx0FwFKfpDiMoBxEsiYcot5Q2+bgCbMJshpZH5sCAPfavnPhThcnZSbyr6
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import roc_curve, roc_auc_score\n",
"\n",
"idx = 0\n",
"for cdata in lst_cdata:\n",
" y_true, y_score = cdata\n",
"\n",
" # AUC 계산\n",
" auc = roc_auc_score(y_true, y_score)\n",
"\n",
" # ROC 커브 계산 및 그리기\n",
" fpr, tpr, _ = roc_curve(y_true, y_score)\n",
" plt.plot(fpr, tpr, color=colors[idx], label=f'{os.path.basename(base_dirs[idx])}(AUC = {auc:.3f})')\n",
"\n",
" # (0, 1)에 가장 가까운 점 찾기\n",
" min_distance = float('inf')\n",
" closest_point = None\n",
" for i in range(len(fpr)):\n",
" distance = ((0 - fpr[i])**2 + (1 - tpr[i])**2)**0.5\n",
" if distance < min_distance:\n",
" min_distance = distance\n",
" closest_point = i\n",
" plt.scatter(fpr[closest_point], tpr[closest_point], color=colors[idx], marker='o')\n",
"\n",
" print(f'{os.path.basename(base_dirs[idx])} ROC AUC: {pr_auc:.3f}, minDist: {min_distance:.3f}, {(fpr[closest_point], tpr[closest_point])}')\n",
" idx += 1\n",
"\n",
"plt.plot([0.0, 1.05], [0.0, 1.05], '--', color='navy', label='baseline')\n",
"plt.xlabel('FPR')\n",
"plt.ylabel('TPR')\n",
"plt.title('ROC Curve')\n",
"plt.legend()\n",
"plt.xlim([0.0, 1.0])\n",
"plt.ylim([0.0, 1.05])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxYAAAOpCAYAAACHIrEHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3gU1f7/39trNr03AiEBRHpRROkCCopelSaCIirq1evXfq2IooL1Knot2KkiXURAEER6LyGQhFTSk0229/n9we/MnZ3dTTZ1s+G8nmef7E49Myczc97zaQKGYRhQKBQKhUKhUCgUSgsQBroBFAqFQqFQKBQKJfihwoJCoVAoFAqFQqG0GCosKBQKhUKhUCgUSouhwoJCoVAoFAqFQqG0GCosKBQKhUKhUCgUSouhwoJCoVAoFAqFQqG0GCosKBQKhUKhUCgUSouhwoJCoVAoFAqFQqG0GCosKBQKhUKhUCgUSouhwoJCoVAobrz++usQCASBbgalg1FQUACBQIDvvvsu0E2hUCgdlA4jLL777jsIBAIcPXo00E2hBIj8/Hw8/vjjyMjIgFKphFKpRK9evfDYY4/h9OnTGDlyJAQCQaOf119/PdCH0mkh1yn3ExMTg1GjRuG3337zWL6iogLPPPMMevToAaVSCZVKhYEDB+LNN99EXV0du1xDfdujR492PMLOCb/f5HI5EhISMH78ePznP/+BXq8PdBMpPLxda+TzwgsvBLp5lFbGnzHQ559/jrvvvhspKSkQCASYM2dO+zWQ0iCN9V9xcTEWLFiAIUOGIDw8HFFRURg5ciR27tzZzi1te8SBbgCFAgBbtmzB1KlTIRaLMXPmTPTt2xdCoRDZ2dlYt24dPv/8c3z77bd48MEH2XWOHDmC//znP/j3v/+Nnj17stP79OkTiEO4qnjjjTeQlpYGhmFQUVGB7777Drfccgs2b96MSZMmAbjSP7fccgsMBgPuvfdeDBw4EABw9OhRvPPOO9i7dy+2b9/ObjMpKQlvv/22x75CQ0Pb56CuAki/2e12lJeX488//8S//vUvfPDBB9i0aRN77bz88st08NpBIH3GpXfv3gFqDSWQvPvuu9Dr9RgyZAjKysoC3RxKE9i4cSPeffddTJkyBbNnz4bD4cAPP/yAcePG4ZtvvsH9998f6Ca2GlRYUAJOXl4epk2bhtTUVPzxxx+Ij493m//uu+/is88+w+jRo5GcnMxOl8vl+M9//oNx48Zh5MiR7dzqq5uJEydi0KBB7O+5c+ciNjYWK1euxKRJk1BXV4c77rgDIpEIJ06c8LA6vPXWW/jqq6/cpoWGhuLee+9tl/ZfrfD77cUXX8SuXbswadIk3HbbbTh//jwUCgXEYjHEYvp46Ajw+4xy9bJnzx7WWqFWqwPdHEoTGDVqFIqKihAVFcVOe+SRR9CvXz+8+uqrnUpYdBhXKD5z5syBWq1GUVERJk2aBLVajcTERCxduhQAcObMGYwePRoqlQqpqalYsWKF2/q1tbV45plncO2110KtVkOj0WDixIk4deqUx74KCwtx2223QaVSISYmBk899RR+//13CAQC/Pnnn27LHjp0CBMmTEBoaCiUSiVGjBiBv//+u83Ow9XA4sWLYTQa8e2333qICgAQi8V44okn3EQFpWMRFhbGDkgB4IsvvsDly5fxwQcfeHVlio2Nxcsvv9zezaR4YfTo0XjllVdQWFiIn376CYDvGIuffvoJQ4YMgVKpRHh4OG666SY3qxMA/Pbbb7jxxhuhUqkQEhKCW2+9FefOnWuXY7na2LVrF3uuw8LCcPvtt+P8+fMey12+fBlz585FQkICZDIZ0tLSMH/+fNhsNgBNe15SAkdqaiqNfQpSrrnmGjdRAQAymQy33HILSkpKOpU7aocVFgDgdDoxceJEJCcnY/HixejSpQsef/xxfPfdd5gwYQIGDRqEd999FyEhIbjvvvuQn5/Prnvp0iVs2LABkyZNwgcffIBnn30WZ86cwYgRI1BaWsouZzQaMXr0aOzcuRNPPPEEXnrpJezfvx/PP/+8R3t27dqFm266CTqdDq+99hoWLVqEuro6jB49GocPH26Xc9IZ2bJlC9LT0zF06NBAN4XiJ/X19aiurkZVVRXOnTuH+fPnsy5PALBp0yYoFArcddddfm/T6XSiurra42M0GtvqMCj/n1mzZgGAh0jgsmDBAsyaNQsSiQRvvPEGFixYgOTkZOzatYtd5scff8Stt94KtVqNd999F6+88gqysrIwfPhwFBQUtPVhdErItcb9AMDOnTsxfvx4VFZW4vXXX8f//d//Yf/+/bjhhhvcznVpaSmGDBmCVatWYerUqfjPf/6DWbNmYc+ePTCZTAD8f15SKJTWpby8nI0p7TQwHYRvv/2WAcAcOXKEYRiGmT17NgOAWbRoEbuMVqtlFAoFIxAImFWrVrHTs7OzGQDMa6+9xk6zWCyM0+l020d+fj4jk8mYN954g532/vvvMwCYDRs2sNPMZjPTo0cPBgCze/duhmEYxuVyMd27d2fGjx/PuFwudlmTycSkpaUx48aNa5XzcLVRX1/PAGCmTJniMU+r1TJVVVXsx2Qyuc3/+eef3fqI0vaQ65T/kclkzHfffccuFx4ezvTt29fv7Y4YMcLrdgEwDz/8cBscydUF//7qjdDQUKZ///4MwzDMa6+9xnAfDzk5OYxQKGTuuOMOj/squR/q9XomLCyMmTdvntv88vJyJjQ01GM6pWF8XWukX/r168fExMQwNTU17DqnTp1ihEIhc99997HT7rvvPkYoFHrte9J3/j4v8/PzGQDMt99+25qHSmH8u0a5qFQqZvbs2W3bKIrfNLX/GObKfVUulzOzZs1qw5a1Px3eiZYbrBsWFobMzEzk5ubinnvuYadnZmYiLCwMly5dYqfJZDL2u9PpRF1dHdRqNTIzM3H8+HF23rZt25CYmIjbbruNnSaXyzFv3jw8/fTT7LSTJ08iJycHL7/8MmpqatzaOGbMGPz4449wuVwQCju0EajDodPpAMCrv+jIkSPdTPFLlizBM888025to/hm6dKlyMjIAHAl89NPP/2EBx98ECEhIbjzzjuh0+kQEhLSpG126dLFI+4CuBLUTWl71Gq1T3P8hg0b4HK58Oqrr3rc44hrxo4dO1BXV4fp06ezb9UBQCQSYejQodi9e3fbNb4Tw73WCGVlZTh58iSee+45REREsNP79OmDcePGYevWrQAAl8uFDRs2YPLkyV7jNEjf+fu8pFAorYPJZMLdd98NhUKBd955J9DNaVU6tLCQy+WIjo52mxYaGoqkpCQPP8PQ0FBotVr2t8vlwscff4zPPvsM+fn5cDqd7LzIyEj2e2FhIbp16+axvfT0dLffOTk5AIDZs2f7bG99fT3Cw8P9PDoKAHbwaTAYPOZ98cUX0Ov1qKiooEG9HYwhQ4a4DVSmT5+O/v374/HHH8ekSZOg0Wia7DOqUqkwduzY1m4qxU8MBgNiYmK8zsvLy4NQKESvXr18rk/ukaNHj/Y6X6PRtLyRVyH8aw0ADh48CODKSzU+PXv2xO+//w6j0QiDwQCdTtdoFil/n5cUCqXlOJ1OTJs2DVlZWfjtt9+QkJAQ6Ca1Kh1aWIhEoiZNZxiG/b5o0SK88soreOCBB7Bw4UJERERAKBTiX//6F1wuV5PbQtZZsmQJ+vXr53UZmqWh6YSGhiI+Ph5nz571mEdiLqhvdsdHKBRi1KhR+Pjjj5GTk4MePXrg5MmTsNlskEqlgW4epRFKSkpQX1/v8UKlKZB75I8//oi4uDiP+TTLVMeltZ+XFArFN/PmzcOWLVuwfPlyny9igplOe6dfu3YtRo0ahWXLlrlNr6urc4vMT01NRVZWFhiGcbNa5Obmuq3XrVs3AFfeutG3qq3Lrbfeiq+//hqHDx/GkCFDAt0cSjNxOBwArrz5njx5Mg4cOIBffvkF06dPD3DLKI3x448/AgDGjx/vdX63bt3gcrmQlZXl88UKuUfGxMTQe2Qbk5qaCgC4cOGCx7zs7GxERUVBpVJBoVBAo9F4fXHDxd/nJYV
"text/plain": [
"<Figure size 800x950 with 56 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# id = 7\n",
"h, w = 0, 2\n",
"lst_idx = [857,739,986,2,592,333,487,155]\n",
"names = ['BCE', 'Dice', 'Focal', 'L1', 'L2']\n",
"fig, axs = plt.subplots(len(lst_idx), 7, figsize=(8, 9.5))\n",
"\n",
"for id in lst_idx:\n",
" result_dir = os.path.join(base_dirs[0], 'result')\n",
" img = np.load(os.path.join(result_dir,\"numpy\", lst_img[id]))\n",
" gt = np.load(os.path.join(result_dir,\"numpy\", lst_gt[id]))\n",
" axs[h,0].imshow(img, cmap='gray')\n",
" axs[h,0].axis('off')\n",
" if h == 0:\n",
" axs[h,0].set_title(f'Image')\n",
" axs[h,1].imshow(gt, cmap='gray')\n",
" axs[h,1].axis('off')\n",
" if h == 0:\n",
" axs[h,1].set_title(f'GT')\n",
"\n",
" for base_dir in base_dirs:\n",
" result_dir = os.path.join(base_dir, 'result')\n",
"\n",
" ##\n",
" lst_data = os.listdir(os.path.join(result_dir, 'numpy'))\n",
"\n",
" lst_gt = [f for f in lst_data if f.startswith('gt')]\n",
" lst_pr = [f for f in lst_data if f.startswith('pr')]\n",
"\n",
" lst_gt.sort()\n",
" lst_pr.sort()\n",
"\n",
" ##\n",
" pr = np.load(os.path.join(result_dir,\"numpy\", lst_pr[id]))\n",
" axs[h,w].imshow(pr, cmap='gray')\n",
" axs[h,w].axis('off')\n",
" if h == 0:\n",
" axs[h,w].set_title(f'{names[w-2]}')\n",
" w += 1\n",
" if w == 7:\n",
" w = 2\n",
" h += 1\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}