{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "1857acd7", "metadata": {}, "source": [ "# Model" ] }, { "cell_type": "code", "execution_count": null, "id": "efb5db0b", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class CNN_WDI(nn.Module):\n", " def __init__(self, class_num=9):\n", " super(CNN_WDI, self).__init__()\n", "\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=0)\n", " self.bn1 = nn.BatchNorm2d(16)\n", " self.pool1 = nn.MaxPool2d(2, 2)\n", " self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1)\n", " self.bn2 = nn.BatchNorm2d(16)\n", "\n", " self.conv3 = nn.Conv2d(16, 32, kernel_size=3, padding=1)\n", " self.bn3 = nn.BatchNorm2d(32)\n", " self.pool2 = nn.MaxPool2d(2, 2)\n", " self.conv4 = nn.Conv2d(32, 32, kernel_size=3, padding=1)\n", " self.bn4 = nn.BatchNorm2d(32)\n", "\n", " self.conv5 = nn.Conv2d(32, 64, kernel_size=3, padding=1)\n", " self.bn5 = nn.BatchNorm2d(64)\n", " self.pool3 = nn.MaxPool2d(2, 2)\n", " self.conv6 = nn.Conv2d(64, 64, kernel_size=3, padding=1)\n", " self.bn6 = nn.BatchNorm2d(64)\n", "\n", " self.conv7 = nn.Conv2d(64, 128, kernel_size=3, padding=1)\n", " self.bn7 = nn.BatchNorm2d(128)\n", " self.pool4 = nn.MaxPool2d(2, 2)\n", " self.conv8 = nn.Conv2d(128, 128, kernel_size=3, padding=1)\n", " self.bn8 = nn.BatchNorm2d(128)\n", "\n", " self.spatial_dropout = nn.Dropout2d(0.2)\n", " self.pool5 = nn.MaxPool2d(2, 2)\n", "\n", " self.fc1 = nn.Linear(4608, 512)\n", " self.fc2 = nn.Linear(512, class_num)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.bn1(self.conv1(x)))\n", " x = self.pool1(F.relu(self.bn2(self.conv2(x))))\n", "\n", " x = F.relu(self.bn3(self.conv3(x)))\n", " x = self.pool2(F.relu(self.bn4(self.conv4(x))))\n", "\n", " x = F.relu(self.bn5(self.conv5(x)))\n", " x = self.pool3(F.relu(self.bn6(self.conv6(x))))\n", "\n", " x = F.relu(self.bn7(self.conv7(x)))\n", " x = self.pool4(F.relu(self.bn8(self.conv8(x))))\n", "\n", " x = self.spatial_dropout(x)\n", " x = self.pool5(x)\n", "\n", " x = x.view(x.size(0), -1)\n", " x = F.relu(self.fc1(x))\n", " x = self.fc2(x)\n", "\n", " return F.softmax(x, dim=1)\n", "\n", "cnn_wdi = CNN_WDI(class_num=9)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1c383602", "metadata": {}, "source": [ "# Load Data" ] }, { "cell_type": "code", "execution_count": null, "id": "a865c00c", "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms, datasets\n", "\n", "# 데이터 전처리\n", "rotation_angles = list(range(0, 361, 15))\n", "rotation_transforms = [transforms.RandomRotation(degrees=(angle, angle), expand=False, center=None, fill=None) for angle in rotation_angles]\n", "\n", "data_transforms = transforms.Compose([\n", " transforms.Pad(padding=224, fill=0, padding_mode='constant'),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomVerticalFlip(),\n", " transforms.RandomApply(rotation_transforms, p=1),\n", " transforms.CenterCrop((224, 224)),\n", " transforms.ToTensor(),\n", "])\n", "\n", "# ImageFolder를 사용하여 데이터셋 불러오기\n", "train_dataset = datasets.ImageFolder(root='E:/wm_images/train/', transform=data_transforms)\n", "val_dataset = datasets.ImageFolder(root='E:/wm_images/val/', transform=data_transforms)\n", "test_dataset = datasets.ImageFolder(root='E:/wm_images/test/', transform=data_transforms)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "36039ab9", "metadata": {}, "source": [ "# Settings" ] }, { "cell_type": "code", "execution_count": null, "id": "b466f397", "metadata": {}, "outputs": [], "source": [ "import torch.optim as optim\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "cnn_wdi.to(device)\n", "print(str(device) + ' loaded.')\n", "\n", "# 손실 함수 및 최적화 알고리즘 설정\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(cnn_wdi.parameters(), lr=0.001)\n", "\n", "# 배치사이즈\n", "batch_size = 18063360 #112\n", "\n", "# 학습 및 평가 실행\n", "num_epochs = 100 #* 192\n", "# num_epochs = 50\n", "\n", "# Random sample size\n", "train_max_images = 95\n", "val_max_images = 25\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cd9ed634", "metadata": {}, "source": [ "# Train Function" ] }, { "cell_type": "code", "execution_count": null, "id": "8020581f", "metadata": {}, "outputs": [], "source": [ "# 학습 함수 정의\n", "def train(model, dataloader, criterion, optimizer, device):\n", " model.train()\n", " running_loss = 0.0\n", " running_corrects = 0\n", "\n", " for inputs, labels in dataloader:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", "\n", " optimizer.zero_grad()\n", "\n", " outputs = model(inputs)\n", " _, preds = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " running_loss += loss.item() * inputs.size(0)\n", " running_corrects += torch.sum(preds == labels.data)\n", "\n", " epoch_loss = running_loss / len(dataloader.dataset)\n", " epoch_acc = running_corrects.double() / len(dataloader.dataset)\n", "\n", " return epoch_loss, epoch_acc" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2fa4e672", "metadata": {}, "source": [ "# Evaluate Function" ] }, { "cell_type": "code", "execution_count": null, "id": "674a1e25", "metadata": {}, "outputs": [], "source": [ "# 평가 함수 정의\n", "def evaluate(model, dataloader, criterion, device):\n", " model.eval()\n", " running_loss = 0.0\n", " running_corrects = 0\n", "\n", " with torch.no_grad():\n", " for inputs, labels in dataloader:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", "\n", " outputs = model(inputs)\n", " _, preds = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", "\n", " running_loss += loss.item() * inputs.size(0)\n", " running_corrects += torch.sum(preds == labels.data)\n", "\n", " epoch_loss = running_loss / len(dataloader.dataset)\n", " epoch_acc = running_corrects.double() / len(dataloader.dataset)\n", "\n", " return epoch_loss, epoch_acc" ] }, { "attachments": {}, "cell_type": "markdown", "id": "42148a41", "metadata": {}, "source": [ "# Train" ] }, { "cell_type": "code", "execution_count": null, "id": "95074e64", "metadata": {}, "outputs": [], "source": [ "# Train & Validation의 Loss, Acc 기록 파일\n", "s_title = 'Epoch,\\tTrain Loss,\\tTrain Acc,\\tVal Loss,\\tVal Acc\\n'\n", "with open('output.txt', 'a') as file:\n", " file.write(s_title)\n", "print(s_title)\n", "\n", "for epoch in range(num_epochs + 1):\n", " # 무작위 샘플 추출\n", " train_indices = torch.randperm(len(train_dataset))[:train_max_images]\n", " train_random_subset = torch.utils.data.Subset(train_dataset, train_indices)\n", " train_loader = torch.utils.data.DataLoader(train_random_subset, batch_size=batch_size, shuffle=True, num_workers=4)\n", " \n", " val_indices = torch.randperm(len(val_dataset))[:val_max_images]\n", " val_random_subset = torch.utils.data.Subset(train_dataset, val_indices)\n", " val_loader = torch.utils.data.DataLoader(val_random_subset, batch_size=batch_size, shuffle=False, num_workers=4)\n", "\n", " # 학습 및 Validation 평가\n", " train_loss, train_acc = train(cnn_wdi, train_loader, criterion, optimizer, device)\n", " val_loss, val_acc = evaluate(cnn_wdi, val_loader, criterion, device)\n", "\n", " # 로그 기록\n", " s_output = f'{epoch + 1}/{num_epochs},\\t{train_loss:.4f},\\t{train_acc:.4f},\\t{val_loss:.4f},\\t{val_acc:.4f}\\n'\n", " with open('output.txt', 'a') as file:\n", " file.write(s_output)\n", " print(s_output)\n", "\n", " if epoch % 10 == 0:\n", " # 모델 저장\n", " torch.save(cnn_wdi.state_dict(), 'CNN_WDI_' + str(epoch) + 'epoch.pth')" ] }, { "attachments": {}, "cell_type": "markdown", "id": "345f1ce5", "metadata": {}, "source": [ "# Confusion Metrix" ] }, { "cell_type": "code", "execution_count": null, "id": "c350bb0d", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "import pandas as pd\n", "\n", "def plot_metrics(title, class_names, precisions, recalls, f1_scores, acc):\n", " num_classes = len(class_names)\n", " index = np.arange(num_classes)\n", " bar_width = 0.2\n", "\n", " plt.figure(figsize=(15, 7))\n", " plt.bar(index, precisions, bar_width, label='Precision')\n", " plt.bar(index + bar_width, recalls, bar_width, label='Recall')\n", " plt.bar(index + 2 * bar_width, f1_scores, bar_width, label='F1-score')\n", " plt.axhline(y=acc, color='r', linestyle='--', label='Accuracy')\n", "\n", " plt.xlabel('Class')\n", " plt.ylabel('Scores')\n", " plt.title(title + ': Precision, Recall, F1-score, and Accuracy per Class')\n", " plt.xticks(index + bar_width, class_names)\n", " plt.legend(loc='upper right')\n", " plt.show()\n", "\n", "def predict_and_plot_metrics(title, model, dataloader, criterion, device):\n", " model.eval()\n", " running_loss = 0.0\n", " running_corrects = 0\n", "\n", " all_preds = []\n", " all_labels = []\n", " class_names = ['Center', 'Donut', 'Edge-Loc', 'Edge-Ring', 'Loc', 'Near-full', 'none', 'Random', 'Scratch']\n", "\n", " with torch.no_grad():\n", " for inputs, labels in dataloader:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", "\n", " outputs = model(inputs)\n", " _, preds = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", "\n", " running_loss += loss.item() * inputs.size(0)\n", " running_corrects += torch.sum(preds == labels.data)\n", "\n", " all_preds.extend(preds.cpu().numpy())\n", " all_labels.extend(labels.cpu().numpy())\n", "\n", " epoch_loss = running_loss / len(dataloader.dataset)\n", " epoch_acc = running_corrects.double() / len(dataloader.dataset)\n", "\n", "\n", " # Calculate classification report\n", " report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)\n", "\n", " # Calculate confusion matrix\n", " cm = confusion_matrix(all_labels, all_preds)\n", "\n", " # Calculate precision, recall, and f1-score per class\n", " precisions = [report[c]['precision'] for c in class_names]\n", " recalls = [report[c]['recall'] for c in class_names]\n", " f1_scores = [report[c]['f1-score'] for c in class_names]\n", " print('p: ' + str(precisions))\n", " print('r: ' + str(recalls))\n", " print('f: ' + str(f1_scores))\n", "\n", " # Plot confusion matrix with normalized values (percentage)\n", " cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " plt.figure(figsize=(12, 12))\n", " sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues', xticklabels=class_names, yticklabels=class_names)\n", " plt.xlabel('Predicted Label')\n", " plt.ylabel('True Label')\n", " plt.title('Normalized Confusion Matrix: ' + title)\n", " plt.show()\n", "\n", " # Plot precision, recall, f1-score, and accuracy per class\n", " plot_metrics(title, class_names, precisions, recalls, f1_scores, epoch_acc.item())\n", "\n", " return epoch_loss, epoch_acc, report\n" ] }, { "cell_type": "markdown", "id": "dfddbcdc", "metadata": {}, "source": [ "# Evaluate" ] }, { "cell_type": "code", "execution_count": null, "id": "d57f59cd", "metadata": {}, "outputs": [], "source": [ "import os\n", "import re\n", "\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=112, shuffle=False, num_workers=4)\n", "\n", "dir = '.'\n", "models = [file for file in os.listdir(dir) if file.endswith(('.pth'))]\n", "\n", "def extract_number(filename):\n", " return int(re.search(r'\\d+', filename).group(0))\n", "\n", "sorted_models = sorted(models, key=extract_number)\n", "\n", "for model in sorted_models:\n", " model_path = os.path.join(dir, model)\n", "\n", " # Load the saved model weights\n", " cnn_wdi.load_state_dict(torch.load(model_path))\n", "\n", " # Call the predict_and_plot_metrics function with the appropriate arguments\n", " epoch_loss, epoch_acc, report = predict_and_plot_metrics(model, cnn_wdi, test_loader, criterion, device)\n", " # print(f'Model: {model} Test Loss: {test_loss:.4f} Acc: {test_acc:.4f}')" ] }, { "attachments": {}, "cell_type": "markdown", "id": "25fa475d", "metadata": {}, "source": [ "# Loss Graph" ] }, { "cell_type": "code", "execution_count": null, "id": "fe2360b6", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# 파일에서 데이터를 읽어들입니다.\n", "with open('output.txt', 'r') as file:\n", " lines = file.readlines()[1:] # 첫 번째 줄은 헤더이므로 건너뜁니다.\n", "\n", "# 데이터를 분석하여 리스트에 저장합니다.\n", "epochs = []\n", "train_losses = []\n", "train_accuracies = []\n", "val_losses = []\n", "val_accuracies = []\n", "\n", "for line in lines:\n", " if line == '\\n':\n", " continue\n", " # epoch, train_loss, train_acc, val_loss, val_acc = line.strip().split(', \\t')\n", " epoch, train_loss, train_acc, val_loss, val_acc = re.split(r'[,\\s\\t]+', line.strip())\n", " epochs.append(int(epoch.split('/')[0]))\n", " train_losses.append(float(train_loss))\n", " train_accuracies.append(float(train_acc))\n", " val_losses.append(float(val_loss))\n", " val_accuracies.append(float(val_acc))\n", "\n", "# 선 그래프를 그립니다.\n", "plt.figure(figsize=(10, 5))\n", "\n", "plt.plot(epochs, train_losses, label='Train Loss')\n", "plt.plot(epochs, train_accuracies, label='Train Acc')\n", "plt.plot(epochs, val_losses, label='Val Loss')\n", "plt.plot(epochs, val_accuracies, label='Val Acc')\n", "\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Values')\n", "plt.title('Training and Validation Loss and Accuracy')\n", "plt.legend()\n", "plt.show()\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "66b5fef7", "metadata": {}, "source": [ "# Print Selecting Test Model Result" ] }, { "cell_type": "code", "execution_count": null, "id": "65a5a7c4", "metadata": {}, "outputs": [], "source": [ "def output(model, dataloader, criterion, device):\n", " model.eval()\n", " running_loss = 0.0\n", " running_corrects = 0\n", "\n", " all_preds = []\n", " all_labels = []\n", " class_names = ['Center', 'Donut', 'Edge-Loc', 'Edge-Ring', 'Loc', 'Near-full', 'none', 'Random', 'Scratch']\n", "\n", " with torch.no_grad():\n", " for inputs, labels in dataloader:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", "\n", " outputs = model(inputs)\n", " _, preds = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", "\n", " running_loss += loss.item() * inputs.size(0)\n", " running_corrects += torch.sum(preds == labels.data)\n", "\n", " all_preds.extend(preds.cpu().numpy())\n", " all_labels.extend(labels.cpu().numpy())\n", "\n", " epoch_loss = running_loss / len(dataloader.dataset)\n", " epoch_acc = running_corrects.double() / len(dataloader.dataset)\n", "\n", "\n", " # Calculate classification report\n", " report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)\n", "\n", " # Calculate precision, recall, and f1-score per class\n", " precisions = [report[c]['precision'] for c in class_names]\n", " recalls = [report[c]['recall'] for c in class_names]\n", " f1_scores = [report[c]['f1-score'] for c in class_names]\n", " accuracy = report['accuracy']\n", "\n", " precs = sum(precisions) / len(precisions)\n", " recs = sum(recalls) / len(recalls)\n", " f1s = sum(f1_scores) / len(f1_scores)\n", " print('precisions: ' + str(precs))\n", " print('recalls: ' + str(recs))\n", " print('f1_scores: ' + str(f1s))\n", " print('accuracy ' + str(accuracy))\n", "\n", "\n", "selected_model = 'CNN_WDI_20epoch.pth'\n", "cnn_wdi.load_state_dict(torch.load(selected_model))\n", "output(cnn_wdi, test_loader, criterion, device)" ] }, { "cell_type": "markdown", "id": "4a2753c3", "metadata": {}, "source": [] }, { "attachments": {}, "cell_type": "markdown", "id": "a9f9a39d", "metadata": {}, "source": [ "# 원본 데이터셋 학습 결과\n", "\n", "### 배치 사이즈\n", "* batch_size = 18063360\n", "\n", "### 학습 및 평가 실행\n", "* num_epochs = 100\n", "\n", "### Random sample size\n", "* train_max_images = 95\n", "* val_max_images = 25\n", "\n", "##### 설정으로 학습 진행 시,\n", "\n", "``` plaintext\n", "Epoch,\tTrain Loss,\tTrain Acc,\tVal Loss,\tVal Acc\n", "\n", "1/100,\t1.5930,\t0.7789,\t1.8920,\t0.4800\n", "\n", "2/100,\t1.5193,\t0.8526,\t1.8520,\t0.5200\n", "\n", "3/100,\t1.4562,\t0.9158,\t1.9320,\t0.4400\n", "\n", "4/100,\t1.5088,\t0.8632,\t1.7720,\t0.6000\n", "\n", "5/100,\t1.5088,\t0.8632,\t1.8920,\t0.4800\n", "\n", "6/100,\t1.4983,\t0.8737,\t1.8520,\t0.5200\n", "\n", "7/100,\t1.4983,\t0.8737,\t1.9720,\t0.4000\n", "\n", "8/100,\t1.5720,\t0.8000,\t2.0120,\t0.3600\n", "...\n", "```\n", "\n", ": 데이터 부족으로 학습이 제대로 이루어지지 않음." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }