cifar10 Swin Transformer (Base) 학습 - 정리
모델 학습에 관련된 환경 셋팅 및 코드 정리
오늘은 Swin, vit, ResNet, EfficientNet-B0 중 Swin 에 대해 정리 해보겠습니다.
회사에서 업무를 진행하며 정리 하는것이라 내용이 부족할수도 있습니다.
pytorch Swin Transformer (Base)
https://pytorch.org/vision/main/models/swin_transformer.html
SwinTransformer — Torchvision main documentation
Shortcuts
pytorch.org
Swin Transformer의 주요 특징
- 윈도우 기반 어텐션(Window-based Attention): 전체 이미지에 대해 어텐션을 계산하는 것이 아니라, 작은 윈도우(서브 이미지)를 나누어 지역적인 어텐션을 적용합니다. 이를 통해 계산 복잡도를 줄입니다.
- 윈도우 이동(Shifted Window): 한 번의 윈도우 어텐션 후 윈도우를 약간 이동시키고, 이 새로운 영역에 대해 다시 어텐션을 계산합니다. 이렇게 하면 전체 이미지에 대한 관계를 효율적으로 파악할 수 있습니다.
- 다단계 구조(Hierarchical Structure): Swin Transformer는 여러 레이어로 구성되며, 각 레이어에서 이미지의 해상도를 점차 줄여나가며 더 넓은 영역을 한 번에 처리할 수 있도록 만듭니다. 이는 CNN의 풀링(Pooling)과 비슷한 역할을 합니다.
Swin Transformer와 CIFAR-10
CIFAR-10과 같은 작은 이미지 데이터셋에 Swin Transformer를 적용하면, Swin의 지역적인 어텐션 메커니즘이 이미지 내의 패턴을 더 잘 잡아내어 성능을 높일 수 있습니다. 특히, Swin Transformer는 대규모 이미지 데이터셋에서 강력한 성능을 발휘하지만, CIFAR-10과 같은 소규모 데이터셋에서도 경쟁력 있는 결과를 얻을 수 있습니다.
1. 환경 설치
1 CUDA 환경에 맞게 torch, torchvision, torchaudio 를 설치한다.
https://pytorch.org/get-started/previous-versions/
2 학습 시 필요한 라이브러리를 설치한다.
pip3 install torch torchvision torchaudio tqdm
pip install -r requirements.txt (파일 안 tqdm)
기타 주의사항
- 학습 epoch는 50 epoch로 고정한다
- 첫번째 epoch 시작 전 epoch 0으로 모니터링 로그를 추가한다. (데이터 로드 타임 반영)
- 매번 epoch마다 모니터링 로그를 추가한다.
train_cifar10_swin.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from tqdm import tqdm
# from monitoring import upd_hw_usage as usage
import time
import timm
# 데이터 전처리: 이미지를 텐서로 변환하고, CIFAR-10 이미지 크기가 작으므로 크기를 조정
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 데이터 정규화
])
model_nm = 'Swin Transformer'
dataset_nm = 'CIFAR-10'
data_size = 0.1592 # GB
# data 다운로드 안된 경우 다운로드
savdir = './data'
if not os.path.exists(savdir) or len(os.listdir(savdir)) == 0:
torchvision.datasets.CIFAR10(root=savdir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=savdir, train=False, download=True, transform=transform)
load_stt = time.time()
# CIFAR-10 데이터셋 로드
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
load_time = time.time() - load_stt
### 1. 데이터 로드 시간 모니터링 (다운로드 시간 제외)
# usage.add_to_db(0, model_nm, dataset_nm, data_load_time=load_time, data_tot_amount=data_size)
# Swin Transformer 모델 불러오기
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
# CIFAR-10에 맞게 출력 레이어 수정 (출력 클래스: 10)
num_ftrs = model.head.in_features
model.fc = nn.Linear(num_ftrs, 10)
# GPU가 사용 가능하다면 GPU로 전송
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = [model.to](http://model.to/)(device)
# 손실 함수와 최적화 알고리즘 설정
criterion = nn.CrossEntropyLoss() # 분류 문제에 적합한 손실 함수
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam optimizer
def train(model, train_loader, criterion, optimizer, device, num_epochs=10):
model.train() # 모델을 학습 모드로 설정
print('Training start')
for epoch in range(num_epochs):
running_loss = 0.0
for inputs, labels in tqdm(train_loader):
inputs, labels = [inputs.to](http://inputs.to/)(device), [labels.to](http://labels.to/)(device)
```
# 그라디언트 초기화
optimizer.zero_grad()
# 순전파 및 손실 계산
outputs = model(inputs)
loss = criterion(outputs, labels)
# 역전파 및 최적화
loss.backward()
optimizer.step()
running_loss += loss.item()
# usage.add_to_db(epoch+1, model_nm, dataset_nm)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
print('Finished Training')
```
# train 에서 모델을 학습 시키고 난후 evaluate 에 새로운 데이터를 넣어 정확도가 올바른지 확인하는 목적
def evaluate(model, test_loader, device):
model.eval() # 평가 모드로 설정 (dropout, batchnorm 등 비활성화)
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(test_loader):
inputs, labels = [inputs.to](http://inputs.to/)(device), [labels.to](http://labels.to/)(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
```
print(f'Accuracy: {100 * correct / total:.2f}%')
```
# 모델 학습
train(model, train_loader, criterion, optimizer, device, num_epochs=2)
-sh, py 파일 생성 후
run_train_cifar10_swin.sh
-gpu 5번 사용
CUDA_VISIBLE_DEVICES=5 python "경로/train_cifar10_swin.py"
-백그라운드로 실행
nohup sh run_train_cifar10_swin.sh > nohup_train_cifar10_swin.log 2>&1 &
-로그 확인
tail -500f nohup_train_cifar10_swin.log