init
commit
d7df4d2537
@ -0,0 +1,618 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "1857acd7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "efb5db0b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"class CNN_WDI(nn.Module):\n",
|
||||
" def __init__(self, class_num=9):\n",
|
||||
" super(CNN_WDI, self).__init__()\n",
|
||||
"\n",
|
||||
" self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=0)\n",
|
||||
" self.bn1 = nn.BatchNorm2d(16)\n",
|
||||
" self.pool1 = nn.MaxPool2d(2, 2)\n",
|
||||
" self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1)\n",
|
||||
" self.bn2 = nn.BatchNorm2d(16)\n",
|
||||
"\n",
|
||||
" self.conv3 = nn.Conv2d(16, 32, kernel_size=3, padding=1)\n",
|
||||
" self.bn3 = nn.BatchNorm2d(32)\n",
|
||||
" self.pool2 = nn.MaxPool2d(2, 2)\n",
|
||||
" self.conv4 = nn.Conv2d(32, 32, kernel_size=3, padding=1)\n",
|
||||
" self.bn4 = nn.BatchNorm2d(32)\n",
|
||||
"\n",
|
||||
" self.conv5 = nn.Conv2d(32, 64, kernel_size=3, padding=1)\n",
|
||||
" self.bn5 = nn.BatchNorm2d(64)\n",
|
||||
" self.pool3 = nn.MaxPool2d(2, 2)\n",
|
||||
" self.conv6 = nn.Conv2d(64, 64, kernel_size=3, padding=1)\n",
|
||||
" self.bn6 = nn.BatchNorm2d(64)\n",
|
||||
"\n",
|
||||
" self.conv7 = nn.Conv2d(64, 128, kernel_size=3, padding=1)\n",
|
||||
" self.bn7 = nn.BatchNorm2d(128)\n",
|
||||
" self.pool4 = nn.MaxPool2d(2, 2)\n",
|
||||
" self.conv8 = nn.Conv2d(128, 128, kernel_size=3, padding=1)\n",
|
||||
" self.bn8 = nn.BatchNorm2d(128)\n",
|
||||
"\n",
|
||||
" self.spatial_dropout = nn.Dropout2d(0.2)\n",
|
||||
" self.pool5 = nn.MaxPool2d(2, 2)\n",
|
||||
"\n",
|
||||
" self.fc1 = nn.Linear(4608, 512)\n",
|
||||
" self.fc2 = nn.Linear(512, class_num)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = F.relu(self.bn1(self.conv1(x)))\n",
|
||||
" x = self.pool1(F.relu(self.bn2(self.conv2(x))))\n",
|
||||
"\n",
|
||||
" x = F.relu(self.bn3(self.conv3(x)))\n",
|
||||
" x = self.pool2(F.relu(self.bn4(self.conv4(x))))\n",
|
||||
"\n",
|
||||
" x = F.relu(self.bn5(self.conv5(x)))\n",
|
||||
" x = self.pool3(F.relu(self.bn6(self.conv6(x))))\n",
|
||||
"\n",
|
||||
" x = F.relu(self.bn7(self.conv7(x)))\n",
|
||||
" x = self.pool4(F.relu(self.bn8(self.conv8(x))))\n",
|
||||
"\n",
|
||||
" x = self.spatial_dropout(x)\n",
|
||||
" x = self.pool5(x)\n",
|
||||
"\n",
|
||||
" x = x.view(x.size(0), -1)\n",
|
||||
" x = F.relu(self.fc1(x))\n",
|
||||
" x = self.fc2(x)\n",
|
||||
"\n",
|
||||
" return F.softmax(x, dim=1)\n",
|
||||
"\n",
|
||||
"cnn_wdi = CNN_WDI(class_num=9)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "1c383602",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Load Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a865c00c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torchvision import transforms, datasets\n",
|
||||
"\n",
|
||||
"# 데이터 전처리\n",
|
||||
"rotation_angles = list(range(0, 361, 15))\n",
|
||||
"rotation_transforms = [transforms.RandomRotation(degrees=(angle, angle), expand=False, center=None, fill=None) for angle in rotation_angles]\n",
|
||||
"\n",
|
||||
"data_transforms = transforms.Compose([\n",
|
||||
" transforms.Pad(padding=224, fill=0, padding_mode='constant'),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.RandomVerticalFlip(),\n",
|
||||
" transforms.RandomApply(rotation_transforms, p=1),\n",
|
||||
" transforms.CenterCrop((224, 224)),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"# ImageFolder를 사용하여 데이터셋 불러오기\n",
|
||||
"train_dataset = datasets.ImageFolder(root='E:/wm_images/train/', transform=data_transforms)\n",
|
||||
"val_dataset = datasets.ImageFolder(root='E:/wm_images/val/', transform=data_transforms)\n",
|
||||
"test_dataset = datasets.ImageFolder(root='E:/wm_images/test/', transform=data_transforms)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "36039ab9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Settings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b466f397",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.optim as optim\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"cnn_wdi.to(device)\n",
|
||||
"print(str(device) + ' loaded.')\n",
|
||||
"\n",
|
||||
"# 손실 함수 및 최적화 알고리즘 설정\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"optimizer = optim.Adam(cnn_wdi.parameters(), lr=0.001)\n",
|
||||
"\n",
|
||||
"# 배치사이즈\n",
|
||||
"batch_size = 18063360 #112\n",
|
||||
"\n",
|
||||
"# 학습 및 평가 실행\n",
|
||||
"num_epochs = 100 #* 192\n",
|
||||
"# num_epochs = 50\n",
|
||||
"\n",
|
||||
"# Random sample size\n",
|
||||
"train_max_images = 95\n",
|
||||
"val_max_images = 25\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "cd9ed634",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train Function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8020581f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 학습 함수 정의\n",
|
||||
"def train(model, dataloader, criterion, optimizer, device):\n",
|
||||
" model.train()\n",
|
||||
" running_loss = 0.0\n",
|
||||
" running_corrects = 0\n",
|
||||
"\n",
|
||||
" for inputs, labels in dataloader:\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
"\n",
|
||||
" outputs = model(inputs)\n",
|
||||
" _, preds = torch.max(outputs, 1)\n",
|
||||
" loss = criterion(outputs, labels)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" running_loss += loss.item() * inputs.size(0)\n",
|
||||
" running_corrects += torch.sum(preds == labels.data)\n",
|
||||
"\n",
|
||||
" epoch_loss = running_loss / len(dataloader.dataset)\n",
|
||||
" epoch_acc = running_corrects.double() / len(dataloader.dataset)\n",
|
||||
"\n",
|
||||
" return epoch_loss, epoch_acc"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "2fa4e672",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Evaluate Function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "674a1e25",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 평가 함수 정의\n",
|
||||
"def evaluate(model, dataloader, criterion, device):\n",
|
||||
" model.eval()\n",
|
||||
" running_loss = 0.0\n",
|
||||
" running_corrects = 0\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for inputs, labels in dataloader:\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
"\n",
|
||||
" outputs = model(inputs)\n",
|
||||
" _, preds = torch.max(outputs, 1)\n",
|
||||
" loss = criterion(outputs, labels)\n",
|
||||
"\n",
|
||||
" running_loss += loss.item() * inputs.size(0)\n",
|
||||
" running_corrects += torch.sum(preds == labels.data)\n",
|
||||
"\n",
|
||||
" epoch_loss = running_loss / len(dataloader.dataset)\n",
|
||||
" epoch_acc = running_corrects.double() / len(dataloader.dataset)\n",
|
||||
"\n",
|
||||
" return epoch_loss, epoch_acc"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "42148a41",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95074e64",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Train & Validation의 Loss, Acc 기록 파일\n",
|
||||
"s_title = 'Epoch,\\tTrain Loss,\\tTrain Acc,\\tVal Loss,\\tVal Acc\\n'\n",
|
||||
"with open('output.txt', 'a') as file:\n",
|
||||
" file.write(s_title)\n",
|
||||
"print(s_title)\n",
|
||||
"\n",
|
||||
"for epoch in range(num_epochs + 1):\n",
|
||||
" # 무작위 샘플 추출\n",
|
||||
" train_indices = torch.randperm(len(train_dataset))[:train_max_images]\n",
|
||||
" train_random_subset = torch.utils.data.Subset(train_dataset, train_indices)\n",
|
||||
" train_loader = torch.utils.data.DataLoader(train_random_subset, batch_size=batch_size, shuffle=True, num_workers=4)\n",
|
||||
" \n",
|
||||
" val_indices = torch.randperm(len(val_dataset))[:val_max_images]\n",
|
||||
" val_random_subset = torch.utils.data.Subset(train_dataset, val_indices)\n",
|
||||
" val_loader = torch.utils.data.DataLoader(val_random_subset, batch_size=batch_size, shuffle=False, num_workers=4)\n",
|
||||
"\n",
|
||||
" # 학습 및 Validation 평가\n",
|
||||
" train_loss, train_acc = train(cnn_wdi, train_loader, criterion, optimizer, device)\n",
|
||||
" val_loss, val_acc = evaluate(cnn_wdi, val_loader, criterion, device)\n",
|
||||
"\n",
|
||||
" # 로그 기록\n",
|
||||
" s_output = f'{epoch + 1}/{num_epochs},\\t{train_loss:.4f},\\t{train_acc:.4f},\\t{val_loss:.4f},\\t{val_acc:.4f}\\n'\n",
|
||||
" with open('output.txt', 'a') as file:\n",
|
||||
" file.write(s_output)\n",
|
||||
" print(s_output)\n",
|
||||
"\n",
|
||||
" if epoch % 10 == 0:\n",
|
||||
" # 모델 저장\n",
|
||||
" torch.save(cnn_wdi.state_dict(), 'CNN_WDI_' + str(epoch) + 'epoch.pth')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "345f1ce5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Confusion Metrix"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c350bb0d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import seaborn as sns\n",
|
||||
"from sklearn.metrics import classification_report, confusion_matrix\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"def plot_metrics(title, class_names, precisions, recalls, f1_scores, acc):\n",
|
||||
" num_classes = len(class_names)\n",
|
||||
" index = np.arange(num_classes)\n",
|
||||
" bar_width = 0.2\n",
|
||||
"\n",
|
||||
" plt.figure(figsize=(15, 7))\n",
|
||||
" plt.bar(index, precisions, bar_width, label='Precision')\n",
|
||||
" plt.bar(index + bar_width, recalls, bar_width, label='Recall')\n",
|
||||
" plt.bar(index + 2 * bar_width, f1_scores, bar_width, label='F1-score')\n",
|
||||
" plt.axhline(y=acc, color='r', linestyle='--', label='Accuracy')\n",
|
||||
"\n",
|
||||
" plt.xlabel('Class')\n",
|
||||
" plt.ylabel('Scores')\n",
|
||||
" plt.title(title + ': Precision, Recall, F1-score, and Accuracy per Class')\n",
|
||||
" plt.xticks(index + bar_width, class_names)\n",
|
||||
" plt.legend(loc='upper right')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"def predict_and_plot_metrics(title, model, dataloader, criterion, device):\n",
|
||||
" model.eval()\n",
|
||||
" running_loss = 0.0\n",
|
||||
" running_corrects = 0\n",
|
||||
"\n",
|
||||
" all_preds = []\n",
|
||||
" all_labels = []\n",
|
||||
" class_names = ['Center', 'Donut', 'Edge-Loc', 'Edge-Ring', 'Loc', 'Near-full', 'none', 'Random', 'Scratch']\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for inputs, labels in dataloader:\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
"\n",
|
||||
" outputs = model(inputs)\n",
|
||||
" _, preds = torch.max(outputs, 1)\n",
|
||||
" loss = criterion(outputs, labels)\n",
|
||||
"\n",
|
||||
" running_loss += loss.item() * inputs.size(0)\n",
|
||||
" running_corrects += torch.sum(preds == labels.data)\n",
|
||||
"\n",
|
||||
" all_preds.extend(preds.cpu().numpy())\n",
|
||||
" all_labels.extend(labels.cpu().numpy())\n",
|
||||
"\n",
|
||||
" epoch_loss = running_loss / len(dataloader.dataset)\n",
|
||||
" epoch_acc = running_corrects.double() / len(dataloader.dataset)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Calculate classification report\n",
|
||||
" report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)\n",
|
||||
"\n",
|
||||
" # Calculate confusion matrix\n",
|
||||
" cm = confusion_matrix(all_labels, all_preds)\n",
|
||||
"\n",
|
||||
" # Calculate precision, recall, and f1-score per class\n",
|
||||
" precisions = [report[c]['precision'] for c in class_names]\n",
|
||||
" recalls = [report[c]['recall'] for c in class_names]\n",
|
||||
" f1_scores = [report[c]['f1-score'] for c in class_names]\n",
|
||||
" print('p: ' + str(precisions))\n",
|
||||
" print('r: ' + str(recalls))\n",
|
||||
" print('f: ' + str(f1_scores))\n",
|
||||
"\n",
|
||||
" # Plot confusion matrix with normalized values (percentage)\n",
|
||||
" cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
|
||||
" plt.figure(figsize=(12, 12))\n",
|
||||
" sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues', xticklabels=class_names, yticklabels=class_names)\n",
|
||||
" plt.xlabel('Predicted Label')\n",
|
||||
" plt.ylabel('True Label')\n",
|
||||
" plt.title('Normalized Confusion Matrix: ' + title)\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
" # Plot precision, recall, f1-score, and accuracy per class\n",
|
||||
" plot_metrics(title, class_names, precisions, recalls, f1_scores, epoch_acc.item())\n",
|
||||
"\n",
|
||||
" return epoch_loss, epoch_acc, report\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dfddbcdc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d57f59cd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=112, shuffle=False, num_workers=4)\n",
|
||||
"\n",
|
||||
"dir = '.'\n",
|
||||
"models = [file for file in os.listdir(dir) if file.endswith(('.pth'))]\n",
|
||||
"\n",
|
||||
"def extract_number(filename):\n",
|
||||
" return int(re.search(r'\\d+', filename).group(0))\n",
|
||||
"\n",
|
||||
"sorted_models = sorted(models, key=extract_number)\n",
|
||||
"\n",
|
||||
"for model in sorted_models:\n",
|
||||
" model_path = os.path.join(dir, model)\n",
|
||||
"\n",
|
||||
" # Load the saved model weights\n",
|
||||
" cnn_wdi.load_state_dict(torch.load(model_path))\n",
|
||||
"\n",
|
||||
" # Call the predict_and_plot_metrics function with the appropriate arguments\n",
|
||||
" epoch_loss, epoch_acc, report = predict_and_plot_metrics(model, cnn_wdi, test_loader, criterion, device)\n",
|
||||
" # print(f'Model: {model} Test Loss: {test_loss:.4f} Acc: {test_acc:.4f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "25fa475d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Loss Graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fe2360b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# 파일에서 데이터를 읽어들입니다.\n",
|
||||
"with open('output.txt', 'r') as file:\n",
|
||||
" lines = file.readlines()[1:] # 첫 번째 줄은 헤더이므로 건너뜁니다.\n",
|
||||
"\n",
|
||||
"# 데이터를 분석하여 리스트에 저장합니다.\n",
|
||||
"epochs = []\n",
|
||||
"train_losses = []\n",
|
||||
"train_accuracies = []\n",
|
||||
"val_losses = []\n",
|
||||
"val_accuracies = []\n",
|
||||
"\n",
|
||||
"for line in lines:\n",
|
||||
" if line == '\\n':\n",
|
||||
" continue\n",
|
||||
" # epoch, train_loss, train_acc, val_loss, val_acc = line.strip().split(', \\t')\n",
|
||||
" epoch, train_loss, train_acc, val_loss, val_acc = re.split(r'[,\\s\\t]+', line.strip())\n",
|
||||
" epochs.append(int(epoch.split('/')[0]))\n",
|
||||
" train_losses.append(float(train_loss))\n",
|
||||
" train_accuracies.append(float(train_acc))\n",
|
||||
" val_losses.append(float(val_loss))\n",
|
||||
" val_accuracies.append(float(val_acc))\n",
|
||||
"\n",
|
||||
"# 선 그래프를 그립니다.\n",
|
||||
"plt.figure(figsize=(10, 5))\n",
|
||||
"\n",
|
||||
"plt.plot(epochs, train_losses, label='Train Loss')\n",
|
||||
"plt.plot(epochs, train_accuracies, label='Train Acc')\n",
|
||||
"plt.plot(epochs, val_losses, label='Val Loss')\n",
|
||||
"plt.plot(epochs, val_accuracies, label='Val Acc')\n",
|
||||
"\n",
|
||||
"plt.xlabel('Epochs')\n",
|
||||
"plt.ylabel('Values')\n",
|
||||
"plt.title('Training and Validation Loss and Accuracy')\n",
|
||||
"plt.legend()\n",
|
||||
"plt.show()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "66b5fef7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Print Selecting Test Model Result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "65a5a7c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def output(model, dataloader, criterion, device):\n",
|
||||
" model.eval()\n",
|
||||
" running_loss = 0.0\n",
|
||||
" running_corrects = 0\n",
|
||||
"\n",
|
||||
" all_preds = []\n",
|
||||
" all_labels = []\n",
|
||||
" class_names = ['Center', 'Donut', 'Edge-Loc', 'Edge-Ring', 'Loc', 'Near-full', 'none', 'Random', 'Scratch']\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for inputs, labels in dataloader:\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
"\n",
|
||||
" outputs = model(inputs)\n",
|
||||
" _, preds = torch.max(outputs, 1)\n",
|
||||
" loss = criterion(outputs, labels)\n",
|
||||
"\n",
|
||||
" running_loss += loss.item() * inputs.size(0)\n",
|
||||
" running_corrects += torch.sum(preds == labels.data)\n",
|
||||
"\n",
|
||||
" all_preds.extend(preds.cpu().numpy())\n",
|
||||
" all_labels.extend(labels.cpu().numpy())\n",
|
||||
"\n",
|
||||
" epoch_loss = running_loss / len(dataloader.dataset)\n",
|
||||
" epoch_acc = running_corrects.double() / len(dataloader.dataset)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Calculate classification report\n",
|
||||
" report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)\n",
|
||||
"\n",
|
||||
" # Calculate precision, recall, and f1-score per class\n",
|
||||
" precisions = [report[c]['precision'] for c in class_names]\n",
|
||||
" recalls = [report[c]['recall'] for c in class_names]\n",
|
||||
" f1_scores = [report[c]['f1-score'] for c in class_names]\n",
|
||||
" accuracy = report['accuracy']\n",
|
||||
"\n",
|
||||
" precs = sum(precisions) / len(precisions)\n",
|
||||
" recs = sum(recalls) / len(recalls)\n",
|
||||
" f1s = sum(f1_scores) / len(f1_scores)\n",
|
||||
" print('precisions: ' + str(precs))\n",
|
||||
" print('recalls: ' + str(recs))\n",
|
||||
" print('f1_scores: ' + str(f1s))\n",
|
||||
" print('accuracy ' + str(accuracy))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"selected_model = 'CNN_WDI_20epoch.pth'\n",
|
||||
"cnn_wdi.load_state_dict(torch.load(selected_model))\n",
|
||||
"output(cnn_wdi, test_loader, criterion, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4a2753c3",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "a9f9a39d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 원본 데이터셋 학습 결과\n",
|
||||
"\n",
|
||||
"### 배치 사이즈\n",
|
||||
"* batch_size = 18063360\n",
|
||||
"\n",
|
||||
"### 학습 및 평가 실행\n",
|
||||
"* num_epochs = 100\n",
|
||||
"\n",
|
||||
"### Random sample size\n",
|
||||
"* train_max_images = 95\n",
|
||||
"* val_max_images = 25\n",
|
||||
"\n",
|
||||
"##### 설정으로 학습 진행 시,\n",
|
||||
"\n",
|
||||
"``` plaintext\n",
|
||||
"Epoch,\tTrain Loss,\tTrain Acc,\tVal Loss,\tVal Acc\n",
|
||||
"\n",
|
||||
"1/100,\t1.5930,\t0.7789,\t1.8920,\t0.4800\n",
|
||||
"\n",
|
||||
"2/100,\t1.5193,\t0.8526,\t1.8520,\t0.5200\n",
|
||||
"\n",
|
||||
"3/100,\t1.4562,\t0.9158,\t1.9320,\t0.4400\n",
|
||||
"\n",
|
||||
"4/100,\t1.5088,\t0.8632,\t1.7720,\t0.6000\n",
|
||||
"\n",
|
||||
"5/100,\t1.5088,\t0.8632,\t1.8920,\t0.4800\n",
|
||||
"\n",
|
||||
"6/100,\t1.4983,\t0.8737,\t1.8520,\t0.5200\n",
|
||||
"\n",
|
||||
"7/100,\t1.4983,\t0.8737,\t1.9720,\t0.4000\n",
|
||||
"\n",
|
||||
"8/100,\t1.5720,\t0.8000,\t2.0120,\t0.3600\n",
|
||||
"...\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
": 데이터 부족으로 학습이 제대로 이루어지지 않음."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
# 지능화 캡스톤 프로젝트 #1 - WDI-CNN
|
||||
### *(Wafer Map 데이터를 9종류의 Class로 분류하는 CNN 모델 만들기)*
|
||||
|
||||
|
||||
-----
|
||||
|
||||
### 논문
|
||||
![반도체 제조공정의 불균형 데이터셋에 대한 웨이퍼 불량 식별을 위한 심층 컨볼루션 신경망](https://file.notion.so/f/s/f80d2b5c-ac36-435b-8cef-19f3f5675940/(%EC%B5%9C%EC%A2%85)%EC%B0%B8%EA%B3%A0%EC%9E%90%EB%A3%8C_%EB%85%BC%EB%AC%B8_Wafer_Map_%EB%B6%88%EB%9F%89%EA%B2%80%EC%B6%9C_%EB%B2%88%EC%97%AD.pdf?id=805060e2-1f8d-4cc0-aba3-26e7b83ed5e9&table=block&spaceId=5c35ea55-42f1-4c42-b112-94f6eb8e2c2e&expirationTimestamp=1682615407281&signature=fK1nI4Wihr3P4g6i0RxbQ8insN8Gcr27vY5_DI0tctk&downloadName=%28%EC%B5%9C%EC%A2%85%29%EC%B0%B8%EA%B3%A0%EC%9E%90%EB%A3%8C_%EB%85%BC%EB%AC%B8_Wafer+Map+%EB%B6%88%EB%9F%89%EA%B2%80%EC%B6%9C_%EB%B2%88%EC%97%AD.pdf)
|
||||
|
||||
### Dataset
|
||||
[Kaggle - WDI Data](https://www.kaggle.com/qingyi/wm811k-wafer-map/code)
|
||||
|
||||
-----
|
||||
|
||||
### 수행방법
|
||||
|
||||
* 위 논문을 참고하여 CNN 모델을 구현하고, WDI Dataset을 학습하여 9개의 클래스(Center, Donut, Edge-Loc, Edge-Ring, Loc, Near-full, none, Random, Scratch)를 분류한다.
|
||||
|
||||
### 평가방법
|
||||
|
||||
* 모델의 성능지표(Precision, Recall, Accuracy, F1-Score)를 혼동행렬(Confusion Metrix)로 구현한다.
|
Loading…
Reference in New Issue