You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1938 lines
1.2 MiB
Plaintext

1 year ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Packages"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# import package\n",
"\n",
"# model\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torchsummary import summary\n",
"from torch import optim\n",
"\n",
"# dataset and transformation\n",
"from torchvision import datasets\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.data import DataLoader\n",
"import os\n",
"\n",
"# display images\n",
"from torchvision import utils\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"# utils\n",
"import numpy as np\n",
"from torchsummary import summary\n",
"import time\n",
"import copy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 데이터셋 불러오기\n",
"\n",
"데이터셋은 torchvision 패키지에서 제공하는 STL10 dataset을 사용한다. \n",
"\n",
"STL10 dataset은 10개의 label을 갖으며 train dataset 5000개, test dataset 8000개로 구성된다."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./dataset\\stl10_binary.tar.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2640397119/2640397119 [05:50<00:00, 7527746.51it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./dataset\\stl10_binary.tar.gz to ./dataset\n",
"Files already downloaded and verified\n",
"5000\n",
"8000\n"
]
}
],
"source": [
"# specift the data path\n",
"path2data = './dataset'\n",
"\n",
"# if not exists the path, make the directory\n",
"if not os.path.exists(path2data):\n",
" os.mkdir(path2data)\n",
"\n",
"# load dataset\n",
"train_ds = datasets.STL10(path2data, split='train', download=True, transform=transforms.ToTensor())\n",
"val_ds = datasets.STL10(path2data, split='test', download=True, transform=transforms.ToTensor())\n",
"\n",
"print(len(train_ds))\n",
"print(len(val_ds))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Transformation & DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# define image transformation\n",
"transformation = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Resize(299)\n",
"])\n",
"\n",
"train_ds.transform = transformation\n",
"val_ds.transform = transformation\n",
"\n",
"# create dataloader\n",
"train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)\n",
"val_dl = DataLoader(val_ds, batch_size=8, shuffle=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAEECAYAAAD0y9+hAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9edA27VnXiX+Oc+nua7m3Z3uXLCSGxImCLEFSIrKIoAaZgqoBsQoGoyKZKkJqYhFELTQOUw6CgCVQgBYEUEoZR8FxxgGhYDatnwKC/sIvGoGY5V2e/V6upbvP5ffHeZ7dfT3vS8wbEuKr/X3qeu77vq6+us8+u/s4ju+xnRJjjMyYMWPGjBkzZsyYMWPGjOeF+lgPYMaMGTNmzJgxY8aMGTP+U8ZMmmbMmDFjxowZM2bMmDHjg2AmTTNmzJgxY8aMGTNmzJjxQTCTphkzZsyYMWPGjBkzZsz4IJhJ04wZM2bMmDFjxowZM2Z8EMykacaMGTNmzJgxY8aMGTM+CGbSNGPGjBkzZsyYMWPGjBkfBDNpmjFjxowZM2bMmDFjxowPgpk0zZgxY8aMGTNmzJgxY8YHwUyaZsyYMeNFgne84x2ICO95z3te8Hc/53M+h0/4hE/4iI7nFa94BX/8j//xj8i+RGR4fdu3fdtHZJ//JeHHf/zHD+bw53/+5z/WQ5oxY8aM/6wwk6YZM2bMmPGfBL7kS76EH/mRH+ELv/ALh/fe8573HJCB6evv/t2/+2Ed533vex9vf/vb+fRP/3TOzs64ceMGn/M5n8NP//RP/6bP4Z/9s3/GZ37mZ7JcLnn88cf5uq/7Oq6urn5T+/yu7/ouXvva11LXNS95yUt461vfymazOdjm0z7t0/iRH/kR/vSf/tO/qWPNmDFjxoznh/lYD2DGjBkzZswA+F2/63fxFV/xFc/72R/7Y3+MN7zhDQfv/Z7f83s+rOP8xE/8BN/yLd/CF3/xF/NVX/VVOOf44R/+YT7/8z+fH/iBH+CNb3zjh7XfX/qlX+LzPu/zeO1rX8u3f/u38/73v59v+7Zv493vfjf/5J/8kw9rn9/wDd/AX/2rf5X/5r/5b3jLW97Cr/zKr/A3/sbf4J3vfCc/+ZM/OWz30pe+lK/4iq/AOcf3f//3f1jHmjFjxowZvzFm0jRjxowZM/6Tx6d+6qf+hoTqheJzP/dzee9738uNGzeG9970pjfxyZ/8yXzTN33Th02a/tyf+3OcnZ3xcz/3cxwfHwMphfGrv/qr+amf+im+4Au+4AXt7+mnn+bbv/3b+cqv/Ep++Id/eHj/Na95DW9+85v5X//X/5Uv+qIv+rDGOmPGjBkzXhjm9LwZM2bMeBHjJ37iJ/jCL/xCnnzySeq65lWvehX/w//wP+C9f97tf+EXfoHP+IzPYLFY8MpXvpLv/d7vfc42bdvyF//iX+TjP/7jqeual73sZbztbW+jbdsPOpa+73n729/Oq1/9apqm4fr163zmZ34m//Sf/tODbd71rnfx9NNPv+Bz3Ww2dF33gr/3KH7n7/ydB4QJoK5r3vCGN/D+97+fy8vLF7zPi4sL/uk//ad8xVd8xUCYAP7b//a/Zb1e82M/9mMveJ///J//c5xzfPmXf/nB++XvDzc9ccaMGTNmvHDMpGnGjBkzXsR4xzvewXq95q1vfSt//a//dV73utfxTd/0TfzZP/tnn7PtgwcPeMMb3sDrXvc6/upf/au89KUv5b/77/47fuAHfmDYJoTAf/1f/9d827d9G1/0RV/E3/gbf4Mv/uIv5ju+4zv4o3/0j37Qsfylv/SXePvb387nfu7n8l3f9V38+T//53n5y1/OL/7iLw7bfOADH+C1r30t3/iN3/iCzvPtb3876/Wapmn43b/7d/NTP/VTL+j7HwqeeeYZlssly+XyBX/33/ybf4Nzjk/7tE87eL+qKj75kz+Zf/Wv/tUL3mchqYvF4uD9Mr5f+IVfeMH7nDFjxowZHx7m9LwZM2bMeBHjR3/0Rw+M6je96U286U1v4nu+53v45m/+Zuq6Hj576qmn+Gt/7a/x1re+FYCv+Zqv4fWvfz3f+I3fyFd+5VdireVHf/RH+emf/mn+z//z/+QzP/Mzh+9+wid8Am9605v4Z//sn/EZn/EZzzuW/+1/+994wxve8BGtqVFK8QVf8AV8yZd8CS95yUv4tV/7Nb7927+dP/yH/zD/6B/9o4OmEb8Z/Pt//+/5B//gH/ClX/qlaK1f8PdL5OyJJ554zmdPPPEE//f//X+/4H3+9t/+2wH4f//f/5fP/dzPHd4v+/rABz7wgvc5Y8aMGTM+PMyRphkzZsx4EWNKmC4vL7l79y6/7/f9PrbbLe9617sOtjXG8DVf8zXD31VV8TVf8zXcvn17iFr8z//z/8xrX/ta/qv/6r/i7t27w+v3//7fD8DP/uzP/oZjOT095Z3vfCfvfve7f8NtXvGKVxBj5B3veMeHdH4vf/nL+cmf/Ene9KY38UVf9EW85S1v4V/9q3/FzZs3+TN/5s98SPv4j2G73fKlX/qlLBYL/qf/6X/6sPax2+0ADkhqQdM0w+cvBJ/6qZ/K61//er7lW76FH/zBH+Q973kP/+Sf/BO+5mu+Bmvth7XPGTNmzJjx4WEmTTNmzJjxIsY73/lOvuRLvoSTkxOOj4+5efPm0DDh/Pz8YNsnn3yS1Wp18N5rXvMagGHtp3e/+928853v5ObNmwevst3t27d/w7H85b/8l3n48CGvec1r+MRP/ES+/uu/nn/9r//1R+pUB1y7do03vvGN/Nt/+295//vf/5val/eeL//yL+dXfuVX+Pt//+/z5JNPflj7KeT1+eq+9vv9c1LsPlT8L//L/8InfdIn8Sf+xJ/gla98JV/0RV/El33Zl/Epn/IprNfrD2ufM2bMmDHjhWNOz5sxY8aMFykePnzIZ3/2Z3N8fMxf/st/mVe96lU0TcMv/uIv8g3f8A2EEF7wPkMIfOInfiLf/u3f/ryfv+xlL/sNv/tZn/VZ/Oqv/io/8RM/wU/91E/xt/7W3+I7vuM7+N7v/V7+1J/6Uy94LB8MZRz379/npS996Ye9n6/+6q/mH//jf8zf+Tt/Z4imfTgoaXnP1+Di6aef/rDJ2Ete8hL+n//n/+Hd7343zzzzDK9+9at5/PHHefLJJwciO2PGjBkzPvqYSdOMGTNmvEjxcz/3c9y7d49/8A/+AZ/1WZ81vP/rv/7rz7v9U089xWazOYg2/bt/9++AlDYH8KpXvYpf/uVf5vM+7/MQkRc8phIFeuMb38jV1RWf9VmfxV/6S3/pI06afu3Xfg2Amzdvftj7+Pqv/3p+8Ad/kO/8zu/kj/2xP/abGs8nfMInYIzh53/+5/myL/uy4f2u6/ilX/qlg/c+HLz61a/m1a9+NQC/8iu/wtNPP80f/+N//De1zxkzZsyY8aFjTs+bMWPGjBcpSsOCGOPwXtd1fM/3fM/zbu+c4/u+7/sOtv2+7/s+bt68yete9zoAvuzLvowPfOAD/M2/+Tef8/3dbsdms/kNx3Pv3r2Dv9frNR//8R9/kLL2QluO37lz5znvfeADH+AHfuAH+F2/63c9b+OFDwXf+q3fyrd927fx5/7cn+Mtb3nLh7WPKU5OTvgDf+AP8Lf/9t8+aFn+Iz/yI1xdXfGlX/qlv+ljQIoEvu1tb2O5XPKmN73pI7LPGTNmzJjxH8ccaZoxY8aMFyk+4zM+g7OzM77qq76Kr/u6r0NE+JEf+ZEDEjXFk08+ybd8y7fwnve8h9e85jX8vb/39/ilX/olvv/7vx9rLQBf+ZVfyY/92I/xpje9iZ/92Z/l9/7e34v3nne961382I/9GD/5kz/5nLbaBb/jd/wOPudzPofXve51XLt2jZ//+Z/n7//9v8/Xfu3XDtuUluNf9VVf9SE1g3jb297Gr/7qr/J5n/d5PPnkk7znPe/h+77v+9hsNvz1v/7XD7Z9xzvewRvf+EZ+8Ad/8INGYf7hP/yHvO1tb+PVr341r33ta/nbf/tvH3z++Z//+Tz22GNAqvV65Stf+SGN93/8H/9HPuM
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# display sample images\n",
"def show(img, y=None, color=True): \n",
" npimg = img.numpy()\n",
" npimg_tr = np.transpose(npimg, (1, 2, 0))\n",
" plt.imshow(npimg_tr)\n",
"\n",
" if y is not None:\n",
" plt.title('labels:' + str(y))\n",
"\n",
"np.random.seed(0)\n",
"torch.manual_seed(0)\n",
"\n",
"grid_size = 4\n",
"rnd_ind = np.random.randint(0, len(train_ds), grid_size)\n",
"\n",
"x_grid = [train_ds[i][0] for i in rnd_ind]\n",
"y_grid = [train_ds[i][1] for i in rnd_ind]\n",
"\n",
"plt.figure(figsize=(10,10))\n",
"x_grid = utils.make_grid(x_grid, nrow=4, padding=2)\n",
"show(x_grid, y_grid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Network (Inception-ResNet-v2)\n",
"\n",
"Inception-ResNet-v2는 Inception-v4에 residual block을 사용하는 모델이다.\n",
"\n",
"[참고](https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/inceptionv4.py)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class BasicConv2d(nn.Module):\n",
" def __init__(self, in_channels, out_channels, kernel_size, **kwargs):\n",
" super().__init__()\n",
"\n",
" # bias=Fasle, because BN after conv includes bias.\n",
" self.conv = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs),\n",
" nn.BatchNorm2d(out_channels),\n",
" nn.ReLU()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" return x\n",
"\n",
"\n",
"class Stem(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" self.conv1 = nn.Sequential(\n",
" BasicConv2d(3, 32, 3, stride=2, padding=0), # 149 x 149 x 32\n",
" BasicConv2d(32, 32, 3, stride=1, padding=0), # 147 x 147 x 32\n",
" BasicConv2d(32, 64, 3, stride=1, padding=1), # 147 x 147 x 64 \n",
" )\n",
"\n",
" self.branch3x3_conv = BasicConv2d(64, 96, 3, stride=2, padding=0) # 73x73x96\n",
"\n",
" # kernel_size=4: 피쳐맵 크기 73, kernel_size=3: 피쳐맵 크기 74\n",
" self.branch3x3_pool = nn.MaxPool2d(4, stride=2, padding=1) # 73x73x64\n",
"\n",
" self.branch7x7a = nn.Sequential(\n",
" BasicConv2d(160, 64, 1, stride=1, padding=0),\n",
" BasicConv2d(64, 96, 3, stride=1, padding=0)\n",
" ) # 71x71x96\n",
"\n",
" self.branch7x7b = nn.Sequential(\n",
" BasicConv2d(160, 64, 1, stride=1, padding=0),\n",
" BasicConv2d(64, 64, (7,1), stride=1, padding=(3,0)),\n",
" BasicConv2d(64, 64, (1,7), stride=1, padding=(0,3)),\n",
" BasicConv2d(64, 96, 3, stride=1, padding=0)\n",
" ) # 71x71x96\n",
"\n",
" self.branchpoola = BasicConv2d(192, 192, 3, stride=2, padding=0) # 35x35x192\n",
"\n",
" # kernel_size=4: 피쳐맵 크기 73, kernel_size=3: 피쳐맵 크기 74\n",
" self.branchpoolb = nn.MaxPool2d(4, 2, 1) # 35x35x192\n",
"\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = torch.cat((self.branch3x3_conv(x), self.branch3x3_pool(x)), dim=1)\n",
" x = torch.cat((self.branch7x7a(x), self.branch7x7b(x)), dim=1)\n",
" x = torch.cat((self.branchpoola(x), self.branchpoolb(x)), dim=1)\n",
" return x\n",
"\n",
"\n",
"class Inception_Resnet_A(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
"\n",
" self.branch1x1 = BasicConv2d(in_channels, 32, 1, stride=1, padding=0)\n",
"\n",
" self.branch3x3 = nn.Sequential(\n",
" BasicConv2d(in_channels, 32, 1, stride=1, padding=0),\n",
" BasicConv2d(32, 32, 3, stride=1, padding=1)\n",
" )\n",
"\n",
" self.branch3x3stack = nn.Sequential(\n",
" BasicConv2d(in_channels, 32, 1, stride=1, padding=0),\n",
" BasicConv2d(32, 48, 3, stride=1, padding=1),\n",
" BasicConv2d(48, 64, 3, stride=1, padding=1)\n",
" )\n",
" \n",
" self.reduction1x1 = nn.Conv2d(128, 384, 1, stride=1, padding=0)\n",
" self.shortcut = nn.Conv2d(in_channels, 384, 1, stride=1, padding=0)\n",
" self.bn = nn.BatchNorm2d(384)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x_shortcut = self.shortcut(x)\n",
" x = torch.cat((self.branch1x1(x), self.branch3x3(x), self.branch3x3stack(x)), dim=1)\n",
" x = self.reduction1x1(x)\n",
" x = self.bn(x_shortcut + x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"\n",
"class Inception_Resnet_B(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
"\n",
" self.branch1x1 = BasicConv2d(in_channels, 192, 1, stride=1, padding=0)\n",
" self.branch7x7 = nn.Sequential(\n",
" BasicConv2d(in_channels, 128, 1, stride=1, padding=0),\n",
" BasicConv2d(128, 160, (1,7), stride=1, padding=(0,3)),\n",
" BasicConv2d(160, 192, (7,1), stride=1, padding=(3,0))\n",
" )\n",
"\n",
" self.reduction1x1 = nn.Conv2d(384, 1152, 1, stride=1, padding=0)\n",
" self.shortcut = nn.Conv2d(in_channels, 1152, 1, stride=1, padding=0)\n",
" self.bn = nn.BatchNorm2d(1152)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x_shortcut = self.shortcut(x)\n",
" x = torch.cat((self.branch1x1(x), self.branch7x7(x)), dim=1)\n",
" x = self.reduction1x1(x) * 0.1\n",
" x = self.bn(x + x_shortcut)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"\n",
"class Inception_Resnet_C(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
"\n",
" self.branch1x1 = BasicConv2d(in_channels, 192, 1, stride=1, padding=0)\n",
" self.branch3x3 = nn.Sequential(\n",
" BasicConv2d(in_channels, 192, 1, stride=1, padding=0),\n",
" BasicConv2d(192, 224, (1,3), stride=1, padding=(0,1)),\n",
" BasicConv2d(224, 256, (3,1), stride=1, padding=(1,0))\n",
" )\n",
"\n",
" self.reduction1x1 = nn.Conv2d(448, 2144, 1, stride=1, padding=0)\n",
" self.shortcut = nn.Conv2d(in_channels, 2144, 1, stride=1, padding=0) # 2144\n",
" self.bn = nn.BatchNorm2d(2144)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x_shortcut = self.shortcut(x)\n",
" x = torch.cat((self.branch1x1(x), self.branch3x3(x)), dim=1)\n",
" x = self.reduction1x1(x) * 0.1\n",
" x = self.bn(x_shortcut + x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
" \n",
"class ReductionA(nn.Module):\n",
" def __init__(self, in_channels, k, l, m, n):\n",
" super().__init__()\n",
"\n",
" self.branchpool = nn.MaxPool2d(3, 2)\n",
" self.branch3x3 = BasicConv2d(in_channels, n, 3, stride=2, padding=0)\n",
" self.branch3x3stack = nn.Sequential(\n",
" BasicConv2d(in_channels, k, 1, stride=1, padding=0),\n",
" BasicConv2d(k, l, 3, stride=1, padding=1),\n",
" BasicConv2d(l, m, 3, stride=2, padding=0)\n",
" )\n",
"\n",
" self.output_channels = in_channels + n + m\n",
"\n",
" def forward(self, x):\n",
" x = torch.cat((self.branchpool(x), self.branch3x3(x), self.branch3x3stack(x)), dim=1)\n",
" return x\n",
"\n",
"\n",
"class ReductionB(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
"\n",
" self.branchpool = nn.MaxPool2d(3, 2)\n",
" self.branch3x3a = nn.Sequential(\n",
" BasicConv2d(in_channels, 256, 1, stride=1, padding=0),\n",
" BasicConv2d(256, 384, 3, stride=2, padding=0)\n",
" )\n",
" self.branch3x3b = nn.Sequential(\n",
" BasicConv2d(in_channels, 256, 1, stride=1, padding=0),\n",
" BasicConv2d(256, 288, 3, stride=2, padding=0)\n",
" )\n",
" self.branch3x3stack = nn.Sequential(\n",
" BasicConv2d(in_channels, 256, 1, stride=1, padding=0),\n",
" BasicConv2d(256, 288, 3, stride=1, padding=1),\n",
" BasicConv2d(288, 320, 3, stride=2, padding=0)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = torch.cat((self.branchpool(x), self.branch3x3a(x), self.branch3x3b(x), self.branch3x3stack(x)), dim=1)\n",
" return x\n",
"\n",
"\n",
"class InceptionResNetV2(nn.Module):\n",
" def __init__(self, A, B, C, k=256, l=256, m=384, n=384, num_classes=10, init_weights=True):\n",
" super().__init__()\n",
" blocks = []\n",
" blocks.append(Stem())\n",
" for i in range(A):\n",
" blocks.append(Inception_Resnet_A(384))\n",
" blocks.append(ReductionA(384, k, l, m, n))\n",
" for i in range(B):\n",
" blocks.append(Inception_Resnet_B(1152))\n",
" blocks.append(ReductionB(1152))\n",
" for i in range(C):\n",
" blocks.append(Inception_Resnet_C(2144))\n",
"\n",
" self.features = nn.Sequential(*blocks)\n",
"\n",
" self.avgpool = nn.AdaptiveAvgPool2d((1,1))\n",
" # drop out\n",
" self.dropout = nn.Dropout2d(0.2)\n",
" self.linear = nn.Linear(2144, num_classes)\n",
"\n",
" # weights inittialization\n",
" if init_weights:\n",
" self._initialize_weights()\n",
"\n",
" def forward(self, x):\n",
" x = self.features(x)\n",
" x = self.avgpool(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.dropout(x)\n",
" x = self.linear(x)\n",
" return x\n",
"\n",
" # define weight initialization function\n",
" def _initialize_weights(self):\n",
" for m in self.modules():\n",
" if isinstance(m, nn.Conv2d):\n",
" nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
" if m.bias is not None:\n",
" nn.init.constant_(m.bias, 0)\n",
" elif isinstance(m, nn.BatchNorm2d):\n",
" nn.init.constant_(m.weight, 1)\n",
" nn.init.constant_(m.bias, 0)\n",
" elif isinstance(m, nn.Linear):\n",
" nn.init.normal_(m.weight, 0, 0.01)\n",
" nn.init.constant_(m.bias, 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Check"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input size: torch.Size([3, 3, 299, 299])\n",
"Stem output size: torch.Size([3, 384, 35, 35])\n"
]
}
],
"source": [
"# check Stem\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"x = torch.randn((3, 3, 299, 299)).to(device)\n",
"model = Stem().to(device)\n",
"output_Stem = model(x)\n",
"print('Input size:', x.size())\n",
"print('Stem output size:', output_Stem.size())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input size: torch.Size([3, 384, 35, 35])\n",
"output size: torch.Size([3, 384, 35, 35])\n"
]
}
],
"source": [
"# check Inception_Resnet_A\n",
"model = Inception_Resnet_A(output_Stem.size()[1]).to(device)\n",
"output_resA = model(output_Stem)\n",
"print('Input size:', output_Stem.size())\n",
"print('output size:', output_resA.size())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input size: torch.Size([3, 384, 35, 35])\n",
"output size: torch.Size([3, 1152, 17, 17])\n"
]
}
],
"source": [
"# check ReductionA\n",
"print('input size:', output_resA.size())\n",
"model = ReductionA(output_resA.size()[1], 256, 256, 384, 384).to(device)\n",
"output_rA = model(output_resA)\n",
"print('output size:', output_rA.size())"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input size: torch.Size([3, 1152, 17, 17])\n",
"output size: torch.Size([3, 1152, 17, 17])\n"
]
}
],
"source": [
"# check Inception_Resnet_B\n",
"model = Inception_Resnet_B(output_rA.size()[1]).to(device)\n",
"output_resB = model(output_rA)\n",
"print('Input size:', output_rA.size())\n",
"print('output size:', output_resB.size())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input size: torch.Size([3, 1152, 17, 17])\n",
"output size: torch.Size([3, 2144, 8, 8])\n"
]
}
],
"source": [
"# check ReductionB\n",
"model = ReductionB(output_resB.size()[1]).to(device)\n",
"output_rB = model(output_resB)\n",
"print('Input size:', output_resB.size())\n",
"print('output size:', output_rB.size())"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input size: torch.Size([3, 2144, 8, 8])\n",
"output size: torch.Size([3, 2144, 8, 8])\n"
]
}
],
"source": [
"# check Inception_Resnet_C\n",
"model = Inception_Resnet_C(output_rB.size()[1]).to(device)\n",
"output_resC = model(output_rB)\n",
"print('Input size:', output_rB.size())\n",
"print('output size:', output_resC.size())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create Model"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 32, 149, 149] 864\n",
" BatchNorm2d-2 [-1, 32, 149, 149] 64\n",
" ReLU-3 [-1, 32, 149, 149] 0\n",
" BasicConv2d-4 [-1, 32, 149, 149] 0\n",
" Conv2d-5 [-1, 32, 147, 147] 9,216\n",
" BatchNorm2d-6 [-1, 32, 147, 147] 64\n",
" ReLU-7 [-1, 32, 147, 147] 0\n",
" BasicConv2d-8 [-1, 32, 147, 147] 0\n",
" Conv2d-9 [-1, 64, 147, 147] 18,432\n",
" BatchNorm2d-10 [-1, 64, 147, 147] 128\n",
" ReLU-11 [-1, 64, 147, 147] 0\n",
" BasicConv2d-12 [-1, 64, 147, 147] 0\n",
" Conv2d-13 [-1, 96, 73, 73] 55,296\n",
" BatchNorm2d-14 [-1, 96, 73, 73] 192\n",
" ReLU-15 [-1, 96, 73, 73] 0\n",
" BasicConv2d-16 [-1, 96, 73, 73] 0\n",
" MaxPool2d-17 [-1, 64, 73, 73] 0\n",
" Conv2d-18 [-1, 64, 73, 73] 10,240\n",
" BatchNorm2d-19 [-1, 64, 73, 73] 128\n",
" ReLU-20 [-1, 64, 73, 73] 0\n",
" BasicConv2d-21 [-1, 64, 73, 73] 0\n",
" Conv2d-22 [-1, 96, 71, 71] 55,296\n",
" BatchNorm2d-23 [-1, 96, 71, 71] 192\n",
" ReLU-24 [-1, 96, 71, 71] 0\n",
" BasicConv2d-25 [-1, 96, 71, 71] 0\n",
" Conv2d-26 [-1, 64, 73, 73] 10,240\n",
" BatchNorm2d-27 [-1, 64, 73, 73] 128\n",
" ReLU-28 [-1, 64, 73, 73] 0\n",
" BasicConv2d-29 [-1, 64, 73, 73] 0\n",
" Conv2d-30 [-1, 64, 73, 73] 28,672\n",
" BatchNorm2d-31 [-1, 64, 73, 73] 128\n",
" ReLU-32 [-1, 64, 73, 73] 0\n",
" BasicConv2d-33 [-1, 64, 73, 73] 0\n",
" Conv2d-34 [-1, 64, 73, 73] 28,672\n",
" BatchNorm2d-35 [-1, 64, 73, 73] 128\n",
" ReLU-36 [-1, 64, 73, 73] 0\n",
" BasicConv2d-37 [-1, 64, 73, 73] 0\n",
" Conv2d-38 [-1, 96, 71, 71] 55,296\n",
" BatchNorm2d-39 [-1, 96, 71, 71] 192\n",
" ReLU-40 [-1, 96, 71, 71] 0\n",
" BasicConv2d-41 [-1, 96, 71, 71] 0\n",
" Conv2d-42 [-1, 192, 35, 35] 331,776\n",
" BatchNorm2d-43 [-1, 192, 35, 35] 384\n",
" ReLU-44 [-1, 192, 35, 35] 0\n",
" BasicConv2d-45 [-1, 192, 35, 35] 0\n",
" MaxPool2d-46 [-1, 192, 35, 35] 0\n",
" Stem-47 [-1, 384, 35, 35] 0\n",
" Conv2d-48 [-1, 384, 35, 35] 147,840\n",
" Conv2d-49 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-50 [-1, 32, 35, 35] 64\n",
" ReLU-51 [-1, 32, 35, 35] 0\n",
" BasicConv2d-52 [-1, 32, 35, 35] 0\n",
" Conv2d-53 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-54 [-1, 32, 35, 35] 64\n",
" ReLU-55 [-1, 32, 35, 35] 0\n",
" BasicConv2d-56 [-1, 32, 35, 35] 0\n",
" Conv2d-57 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-58 [-1, 32, 35, 35] 64\n",
" ReLU-59 [-1, 32, 35, 35] 0\n",
" BasicConv2d-60 [-1, 32, 35, 35] 0\n",
" Conv2d-61 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-62 [-1, 32, 35, 35] 64\n",
" ReLU-63 [-1, 32, 35, 35] 0\n",
" BasicConv2d-64 [-1, 32, 35, 35] 0\n",
" Conv2d-65 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-66 [-1, 48, 35, 35] 96\n",
" ReLU-67 [-1, 48, 35, 35] 0\n",
" BasicConv2d-68 [-1, 48, 35, 35] 0\n",
" Conv2d-69 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-70 [-1, 64, 35, 35] 128\n",
" ReLU-71 [-1, 64, 35, 35] 0\n",
" BasicConv2d-72 [-1, 64, 35, 35] 0\n",
" Conv2d-73 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-74 [-1, 384, 35, 35] 768\n",
" ReLU-75 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-76 [-1, 384, 35, 35] 0\n",
" Conv2d-77 [-1, 384, 35, 35] 147,840\n",
" Conv2d-78 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-79 [-1, 32, 35, 35] 64\n",
" ReLU-80 [-1, 32, 35, 35] 0\n",
" BasicConv2d-81 [-1, 32, 35, 35] 0\n",
" Conv2d-82 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-83 [-1, 32, 35, 35] 64\n",
" ReLU-84 [-1, 32, 35, 35] 0\n",
" BasicConv2d-85 [-1, 32, 35, 35] 0\n",
" Conv2d-86 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-87 [-1, 32, 35, 35] 64\n",
" ReLU-88 [-1, 32, 35, 35] 0\n",
" BasicConv2d-89 [-1, 32, 35, 35] 0\n",
" Conv2d-90 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-91 [-1, 32, 35, 35] 64\n",
" ReLU-92 [-1, 32, 35, 35] 0\n",
" BasicConv2d-93 [-1, 32, 35, 35] 0\n",
" Conv2d-94 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-95 [-1, 48, 35, 35] 96\n",
" ReLU-96 [-1, 48, 35, 35] 0\n",
" BasicConv2d-97 [-1, 48, 35, 35] 0\n",
" Conv2d-98 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-99 [-1, 64, 35, 35] 128\n",
" ReLU-100 [-1, 64, 35, 35] 0\n",
" BasicConv2d-101 [-1, 64, 35, 35] 0\n",
" Conv2d-102 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-103 [-1, 384, 35, 35] 768\n",
" ReLU-104 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-105 [-1, 384, 35, 35] 0\n",
" Conv2d-106 [-1, 384, 35, 35] 147,840\n",
" Conv2d-107 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-108 [-1, 32, 35, 35] 64\n",
" ReLU-109 [-1, 32, 35, 35] 0\n",
" BasicConv2d-110 [-1, 32, 35, 35] 0\n",
" Conv2d-111 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-112 [-1, 32, 35, 35] 64\n",
" ReLU-113 [-1, 32, 35, 35] 0\n",
" BasicConv2d-114 [-1, 32, 35, 35] 0\n",
" Conv2d-115 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-116 [-1, 32, 35, 35] 64\n",
" ReLU-117 [-1, 32, 35, 35] 0\n",
" BasicConv2d-118 [-1, 32, 35, 35] 0\n",
" Conv2d-119 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-120 [-1, 32, 35, 35] 64\n",
" ReLU-121 [-1, 32, 35, 35] 0\n",
" BasicConv2d-122 [-1, 32, 35, 35] 0\n",
" Conv2d-123 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-124 [-1, 48, 35, 35] 96\n",
" ReLU-125 [-1, 48, 35, 35] 0\n",
" BasicConv2d-126 [-1, 48, 35, 35] 0\n",
" Conv2d-127 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-128 [-1, 64, 35, 35] 128\n",
" ReLU-129 [-1, 64, 35, 35] 0\n",
" BasicConv2d-130 [-1, 64, 35, 35] 0\n",
" Conv2d-131 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-132 [-1, 384, 35, 35] 768\n",
" ReLU-133 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-134 [-1, 384, 35, 35] 0\n",
" Conv2d-135 [-1, 384, 35, 35] 147,840\n",
" Conv2d-136 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-137 [-1, 32, 35, 35] 64\n",
" ReLU-138 [-1, 32, 35, 35] 0\n",
" BasicConv2d-139 [-1, 32, 35, 35] 0\n",
" Conv2d-140 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-141 [-1, 32, 35, 35] 64\n",
" ReLU-142 [-1, 32, 35, 35] 0\n",
" BasicConv2d-143 [-1, 32, 35, 35] 0\n",
" Conv2d-144 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-145 [-1, 32, 35, 35] 64\n",
" ReLU-146 [-1, 32, 35, 35] 0\n",
" BasicConv2d-147 [-1, 32, 35, 35] 0\n",
" Conv2d-148 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-149 [-1, 32, 35, 35] 64\n",
" ReLU-150 [-1, 32, 35, 35] 0\n",
" BasicConv2d-151 [-1, 32, 35, 35] 0\n",
" Conv2d-152 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-153 [-1, 48, 35, 35] 96\n",
" ReLU-154 [-1, 48, 35, 35] 0\n",
" BasicConv2d-155 [-1, 48, 35, 35] 0\n",
" Conv2d-156 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-157 [-1, 64, 35, 35] 128\n",
" ReLU-158 [-1, 64, 35, 35] 0\n",
" BasicConv2d-159 [-1, 64, 35, 35] 0\n",
" Conv2d-160 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-161 [-1, 384, 35, 35] 768\n",
" ReLU-162 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-163 [-1, 384, 35, 35] 0\n",
" Conv2d-164 [-1, 384, 35, 35] 147,840\n",
" Conv2d-165 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-166 [-1, 32, 35, 35] 64\n",
" ReLU-167 [-1, 32, 35, 35] 0\n",
" BasicConv2d-168 [-1, 32, 35, 35] 0\n",
" Conv2d-169 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-170 [-1, 32, 35, 35] 64\n",
" ReLU-171 [-1, 32, 35, 35] 0\n",
" BasicConv2d-172 [-1, 32, 35, 35] 0\n",
" Conv2d-173 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-174 [-1, 32, 35, 35] 64\n",
" ReLU-175 [-1, 32, 35, 35] 0\n",
" BasicConv2d-176 [-1, 32, 35, 35] 0\n",
" Conv2d-177 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-178 [-1, 32, 35, 35] 64\n",
" ReLU-179 [-1, 32, 35, 35] 0\n",
" BasicConv2d-180 [-1, 32, 35, 35] 0\n",
" Conv2d-181 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-182 [-1, 48, 35, 35] 96\n",
" ReLU-183 [-1, 48, 35, 35] 0\n",
" BasicConv2d-184 [-1, 48, 35, 35] 0\n",
" Conv2d-185 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-186 [-1, 64, 35, 35] 128\n",
" ReLU-187 [-1, 64, 35, 35] 0\n",
" BasicConv2d-188 [-1, 64, 35, 35] 0\n",
" Conv2d-189 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-190 [-1, 384, 35, 35] 768\n",
" ReLU-191 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-192 [-1, 384, 35, 35] 0\n",
" Conv2d-193 [-1, 384, 35, 35] 147,840\n",
" Conv2d-194 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-195 [-1, 32, 35, 35] 64\n",
" ReLU-196 [-1, 32, 35, 35] 0\n",
" BasicConv2d-197 [-1, 32, 35, 35] 0\n",
" Conv2d-198 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-199 [-1, 32, 35, 35] 64\n",
" ReLU-200 [-1, 32, 35, 35] 0\n",
" BasicConv2d-201 [-1, 32, 35, 35] 0\n",
" Conv2d-202 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-203 [-1, 32, 35, 35] 64\n",
" ReLU-204 [-1, 32, 35, 35] 0\n",
" BasicConv2d-205 [-1, 32, 35, 35] 0\n",
" Conv2d-206 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-207 [-1, 32, 35, 35] 64\n",
" ReLU-208 [-1, 32, 35, 35] 0\n",
" BasicConv2d-209 [-1, 32, 35, 35] 0\n",
" Conv2d-210 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-211 [-1, 48, 35, 35] 96\n",
" ReLU-212 [-1, 48, 35, 35] 0\n",
" BasicConv2d-213 [-1, 48, 35, 35] 0\n",
" Conv2d-214 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-215 [-1, 64, 35, 35] 128\n",
" ReLU-216 [-1, 64, 35, 35] 0\n",
" BasicConv2d-217 [-1, 64, 35, 35] 0\n",
" Conv2d-218 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-219 [-1, 384, 35, 35] 768\n",
" ReLU-220 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-221 [-1, 384, 35, 35] 0\n",
" Conv2d-222 [-1, 384, 35, 35] 147,840\n",
" Conv2d-223 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-224 [-1, 32, 35, 35] 64\n",
" ReLU-225 [-1, 32, 35, 35] 0\n",
" BasicConv2d-226 [-1, 32, 35, 35] 0\n",
" Conv2d-227 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-228 [-1, 32, 35, 35] 64\n",
" ReLU-229 [-1, 32, 35, 35] 0\n",
" BasicConv2d-230 [-1, 32, 35, 35] 0\n",
" Conv2d-231 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-232 [-1, 32, 35, 35] 64\n",
" ReLU-233 [-1, 32, 35, 35] 0\n",
" BasicConv2d-234 [-1, 32, 35, 35] 0\n",
" Conv2d-235 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-236 [-1, 32, 35, 35] 64\n",
" ReLU-237 [-1, 32, 35, 35] 0\n",
" BasicConv2d-238 [-1, 32, 35, 35] 0\n",
" Conv2d-239 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-240 [-1, 48, 35, 35] 96\n",
" ReLU-241 [-1, 48, 35, 35] 0\n",
" BasicConv2d-242 [-1, 48, 35, 35] 0\n",
" Conv2d-243 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-244 [-1, 64, 35, 35] 128\n",
" ReLU-245 [-1, 64, 35, 35] 0\n",
" BasicConv2d-246 [-1, 64, 35, 35] 0\n",
" Conv2d-247 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-248 [-1, 384, 35, 35] 768\n",
" ReLU-249 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-250 [-1, 384, 35, 35] 0\n",
" Conv2d-251 [-1, 384, 35, 35] 147,840\n",
" Conv2d-252 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-253 [-1, 32, 35, 35] 64\n",
" ReLU-254 [-1, 32, 35, 35] 0\n",
" BasicConv2d-255 [-1, 32, 35, 35] 0\n",
" Conv2d-256 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-257 [-1, 32, 35, 35] 64\n",
" ReLU-258 [-1, 32, 35, 35] 0\n",
" BasicConv2d-259 [-1, 32, 35, 35] 0\n",
" Conv2d-260 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-261 [-1, 32, 35, 35] 64\n",
" ReLU-262 [-1, 32, 35, 35] 0\n",
" BasicConv2d-263 [-1, 32, 35, 35] 0\n",
" Conv2d-264 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-265 [-1, 32, 35, 35] 64\n",
" ReLU-266 [-1, 32, 35, 35] 0\n",
" BasicConv2d-267 [-1, 32, 35, 35] 0\n",
" Conv2d-268 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-269 [-1, 48, 35, 35] 96\n",
" ReLU-270 [-1, 48, 35, 35] 0\n",
" BasicConv2d-271 [-1, 48, 35, 35] 0\n",
" Conv2d-272 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-273 [-1, 64, 35, 35] 128\n",
" ReLU-274 [-1, 64, 35, 35] 0\n",
" BasicConv2d-275 [-1, 64, 35, 35] 0\n",
" Conv2d-276 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-277 [-1, 384, 35, 35] 768\n",
" ReLU-278 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-279 [-1, 384, 35, 35] 0\n",
" Conv2d-280 [-1, 384, 35, 35] 147,840\n",
" Conv2d-281 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-282 [-1, 32, 35, 35] 64\n",
" ReLU-283 [-1, 32, 35, 35] 0\n",
" BasicConv2d-284 [-1, 32, 35, 35] 0\n",
" Conv2d-285 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-286 [-1, 32, 35, 35] 64\n",
" ReLU-287 [-1, 32, 35, 35] 0\n",
" BasicConv2d-288 [-1, 32, 35, 35] 0\n",
" Conv2d-289 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-290 [-1, 32, 35, 35] 64\n",
" ReLU-291 [-1, 32, 35, 35] 0\n",
" BasicConv2d-292 [-1, 32, 35, 35] 0\n",
" Conv2d-293 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-294 [-1, 32, 35, 35] 64\n",
" ReLU-295 [-1, 32, 35, 35] 0\n",
" BasicConv2d-296 [-1, 32, 35, 35] 0\n",
" Conv2d-297 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-298 [-1, 48, 35, 35] 96\n",
" ReLU-299 [-1, 48, 35, 35] 0\n",
" BasicConv2d-300 [-1, 48, 35, 35] 0\n",
" Conv2d-301 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-302 [-1, 64, 35, 35] 128\n",
" ReLU-303 [-1, 64, 35, 35] 0\n",
" BasicConv2d-304 [-1, 64, 35, 35] 0\n",
" Conv2d-305 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-306 [-1, 384, 35, 35] 768\n",
" ReLU-307 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-308 [-1, 384, 35, 35] 0\n",
" Conv2d-309 [-1, 384, 35, 35] 147,840\n",
" Conv2d-310 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-311 [-1, 32, 35, 35] 64\n",
" ReLU-312 [-1, 32, 35, 35] 0\n",
" BasicConv2d-313 [-1, 32, 35, 35] 0\n",
" Conv2d-314 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-315 [-1, 32, 35, 35] 64\n",
" ReLU-316 [-1, 32, 35, 35] 0\n",
" BasicConv2d-317 [-1, 32, 35, 35] 0\n",
" Conv2d-318 [-1, 32, 35, 35] 9,216\n",
" BatchNorm2d-319 [-1, 32, 35, 35] 64\n",
" ReLU-320 [-1, 32, 35, 35] 0\n",
" BasicConv2d-321 [-1, 32, 35, 35] 0\n",
" Conv2d-322 [-1, 32, 35, 35] 12,288\n",
" BatchNorm2d-323 [-1, 32, 35, 35] 64\n",
" ReLU-324 [-1, 32, 35, 35] 0\n",
" BasicConv2d-325 [-1, 32, 35, 35] 0\n",
" Conv2d-326 [-1, 48, 35, 35] 13,824\n",
" BatchNorm2d-327 [-1, 48, 35, 35] 96\n",
" ReLU-328 [-1, 48, 35, 35] 0\n",
" BasicConv2d-329 [-1, 48, 35, 35] 0\n",
" Conv2d-330 [-1, 64, 35, 35] 27,648\n",
" BatchNorm2d-331 [-1, 64, 35, 35] 128\n",
" ReLU-332 [-1, 64, 35, 35] 0\n",
" BasicConv2d-333 [-1, 64, 35, 35] 0\n",
" Conv2d-334 [-1, 384, 35, 35] 49,536\n",
" BatchNorm2d-335 [-1, 384, 35, 35] 768\n",
" ReLU-336 [-1, 384, 35, 35] 0\n",
"Inception_Resnet_A-337 [-1, 384, 35, 35] 0\n",
" MaxPool2d-338 [-1, 384, 17, 17] 0\n",
" Conv2d-339 [-1, 384, 17, 17] 1,327,104\n",
" BatchNorm2d-340 [-1, 384, 17, 17] 768\n",
" ReLU-341 [-1, 384, 17, 17] 0\n",
" BasicConv2d-342 [-1, 384, 17, 17] 0\n",
" Conv2d-343 [-1, 256, 35, 35] 98,304\n",
" BatchNorm2d-344 [-1, 256, 35, 35] 512\n",
" ReLU-345 [-1, 256, 35, 35] 0\n",
" BasicConv2d-346 [-1, 256, 35, 35] 0\n",
" Conv2d-347 [-1, 256, 35, 35] 589,824\n",
" BatchNorm2d-348 [-1, 256, 35, 35] 512\n",
" ReLU-349 [-1, 256, 35, 35] 0\n",
" BasicConv2d-350 [-1, 256, 35, 35] 0\n",
" Conv2d-351 [-1, 384, 17, 17] 884,736\n",
" BatchNorm2d-352 [-1, 384, 17, 17] 768\n",
" ReLU-353 [-1, 384, 17, 17] 0\n",
" BasicConv2d-354 [-1, 384, 17, 17] 0\n",
" ReductionA-355 [-1, 1152, 17, 17] 0\n",
" Conv2d-356 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-357 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-358 [-1, 192, 17, 17] 384\n",
" ReLU-359 [-1, 192, 17, 17] 0\n",
" BasicConv2d-360 [-1, 192, 17, 17] 0\n",
" Conv2d-361 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-362 [-1, 128, 17, 17] 256\n",
" ReLU-363 [-1, 128, 17, 17] 0\n",
" BasicConv2d-364 [-1, 128, 17, 17] 0\n",
" Conv2d-365 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-366 [-1, 160, 17, 17] 320\n",
" ReLU-367 [-1, 160, 17, 17] 0\n",
" BasicConv2d-368 [-1, 160, 17, 17] 0\n",
" Conv2d-369 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-370 [-1, 192, 17, 17] 384\n",
" ReLU-371 [-1, 192, 17, 17] 0\n",
" BasicConv2d-372 [-1, 192, 17, 17] 0\n",
" Conv2d-373 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-374 [-1, 1152, 17, 17] 2,304\n",
" ReLU-375 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-376 [-1, 1152, 17, 17] 0\n",
" Conv2d-377 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-378 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-379 [-1, 192, 17, 17] 384\n",
" ReLU-380 [-1, 192, 17, 17] 0\n",
" BasicConv2d-381 [-1, 192, 17, 17] 0\n",
" Conv2d-382 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-383 [-1, 128, 17, 17] 256\n",
" ReLU-384 [-1, 128, 17, 17] 0\n",
" BasicConv2d-385 [-1, 128, 17, 17] 0\n",
" Conv2d-386 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-387 [-1, 160, 17, 17] 320\n",
" ReLU-388 [-1, 160, 17, 17] 0\n",
" BasicConv2d-389 [-1, 160, 17, 17] 0\n",
" Conv2d-390 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-391 [-1, 192, 17, 17] 384\n",
" ReLU-392 [-1, 192, 17, 17] 0\n",
" BasicConv2d-393 [-1, 192, 17, 17] 0\n",
" Conv2d-394 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-395 [-1, 1152, 17, 17] 2,304\n",
" ReLU-396 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-397 [-1, 1152, 17, 17] 0\n",
" Conv2d-398 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-399 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-400 [-1, 192, 17, 17] 384\n",
" ReLU-401 [-1, 192, 17, 17] 0\n",
" BasicConv2d-402 [-1, 192, 17, 17] 0\n",
" Conv2d-403 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-404 [-1, 128, 17, 17] 256\n",
" ReLU-405 [-1, 128, 17, 17] 0\n",
" BasicConv2d-406 [-1, 128, 17, 17] 0\n",
" Conv2d-407 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-408 [-1, 160, 17, 17] 320\n",
" ReLU-409 [-1, 160, 17, 17] 0\n",
" BasicConv2d-410 [-1, 160, 17, 17] 0\n",
" Conv2d-411 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-412 [-1, 192, 17, 17] 384\n",
" ReLU-413 [-1, 192, 17, 17] 0\n",
" BasicConv2d-414 [-1, 192, 17, 17] 0\n",
" Conv2d-415 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-416 [-1, 1152, 17, 17] 2,304\n",
" ReLU-417 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-418 [-1, 1152, 17, 17] 0\n",
" Conv2d-419 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-420 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-421 [-1, 192, 17, 17] 384\n",
" ReLU-422 [-1, 192, 17, 17] 0\n",
" BasicConv2d-423 [-1, 192, 17, 17] 0\n",
" Conv2d-424 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-425 [-1, 128, 17, 17] 256\n",
" ReLU-426 [-1, 128, 17, 17] 0\n",
" BasicConv2d-427 [-1, 128, 17, 17] 0\n",
" Conv2d-428 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-429 [-1, 160, 17, 17] 320\n",
" ReLU-430 [-1, 160, 17, 17] 0\n",
" BasicConv2d-431 [-1, 160, 17, 17] 0\n",
" Conv2d-432 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-433 [-1, 192, 17, 17] 384\n",
" ReLU-434 [-1, 192, 17, 17] 0\n",
" BasicConv2d-435 [-1, 192, 17, 17] 0\n",
" Conv2d-436 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-437 [-1, 1152, 17, 17] 2,304\n",
" ReLU-438 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-439 [-1, 1152, 17, 17] 0\n",
" Conv2d-440 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-441 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-442 [-1, 192, 17, 17] 384\n",
" ReLU-443 [-1, 192, 17, 17] 0\n",
" BasicConv2d-444 [-1, 192, 17, 17] 0\n",
" Conv2d-445 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-446 [-1, 128, 17, 17] 256\n",
" ReLU-447 [-1, 128, 17, 17] 0\n",
" BasicConv2d-448 [-1, 128, 17, 17] 0\n",
" Conv2d-449 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-450 [-1, 160, 17, 17] 320\n",
" ReLU-451 [-1, 160, 17, 17] 0\n",
" BasicConv2d-452 [-1, 160, 17, 17] 0\n",
" Conv2d-453 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-454 [-1, 192, 17, 17] 384\n",
" ReLU-455 [-1, 192, 17, 17] 0\n",
" BasicConv2d-456 [-1, 192, 17, 17] 0\n",
" Conv2d-457 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-458 [-1, 1152, 17, 17] 2,304\n",
" ReLU-459 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-460 [-1, 1152, 17, 17] 0\n",
" Conv2d-461 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-462 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-463 [-1, 192, 17, 17] 384\n",
" ReLU-464 [-1, 192, 17, 17] 0\n",
" BasicConv2d-465 [-1, 192, 17, 17] 0\n",
" Conv2d-466 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-467 [-1, 128, 17, 17] 256\n",
" ReLU-468 [-1, 128, 17, 17] 0\n",
" BasicConv2d-469 [-1, 128, 17, 17] 0\n",
" Conv2d-470 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-471 [-1, 160, 17, 17] 320\n",
" ReLU-472 [-1, 160, 17, 17] 0\n",
" BasicConv2d-473 [-1, 160, 17, 17] 0\n",
" Conv2d-474 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-475 [-1, 192, 17, 17] 384\n",
" ReLU-476 [-1, 192, 17, 17] 0\n",
" BasicConv2d-477 [-1, 192, 17, 17] 0\n",
" Conv2d-478 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-479 [-1, 1152, 17, 17] 2,304\n",
" ReLU-480 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-481 [-1, 1152, 17, 17] 0\n",
" Conv2d-482 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-483 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-484 [-1, 192, 17, 17] 384\n",
" ReLU-485 [-1, 192, 17, 17] 0\n",
" BasicConv2d-486 [-1, 192, 17, 17] 0\n",
" Conv2d-487 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-488 [-1, 128, 17, 17] 256\n",
" ReLU-489 [-1, 128, 17, 17] 0\n",
" BasicConv2d-490 [-1, 128, 17, 17] 0\n",
" Conv2d-491 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-492 [-1, 160, 17, 17] 320\n",
" ReLU-493 [-1, 160, 17, 17] 0\n",
" BasicConv2d-494 [-1, 160, 17, 17] 0\n",
" Conv2d-495 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-496 [-1, 192, 17, 17] 384\n",
" ReLU-497 [-1, 192, 17, 17] 0\n",
" BasicConv2d-498 [-1, 192, 17, 17] 0\n",
" Conv2d-499 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-500 [-1, 1152, 17, 17] 2,304\n",
" ReLU-501 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-502 [-1, 1152, 17, 17] 0\n",
" Conv2d-503 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-504 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-505 [-1, 192, 17, 17] 384\n",
" ReLU-506 [-1, 192, 17, 17] 0\n",
" BasicConv2d-507 [-1, 192, 17, 17] 0\n",
" Conv2d-508 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-509 [-1, 128, 17, 17] 256\n",
" ReLU-510 [-1, 128, 17, 17] 0\n",
" BasicConv2d-511 [-1, 128, 17, 17] 0\n",
" Conv2d-512 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-513 [-1, 160, 17, 17] 320\n",
" ReLU-514 [-1, 160, 17, 17] 0\n",
" BasicConv2d-515 [-1, 160, 17, 17] 0\n",
" Conv2d-516 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-517 [-1, 192, 17, 17] 384\n",
" ReLU-518 [-1, 192, 17, 17] 0\n",
" BasicConv2d-519 [-1, 192, 17, 17] 0\n",
" Conv2d-520 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-521 [-1, 1152, 17, 17] 2,304\n",
" ReLU-522 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-523 [-1, 1152, 17, 17] 0\n",
" Conv2d-524 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-525 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-526 [-1, 192, 17, 17] 384\n",
" ReLU-527 [-1, 192, 17, 17] 0\n",
" BasicConv2d-528 [-1, 192, 17, 17] 0\n",
" Conv2d-529 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-530 [-1, 128, 17, 17] 256\n",
" ReLU-531 [-1, 128, 17, 17] 0\n",
" BasicConv2d-532 [-1, 128, 17, 17] 0\n",
" Conv2d-533 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-534 [-1, 160, 17, 17] 320\n",
" ReLU-535 [-1, 160, 17, 17] 0\n",
" BasicConv2d-536 [-1, 160, 17, 17] 0\n",
" Conv2d-537 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-538 [-1, 192, 17, 17] 384\n",
" ReLU-539 [-1, 192, 17, 17] 0\n",
" BasicConv2d-540 [-1, 192, 17, 17] 0\n",
" Conv2d-541 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-542 [-1, 1152, 17, 17] 2,304\n",
" ReLU-543 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-544 [-1, 1152, 17, 17] 0\n",
" Conv2d-545 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-546 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-547 [-1, 192, 17, 17] 384\n",
" ReLU-548 [-1, 192, 17, 17] 0\n",
" BasicConv2d-549 [-1, 192, 17, 17] 0\n",
" Conv2d-550 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-551 [-1, 128, 17, 17] 256\n",
" ReLU-552 [-1, 128, 17, 17] 0\n",
" BasicConv2d-553 [-1, 128, 17, 17] 0\n",
" Conv2d-554 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-555 [-1, 160, 17, 17] 320\n",
" ReLU-556 [-1, 160, 17, 17] 0\n",
" BasicConv2d-557 [-1, 160, 17, 17] 0\n",
" Conv2d-558 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-559 [-1, 192, 17, 17] 384\n",
" ReLU-560 [-1, 192, 17, 17] 0\n",
" BasicConv2d-561 [-1, 192, 17, 17] 0\n",
" Conv2d-562 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-563 [-1, 1152, 17, 17] 2,304\n",
" ReLU-564 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-565 [-1, 1152, 17, 17] 0\n",
" Conv2d-566 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-567 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-568 [-1, 192, 17, 17] 384\n",
" ReLU-569 [-1, 192, 17, 17] 0\n",
" BasicConv2d-570 [-1, 192, 17, 17] 0\n",
" Conv2d-571 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-572 [-1, 128, 17, 17] 256\n",
" ReLU-573 [-1, 128, 17, 17] 0\n",
" BasicConv2d-574 [-1, 128, 17, 17] 0\n",
" Conv2d-575 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-576 [-1, 160, 17, 17] 320\n",
" ReLU-577 [-1, 160, 17, 17] 0\n",
" BasicConv2d-578 [-1, 160, 17, 17] 0\n",
" Conv2d-579 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-580 [-1, 192, 17, 17] 384\n",
" ReLU-581 [-1, 192, 17, 17] 0\n",
" BasicConv2d-582 [-1, 192, 17, 17] 0\n",
" Conv2d-583 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-584 [-1, 1152, 17, 17] 2,304\n",
" ReLU-585 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-586 [-1, 1152, 17, 17] 0\n",
" Conv2d-587 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-588 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-589 [-1, 192, 17, 17] 384\n",
" ReLU-590 [-1, 192, 17, 17] 0\n",
" BasicConv2d-591 [-1, 192, 17, 17] 0\n",
" Conv2d-592 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-593 [-1, 128, 17, 17] 256\n",
" ReLU-594 [-1, 128, 17, 17] 0\n",
" BasicConv2d-595 [-1, 128, 17, 17] 0\n",
" Conv2d-596 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-597 [-1, 160, 17, 17] 320\n",
" ReLU-598 [-1, 160, 17, 17] 0\n",
" BasicConv2d-599 [-1, 160, 17, 17] 0\n",
" Conv2d-600 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-601 [-1, 192, 17, 17] 384\n",
" ReLU-602 [-1, 192, 17, 17] 0\n",
" BasicConv2d-603 [-1, 192, 17, 17] 0\n",
" Conv2d-604 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-605 [-1, 1152, 17, 17] 2,304\n",
" ReLU-606 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-607 [-1, 1152, 17, 17] 0\n",
" Conv2d-608 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-609 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-610 [-1, 192, 17, 17] 384\n",
" ReLU-611 [-1, 192, 17, 17] 0\n",
" BasicConv2d-612 [-1, 192, 17, 17] 0\n",
" Conv2d-613 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-614 [-1, 128, 17, 17] 256\n",
" ReLU-615 [-1, 128, 17, 17] 0\n",
" BasicConv2d-616 [-1, 128, 17, 17] 0\n",
" Conv2d-617 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-618 [-1, 160, 17, 17] 320\n",
" ReLU-619 [-1, 160, 17, 17] 0\n",
" BasicConv2d-620 [-1, 160, 17, 17] 0\n",
" Conv2d-621 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-622 [-1, 192, 17, 17] 384\n",
" ReLU-623 [-1, 192, 17, 17] 0\n",
" BasicConv2d-624 [-1, 192, 17, 17] 0\n",
" Conv2d-625 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-626 [-1, 1152, 17, 17] 2,304\n",
" ReLU-627 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-628 [-1, 1152, 17, 17] 0\n",
" Conv2d-629 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-630 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-631 [-1, 192, 17, 17] 384\n",
" ReLU-632 [-1, 192, 17, 17] 0\n",
" BasicConv2d-633 [-1, 192, 17, 17] 0\n",
" Conv2d-634 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-635 [-1, 128, 17, 17] 256\n",
" ReLU-636 [-1, 128, 17, 17] 0\n",
" BasicConv2d-637 [-1, 128, 17, 17] 0\n",
" Conv2d-638 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-639 [-1, 160, 17, 17] 320\n",
" ReLU-640 [-1, 160, 17, 17] 0\n",
" BasicConv2d-641 [-1, 160, 17, 17] 0\n",
" Conv2d-642 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-643 [-1, 192, 17, 17] 384\n",
" ReLU-644 [-1, 192, 17, 17] 0\n",
" BasicConv2d-645 [-1, 192, 17, 17] 0\n",
" Conv2d-646 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-647 [-1, 1152, 17, 17] 2,304\n",
" ReLU-648 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-649 [-1, 1152, 17, 17] 0\n",
" Conv2d-650 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-651 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-652 [-1, 192, 17, 17] 384\n",
" ReLU-653 [-1, 192, 17, 17] 0\n",
" BasicConv2d-654 [-1, 192, 17, 17] 0\n",
" Conv2d-655 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-656 [-1, 128, 17, 17] 256\n",
" ReLU-657 [-1, 128, 17, 17] 0\n",
" BasicConv2d-658 [-1, 128, 17, 17] 0\n",
" Conv2d-659 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-660 [-1, 160, 17, 17] 320\n",
" ReLU-661 [-1, 160, 17, 17] 0\n",
" BasicConv2d-662 [-1, 160, 17, 17] 0\n",
" Conv2d-663 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-664 [-1, 192, 17, 17] 384\n",
" ReLU-665 [-1, 192, 17, 17] 0\n",
" BasicConv2d-666 [-1, 192, 17, 17] 0\n",
" Conv2d-667 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-668 [-1, 1152, 17, 17] 2,304\n",
" ReLU-669 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-670 [-1, 1152, 17, 17] 0\n",
" Conv2d-671 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-672 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-673 [-1, 192, 17, 17] 384\n",
" ReLU-674 [-1, 192, 17, 17] 0\n",
" BasicConv2d-675 [-1, 192, 17, 17] 0\n",
" Conv2d-676 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-677 [-1, 128, 17, 17] 256\n",
" ReLU-678 [-1, 128, 17, 17] 0\n",
" BasicConv2d-679 [-1, 128, 17, 17] 0\n",
" Conv2d-680 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-681 [-1, 160, 17, 17] 320\n",
" ReLU-682 [-1, 160, 17, 17] 0\n",
" BasicConv2d-683 [-1, 160, 17, 17] 0\n",
" Conv2d-684 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-685 [-1, 192, 17, 17] 384\n",
" ReLU-686 [-1, 192, 17, 17] 0\n",
" BasicConv2d-687 [-1, 192, 17, 17] 0\n",
" Conv2d-688 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-689 [-1, 1152, 17, 17] 2,304\n",
" ReLU-690 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-691 [-1, 1152, 17, 17] 0\n",
" Conv2d-692 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-693 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-694 [-1, 192, 17, 17] 384\n",
" ReLU-695 [-1, 192, 17, 17] 0\n",
" BasicConv2d-696 [-1, 192, 17, 17] 0\n",
" Conv2d-697 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-698 [-1, 128, 17, 17] 256\n",
" ReLU-699 [-1, 128, 17, 17] 0\n",
" BasicConv2d-700 [-1, 128, 17, 17] 0\n",
" Conv2d-701 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-702 [-1, 160, 17, 17] 320\n",
" ReLU-703 [-1, 160, 17, 17] 0\n",
" BasicConv2d-704 [-1, 160, 17, 17] 0\n",
" Conv2d-705 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-706 [-1, 192, 17, 17] 384\n",
" ReLU-707 [-1, 192, 17, 17] 0\n",
" BasicConv2d-708 [-1, 192, 17, 17] 0\n",
" Conv2d-709 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-710 [-1, 1152, 17, 17] 2,304\n",
" ReLU-711 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-712 [-1, 1152, 17, 17] 0\n",
" Conv2d-713 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-714 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-715 [-1, 192, 17, 17] 384\n",
" ReLU-716 [-1, 192, 17, 17] 0\n",
" BasicConv2d-717 [-1, 192, 17, 17] 0\n",
" Conv2d-718 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-719 [-1, 128, 17, 17] 256\n",
" ReLU-720 [-1, 128, 17, 17] 0\n",
" BasicConv2d-721 [-1, 128, 17, 17] 0\n",
" Conv2d-722 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-723 [-1, 160, 17, 17] 320\n",
" ReLU-724 [-1, 160, 17, 17] 0\n",
" BasicConv2d-725 [-1, 160, 17, 17] 0\n",
" Conv2d-726 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-727 [-1, 192, 17, 17] 384\n",
" ReLU-728 [-1, 192, 17, 17] 0\n",
" BasicConv2d-729 [-1, 192, 17, 17] 0\n",
" Conv2d-730 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-731 [-1, 1152, 17, 17] 2,304\n",
" ReLU-732 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-733 [-1, 1152, 17, 17] 0\n",
" Conv2d-734 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-735 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-736 [-1, 192, 17, 17] 384\n",
" ReLU-737 [-1, 192, 17, 17] 0\n",
" BasicConv2d-738 [-1, 192, 17, 17] 0\n",
" Conv2d-739 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-740 [-1, 128, 17, 17] 256\n",
" ReLU-741 [-1, 128, 17, 17] 0\n",
" BasicConv2d-742 [-1, 128, 17, 17] 0\n",
" Conv2d-743 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-744 [-1, 160, 17, 17] 320\n",
" ReLU-745 [-1, 160, 17, 17] 0\n",
" BasicConv2d-746 [-1, 160, 17, 17] 0\n",
" Conv2d-747 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-748 [-1, 192, 17, 17] 384\n",
" ReLU-749 [-1, 192, 17, 17] 0\n",
" BasicConv2d-750 [-1, 192, 17, 17] 0\n",
" Conv2d-751 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-752 [-1, 1152, 17, 17] 2,304\n",
" ReLU-753 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-754 [-1, 1152, 17, 17] 0\n",
" Conv2d-755 [-1, 1152, 17, 17] 1,328,256\n",
" Conv2d-756 [-1, 192, 17, 17] 221,184\n",
" BatchNorm2d-757 [-1, 192, 17, 17] 384\n",
" ReLU-758 [-1, 192, 17, 17] 0\n",
" BasicConv2d-759 [-1, 192, 17, 17] 0\n",
" Conv2d-760 [-1, 128, 17, 17] 147,456\n",
" BatchNorm2d-761 [-1, 128, 17, 17] 256\n",
" ReLU-762 [-1, 128, 17, 17] 0\n",
" BasicConv2d-763 [-1, 128, 17, 17] 0\n",
" Conv2d-764 [-1, 160, 17, 17] 143,360\n",
" BatchNorm2d-765 [-1, 160, 17, 17] 320\n",
" ReLU-766 [-1, 160, 17, 17] 0\n",
" BasicConv2d-767 [-1, 160, 17, 17] 0\n",
" Conv2d-768 [-1, 192, 17, 17] 215,040\n",
" BatchNorm2d-769 [-1, 192, 17, 17] 384\n",
" ReLU-770 [-1, 192, 17, 17] 0\n",
" BasicConv2d-771 [-1, 192, 17, 17] 0\n",
" Conv2d-772 [-1, 1152, 17, 17] 443,520\n",
" BatchNorm2d-773 [-1, 1152, 17, 17] 2,304\n",
" ReLU-774 [-1, 1152, 17, 17] 0\n",
"Inception_Resnet_B-775 [-1, 1152, 17, 17] 0\n",
" MaxPool2d-776 [-1, 1152, 8, 8] 0\n",
" Conv2d-777 [-1, 256, 17, 17] 294,912\n",
" BatchNorm2d-778 [-1, 256, 17, 17] 512\n",
" ReLU-779 [-1, 256, 17, 17] 0\n",
" BasicConv2d-780 [-1, 256, 17, 17] 0\n",
" Conv2d-781 [-1, 384, 8, 8] 884,736\n",
" BatchNorm2d-782 [-1, 384, 8, 8] 768\n",
" ReLU-783 [-1, 384, 8, 8] 0\n",
" BasicConv2d-784 [-1, 384, 8, 8] 0\n",
" Conv2d-785 [-1, 256, 17, 17] 294,912\n",
" BatchNorm2d-786 [-1, 256, 17, 17] 512\n",
" ReLU-787 [-1, 256, 17, 17] 0\n",
" BasicConv2d-788 [-1, 256, 17, 17] 0\n",
" Conv2d-789 [-1, 288, 8, 8] 663,552\n",
" BatchNorm2d-790 [-1, 288, 8, 8] 576\n",
" ReLU-791 [-1, 288, 8, 8] 0\n",
" BasicConv2d-792 [-1, 288, 8, 8] 0\n",
" Conv2d-793 [-1, 256, 17, 17] 294,912\n",
" BatchNorm2d-794 [-1, 256, 17, 17] 512\n",
" ReLU-795 [-1, 256, 17, 17] 0\n",
" BasicConv2d-796 [-1, 256, 17, 17] 0\n",
" Conv2d-797 [-1, 288, 17, 17] 663,552\n",
" BatchNorm2d-798 [-1, 288, 17, 17] 576\n",
" ReLU-799 [-1, 288, 17, 17] 0\n",
" BasicConv2d-800 [-1, 288, 17, 17] 0\n",
" Conv2d-801 [-1, 320, 8, 8] 829,440\n",
" BatchNorm2d-802 [-1, 320, 8, 8] 640\n",
" ReLU-803 [-1, 320, 8, 8] 0\n",
" BasicConv2d-804 [-1, 320, 8, 8] 0\n",
" ReductionB-805 [-1, 2144, 8, 8] 0\n",
" Conv2d-806 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-807 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-808 [-1, 192, 8, 8] 384\n",
" ReLU-809 [-1, 192, 8, 8] 0\n",
" BasicConv2d-810 [-1, 192, 8, 8] 0\n",
" Conv2d-811 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-812 [-1, 192, 8, 8] 384\n",
" ReLU-813 [-1, 192, 8, 8] 0\n",
" BasicConv2d-814 [-1, 192, 8, 8] 0\n",
" Conv2d-815 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-816 [-1, 224, 8, 8] 448\n",
" ReLU-817 [-1, 224, 8, 8] 0\n",
" BasicConv2d-818 [-1, 224, 8, 8] 0\n",
" Conv2d-819 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-820 [-1, 256, 8, 8] 512\n",
" ReLU-821 [-1, 256, 8, 8] 0\n",
" BasicConv2d-822 [-1, 256, 8, 8] 0\n",
" Conv2d-823 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-824 [-1, 2144, 8, 8] 4,288\n",
" ReLU-825 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-826 [-1, 2144, 8, 8] 0\n",
" Conv2d-827 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-828 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-829 [-1, 192, 8, 8] 384\n",
" ReLU-830 [-1, 192, 8, 8] 0\n",
" BasicConv2d-831 [-1, 192, 8, 8] 0\n",
" Conv2d-832 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-833 [-1, 192, 8, 8] 384\n",
" ReLU-834 [-1, 192, 8, 8] 0\n",
" BasicConv2d-835 [-1, 192, 8, 8] 0\n",
" Conv2d-836 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-837 [-1, 224, 8, 8] 448\n",
" ReLU-838 [-1, 224, 8, 8] 0\n",
" BasicConv2d-839 [-1, 224, 8, 8] 0\n",
" Conv2d-840 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-841 [-1, 256, 8, 8] 512\n",
" ReLU-842 [-1, 256, 8, 8] 0\n",
" BasicConv2d-843 [-1, 256, 8, 8] 0\n",
" Conv2d-844 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-845 [-1, 2144, 8, 8] 4,288\n",
" ReLU-846 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-847 [-1, 2144, 8, 8] 0\n",
" Conv2d-848 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-849 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-850 [-1, 192, 8, 8] 384\n",
" ReLU-851 [-1, 192, 8, 8] 0\n",
" BasicConv2d-852 [-1, 192, 8, 8] 0\n",
" Conv2d-853 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-854 [-1, 192, 8, 8] 384\n",
" ReLU-855 [-1, 192, 8, 8] 0\n",
" BasicConv2d-856 [-1, 192, 8, 8] 0\n",
" Conv2d-857 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-858 [-1, 224, 8, 8] 448\n",
" ReLU-859 [-1, 224, 8, 8] 0\n",
" BasicConv2d-860 [-1, 224, 8, 8] 0\n",
" Conv2d-861 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-862 [-1, 256, 8, 8] 512\n",
" ReLU-863 [-1, 256, 8, 8] 0\n",
" BasicConv2d-864 [-1, 256, 8, 8] 0\n",
" Conv2d-865 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-866 [-1, 2144, 8, 8] 4,288\n",
" ReLU-867 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-868 [-1, 2144, 8, 8] 0\n",
" Conv2d-869 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-870 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-871 [-1, 192, 8, 8] 384\n",
" ReLU-872 [-1, 192, 8, 8] 0\n",
" BasicConv2d-873 [-1, 192, 8, 8] 0\n",
" Conv2d-874 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-875 [-1, 192, 8, 8] 384\n",
" ReLU-876 [-1, 192, 8, 8] 0\n",
" BasicConv2d-877 [-1, 192, 8, 8] 0\n",
" Conv2d-878 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-879 [-1, 224, 8, 8] 448\n",
" ReLU-880 [-1, 224, 8, 8] 0\n",
" BasicConv2d-881 [-1, 224, 8, 8] 0\n",
" Conv2d-882 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-883 [-1, 256, 8, 8] 512\n",
" ReLU-884 [-1, 256, 8, 8] 0\n",
" BasicConv2d-885 [-1, 256, 8, 8] 0\n",
" Conv2d-886 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-887 [-1, 2144, 8, 8] 4,288\n",
" ReLU-888 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-889 [-1, 2144, 8, 8] 0\n",
" Conv2d-890 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-891 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-892 [-1, 192, 8, 8] 384\n",
" ReLU-893 [-1, 192, 8, 8] 0\n",
" BasicConv2d-894 [-1, 192, 8, 8] 0\n",
" Conv2d-895 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-896 [-1, 192, 8, 8] 384\n",
" ReLU-897 [-1, 192, 8, 8] 0\n",
" BasicConv2d-898 [-1, 192, 8, 8] 0\n",
" Conv2d-899 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-900 [-1, 224, 8, 8] 448\n",
" ReLU-901 [-1, 224, 8, 8] 0\n",
" BasicConv2d-902 [-1, 224, 8, 8] 0\n",
" Conv2d-903 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-904 [-1, 256, 8, 8] 512\n",
" ReLU-905 [-1, 256, 8, 8] 0\n",
" BasicConv2d-906 [-1, 256, 8, 8] 0\n",
" Conv2d-907 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-908 [-1, 2144, 8, 8] 4,288\n",
" ReLU-909 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-910 [-1, 2144, 8, 8] 0\n",
" Conv2d-911 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-912 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-913 [-1, 192, 8, 8] 384\n",
" ReLU-914 [-1, 192, 8, 8] 0\n",
" BasicConv2d-915 [-1, 192, 8, 8] 0\n",
" Conv2d-916 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-917 [-1, 192, 8, 8] 384\n",
" ReLU-918 [-1, 192, 8, 8] 0\n",
" BasicConv2d-919 [-1, 192, 8, 8] 0\n",
" Conv2d-920 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-921 [-1, 224, 8, 8] 448\n",
" ReLU-922 [-1, 224, 8, 8] 0\n",
" BasicConv2d-923 [-1, 224, 8, 8] 0\n",
" Conv2d-924 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-925 [-1, 256, 8, 8] 512\n",
" ReLU-926 [-1, 256, 8, 8] 0\n",
" BasicConv2d-927 [-1, 256, 8, 8] 0\n",
" Conv2d-928 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-929 [-1, 2144, 8, 8] 4,288\n",
" ReLU-930 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-931 [-1, 2144, 8, 8] 0\n",
" Conv2d-932 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-933 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-934 [-1, 192, 8, 8] 384\n",
" ReLU-935 [-1, 192, 8, 8] 0\n",
" BasicConv2d-936 [-1, 192, 8, 8] 0\n",
" Conv2d-937 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-938 [-1, 192, 8, 8] 384\n",
" ReLU-939 [-1, 192, 8, 8] 0\n",
" BasicConv2d-940 [-1, 192, 8, 8] 0\n",
" Conv2d-941 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-942 [-1, 224, 8, 8] 448\n",
" ReLU-943 [-1, 224, 8, 8] 0\n",
" BasicConv2d-944 [-1, 224, 8, 8] 0\n",
" Conv2d-945 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-946 [-1, 256, 8, 8] 512\n",
" ReLU-947 [-1, 256, 8, 8] 0\n",
" BasicConv2d-948 [-1, 256, 8, 8] 0\n",
" Conv2d-949 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-950 [-1, 2144, 8, 8] 4,288\n",
" ReLU-951 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-952 [-1, 2144, 8, 8] 0\n",
" Conv2d-953 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-954 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-955 [-1, 192, 8, 8] 384\n",
" ReLU-956 [-1, 192, 8, 8] 0\n",
" BasicConv2d-957 [-1, 192, 8, 8] 0\n",
" Conv2d-958 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-959 [-1, 192, 8, 8] 384\n",
" ReLU-960 [-1, 192, 8, 8] 0\n",
" BasicConv2d-961 [-1, 192, 8, 8] 0\n",
" Conv2d-962 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-963 [-1, 224, 8, 8] 448\n",
" ReLU-964 [-1, 224, 8, 8] 0\n",
" BasicConv2d-965 [-1, 224, 8, 8] 0\n",
" Conv2d-966 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-967 [-1, 256, 8, 8] 512\n",
" ReLU-968 [-1, 256, 8, 8] 0\n",
" BasicConv2d-969 [-1, 256, 8, 8] 0\n",
" Conv2d-970 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-971 [-1, 2144, 8, 8] 4,288\n",
" ReLU-972 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-973 [-1, 2144, 8, 8] 0\n",
" Conv2d-974 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-975 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-976 [-1, 192, 8, 8] 384\n",
" ReLU-977 [-1, 192, 8, 8] 0\n",
" BasicConv2d-978 [-1, 192, 8, 8] 0\n",
" Conv2d-979 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-980 [-1, 192, 8, 8] 384\n",
" ReLU-981 [-1, 192, 8, 8] 0\n",
" BasicConv2d-982 [-1, 192, 8, 8] 0\n",
" Conv2d-983 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-984 [-1, 224, 8, 8] 448\n",
" ReLU-985 [-1, 224, 8, 8] 0\n",
" BasicConv2d-986 [-1, 224, 8, 8] 0\n",
" Conv2d-987 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-988 [-1, 256, 8, 8] 512\n",
" ReLU-989 [-1, 256, 8, 8] 0\n",
" BasicConv2d-990 [-1, 256, 8, 8] 0\n",
" Conv2d-991 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-992 [-1, 2144, 8, 8] 4,288\n",
" ReLU-993 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-994 [-1, 2144, 8, 8] 0\n",
" Conv2d-995 [-1, 2144, 8, 8] 4,598,880\n",
" Conv2d-996 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-997 [-1, 192, 8, 8] 384\n",
" ReLU-998 [-1, 192, 8, 8] 0\n",
" BasicConv2d-999 [-1, 192, 8, 8] 0\n",
" Conv2d-1000 [-1, 192, 8, 8] 411,648\n",
" BatchNorm2d-1001 [-1, 192, 8, 8] 384\n",
" ReLU-1002 [-1, 192, 8, 8] 0\n",
" BasicConv2d-1003 [-1, 192, 8, 8] 0\n",
" Conv2d-1004 [-1, 224, 8, 8] 129,024\n",
" BatchNorm2d-1005 [-1, 224, 8, 8] 448\n",
" ReLU-1006 [-1, 224, 8, 8] 0\n",
" BasicConv2d-1007 [-1, 224, 8, 8] 0\n",
" Conv2d-1008 [-1, 256, 8, 8] 172,032\n",
" BatchNorm2d-1009 [-1, 256, 8, 8] 512\n",
" ReLU-1010 [-1, 256, 8, 8] 0\n",
" BasicConv2d-1011 [-1, 256, 8, 8] 0\n",
" Conv2d-1012 [-1, 2144, 8, 8] 962,656\n",
" BatchNorm2d-1013 [-1, 2144, 8, 8] 4,288\n",
" ReLU-1014 [-1, 2144, 8, 8] 0\n",
"Inception_Resnet_C-1015 [-1, 2144, 8, 8] 0\n",
"AdaptiveAvgPool2d-1016 [-1, 2144, 1, 1] 0\n",
" Dropout2d-1017 [-1, 2144] 0\n",
" Linear-1018 [-1, 10] 21,450\n",
"================================================================\n",
"Total params: 127,289,898\n",
"Trainable params: 127,289,898\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 1.02\n",
"Forward/backward pass size (MB): 940.05\n",
"Params size (MB): 485.57\n",
"Estimated Total Size (MB): 1426.65\n",
"----------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\pinb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\functional.py:1331: UserWarning: dropout2d: Received a 2-D input to dropout2d, which is deprecated and will result in an error in a future release. To retain the behavior and silence this warning, please use dropout instead. Note that dropout2d exists to provide channel-wise dropout on inputs with 2 spatial dimensions, a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).\n",
" warnings.warn(warn_msg)\n"
]
}
],
"source": [
"# create InceptionResNetV2\n",
"model = InceptionResNetV2(10, 20, 10).to(device)\n",
"summary(model, (3, 299, 299), device=device.type)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# define loss function and optimizer\n",
"loss_func = nn.CrossEntropyLoss(reduction='sum')\n",
"opt = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
"lr_scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10)\n",
"\n",
"# function to get current learning rate\n",
"def get_lr(opt):\n",
" for param_group in opt.param_groups:\n",
" return param_group['lr']\n",
"\n",
"# function to calculate metric per mini-batch\n",
"def metric_batch(output, target):\n",
" pred = output.argmax(1, keepdim=True)\n",
" corrects = pred.eq(target.view_as(pred)).sum().item()\n",
" return corrects\n",
"\n",
"# function to calculate loss per mini-batch\n",
"def loss_batch(loss_func, output, target, opt=None):\n",
" loss_b = loss_func(output, target)\n",
" metric_b = metric_batch(output, target)\n",
"\n",
" if opt is not None:\n",
" opt.zero_grad()\n",
" loss_b.backward()\n",
" opt.step()\n",
"\n",
" return loss_b.item(), metric_b\n",
"\n",
"# function to calculate loss per epoch\n",
"def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):\n",
" running_loss = 0.0\n",
" running_metric = 0.0\n",
" len_data = len(dataset_dl.dataset)\n",
"\n",
" for xb, yb in dataset_dl:\n",
" xb = xb.to(device)\n",
" yb = yb.to(device)\n",
" output = model(xb)\n",
"\n",
" loss_b, metric_b = loss_batch(loss_func, output, yb, opt)\n",
"\n",
" running_loss += loss_b\n",
"\n",
" if metric_b is not None:\n",
" running_metric += metric_b\n",
"\n",
" if sanity_check is True:\n",
" break\n",
"\n",
" loss = running_loss / len_data\n",
" metric = running_metric / len_data\n",
"\n",
" return loss, metric\n",
"\n",
"# function to start training\n",
"def train_val(model, params):\n",
" num_epochs=params['num_epochs']\n",
" loss_func=params[\"loss_func\"]\n",
" opt=params[\"optimizer\"]\n",
" train_dl=params[\"train_dl\"]\n",
" val_dl=params[\"val_dl\"]\n",
" sanity_check=params[\"sanity_check\"]\n",
" lr_scheduler=params[\"lr_scheduler\"]\n",
" path2weights=params[\"path2weights\"]\n",
"\n",
" loss_history = {'train': [], 'val': []}\n",
" metric_history = {'train': [], 'val': []}\n",
"\n",
" best_loss = float('inf')\n",
"\n",
" start_time = time.time()\n",
"\n",
" for epoch in range(num_epochs):\n",
" current_lr = get_lr(opt)\n",
" print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs-1, current_lr))\n",
"\n",
" model.train()\n",
" train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)\n",
" loss_history['train'].append(train_loss)\n",
" metric_history['train'].append(train_metric)\n",
"\n",
" model.eval()\n",
" with torch.no_grad():\n",
" val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)\n",
" loss_history['val'].append(val_loss)\n",
" metric_history['val'].append(val_metric)\n",
"\n",
" if val_loss < best_loss:\n",
" best_loss = val_loss\n",
" print('Get best val_loss!')\n",
"\n",
" lr_scheduler.step(val_loss)\n",
"\n",
" print('train loss: %.6f, val loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))\n",
" print('-'*10)\n",
"\n",
" return model, loss_history, metric_history"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# definc the training parameters\n",
"params_train = {\n",
" 'num_epochs':5,\n",
" 'optimizer':opt,\n",
" 'loss_func':loss_func,\n",
" 'train_dl':train_dl,\n",
" 'val_dl':val_dl,\n",
" 'sanity_check':False,\n",
" 'lr_scheduler':lr_scheduler,\n",
" 'path2weights':'./models/weights.pt',\n",
"}\n",
"\n",
"# create the directory that stores weights.pt\n",
"def createFolder(directory):\n",
" try:\n",
" if not os.path.exists(directory):\n",
" os.makedirs(directory)\n",
" except OSerror:\n",
" print('Error')\n",
"createFolder('./models')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0/4, current lr=0.001\n",
"Get best val_loss!\n",
"train loss: 2.306605, val loss: 2.231432, accuracy: 15.30, time: 5.8292 min\n",
"----------\n",
"Epoch 1/4, current lr=0.001\n",
"train loss: 2.189562, val loss: 2.556953, accuracy: 13.54, time: 11.6434 min\n",
"----------\n",
"Epoch 2/4, current lr=0.001\n",
"train loss: 2.219446, val loss: 3.126347, accuracy: 12.56, time: 17.9946 min\n",
"----------\n",
"Epoch 3/4, current lr=0.001\n",
"train loss: 2.309346, val loss: 2.388916, accuracy: 10.78, time: 23.9601 min\n",
"----------\n",
"Epoch 4/4, current lr=0.001\n",
"train loss: 2.317668, val loss: 2.364337, accuracy: 10.96, time: 29.7933 min\n",
"----------\n"
]
}
],
"source": [
"model, loss_hist, metric_hist = train_val(model, params_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loss & Accuracy Graph"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABjrUlEQVR4nO3dd3hUZeL28e+kJ6RREgKEJmDoLRQBKYqA4iIgKAoKKLs2sLy77G/FiqJiQVdXXesKNkSl24BQAgJKCR0hIL0HCKQQSJvz/nHIQCAJSUhyptyf65qLJydnZu7DKLnznGYzDMNARERExE14WR1AREREpCyp3IiIiIhbUbkRERERt6JyIyIiIm5F5UZERETcisqNiIiIuBWVGxEREXErKjciIiLiVlRuRERExK2o3IhIuRk5ciT16tWzOkahpkyZgs1mY+/evVZHEZEypHIj4oFsNluxHvHx8VZHBSA7O5tq1apx/fXXF7qOYRjUrl2btm3blvn7jx8/HpvNxokTJ8r8tUWk7PlYHUBEKt6XX36Z7+svvviCuLi4y5Y3adLkqt7nk08+wW63X9VrAPj6+nLHHXfw0UcfsW/fPurWrXvZOsuWLePgwYP8v//3/676/UTEtanciHige+65J9/Xv//+O3FxcZctv1RGRgZBQUHFfh9fX99S5SvIsGHD+PDDD/nmm2948sknL/v+1KlT8fLy4q677iqz9xQR16TdUiJSoB49etC8eXMSEhLo1q0bQUFBPPXUUwDMmTOHW2+9lZo1a+Lv70+DBg2YMGECubm5+V7j0mNu9u7di81mY9KkSXz88cc0aNAAf39/2rdvz5o1a4rM06VLF+rVq8fUqVMv+152djbTp0/nhhtuoGbNmmzatImRI0dyzTXXEBAQQFRUFPfffz8nT568+r+YIixevJiuXbtSqVIlwsPD6d+/P9u2bcu3TlpaGk888QT16tXD39+fyMhIevXqxbp16xzr7Ny5k0GDBhEVFUVAQADR0dHcddddpKSklGt+EXehmRsRKdTJkye55ZZbuOuuu7jnnnuoXr06YB6IGxwczN///neCg4NZvHgxzz33HKmpqbzxxhtXfN2pU6eSlpbGgw8+iM1m4/XXX+f2229n9+7dhc722Gw2hg4dyiuvvMLWrVtp1qyZ43vz5s0jOTmZYcOGARAXF8fu3bu57777iIqKYuvWrXz88cds3bqV33//HZvNVgZ/O/ktXLiQW265hWuuuYbx48dz9uxZ3n33Xbp06cK6descJe+hhx5i+vTpjBkzhqZNm3Ly5EmWL1/Otm3baNu2LVlZWfTp04fMzEweffRRoqKiOHToED/++COnT58mLCyszLOLuB1DRDze6NGjjUv/OejevbsBGB9++OFl62dkZFy27MEHHzSCgoKMc+fOOZaNGDHCqFu3ruPrPXv2GIBRtWpVIzk52bF8zpw5BmD88MMPRebcunWrARjjxo3Lt/yuu+4yAgICjJSUlELzffPNNwZgLFu2zLFs8uTJBmDs2bOnyPd9/vnnDcA4fvx4oeu0bt3aiIyMNE6ePOlYtnHjRsPLy8sYPny4Y1lYWJgxevToQl9n/fr1BmB8//33RWYSkcJpt5SIFMrf35/77rvvsuWBgYGOcVpaGidOnKBr165kZGSwffv2K77ukCFDqFy5suPrrl27ArB79+4in9e0aVPatGnDtGnTHMvOnDnD3Llz+ctf/kJoaOhl+c6dO8eJEye47rrrAPLt/ikrR44cYcOGDYwcOZIqVao4lrds2ZJevXrx888/O5aFh4ezatUqDh8+XOBr5c3MzJ8/n4yMjDLPKuIJVG5EpFC1atXCz8/vsuVbt25l4MCBhIWFERoaSkREhONg5OIcF1KnTp18X+cVnVOnTgFw9uxZjh49mu+RZ9iwYezZs4eVK1cCMHv2bDIyMhy7pACSk5N5/PHHqV69OoGBgURERFC/fv1i5yupffv2ARATE3PZ95o0acKJEyc4c+YMAK+//jpbtmyhdu3adOjQgfHjx+crdfXr1+fvf/87n376KdWqVaNPnz68//77Ot5GpARUbkSkUBfPgOQ5ffo03bt3Z+PGjbz44ov88MMPxMXF8dprrwEU69Rvb2/vApcbhgHAt99+S40aNfI98tx99914eXk5DiyeOnUqlStXpm/fvo517rzzTj755BMeeughZs6cyYIFC5g3b16x85WnO++8k927d/Puu+9Ss2ZN3njjDZo1a8Yvv/ziWOfNN99k06ZNPPXUU5w9e5bHHnuMZs2acfDgQQuTi7gOHVAsIiUSHx/PyZMnmTlzJt26dXMs37NnT5m9R58+fYiLiyvwezVr1uSGG27g+++/59lnnyUuLo6RI0c6ZphOnTrFokWLeOGFF3juueccz9u5c2eZ5btU3nV3EhMTL/ve9u3bqVatGpUqVXIsq1GjBo888giPPPIISUlJtG3blpdffplbbrnFsU6LFi1o0aIFzzzzDCtXrqRLly58+OGHvPTSS+W2HSLuQuVGREokb9Ylb5YFICsri//+979l9h6XztZcatiwYdx///08+OCDZGdn59slVVA+gLfffrvM8l2qRo0atG7dms8//5xx48YRHh4OwJYtW1iwYIFjl11ubi7p6en5zniKjIykZs2aZGZmApCamkpQUBA+Phf+eW7RogVeXl6OdUSkaCo3IlIinTt3pnLlyowYMYLHHnsMm83Gl19+eVmZKE+DBg3ikUceYc6cOdSuXTvfDFJoaCjdunXj9ddfJzs7m1q1arFgwYIymVl66623LruIoZeXF0899RRvvPEGt9xyC506dWLUqFGOU8HDwsIYP348YB58HR0dzeDBg2nVqhXBwcEsXLiQNWvW8OabbwLmtXLGjBnDHXfcwbXXXktOTg5ffvkl3t7eDBo06Kq3QcQTqNyISIlUrVqVH3/8kX/84x8888wzVK5cmXvuuYeePXvSp0+fCskQGhpKv379+P7777n77rsvu27N1KlTefTRR3n//fcxDIPevXvzyy+/ULNmzat634kTJ162zNvbm6eeeoqbbrqJefPm8fzzz/Pcc8/h6+tL9+7dee211xwHMwcFBfHII4+wYMECZs6cid1up2HDhvz3v//l4YcfBqBVq1b06dOHH374gUOHDhEUFESrVq345ZdfHGd8iUjRbEZF/rolIiIiUs50tpSIiIi4FZUbERERcSsqNyIiIuJWVG5ERETErajciIiIiFtRuRERERG34nHXubHb7Rw+fJiQkJDLro0hIiIizskwDNLS0qhZsyZeXkXPzXhcuTl8+DC1a9e2OoaIiIiUwoEDB4iOji5yHY8rNyEhIYD5lxMaGmpxGhERESmO1NRUateu7fg5XhSPKzd5u6JCQ0NVbkRERFxMcQ4p0QHFIiIi4lZUbkRERMStqNyIiIiIW/G4Y25ERETKU25uLtnZ2VbHcEl+fn5XPM27OFRuREREyoBhGBw9epTTp09bHcVleXl5Ub9+ffz8/K7qdVRuREREykBesYmMjCQoKEgXii2hvIvsHjlyhDp16lzV35/KjYiIyFXKzc11FJuqVataHcdlRUREcPjwYXJycvD19S316+iAYhERkauUd4xNUFCQxUlcW97uqNzc3Kt6HZUbERGRMqJdUVenrP7+VG5ERETErajciIiISJmoV68eb7/9ttUxdECxiIiIJ+vRowetW7cuk1KyZs0aKlWqdPWhrpLKjYi4D7sdcjPBN9DqJCJuwzAMcnNz8fG5cmWIiIiogERXpt1SIuI+5oyG1+rBkY1WJxFxCSNHjmTp0qW888472Gw2bDYbU6ZMwWaz8csvvxAbG4u/vz/Lly9n165d9O/fn+rVqxMcHEz79u1ZuHBhvte7dLeUzWbj008/ZeDAgQQFBdGoUSPmzp1b7tulciMi7mHfStg4FXLOwYp3rE4jHs4wDDKycix5GIZR7JzvvPMOnTp14m9/+xtHjhzhyJEj1K5dG4Ann3ySV199lW3bttGyZUvS09Pp27cvixYtYv369dx8883069eP/fv3F/keL7zwAnfeeSebNm2ib9++DBs
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAHHCAYAAABXx+fLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB+GUlEQVR4nO3deXxM1//H8dfMZN9FIguRBbXHLvalVZQqapfW3qK09VVt6bel2l9L0Z1aW9qKfasqSu37HkuoIiFEIoLsss79/THk25BEEknuJPN5Ph7z6J3JnTvv66r55Jxzz9EoiqIghBBCCGFCtGoHEEIIIYQoaVIACSGEEMLkSAEkhBBCCJMjBZAQQgghTI4UQEIIIYQwOVIACSGEEMLkSAEkhBBCCJMjBZAQQgghTI4UQEIIIYQwOVIACSEKZMiQIfj4+KgdI1dLlixBo9Fw9epVtaMIIYyYFEBClBEajSZfj927d6sdFYD09HRcXFxo1apVrvsoioKXlxcNGzYs1izvvfceGo2Gfv36FevnCCGMh5naAYQQRePXX3/N9vyXX35h+/btj71es2bNp/qchQsXotfrn+oYAObm5vTp04f58+dz7do1vL29H9tn79693Lhxg//85z9P/Xm5URSF5cuX4+Pjw++//05CQgL29vbF9nlCCOMgBZAQZcQrr7yS7fnhw4fZvn37Y68/Kjk5GRsbm3x/jrm5eaHy5SQwMJB58+axfPlyJk6c+NjPly1bhlarpX///kX2mY/avXs3N27cYOfOnXTq1Il169YxePDgYvu8p1HQayWEyJ10gQlhQtq1a0edOnU4ceIEbdq0wcbGhg8++ACA3377ja5du+Lp6YmlpSVVqlTh008/JTMzM9sxHh0DdPXqVTQaDbNmzWLBggVUqVIFS0tLmjRpwrFjx/LM07JlS3x8fFi2bNljP0tPT2fNmjW0b98eT09Pzpw5w5AhQ/Dz88PKygp3d3eGDRvGnTt3nurPJCgoiFq1atG+fXs6dOhAUFBQjvtFREQwfPjwrD8fX19fRo8eTVpaWtY+sbGx/Oc//8HHxwdLS0sqVarEoEGDiImJAXIfn7R79+7HuieL4loBHDlyhC5dulCuXDlsbW3x9/fn22+/BWDx4sVoNBpOnTr12Ps+//xzdDodERERBfrzFKK0kBYgIUzMnTt3eOGFF+jfvz+vvPIKbm5ugOHL2c7OjvHjx2NnZ8fOnTuZPHky8fHxzJw584nHXbZsGQkJCYwcORKNRsOMGTN4+eWXCQ0NzbXVSKPRMHDgQD7//HNCQkKoXbt21s+2bt3K3bt3CQwMBGD79u2EhoYydOhQ3N3dCQkJYcGCBYSEhHD48GE0Gk2B/yxSU1NZu3Yt77zzDgADBgxg6NChREVF4e7unrXfzZs3adq0KbGxsbz++uvUqFGDiIgI1qxZQ3JyMhYWFiQmJtK6dWsuXLjAsGHDaNiwITExMWzcuJEbN27g4uJS4HxPe622b9/Oiy++iIeHB2+//Tbu7u5cuHCBTZs28fbbb9O7d2/GjBlDUFAQDRo0yPbZQUFBtGvXjooVKxY4txClgiKEKJPGjBmjPPq/eNu2bRVAmTdv3mP7JycnP/bayJEjFRsbGyUlJSXrtcGDByve3t5Zz8PCwhRAKV++vHL37t2s13/77TcFUH7//fc8c4aEhCiAMmnSpGyv9+/fX7GyslLi4uJyzbd8+XIFUPbu3Zv12uLFixVACQsLy/NzFUVR1qxZowDKpUuXFEVRlPj4eMXKykr5+uuvs+03aNAgRavVKseOHXvsGHq9XlEURZk8ebICKOvWrct1n9yy7dq1SwGUXbt2Zb32tNcqIyND8fX1Vby9vZV79+7lmEdRFGXAgAGKp6enkpmZmfXayZMnFUBZvHjxY58jRFkhXWBCmBhLS0uGDh362OvW1tZZ2wkJCcTExNC6dWuSk5P5+++/n3jcfv36Ua5cuaznrVu3BiA0NDTP99WqVYsGDRqwYsWKrNeSkpLYuHEjL774Ig4ODo/lS0lJISYmhmbNmgFw8uTJJ+bLSVBQEI0bN6Zq1aoA2Nvb07Vr12zdYHq9ng0bNtCtWzcaN2782DEetjytXbuWevXq0bNnz1z3KainuVanTp0iLCyMcePG4eTklGueQYMGcfPmTXbt2pX1WlBQENbW1vTq1atQuYUoDaQAEsLEVKxYEQsLi8deDwkJoWfPnjg6OuLg4ICrq2vWAOq4uLgnHrdy5crZnj8shu7duwfA/fv3iYqKyvZ4KDAwkLCwMA4ePAjAhg0bSE5Ozur+Arh79y5vv/02bm5uWFtb4+rqiq+vb77zPSo2NpbNmzfTtm1bLl++nPVo2bIlx48f559//gHg9u3bxMfHU6dOnTyPd+XKlSfuU1BPc62uXLkC8MRMzz//PB4eHllFn16vZ/ny5XTv3l3uhhNlmhRAQpiYf7cePBQbG0vbtm05ffo0n3zyCb///jvbt2/niy++AMjXbe86nS7H1xVFAWDlypV4eHhkezw0YMAAtFpt1mDoZcuWUa5cObp06ZK1T9++fVm4cCGjRo1i3bp1bNu2ja1bt+Y736NWr15NamoqX375JdWqVct6jB8/HiDXwdBPI7eWoJwGL0PxXat/0+l0DBw4kLVr15KSksKuXbu4efPmE+8eFKK0k0HQQgh2797NnTt3WLduHW3atMl6PSwsrMg+o1OnTmzfvj3Hn3l6etK+fXtWr17NRx99xPbt2xkyZEhW68e9e/fYsWMHU6dOZfLkyVnvu3TpUqHzBAUFUadOHaZMmfLYz+bPn8+yZcuYOnUqrq6uODg4cO7cuTyPV6VKlSfu87BVLDY2Ntvr165dy3fu/F6rKlWqAHDu3Dk6dOiQ5zEHDRrEl19+ye+//86WLVtwdXWlU6dO+c4kRGkkBZAQIqv15mFrDUBaWho//PBDkX3Go60+jwoMDGTYsGGMHDmS9PT0bN1fOeUD+OabbwqV5fr16+zdu5epU6fSu3fvx36elpZGYGAgR44cISAggB49erB06VKOHz/+2DggRVHQaDT06tWLTz75hPXr1z82DujhPg+Lkr1791K/fn3A0PqzYMGCfGfP77Vq2LAhvr6+fPPNNwwZMiTbOKCHeR7y9/fH39+fRYsWcfjwYQYPHoyZmXw9iLJN/oYLIWjRogXlypVj8ODBvPXWW2g0Gn799dfHCo7i1KtXL9544w1+++03vLy8srVuODg40KZNG2bMmEF6ejoVK1Zk27ZthW6hWrZsGYqi8NJLL+X48y5dumBmZkZQUBABAQF8/vnnbNu2jbZt2/L6669Ts2ZNIiMjWb16Nfv378fJyYl3332XNWvW0KdPH4YNG0ajRo24e/cuGzduZN68edSrV4/atWvTrFkzJk2axN27d3F2dmbFihVkZGTkO3t+r5VWq2Xu3Ll069aN+vXrM3ToUDw8PPj7778JCQnhzz//zLb/oEGDmDBhAvD4pJpClEUyBkgIQfny5dm0aRMeHh58+OGHzJo1i+eff54ZM2aUWAYHBwe6desGGMYEPTpeZtmyZXTq1Ik5c+YwadIkzM3N2bJlS6E+KygoiMqVK1OvXr0cf+7k5ESrVq1YuXIlGRkZVKxYkSNHjtC7d2+CgoJ46623+OWXX2jXrl3WzMx2dnbs27eP0aNHs3nzZt566y1++OEHqlevTqVKlbJ9dosWLZg+fTqff/457du3Z/r06fnOXpBr1alTJ3bt2sUzzzzDl19+yfjx49mxY0fWn/O/BQYGotPpeOaZZ2jatGm+8whRWmmUkvwVTwghhFGKiYnBw8ODyZMn89FHH6kdR4hiJy1AQgghWLJkCZmZmbz66qtqRxGiRMgYICGEMGE7d+7k/PnzfPbZZ/To0SPbOm9ClGXSBSaEECasXbt2HDx4kJYtW7J06VJZ+0uYDCmAhBBCCGFyZAyQEEIIIUyOFEBCCCGEMDkyCDoHer2emzdvYm9vX+hVnIUQQghRshRFISEhAU9PT7TavNt4pADKwc2bN/Hy8lI7hhBCCCEK4fr169kmIM2JFEA5sLe3Bwx/gA4ODiqnEUI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Train-Validation Progress\n",
"num_epochs=params_train[\"num_epochs\"]\n",
"\n",
"# plot loss progress\n",
"plt.title(\"Train-Val Loss\")\n",
"plt.plot(range(1,num_epochs+1),loss_hist[\"train\"],label=\"train\")\n",
"plt.plot(range(1,num_epochs+1),loss_hist[\"val\"],label=\"val\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.xlabel(\"Training Epochs\")\n",
"plt.legend()\n",
"plt.show()\n",
"\n",
"# plot accuracy progress\n",
"plt.title(\"Train-Val Accuracy\")\n",
"plt.plot(range(1,num_epochs+1),metric_hist[\"train\"],label=\"train\")\n",
"plt.plot(range(1,num_epochs+1),metric_hist[\"val\"],label=\"val\")\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.xlabel(\"Training Epochs\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Result"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABHoAAAKSCAYAAACtCLygAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9edhsWVXf/1l773NODe97577d0GDTgAgCBoOA4gBogAcRQyIhahTRRx8nHIiISkScohimdsAxGhQxCQ6oMTiDifoozuQHEYLYyAzdd3yHqjpn771+f6x9TtV77+2mgb70cM8X3tvvW1VnrKq91/6u7/ouUVVlxIgRI0aMGDFixIgRI0aMGDFixF0e7o4+gREjRowYMWLEiBEjRowYMWLEiBG3D0aiZ8SIESNGjBgxYsSIESNGjBgx4m6CkegZMWLEiBEjRowYMWLEiBEjRoy4m2AkekaMGDFixIgRI0aMGDFixIgRI+4mGImeESNGjBgxYsSIESNGjBgxYsSIuwlGomfEiBEjRowYMWLEiBEjRowYMeJugpHoGTFixIgRI0aMGDFixIgRI0aMuJtgJHpGjBgxYsSIESNGjBgxYsSIESPuJhiJnhEjRowYMWLEiBEjRowYMWLEiLsJ7jCi5x3veAciwotf/OLbbZ9/9Ed/hIjwR3/0R7fbPu8OEBG++7u/e/j7Fa94BSLCO97xjjvsnEaMGLHGOB5+7DCOhyNG3HkxjoUfO9znPvfhmc985vD3eJ9GjLjzYBwLP3a4O8eFHxbR01/4X/3VX12u87lT4fGPfzwiwrOe9ayPeB/f/d3fjYgMP7PZjE/8xE/kO7/zOzl//vzteLaXH7/0S7/EDTfc8FHvZ/N+bP688IUv/OhPcsSIjxGuhPHwPe95D09/+tM5cuQIhw4d4l/+y3/JP/7jP37E+xvHw4sxjocj7uq4u4+Fb33rW3n2s5/Nox/9aCaTye2yAOjvWf8zmUx4wAMewLOe9Sw+8IEP3D4n/jHCa1/72gOLpI8U97nPfS45Fn7N13zNR3+SI0Z8DHB3HwsB/uAP/oDHPe5xnDhxgiNHjvDIRz6SV77ylR/x/sa48GLcnnFh+KjP5m6KX/u1X+PP/uzPbrf9/cRP/ARbW1vs7u7ye7/3e/zH//gfed3rXsef/umfIiK323FuC770S7+UL/zCL6Rpmg9ru1/6pV/iTW96E9/8zd/8UZ/D4x//eJ7xjGcceOyTP/mTP+r9jhgx4vbB7u4uj3vc4zh37hzPe97zqKqKl73sZTzmMY/h7/7u7zh+/PhHvO9xPDyIcTwcMeLOiz/7sz/jR37kR/jET/xEHvSgB/F3f/d3t9u+v/d7v5frr7+e5XLJn/zJn/ATP/ETvPa1r+VNb3oTs9nsdjvObcFnfdZnsVgsqOv6w9ruta99LS9/+ctvF7LnYQ97GN/yLd9y4LEHPOABH/V+R4wY8dHjN3/zN3nqU5/Kp33apw0Ezatf/Wqe8YxncPPNN/PsZz/7I973GBcexO0VF45EzyWwXC75lm/5Fr7t276N7/qu77pd9vm0pz2NEydOAPA1X/M1fMEXfAG/9mu/xp//+Z/zaZ/2aZfcZn9//7JM9N57vPe3+34/HDzgAQ/gS77kS+7QcxgxYsQt48d//Md529vexl/8xV/wiEc8AoAnPelJPOQhD+ElL3kJP/ADP/AR73scDw9iHA9HjLjz4vM///M5e/Ys29vbvPjFL75diZ4nPelJfMqnfAoAX/mVX8nx48d56Utfym/8xm/wRV/0RZfcZm9vj/l8frudQw/nHJPJ5Hbf74eDa6+9dhwLR4y4k+LHfuzHuMc97sHrXve6gQT56q/+ah74wAfyile84qMiesa48CBur7jwdvfoaduW7/qu7+LhD384hw8fZj6f85mf+Zm8/vWvv8VtXvayl3HdddcxnU55zGMew5ve9KaLXvOWt7yFpz3taRw7dozJZMKnfMqn8Ju/+Zsf8nz29/d5y1vews0333ybr+E//af/RM6Z5zznObd5mw8Xn/3Znw3AjTfeCMBjH/tYHvKQh/DXf/3XfNZnfRaz2YznPe95AKxWK17wghdw//vfn6ZpuPe9781zn/tcVqvVgX2uViue/exnc9VVV7G9vc3nf/7n8+53v/uiY99S7eFv//Zv85jHPIbt7W0OHTrEIx7xCH7pl35pOL//+T//J//0T/80SMjuc5/7DNu+853v5C1vecuHdQ8WiwXL5fLD2mbEiLsS7srj4a/8yq/wiEc8YiB5AB74wAfyOZ/zObz61a/+kNt/OBjHw3E8HHH3xl15LDx27Bjb29sf8nW3By4cC5/5zGeytbXF29/+dj73cz+X7e1t/t2/+3cA5Jy54YYbePCDH8xkMuHqq6/mq7/6qzlz5syBfaoq3//938+97nUvZrMZj3vc43jzm9980bFvyb/jDW94A5/7uZ/L0aNHmc/nfNInfRI//MM/PJzfy1/+cuBguUGP973vfbzlLW+h67rbfA/atmVvb+82v37EiLsS7spj4fnz5zl69OgBpUsIgRMnTjCdTj/k9h8Oxrjw9okLb3ei5/z58/zn//yfeexjH8sP/dAP8d3f/d3cdNNNPPGJT7xkFuQXfuEX+JEf+RG+/uu/nu/4ju/gTW96E5/92Z99oEb5zW9+M5/6qZ/K3//93/Pt3/7tvOQlL2E+n/PUpz6V17zmNbd6Pn/xF3/Bgx70IH7sx37sNp3/O9/5Tl74whfyQz/0Q7f7h3YTb3/72wEOlD+cOnWKJz3pSTzsYQ/jhhtu4HGPexw5Zz7/8z+fF7/4xTzlKU/hR3/0R3nqU5/Ky172Mv7tv/23B/b5lV/5ldxwww084QlP4IUvfCFVVfHkJz/5Np3PK17xCp785Cdz+vRpvuM7voMXvvCFPOxhD+N3fud3APgP/+E/8LCHPYwTJ07wyle+kle+8pUH6hCf8Yxn8KAHPeg2X/8rXvEK5vM50+mUT/zETxy+KCNG3J1wVx0Pc878n//zf4ZM8yYe+chH8va3v52dnZ3bdhNuA8bxcBwPR9y9cVcdCz/WuNRYGGPkiU98IidPnuTFL34xX/AFXwBYJv1bv/Vb+fRP/3R++Id/mC//8i/nVa96FU984hMPECvf9V3fxfOf/3z+2T/7Z7zoRS/ivve9L094whNuE5ny+7//+3zWZ30W//f//l++6Zu+iZe85CU87nGP47d+67eGc3j84x8PMIyFm34d3/Ed38GDHvQg3vOe99ym63/d617HbDZja2uL+9znPgOhNGLE3QV35bHwsY99LG9+85t5/vOfzz/8wz/w9re/ne/7vu/jr/7qr3juc5/7Yd+LW8MYF95OcaF+GPgv/+W/KKB/+Zd/eYuviTHqarU68NiZM2f06quv1q/4iq8YHrvxxhsV0Ol0qu9+97uHx9/whjcooM9+9rOHxz7ncz5HH/rQh+pyuRweyznrox/9aP34j//44bHXv/71CujrX//6ix57wQtecJuu8WlPe5o++tGPHv4G9Ou//utv07aXwgte8AIF9K1vfavedNNNeuONN+pP/dRPadM0evXVV+ve3p6qqj7mMY9RQH/yJ3/ywPavfOUr1Tmnf/zHf3zg8Z/8yZ9UQP/0T/9UVVX/7u/+TgH9uq/7ugOv++Iv/uKLrr9/H2+88UZVVT179qxub2/rox71KF0sFge2zzkPvz/5yU/W66677pLX2Z//bcGjH/1oveGGG/Q3fuM39Cd+4if0IQ95iAL64z/+47dp+xEj7gy4O4+HN910kwL6vd/7vRc99/KXv1wBfctb3nKr+7gUxvHwYozj4Yi7Ou7OY+GFeNGLXnRgvPhI0d+zP/iDP9CbbrpJ3/Wud+l/+2//TY8fP37g2r/sy75MAf32b//2A9v/8R//sQL6qle96sDjv/M7v3Pg8Q9+8INa17U++clPPjB+Pe95z1NAv+zLvmx47ML7FGPU66+/Xq+77jo9c+bMgeNs7uvrv/7rb3G868//ttyvpzzlKfpDP/RD+uu//uv6sz/7s/qZn/mZCuhzn/vcD7ntiBF3Btzdx8Ld3V1
"text/plain": [
"<Figure size 1600x800 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"model.eval()\n",
"\n",
"result_images = []\n",
"result_preds = []\n",
"result_labels = []\n",
"\n",
"with torch.no_grad():\n",
" for i, (images, labels) in enumerate(val_dl):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" outputs = model(images)\n",
" _, preds = torch.max(outputs, 1)\n",
"\n",
" result_images.extend(images.cpu().numpy())\n",
" result_preds.extend(preds.cpu().numpy())\n",
" result_labels.extend(labels.cpu().numpy())\n",
"\n",
" if i == 2: # 3번째 배치까지만 시각화\n",
" break\n",
"\n",
"# 결과 시각화\n",
"plt.figure(figsize=(16, 8))\n",
"for i in range(16):\n",
" plt.subplot(4, 4, i+1)\n",
" img = result_images[i].transpose((1, 2, 0))\n",
" img = (img - img.min()) / (img.max() - img.min())\n",
" plt.imshow(img)\n",
" plt.title(f'Label: {result_labels[i]}, Predict: {result_preds[i]}')\n",
" plt.axis('off')\n",
"plt.show()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}