From 5954cf2cfe02ba62d5b9390b8b59d0691bca2ad1 Mon Sep 17 00:00:00 2001 From: pinb Date: Thu, 27 Apr 2023 00:58:12 +0000 Subject: [PATCH] =?UTF-8?q?=EC=97=85=EB=8D=B0=EC=9D=B4=ED=8A=B8=20'readme.?= =?UTF-8?q?md'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- readme.md | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/readme.md b/readme.md index c34791e..201c52c 100644 --- a/readme.md +++ b/readme.md @@ -13,11 +13,13 @@ https://gitea.pinblog.codes/attachments/9b2424f7-7e7d-4ad1-a368-86a523d67504 * 원본 https://gitea.pinblog.codes/attachments/9a31bb80-bc0a-4d5a-83b1-4ef0557456ad + ----- ### Dataset [Kaggle - WDI Data](https://www.kaggle.com/qingyi/wm811k-wafer-map/code) + ----- ### 수행방법 @@ -29,8 +31,8 @@ https://gitea.pinblog.codes/attachments/9a31bb80-bc0a-4d5a-83b1-4ef0557456ad https://gitea.pinblog.codes/CBNU/03_WDI_CNN/releases/tag/info - # Model +# Model ```python import torch @@ -99,7 +101,6 @@ cnn_wdi = CNN_WDI(class_num=9) # Load Data - ```python from torchvision import transforms, datasets @@ -124,7 +125,6 @@ test_dataset = datasets.ImageFolder(root='E:/wm_images/test/', transform=data_tr # Settings - ```python import torch.optim as optim @@ -149,11 +149,8 @@ val_max_images = 25 ``` - - # Train Function - ```python # 학습 함수 정의 def train(model, dataloader, criterion, optimizer, device): @@ -185,7 +182,6 @@ def train(model, dataloader, criterion, optimizer, device): # Evaluate Function - ```python # 평가 함수 정의 def evaluate(model, dataloader, criterion, device): @@ -246,14 +242,15 @@ for epoch in range(num_epochs + 1): torch.save(cnn_wdi.state_dict(), 'CNN_WDI_' + str(epoch) + 'epoch.pth') ``` + ----- ### 평가방법 * 모델의 성능지표(Precision, Recall, Accuracy, F1-Score)를 혼동행렬(Confusion Metrix)로 구현한다. -# Confusion Metrix +# Confusion Metrix ```python import numpy as np @@ -340,7 +337,6 @@ def predict_and_plot_metrics(title, model, dataloader, criterion, device): # Evaluate - ```python import os import re @@ -368,7 +364,6 @@ for model in sorted_models: # Loss Graph - ```python import matplotlib.pyplot as plt @@ -412,7 +407,6 @@ plt.show() # Print Selecting Test Model Result - ```python def output(model, dataloader, criterion, device): model.eval()