-
cifar10 Swin Transformer (Base) 학습 - 정리모델 학습 2024. 10. 13. 15:57반응형
모델 학습에 관련된 환경 셋팅 및 코드 정리
오늘은 Swin, vit, ResNet, EfficientNet-B0 중 Swin 에 대해 정리 해보겠습니다.
회사에서 업무를 진행하며 정리 하는것이라 내용이 부족할수도 있습니다.
pytorch Swin Transformer (Base)
https://pytorch.org/vision/main/models/swin_transformer.html
Swin Transformer의 주요 특징
- 윈도우 기반 어텐션(Window-based Attention): 전체 이미지에 대해 어텐션을 계산하는 것이 아니라, 작은 윈도우(서브 이미지)를 나누어 지역적인 어텐션을 적용합니다. 이를 통해 계산 복잡도를 줄입니다.
- 윈도우 이동(Shifted Window): 한 번의 윈도우 어텐션 후 윈도우를 약간 이동시키고, 이 새로운 영역에 대해 다시 어텐션을 계산합니다. 이렇게 하면 전체 이미지에 대한 관계를 효율적으로 파악할 수 있습니다.
- 다단계 구조(Hierarchical Structure): Swin Transformer는 여러 레이어로 구성되며, 각 레이어에서 이미지의 해상도를 점차 줄여나가며 더 넓은 영역을 한 번에 처리할 수 있도록 만듭니다. 이는 CNN의 풀링(Pooling)과 비슷한 역할을 합니다.
Swin Transformer와 CIFAR-10
CIFAR-10과 같은 작은 이미지 데이터셋에 Swin Transformer를 적용하면, Swin의 지역적인 어텐션 메커니즘이 이미지 내의 패턴을 더 잘 잡아내어 성능을 높일 수 있습니다. 특히, Swin Transformer는 대규모 이미지 데이터셋에서 강력한 성능을 발휘하지만, CIFAR-10과 같은 소규모 데이터셋에서도 경쟁력 있는 결과를 얻을 수 있습니다.
1. 환경 설치
1 CUDA 환경에 맞게 torch, torchvision, torchaudio 를 설치한다.
https://pytorch.org/get-started/previous-versions/
2 학습 시 필요한 라이브러리를 설치한다.
pip3 install torch torchvision torchaudio tqdm
pip install -r requirements.txt (파일 안 tqdm)
기타 주의사항
- 학습 epoch는 50 epoch로 고정한다
- 첫번째 epoch 시작 전 epoch 0으로 모니터링 로그를 추가한다. (데이터 로드 타임 반영)
- 매번 epoch마다 모니터링 로그를 추가한다.
train_cifar10_swin.py
import os import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torchvision import models from torch.utils.data import DataLoader from tqdm import tqdm # from monitoring import upd_hw_usage as usage import time import timm # 데이터 전처리: 이미지를 텐서로 변환하고, CIFAR-10 이미지 크기가 작으므로 크기를 조정 transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 데이터 정규화 ]) model_nm = 'Swin Transformer' dataset_nm = 'CIFAR-10' data_size = 0.1592 # GB # data 다운로드 안된 경우 다운로드 savdir = './data' if not os.path.exists(savdir) or len(os.listdir(savdir)) == 0: torchvision.datasets.CIFAR10(root=savdir, train=True, download=True, transform=transform) torchvision.datasets.CIFAR10(root=savdir, train=False, download=True, transform=transform) load_stt = time.time() # CIFAR-10 데이터셋 로드 train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) load_time = time.time() - load_stt ### 1. 데이터 로드 시간 모니터링 (다운로드 시간 제외) # usage.add_to_db(0, model_nm, dataset_nm, data_load_time=load_time, data_tot_amount=data_size) # Swin Transformer 모델 불러오기 model = timm.create_model('swin_base_patch4_window7_224', pretrained=True) # CIFAR-10에 맞게 출력 레이어 수정 (출력 클래스: 10) num_ftrs = model.head.in_features model.fc = nn.Linear(num_ftrs, 10) # GPU가 사용 가능하다면 GPU로 전송 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = [model.to](http://model.to/)(device) # 손실 함수와 최적화 알고리즘 설정 criterion = nn.CrossEntropyLoss() # 분류 문제에 적합한 손실 함수 optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam optimizer def train(model, train_loader, criterion, optimizer, device, num_epochs=10): model.train() # 모델을 학습 모드로 설정 print('Training start') for epoch in range(num_epochs): running_loss = 0.0 for inputs, labels in tqdm(train_loader): inputs, labels = [inputs.to](http://inputs.to/)(device), [labels.to](http://labels.to/)(device) ``` # 그라디언트 초기화 optimizer.zero_grad() # 순전파 및 손실 계산 outputs = model(inputs) loss = criterion(outputs, labels) # 역전파 및 최적화 loss.backward() optimizer.step() running_loss += loss.item() # usage.add_to_db(epoch+1, model_nm, dataset_nm) print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}') print('Finished Training') ``` # train 에서 모델을 학습 시키고 난후 evaluate 에 새로운 데이터를 넣어 정확도가 올바른지 확인하는 목적 def evaluate(model, test_loader, device): model.eval() # 평가 모드로 설정 (dropout, batchnorm 등 비활성화) correct = 0 total = 0 with torch.no_grad(): for inputs, labels in tqdm(test_loader): inputs, labels = [inputs.to](http://inputs.to/)(device), [labels.to](http://labels.to/)(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() ``` print(f'Accuracy: {100 * correct / total:.2f}%') ``` # 모델 학습 train(model, train_loader, criterion, optimizer, device, num_epochs=2)
-sh, py 파일 생성 후
run_train_cifar10_swin.sh
-gpu 5번 사용
CUDA_VISIBLE_DEVICES=5 python "경로/train_cifar10_swin.py"
-백그라운드로 실행
nohup sh run_train_cifar10_swin.sh > nohup_train_cifar10_swin.log 2>&1 &
-로그 확인
tail -500f nohup_train_cifar10_swin.log
반응형