모델 학습

cifar10 Swin Transformer (Base) 학습 - 정리

이나주니 2024. 10. 13. 15:57
반응형

모델 학습에 관련된 환경 셋팅 및 코드 정리

 

오늘은 Swin, vit, ResNet, EfficientNet-B0 중 Swin 에 대해 정리 해보겠습니다.

회사에서 업무를 진행하며 정리 하는것이라 내용이 부족할수도 있습니다.

 

pytorch Swin Transformer (Base)

https://pytorch.org/vision/main/models/swin_transformer.html

 

SwinTransformer — Torchvision main documentation

Shortcuts

pytorch.org

 

Swin Transformer의 주요 특징

  1. 윈도우 기반 어텐션(Window-based Attention): 전체 이미지에 대해 어텐션을 계산하는 것이 아니라, 작은 윈도우(서브 이미지)를 나누어 지역적인 어텐션을 적용합니다. 이를 통해 계산 복잡도를 줄입니다.
  2. 윈도우 이동(Shifted Window): 한 번의 윈도우 어텐션 후 윈도우를 약간 이동시키고, 이 새로운 영역에 대해 다시 어텐션을 계산합니다. 이렇게 하면 전체 이미지에 대한 관계를 효율적으로 파악할 수 있습니다.
  3. 다단계 구조(Hierarchical Structure): Swin Transformer는 여러 레이어로 구성되며, 각 레이어에서 이미지의 해상도를 점차 줄여나가며 더 넓은 영역을 한 번에 처리할 수 있도록 만듭니다. 이는 CNN의 풀링(Pooling)과 비슷한 역할을 합니다.

Swin Transformer와 CIFAR-10

CIFAR-10과 같은 작은 이미지 데이터셋에 Swin Transformer를 적용하면, Swin의 지역적인 어텐션 메커니즘이 이미지 내의 패턴을 더 잘 잡아내어 성능을 높일 수 있습니다. 특히, Swin Transformer는 대규모 이미지 데이터셋에서 강력한 성능을 발휘하지만, CIFAR-10과 같은 소규모 데이터셋에서도 경쟁력 있는 결과를 얻을 수 있습니다.

 

 

1. 환경 설치

1 CUDA 환경에 맞게 torch, torchvision, torchaudio 를 설치한다.

https://pytorch.org/get-started/previous-versions/

2 학습 시 필요한 라이브러리를 설치한다.

pip3 install torch torchvision torchaudio tqdm

pip install -r requirements.txt (파일 안 tqdm)

기타 주의사항

  • 학습 epoch는 50 epoch로 고정한다
  • 첫번째 epoch 시작 전 epoch 0으로 모니터링 로그를 추가한다. (데이터 로드 타임 반영)
  • 매번 epoch마다 모니터링 로그를 추가한다.

 

train_cifar10_swin.py

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torchvision import models

from torch.utils.data import DataLoader
from tqdm import tqdm

# from monitoring import upd_hw_usage as usage

import time
import timm

# 데이터 전처리: 이미지를 텐서로 변환하고, CIFAR-10 이미지 크기가 작으므로 크기를 조정

transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 데이터 정규화
])

model_nm = 'Swin Transformer'
dataset_nm = 'CIFAR-10'
data_size = 0.1592 # GB

# data 다운로드 안된 경우 다운로드

savdir = './data'
if not os.path.exists(savdir) or len(os.listdir(savdir)) == 0:
torchvision.datasets.CIFAR10(root=savdir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=savdir, train=False, download=True, transform=transform)

load_stt = time.time()

# CIFAR-10 데이터셋 로드

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

load_time = time.time() - load_stt

### 1. 데이터 로드 시간 모니터링 (다운로드 시간 제외)

# usage.add_to_db(0, model_nm, dataset_nm, data_load_time=load_time, data_tot_amount=data_size)

# Swin Transformer 모델 불러오기

model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)

# CIFAR-10에 맞게 출력 레이어 수정 (출력 클래스: 10)
num_ftrs = model.head.in_features
model.fc = nn.Linear(num_ftrs, 10)

# GPU가 사용 가능하다면 GPU로 전송

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = [model.to](http://model.to/)(device)

# 손실 함수와 최적화 알고리즘 설정

criterion = nn.CrossEntropyLoss()  # 분류 문제에 적합한 손실 함수
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer

def train(model, train_loader, criterion, optimizer, device, num_epochs=10):
model.train()  # 모델을 학습 모드로 설정
print('Training start')
for epoch in range(num_epochs):
running_loss = 0.0
for inputs, labels in tqdm(train_loader):
inputs, labels = [inputs.to](http://inputs.to/)(device), [labels.to](http://labels.to/)(device)

```
        # 그라디언트 초기화
        optimizer.zero_grad()

        # 순전파 및 손실 계산
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 역전파 및 최적화
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # usage.add_to_db(epoch+1, model_nm, dataset_nm)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
print('Finished Training')

```

# train 에서 모델을 학습 시키고 난후 evaluate 에 새로운 데이터를 넣어 정확도가 올바른지 확인하는 목적

def evaluate(model, test_loader, device):
model.eval()  # 평가 모드로 설정 (dropout, batchnorm 등 비활성화)
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(test_loader):
inputs, labels = [inputs.to](http://inputs.to/)(device), [labels.to](http://labels.to/)(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

```
print(f'Accuracy: {100 * correct / total:.2f}%')

```

# 모델 학습

train(model, train_loader, criterion, optimizer, device, num_epochs=2)

 

 

-sh, py 파일 생성 후 

run_train_cifar10_swin.sh

-gpu 5번 사용

CUDA_VISIBLE_DEVICES=5 python "경로/train_cifar10_swin.py"

-백그라운드로 실행

nohup sh run_train_cifar10_swin.sh > nohup_train_cifar10_swin.log 2>&1 &

-로그 확인

tail -500f nohup_train_cifar10_swin.log

 

반응형