|
|
|
@ -13,11 +13,13 @@ https://gitea.pinblog.codes/attachments/9b2424f7-7e7d-4ad1-a368-86a523d67504
|
|
|
|
|
* 원본
|
|
|
|
|
https://gitea.pinblog.codes/attachments/9a31bb80-bc0a-4d5a-83b1-4ef0557456ad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
-----
|
|
|
|
|
|
|
|
|
|
### Dataset
|
|
|
|
|
[Kaggle - WDI Data](https://www.kaggle.com/qingyi/wm811k-wafer-map/code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
-----
|
|
|
|
|
|
|
|
|
|
### 수행방법
|
|
|
|
@ -29,8 +31,8 @@ https://gitea.pinblog.codes/attachments/9a31bb80-bc0a-4d5a-83b1-4ef0557456ad
|
|
|
|
|
|
|
|
|
|
https://gitea.pinblog.codes/CBNU/03_WDI_CNN/releases/tag/info
|
|
|
|
|
|
|
|
|
|
# Model
|
|
|
|
|
|
|
|
|
|
# Model
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
import torch
|
|
|
|
@ -99,7 +101,6 @@ cnn_wdi = CNN_WDI(class_num=9)
|
|
|
|
|
|
|
|
|
|
# Load Data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
from torchvision import transforms, datasets
|
|
|
|
|
|
|
|
|
@ -124,7 +125,6 @@ test_dataset = datasets.ImageFolder(root='E:/wm_images/test/', transform=data_tr
|
|
|
|
|
|
|
|
|
|
# Settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
|
|
|
|
@ -149,11 +149,8 @@ val_max_images = 25
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Train Function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
# 학습 함수 정의
|
|
|
|
|
def train(model, dataloader, criterion, optimizer, device):
|
|
|
|
@ -185,7 +182,6 @@ def train(model, dataloader, criterion, optimizer, device):
|
|
|
|
|
|
|
|
|
|
# Evaluate Function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
# 평가 함수 정의
|
|
|
|
|
def evaluate(model, dataloader, criterion, device):
|
|
|
|
@ -246,14 +242,15 @@ for epoch in range(num_epochs + 1):
|
|
|
|
|
torch.save(cnn_wdi.state_dict(), 'CNN_WDI_' + str(epoch) + 'epoch.pth')
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
-----
|
|
|
|
|
|
|
|
|
|
### 평가방법
|
|
|
|
|
|
|
|
|
|
* 모델의 성능지표(Precision, Recall, Accuracy, F1-Score)를 혼동행렬(Confusion Metrix)로 구현한다.
|
|
|
|
|
|
|
|
|
|
# Confusion Metrix
|
|
|
|
|
|
|
|
|
|
# Confusion Metrix
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
import numpy as np
|
|
|
|
@ -340,7 +337,6 @@ def predict_and_plot_metrics(title, model, dataloader, criterion, device):
|
|
|
|
|
|
|
|
|
|
# Evaluate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
@ -368,7 +364,6 @@ for model in sorted_models:
|
|
|
|
|
|
|
|
|
|
# Loss Graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
@ -412,7 +407,6 @@ plt.show()
|
|
|
|
|
|
|
|
|
|
# Print Selecting Test Model Result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def output(model, dataloader, criterion, device):
|
|
|
|
|
model.eval()
|
|
|
|
|