Compare commits
14 Commits
Author | SHA1 | Date |
---|---|---|
pinb | 38f2c2c47d | 2 years ago |
pinb | ac750b1d1a | 2 years ago |
pinb | d8ccdc87a7 | 2 years ago |
pinb | 5a4956063d | 2 years ago |
pinb | 9b4f8ac3d4 | 2 years ago |
pinb | 5954cf2cfe | 2 years ago |
pinb | c9dd9d969d | 2 years ago |
pinb | 73f9072254 | 2 years ago |
pinb | 3f7b346fd7 | 2 years ago |
pinb | c1a6685630 | 2 years ago |
pinb | c4f1bf542c | 2 years ago |
pinb | 93a3d862dc | 2 years ago |
pinb | b1787f033b | 2 years ago |
pinb | 8fb51417bb | 2 years ago |
@ -1,31 +1,551 @@
|
|||||||
# 지능화 캡스톤 프로젝트 #1 - WDI-CNN
|
# 지능화 캡스톤 프로젝트 #1 - WDI-CNN
|
||||||
### *(Wafer Map 데이터를 9종류의 Class로 분류하는 CNN 모델 만들기)*
|
### *Wafer Map 데이터를 9종류의 Class로 분류하는 CNN 모델 만들기*
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
|
||||||
|
[PINBlog Gitea Repository](https://gitea.pinblog.codes/CBNU/03_WDI_CNN)
|
||||||
|
|
||||||
-----
|
-----
|
||||||
|
|
||||||
|
|
||||||
### 논문
|
### 논문
|
||||||
반도체 제조공정의 불균형 데이터셋에 대한 웨이퍼 불량 식별을 위한 심층 컨볼루션 신경망
|
A Deep Convolutional Neural Network for Wafer Defect Identification on an Imbalanced Dataset in Semiconductor Manufacturing Processes
|
||||||
|
|
||||||
|
(반도체 제조공정의 불균형 데이터셋에 대한 웨이퍼 불량 식별을 위한 심층 컨볼루션 신경망)
|
||||||
|
|
||||||
* 번역본
|
* [번역본](https://gitea.pinblog.codes/attachments/9b2424f7-7e7d-4ad1-a368-86a523d67504)
|
||||||
https://gitea.pinblog.codes/attachments/9b2424f7-7e7d-4ad1-a368-86a523d67504
|
|
||||||
|
* [원본](https://gitea.pinblog.codes/attachments/9a31bb80-bc0a-4d5a-83b1-4ef0557456ad)
|
||||||
|
|
||||||
|
* 인용된 논문 리스트
|
||||||
|
|
||||||
|
| 번호 | 논문 제목 | 저자 | 출판사 및 링크 |
|
||||||
|
|------|-----------|------|----------------|
|
||||||
|
| 1 | Deep Residual Learning for Image Recognition | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016. [링크](https://arxiv.org/abs/1512.03385) |
|
||||||
|
| 2 | Very Deep Convolutional Networks for Large-Scale Image Recognition | Karen Simonyan, Andrew Zisserman | International Conference on Learning Representations (ICLR), 2015. [링크](https://arxiv.org/abs/1409.1556) |
|
||||||
|
| 3 | Going Deeper with Convolutions | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2015. [링크](https://arxiv.org/abs/1409.4842) |
|
||||||
|
| 4 | Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift | Sergey Ioffe, Christian Szegedy | International Conference on Machine Learning (ICML), 2015. [링크](https://arxiv.org/abs/1502.03167) |
|
||||||
|
| 5 | Dropout: A Simple Way to Prevent Neural Networks from Overfitting | Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, Ruslan Salakhutdinov | Journal of Machine Learning Research (JMLR), 2014. [링크](http://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) |
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
* 원본
|
|
||||||
https://gitea.pinblog.codes/attachments/9a31bb80-bc0a-4d5a-83b1-4ef0557456ad
|
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
[Kaggle - WDI Data](https://www.kaggle.com/qingyi/wm811k-wafer-map/code)
|
[Kaggle - WDI Data](https://www.kaggle.com/qingyi/wm811k-wafer-map/code)
|
||||||
|
|
||||||
|
[Pickle Dataset](https://gitea.pinblog.codes/attachments/d16767f7-a31a-4455-a550-70fa4c660b7d)
|
||||||
|
|
||||||
|
[BMP Dataset](https://gitea.pinblog.codes/attachments/be9fa247-3c31-4db1-88a0-390814190532)
|
||||||
|
|
||||||
-----
|
-----
|
||||||
|
|
||||||
|
|
||||||
### 수행방법
|
### 수행방법
|
||||||
|
|
||||||
* 위 논문을 참고하여 CNN 모델을 구현하고,
|
* 위 논문을 참고하여 CNN 모델을 구현하고,
|
||||||
WDI Dataset을 학습하여 9개의 클래스로 분류한다.
|
WDI Dataset을 학습하여 9개의 클래스로 분류한다.
|
||||||
(Center, Donut, Edge-Loc, Edge-Ring, Loc, Near-full, none, Random, Scratch)
|
|
||||||
|
| 클래스 | 라벨 | Train 이미지 개수 | Validation 이미지 개수 | Test 이미지 개수 |
|
||||||
|
|--------|------|-------------------|------------------------|------------------|
|
||||||
|
| None | 0 | 117,431 | 15,000 | 15,000 |
|
||||||
|
| Center | 1 | 3,294 | 500 | 500 |
|
||||||
|
| Donut | 2 | 444 | 50 | 50 |
|
||||||
|
| Edge-Loc | 3 | 4,189 | 500 | 500 |
|
||||||
|
| Edge-Ring |4 |7,680 |1,000 |1,000 |
|
||||||
|
| Local |5 |2,794 |400 |400 |
|
||||||
|
| Random |6 |666 |100 |100 |
|
||||||
|
| Scratch |7 |894 |150 |150 |
|
||||||
|
| Near-full |8 |149 |- |- |
|
||||||
|
|
||||||
|
[프로젝트 관련 자료](https://gitea.pinblog.codes/CBNU/03_WDI_CNN/releases/tag/info)
|
||||||
|
|
||||||
|
|
||||||
|
# Model
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class CNN_WDI(nn.Module):
|
||||||
|
def __init__(self, class_num=9):
|
||||||
|
super(CNN_WDI, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=0)
|
||||||
|
self.bn1 = nn.BatchNorm2d(16)
|
||||||
|
self.pool1 = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
|
||||||
|
self.bn2 = nn.BatchNorm2d(16)
|
||||||
|
|
||||||
|
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
||||||
|
self.bn3 = nn.BatchNorm2d(32)
|
||||||
|
self.pool2 = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
|
||||||
|
self.bn4 = nn.BatchNorm2d(32)
|
||||||
|
|
||||||
|
self.conv5 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
||||||
|
self.bn5 = nn.BatchNorm2d(64)
|
||||||
|
self.pool3 = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv6 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
|
||||||
|
self.bn6 = nn.BatchNorm2d(64)
|
||||||
|
|
||||||
|
self.conv7 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
||||||
|
self.bn7 = nn.BatchNorm2d(128)
|
||||||
|
self.pool4 = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv8 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
||||||
|
self.bn8 = nn.BatchNorm2d(128)
|
||||||
|
|
||||||
|
self.spatial_dropout = nn.Dropout2d(0.2)
|
||||||
|
self.pool5 = nn.MaxPool2d(2, 2)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(4608, 512)
|
||||||
|
self.fc2 = nn.Linear(512, class_num)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
x = self.pool1(F.relu(self.bn2(self.conv2(x))))
|
||||||
|
|
||||||
|
x = F.relu(self.bn3(self.conv3(x)))
|
||||||
|
x = self.pool2(F.relu(self.bn4(self.conv4(x))))
|
||||||
|
|
||||||
|
x = F.relu(self.bn5(self.conv5(x)))
|
||||||
|
x = self.pool3(F.relu(self.bn6(self.conv6(x))))
|
||||||
|
|
||||||
|
x = F.relu(self.bn7(self.conv7(x)))
|
||||||
|
x = self.pool4(F.relu(self.bn8(self.conv8(x))))
|
||||||
|
|
||||||
|
x = self.spatial_dropout(x)
|
||||||
|
x = self.pool5(x)
|
||||||
|
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = self.fc2(x)
|
||||||
|
|
||||||
|
return F.softmax(x, dim=1)
|
||||||
|
|
||||||
|
cnn_wdi = CNN_WDI(class_num=9)
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
# Load Data
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
from torchvision import transforms, datasets
|
||||||
|
|
||||||
|
# 데이터 전처리
|
||||||
|
rotation_angles = list(range(0, 361, 15))
|
||||||
|
rotation_transforms = [transforms.RandomRotation(degrees=(angle, angle), expand=False, center=None, fill=None) for angle in rotation_angles]
|
||||||
|
|
||||||
|
data_transforms = transforms.Compose([
|
||||||
|
transforms.Pad(padding=224, fill=0, padding_mode='constant'),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomVerticalFlip(),
|
||||||
|
transforms.RandomApply(rotation_transforms, p=1),
|
||||||
|
transforms.CenterCrop((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
|
# ImageFolder를 사용하여 데이터셋 불러오기
|
||||||
|
train_dataset = datasets.ImageFolder(root='E:/wm_images/train/', transform=data_transforms)
|
||||||
|
val_dataset = datasets.ImageFolder(root='E:/wm_images/val/', transform=data_transforms)
|
||||||
|
test_dataset = datasets.ImageFolder(root='E:/wm_images/test/', transform=data_transforms)
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
# Settings
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
cnn_wdi.to(device)
|
||||||
|
print(str(device) + ' loaded.')
|
||||||
|
|
||||||
|
# 손실 함수 및 최적화 알고리즘 설정
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(cnn_wdi.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# 배치사이즈
|
||||||
|
batch_size = 18063360 #112
|
||||||
|
|
||||||
|
# 학습 및 평가 실행
|
||||||
|
num_epochs = 100 #* 192
|
||||||
|
# num_epochs = 50
|
||||||
|
|
||||||
|
# Random sample size
|
||||||
|
train_max_images = 95
|
||||||
|
val_max_images = 25
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
# Train Function
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
# 학습 함수 정의
|
||||||
|
def train(model, dataloader, criterion, optimizer, device):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
running_corrects = 0
|
||||||
|
|
||||||
|
for inputs, labels in dataloader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
_, preds = torch.max(outputs, 1)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / len(dataloader.dataset)
|
||||||
|
epoch_acc = running_corrects.double() / len(dataloader.dataset)
|
||||||
|
|
||||||
|
return epoch_loss, epoch_acc
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
# Evaluate Function
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
# 평가 함수 정의
|
||||||
|
def evaluate(model, dataloader, criterion, device):
|
||||||
|
model.eval()
|
||||||
|
running_loss = 0.0
|
||||||
|
running_corrects = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, labels in dataloader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
_, preds = torch.max(outputs, 1)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / len(dataloader.dataset)
|
||||||
|
epoch_acc = running_corrects.double() / len(dataloader.dataset)
|
||||||
|
|
||||||
|
return epoch_loss, epoch_acc
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
# Train
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
# Train & Validation의 Loss, Acc 기록 파일
|
||||||
|
s_title = 'Epoch,\tTrain Loss,\tTrain Acc,\tVal Loss,\tVal Acc\n'
|
||||||
|
with open('output.txt', 'a') as file:
|
||||||
|
file.write(s_title)
|
||||||
|
print(s_title)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs + 1):
|
||||||
|
# 무작위 샘플 추출
|
||||||
|
train_indices = torch.randperm(len(train_dataset))[:train_max_images]
|
||||||
|
train_random_subset = torch.utils.data.Subset(train_dataset, train_indices)
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_random_subset, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
val_indices = torch.randperm(len(val_dataset))[:val_max_images]
|
||||||
|
val_random_subset = torch.utils.data.Subset(train_dataset, val_indices)
|
||||||
|
val_loader = torch.utils.data.DataLoader(val_random_subset, batch_size=batch_size, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# 학습 및 Validation 평가
|
||||||
|
train_loss, train_acc = train(cnn_wdi, train_loader, criterion, optimizer, device)
|
||||||
|
val_loss, val_acc = evaluate(cnn_wdi, val_loader, criterion, device)
|
||||||
|
|
||||||
|
# 로그 기록
|
||||||
|
s_output = f'{epoch + 1}/{num_epochs},\t{train_loss:.4f},\t{train_acc:.4f},\t{val_loss:.4f},\t{val_acc:.4f}\n'
|
||||||
|
with open('output.txt', 'a') as file:
|
||||||
|
file.write(s_output)
|
||||||
|
print(s_output)
|
||||||
|
|
||||||
|
if epoch % 10 == 0:
|
||||||
|
# 모델 저장
|
||||||
|
torch.save(cnn_wdi.state_dict(), 'CNN_WDI_' + str(epoch) + 'epoch.pth')
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
|
||||||
### 평가방법
|
### 평가방법
|
||||||
|
|
||||||
* 모델의 성능지표(Precision, Recall, Accuracy, F1-Score)를 혼동행렬(Confusion Metrix)로 구현한다.
|
* 모델의 성능지표(Precision, Recall, Accuracy, F1-Score)를 혼동행렬(Confusion Metrix)로 구현한다.
|
||||||
|
|
||||||
|
|
||||||
|
# Confusion Metrix
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
from sklearn.metrics import classification_report, confusion_matrix
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
def plot_metrics(title, class_names, precisions, recalls, f1_scores, acc):
|
||||||
|
num_classes = len(class_names)
|
||||||
|
index = np.arange(num_classes)
|
||||||
|
bar_width = 0.2
|
||||||
|
|
||||||
|
plt.figure(figsize=(15, 7))
|
||||||
|
plt.bar(index, precisions, bar_width, label='Precision')
|
||||||
|
plt.bar(index + bar_width, recalls, bar_width, label='Recall')
|
||||||
|
plt.bar(index + 2 * bar_width, f1_scores, bar_width, label='F1-score')
|
||||||
|
plt.axhline(y=acc, color='r', linestyle='--', label='Accuracy')
|
||||||
|
|
||||||
|
plt.xlabel('Class')
|
||||||
|
plt.ylabel('Scores')
|
||||||
|
plt.title(title + ': Precision, Recall, F1-score, and Accuracy per Class')
|
||||||
|
plt.xticks(index + bar_width, class_names)
|
||||||
|
plt.legend(loc='upper right')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def predict_and_plot_metrics(title, model, dataloader, criterion, device):
|
||||||
|
model.eval()
|
||||||
|
running_loss = 0.0
|
||||||
|
running_corrects = 0
|
||||||
|
|
||||||
|
all_preds = []
|
||||||
|
all_labels = []
|
||||||
|
class_names = ['Center', 'Donut', 'Edge-Loc', 'Edge-Ring', 'Loc', 'Near-full', 'none', 'Random', 'Scratch']
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, labels in dataloader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
_, preds = torch.max(outputs, 1)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
|
all_preds.extend(preds.cpu().numpy())
|
||||||
|
all_labels.extend(labels.cpu().numpy())
|
||||||
|
|
||||||
|
epoch_loss = running_loss / len(dataloader.dataset)
|
||||||
|
epoch_acc = running_corrects.double() / len(dataloader.dataset)
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate classification report
|
||||||
|
report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
|
||||||
|
|
||||||
|
# Calculate confusion matrix
|
||||||
|
cm = confusion_matrix(all_labels, all_preds)
|
||||||
|
|
||||||
|
# Calculate precision, recall, and f1-score per class
|
||||||
|
precisions = [report[c]['precision'] for c in class_names]
|
||||||
|
recalls = [report[c]['recall'] for c in class_names]
|
||||||
|
f1_scores = [report[c]['f1-score'] for c in class_names]
|
||||||
|
print('p: ' + str(precisions))
|
||||||
|
print('r: ' + str(recalls))
|
||||||
|
print('f: ' + str(f1_scores))
|
||||||
|
|
||||||
|
# Plot confusion matrix with normalized values (percentage)
|
||||||
|
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||||
|
plt.figure(figsize=(12, 12))
|
||||||
|
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
|
||||||
|
plt.xlabel('Predicted Label')
|
||||||
|
plt.ylabel('True Label')
|
||||||
|
plt.title('Normalized Confusion Matrix: ' + title)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Plot precision, recall, f1-score, and accuracy per class
|
||||||
|
plot_metrics(title, class_names, precisions, recalls, f1_scores, epoch_acc.item())
|
||||||
|
|
||||||
|
return epoch_loss, epoch_acc, report
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=112, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
dir = '.'
|
||||||
|
models = [file for file in os.listdir(dir) if file.endswith(('.pth'))]
|
||||||
|
|
||||||
|
def extract_number(filename):
|
||||||
|
return int(re.search(r'\d+', filename).group(0))
|
||||||
|
|
||||||
|
sorted_models = sorted(models, key=extract_number)
|
||||||
|
|
||||||
|
for model in sorted_models:
|
||||||
|
model_path = os.path.join(dir, model)
|
||||||
|
|
||||||
|
# Load the saved model weights
|
||||||
|
cnn_wdi.load_state_dict(torch.load(model_path))
|
||||||
|
|
||||||
|
# Call the predict_and_plot_metrics function with the appropriate arguments
|
||||||
|
epoch_loss, epoch_acc, report = predict_and_plot_metrics(model, cnn_wdi, test_loader, criterion, device)
|
||||||
|
# print(f'Model: {model} Test Loss: {test_loss:.4f} Acc: {test_acc:.4f}')
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
# Loss Graph
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# 파일에서 데이터를 읽어들입니다.
|
||||||
|
with open('output.txt', 'r') as file:
|
||||||
|
lines = file.readlines()[1:] # 첫 번째 줄은 헤더이므로 건너뜁니다.
|
||||||
|
|
||||||
|
# 데이터를 분석하여 리스트에 저장합니다.
|
||||||
|
epochs = []
|
||||||
|
train_losses = []
|
||||||
|
train_accuracies = []
|
||||||
|
val_losses = []
|
||||||
|
val_accuracies = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line == '\n':
|
||||||
|
continue
|
||||||
|
# epoch, train_loss, train_acc, val_loss, val_acc = line.strip().split(', \t')
|
||||||
|
epoch, train_loss, train_acc, val_loss, val_acc = re.split(r'[,\s\t]+', line.strip())
|
||||||
|
epochs.append(int(epoch.split('/')[0]))
|
||||||
|
train_losses.append(float(train_loss))
|
||||||
|
train_accuracies.append(float(train_acc))
|
||||||
|
val_losses.append(float(val_loss))
|
||||||
|
val_accuracies.append(float(val_acc))
|
||||||
|
|
||||||
|
# 선 그래프를 그립니다.
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
|
||||||
|
plt.plot(epochs, train_losses, label='Train Loss')
|
||||||
|
plt.plot(epochs, train_accuracies, label='Train Acc')
|
||||||
|
plt.plot(epochs, val_losses, label='Val Loss')
|
||||||
|
plt.plot(epochs, val_accuracies, label='Val Acc')
|
||||||
|
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Values')
|
||||||
|
plt.title('Training and Validation Loss and Accuracy')
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
# Print Selecting Test Model Result
|
||||||
|
<details>
|
||||||
|
<summary>Code View</summary>
|
||||||
|
<div markdown="1">
|
||||||
|
|
||||||
|
````python
|
||||||
|
def output(model, dataloader, criterion, device):
|
||||||
|
model.eval()
|
||||||
|
running_loss = 0.0
|
||||||
|
running_corrects = 0
|
||||||
|
|
||||||
|
all_preds = []
|
||||||
|
all_labels = []
|
||||||
|
class_names = ['Center', 'Donut', 'Edge-Loc', 'Edge-Ring', 'Loc', 'Near-full', 'none', 'Random', 'Scratch']
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, labels in dataloader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
_, preds = torch.max(outputs, 1)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
|
all_preds.extend(preds.cpu().numpy())
|
||||||
|
all_labels.extend(labels.cpu().numpy())
|
||||||
|
|
||||||
|
epoch_loss = running_loss / len(dataloader.dataset)
|
||||||
|
epoch_acc = running_corrects.double() / len(dataloader.dataset)
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate classification report
|
||||||
|
report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
|
||||||
|
|
||||||
|
# Calculate precision, recall, and f1-score per class
|
||||||
|
precisions = [report[c]['precision'] for c in class_names]
|
||||||
|
recalls = [report[c]['recall'] for c in class_names]
|
||||||
|
f1_scores = [report[c]['f1-score'] for c in class_names]
|
||||||
|
accuracy = report['accuracy']
|
||||||
|
|
||||||
|
precs = sum(precisions) / len(precisions)
|
||||||
|
recs = sum(recalls) / len(recalls)
|
||||||
|
f1s = sum(f1_scores) / len(f1_scores)
|
||||||
|
print('precisions: ' + str(precs))
|
||||||
|
print('recalls: ' + str(recs))
|
||||||
|
print('f1_scores: ' + str(f1s))
|
||||||
|
print('accuracy ' + str(accuracy))
|
||||||
|
|
||||||
|
|
||||||
|
selected_model = 'CNN_WDI_20epoch.pth'
|
||||||
|
cnn_wdi.load_state_dict(torch.load(selected_model))
|
||||||
|
output(cnn_wdi, test_loader, criterion, device)
|
||||||
|
````
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
|
||||||
|
### 테스트 결과
|
||||||
|
|
||||||
|
[1차 테스트](https://gitea.pinblog.codes/CBNU/03_WDI_CNN/wiki/1%EC%B0%A8-%ED%85%8C%EC%8A%A4%ED%8A%B8_%EC%9B%90%EB%B3%B8-%EB%8D%B0%EC%9D%B4%ED%84%B0-%ED%95%99%EC%8A%B5)
|
||||||
|
Loading…
Reference in New Issue