You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

634 lines
330 KiB
Plaintext

1 year ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ResNet이란?\n",
"\n",
"![resnet](./images/resnet-layers.png)\n",
"\n",
"ResNet(Residual Network)는 딥러닝에서 사용되는 컨볼루션 신경망(CNN) 구조이다.\n",
"\n",
"2015년 마이크로소프트 연구팀에 의해 개발되었으며, 깊은 신경망을 효율적으로 학습시키기 위해 \"잔차 학습(Residual Learning)\" 개념을 도입했다. \n",
"\n",
"이 아이디어는 신경망의 층을 거쳐가는 동안 신호가 약화되거나 왜곡되는 것을 방지하기 위해, 입력을 층의 출력에 직접 추가한 것이다."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ResNet의 구조\n",
"\n",
"![resnet](./images/residual.png)\n",
"\n",
"ResNet의 핵심 구조는 \"잔차 블록(Residual Block)\"이다. \n",
"\n",
"이 블록은 입력을 블록의 출력에 더하는 스킵 연결(skip connection)을 포함한다. \n",
"\n",
"이를 통해 네트워크는 학습해야 할 목표 함수를 보다 쉽게 최적화할 수 있다. \n",
"\n",
"![resnet](./images/resnet.png)\n",
"\n",
"ResNet은 깊이에 따라 여러 버전이 있으며, ResNet-34, ResNet-50, ResNet-101, ResNet-152 등이 일반적이다. \n",
"\n",
"여기서 숫자는 네트워크에 있는 층의 수를 나타낸다.\n",
"\n",
"![resnet](./images/resnet-network.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\pinb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.init as init\n",
"\n",
"import torchvision\n",
"import torchvision.datasets as datasets\n",
"import torchvision.transforms as transforms\n",
"\n",
"from torch.utils.data import DataLoader\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tqdm\n",
"from tqdm.auto import trange\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 하이퍼 파라미터"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 50\n",
"learning_rate = 0.0002\n",
"num_epoch = 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 데이터셋 구성\n",
"\n",
"ResNet 학습을 위한 데이터셋 구성으로 CIFAR10를 사용한다."
]
},
{
"cell_type": "code",
1 year ago
"execution_count": 3,
1 year ago
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
"\n",
"# define dataset\n",
"cifar10_train = datasets.CIFAR10(root=\"./Data/\", train=True, transform=transform, target_transform=None, download=True)\n",
"cifar10_test = datasets.CIFAR10(root=\"./Data/\", train=False, transform=transform, target_transform=None, download=True)\n",
"\n",
"# define loader\n",
"train_loader = DataLoader(cifar10_train,batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)\n",
"test_loader = DataLoader(cifar10_test,batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True)\n",
"\n",
"# define classes\n",
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Basic Module"
]
},
{
"cell_type": "code",
1 year ago
"execution_count": 4,
1 year ago
"metadata": {},
"outputs": [],
"source": [
"def conv_block_1(in_dim,out_dim, activation,stride=1):\n",
" model = nn.Sequential(\n",
" nn.Conv2d(in_dim,out_dim, kernel_size=1, stride=stride),\n",
" nn.BatchNorm2d(out_dim),\n",
" activation,\n",
" )\n",
" return model\n",
"\n",
"\n",
"def conv_block_3(in_dim,out_dim, activation, stride=1):\n",
" model = nn.Sequential(\n",
" nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=stride, padding=1),\n",
" nn.BatchNorm2d(out_dim),\n",
" activation,\n",
" )\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Bottleneck Module"
]
},
{
"cell_type": "code",
1 year ago
"execution_count": 5,
1 year ago
"metadata": {},
"outputs": [],
"source": [
"class BottleNeck(nn.Module):\n",
" def __init__(self,in_dim,mid_dim,out_dim,activation,down=False):\n",
" super(BottleNeck,self).__init__()\n",
" self.down=down\n",
" \n",
" # 특성지도의 크기가 감소하는 경우\n",
" if self.down:\n",
" self.layer = nn.Sequential(\n",
" conv_block_1(in_dim,mid_dim,activation,stride=2),\n",
" conv_block_3(mid_dim,mid_dim,activation,stride=1),\n",
" conv_block_1(mid_dim,out_dim,activation,stride=1),\n",
" )\n",
" \n",
" # 특성지도 크기 + 채널을 맞춰주는 부분\n",
" self.downsample = nn.Conv2d(in_dim,out_dim,kernel_size=1,stride=2)\n",
" \n",
" # 특성지도의 크기가 그대로인 경우\n",
" else:\n",
" self.layer = nn.Sequential(\n",
" conv_block_1(in_dim,mid_dim,activation,stride=1),\n",
" conv_block_3(mid_dim,mid_dim,activation,stride=1),\n",
" conv_block_1(mid_dim,out_dim,activation,stride=1),\n",
" )\n",
" \n",
" # 채널을 맞춰주는 부분\n",
" self.dim_equalizer = nn.Conv2d(in_dim,out_dim,kernel_size=1)\n",
" \n",
" def forward(self,x):\n",
" if self.down:\n",
" downsample = self.downsample(x)\n",
" out = self.layer(x)\n",
" out = out + downsample\n",
" else:\n",
" out = self.layer(x)\n",
" if x.size() is not out.size():\n",
" x = self.dim_equalizer(x)\n",
" out = out + x\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ResNet-50 Network"
]
},
{
"cell_type": "code",
1 year ago
"execution_count": 6,
1 year ago
"metadata": {},
"outputs": [],
"source": [
"# 50-layer\n",
"class ResNet(nn.Module):\n",
"\n",
" def __init__(self, base_dim, num_classes=10):\n",
" super(ResNet, self).__init__()\n",
" self.activation = nn.ReLU()\n",
" self.layer_1 = nn.Sequential(\n",
" nn.Conv2d(3,base_dim,7,2,3),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(3,2,1),\n",
" )\n",
" self.layer_2 = nn.Sequential(\n",
" BottleNeck(base_dim,base_dim,base_dim*4,self.activation),\n",
" BottleNeck(base_dim*4,base_dim,base_dim*4,self.activation),\n",
" BottleNeck(base_dim*4,base_dim,base_dim*4,self.activation,down=True),\n",
" ) \n",
" self.layer_3 = nn.Sequential(\n",
" BottleNeck(base_dim*4,base_dim*2,base_dim*8,self.activation),\n",
" BottleNeck(base_dim*8,base_dim*2,base_dim*8,self.activation),\n",
" BottleNeck(base_dim*8,base_dim*2,base_dim*8,self.activation),\n",
" BottleNeck(base_dim*8,base_dim*2,base_dim*8,self.activation,down=True),\n",
" )\n",
" self.layer_4 = nn.Sequential(\n",
" BottleNeck(base_dim*8,base_dim*4,base_dim*16,self.activation),\n",
" BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation),\n",
" BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation), \n",
" BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation),\n",
" BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation),\n",
" BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation,down=True),\n",
" )\n",
" self.layer_5 = nn.Sequential(\n",
" BottleNeck(base_dim*16,base_dim*8,base_dim*32,self.activation),\n",
" BottleNeck(base_dim*32,base_dim*8,base_dim*32,self.activation),\n",
" BottleNeck(base_dim*32,base_dim*8,base_dim*32,self.activation),\n",
" )\n",
" self.avgpool = nn.AvgPool2d(1,1) \n",
" self.fc_layer = nn.Linear(base_dim*32,num_classes)\n",
" \n",
" def forward(self, x):\n",
" out = self.layer_1(x)\n",
" out = self.layer_2(out)\n",
" out = self.layer_3(out)\n",
" out = self.layer_4(out)\n",
" out = self.layer_5(out)\n",
" out = self.avgpool(out)\n",
" out = out.view(batch_size,-1)\n",
" out = self.fc_layer(out)\n",
" \n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 학습"
]
},
1 year ago
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"model = ResNet(base_dim=64).to(device)\n",
"loss_func = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
]
},
1 year ago
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/100 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [0/100] Train Loss: 2.0092\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
1 year ago
" 10%|█ | 10/100 [13:35<2:00:50, 80.56s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [10/100] Train Loss: 0.6987\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 20%|██ | 20/100 [27:26<1:54:48, 86.10s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [20/100] Train Loss: 0.2491\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 30%|███ | 30/100 [41:02<1:33:23, 80.05s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [30/100] Train Loss: 0.2013\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 40%|████ | 40/100 [57:29<1:44:20, 104.35s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [40/100] Train Loss: 0.1048\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 50%|█████ | 50/100 [1:13:20<1:15:53, 91.07s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [50/100] Train Loss: 0.0849\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 60%|██████ | 60/100 [1:28:19<59:53, 89.83s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [60/100] Train Loss: 0.0724\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 70%|███████ | 70/100 [1:42:47<43:55, 87.86s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [70/100] Train Loss: 0.0653\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|████████ | 80/100 [1:57:39<29:47, 89.39s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [80/100] Train Loss: 0.0630\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 90%|█████████ | 90/100 [2:12:17<14:39, 87.94s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [90/100] Train Loss: 0.0605\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [2:27:01<00:00, 88.22s/it]\n"
1 year ago
]
}
],
"source": [
"for i in trange(num_epoch):\n",
" model.train() # 모델을 학습 모드로 설정\n",
" train_loss = 0.0\n",
" for j, [image, label] in enumerate(train_loader):\n",
" x = image.to(device)\n",
" y_ = label.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" output = model(x)\n",
" loss = loss_func(output, y_)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" train_loss += loss.item()\n",
" \n",
" train_loss /= len(train_loader)\n",
1 year ago
" print(f\"Epoch [{i}/{num_epoch}] Train Loss: {train_loss:.4f}\")\n",
" \n",
1 year ago
" if i % 10 == 0:\n",
" torch.save(model.state_dict(), f'model_epoch_{i}.pth')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 결과 이미지 표시"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# 테스트 이미지 로드 및 변환\n",
"test_image = Image.open('test.jpg')\n",
"test_image = transform(test_image).unsqueeze(0) # 차원 추가\n",
"\n",
"# 예측 수행\n",
"model.eval()\n",
"with torch.no_grad():\n",
" output = model(test_image.to(device))\n",
"\n",
"# 결과 시각화\n",
"plt.imshow(test_image.squeeze().permute(1, 2, 0)) # 차원 변경 및 시각화\n",
"plt.title(\"Predicted Label\")\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
1 year ago
"execution_count": 29,
1 year ago
"metadata": {},
1 year ago
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAJYAAB47CAYAAACuYpRBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9e3zU1bX/j79mMpnMxEkYkhCSEEKMMRBKMSCEaysYPl5DAYutHg8JlGKhQv21eqh8rRD7MXL0CJ+KrXJTIJ62frQU5FMVWm6nECXcDBASEi+ECAEDhCHEcZhMZv/+oI6sNUlm3iEbvKzn48Efa97v937v9ztr9rxYe+21TUopBUHoYszXugPCNxNxLEEL4liCFsSxBC2IYwlaEMcStCCOJWhBHEvQgjiWoIWr6lirV6+G0+m84nZMJhPWr19/xe18lSkqKkJOTk7Anjp1KiZOnHjN+mMUQ471dXu4y/nBD36AtLQ02Gw2JCcnY8qUKaivrzfURlFREUwmE0wmEywWC9LT0/HLX/4Szc3Nmnr9Jc8//zxWr14d1rm1tbUwmUwoLy83fJ+Wlhb89re/xQ033ACbzYabbroJGzduNNzOt+ancOzYsXj99ddRXV2NtWvX4qOPPsLkyZMNt/Od73wHJ0+eRG1tLZ555hksX74cjzzySJvner3eK+12gG7dunXJaB+K3/zmN1i2bBleeOEFVFZWYubMmZg0aRLef/99Yw0pAxQWFqoJEya0e3zRokVqwIABKjo6WqWmpqpZs2apCxcuBI6vWrVKdevWTa1bt05lZmaqqKgoddttt6m6ujrSzvr169WgQYNUVFSUuv7661VRUZFqaWkJHAeg1q1bZ6TrQbz55pvKZDIpr9cb9jULFixQN910E/lsxowZKikpiRxfsWKFSk9PVyaTSSml1Llz59T06dNVQkKCiomJUWPHjlXl5eWknYULF6rExETlcDjUT37yE/XrX/+a3Iu/+9bWVvXMM8+oG264QVmtVtW7d2/11FNPKaUuvZ/L/91yyy1hP2NycrL6/e9/Tz6755571AMPPBB2G0op1aUjltlsxpIlS3D48GGsWbMGW7duxdy5c8k5brcbxcXFKCkpQWlpKVwuF+67777A8R07dqCgoAAPP/wwKisrsWzZMqxevRrFxcXt3nfMmDGYOnVq2P1sbGzEH//4R4wcORKRkZGGn/Ny7HY7GZk+/PBDrF27Fn/9618DP0X33nsvGhoa8M4772Dfvn0YPHgw8vLy0NjYCAB4/fXXUVRUhKeffhp79+5FcnIyXnzxxQ7vO2/ePPznf/4nnnjiCVRWVuJPf/oTevbsCQDYvXs3AGDz5s04efIk/vrXvwIAtm/fDpPJhNra2nbbvXjxImw2W9Az7ty509B76dIRi/PGG2+o+Pj4gL1q1SoFQO3atSvwWVVVlQKgysrKlFJK5eXlqaeffpq08+qrr6rk5OSADTZiTZkyRT322GMh+zN37lwVHR2tAKjhw4erM2fOhP0sSgWPWHv37lUJCQlq8uTJgeORkZGqoaEhcM6OHTtUbGys8ng8pK0bbrhBLVu2TCml1IgRI9TPf/5zcnzYsGHtjlhNTU0qKipKrVixos1+Hj16VAFQ77//Pvm8rKxM9e3bVx0/frzdZ7z//vtV//79VU1NjWptbVV///vfld1uV1artd1r2qJLR6zNmzcjLy8PvXr1QkxMDKZMmYKzZ8/C7XYHzrFYLBg6dGjA7tevH5xOJ6qqqgAABw4cwG9/+1s4HI7AvxkzZuDkyZOkncspKSnBwoULQ/bvP/7jP/D+++/j73//OyIiIlBQUABlMB3t0KFDcDgcsNvtyM3NxYgRI/D73/8+cLxPnz7o0aNHwD5w4ACam5sRHx9Pnuno0aP46KOPAABVVVUYNmwYuc+IESPa7UNVVRUuXryIvLw8Q33Pzc3FkSNH0KtXr3bPef7553HjjTeiX79+sFqtmD17NqZNmwaz2ZirWAyd3QG1tbXIz8/HrFmzUFxcjLi4OOzcuRPTp0+H1+tFdHR0WO00NzfjySefxD333BN0jA/RRklISEBCQgKysrKQnZ2N3r17Y9euXR3+ETl9+/bFhg0bYLFYkJKSAqvVSo5fd911xG5ubkZycjK2b98e1FZnxbjdbu/UdeHQo0cPrF+/Hh6PB2fPnkVKSgoee+wxZGRkGGqnyxxr37598Pv9WLRoUcC7X3/99aDzfD4f9u7di9zcXABAdXU1XC4XsrOzAQCDBw9GdXU1MjMzu6prbeL3+wFc0hRGsFqthvo2ePBgnDp1KhCeaIvs7GyUlZWhoKAg8NmuXbvabfPGG2+E3W7Hli1b8NOf/rTNPgJAa2tr2P3k2Gw29OrVCy0tLVi7di1+9KMfGbresGOdP38+KD4SHx+PzMxMtLS04IUXXsD48eNRWlqKpUuXBl0fGRmJOXPmYMmSJbBYLJg9ezaGDx8ecLT58+cjPz8faWlpmDx5MsxmMw4cOICKigo89dRTbfapoKAAvXr1avfnsKysDHv27MHo0aPRvXt3fPTRR3jiiSdwww03GBqtOsO4ceMwYsQITJw4Ec8++yyysrJQX1+Pt956C5MmTcKQIUPw8MMPY+rUqRgyZAhGjRqFP/7xjzh8+HC7o4TNZsOvf/1rzJ07F1arFaNGjcLp06dx+PBhTJ8+HYmJibDb7di4cSNSU1Nhs9nQrVs37N69GwUFBdiyZUu7P4dlZWU4ceIEcnJycOLECRQVFcHv9wf9JywkRgRZYWFh0H9lAajp06crpZRavHixSk5OVna7Xd1+++2qpKREAVDnzp1TSn0Zbli7dq3KyMhQUVFRaty4cerYsWPkPhs3blQjR45UdrtdxcbGqtzcXLV8+fLAcTDxfsstt6jCwsJ2+33w4EE1duxYFRcXp6KiolR6erqaOXNmkIgFoFatWtVuO22FG8I53tTUpObMmaNSUlJUZGSk6t27t3rggQdImKW4uFglJCQoh8OhCgsL1dy5c0OGG5566inVp08fFRkZqdLS0sh/elasWKF69+6tzGZzINywbds2BUAdPXq03WfYvn27ys7OVlFRUSo+Pl5NmTJFnThxot3z28OklCymAICjR48iKysLlZWVuPHGG691d772fGsi76F4++238eCDD4pTdREyYglakBFL0II4lqAFcSxBC+JYghbEsQQtGIq8f99kIjb3Sgez+XGeZ+kJcT1Pk/OHaJ8/TKjz+eylldn8/LbS9vi0uK+Ncy6HP3NTiHvy9kK9U97HUO+IX38WHRNuEEFGLEEL4liCFsSxBC0Y0lj895hrGE5siJvx7KpQneHfAq5PuH7h7XMNxY+H+pa11T+ugUJpHn6c27wP3HYyO9Q7DLWcI9TfsLPIiCVoQRxL0II4lqAFQxor1O9xqLgR1yOh4kb8OL+ex724xuLnc5ufHyrm09a3kD9zI7M/auOaK+EEs2OYHcds/sz8mfg75ovhWsLsF0dGLEEL4liCFsSxBC0Y0lihVvXxxngcKxShOsM1EYfP/eWwLOPjtdQ+wgQEvz7UvBqgX1OF4kIIm69AdDI71HzvuU70qa12BaFLEMcStCCOJWjhipbYh5rX4poolGYxmo/FcdB0MfziP+8k9vLn3iH23vfo+Q2sPd7/T0Lc/6vI58zmz5TAbP7OIzp5XxmxBC2IYwlaEMcStGBoJXR/lvNex47z32deKyWV2UbzsUJpNj4vltGD2gdPU/tddv63cUl4MrNDzac2Sc67cC0RxxK0II4laMGQxjIxjWWUm5nNNVio3CEex+Kajl//MbM/ZHbnCymGj50lOKUm0dm7Dz7hkaarC59L5PlcfF2BaCzhmiKOJWhBHEvQgqG5wihmGytkDZxhNo9rhap7wPOzXcyuZPZJZvN87jT2QEeNPlAYfM5yvjISncQ+wzRWZ/OfOgtXeFy3dnYyWUYsQQviWIIWxLEELRj6Cf1B3vXEfmPLUUM347k/PB8r1Do+/vuf0o3aW853fP9cthmD30LjckeP6Z8t3L+PKr8E9gznQjyDbni9r87uXiQjlqAFcSxBC+JYghY
"text/plain": [
"<Figure size 10000x10000 with 50 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
1 year ago
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# 모델을 평가 모드로 설정\n",
"model.eval()\n",
"\n",
"# 테스트 데이터셋의 첫 번째 배치를 가져옴\n",
"images, labels = next(iter(test_loader))\n",
"images, labels = images.to(device), labels.to(device)\n",
"\n",
"# 모델 예측\n",
"with torch.no_grad():\n",
" outputs = model(images)\n",
"\n",
"# 예측 결과 처리\n",
"_, predicted = torch.max(outputs, 1)\n",
"\n",
"# 이미지 출력 설정\n",
1 year ago
"fig, axs = plt.subplots(len(images), 1, figsize=(100, 100))\n",
1 year ago
"\n",
"for i, img in enumerate(images.cpu()):\n",
" img = img.numpy().transpose((1, 2, 0))\n",
" axs[i].imshow(img)\n",
1 year ago
" axs[i].set_title(f'Label: {labels[i].item()}, Predict: {predicted[i].item()}', fontsize=10)\n",
1 year ago
" axs[i].axis('off')\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}