ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • cifar10 Swin Transformer (Base) 학습 - 정리
    모델 학습 2024. 10. 13. 15:57
    반응형

    모델 학습에 관련된 환경 셋팅 및 코드 정리

     

    오늘은 Swin, vit, ResNet, EfficientNet-B0 중 Swin 에 대해 정리 해보겠습니다.

    회사에서 업무를 진행하며 정리 하는것이라 내용이 부족할수도 있습니다.

     

    pytorch Swin Transformer (Base)

    https://pytorch.org/vision/main/models/swin_transformer.html

     

    SwinTransformer — Torchvision main documentation

    Shortcuts

    pytorch.org

     

    Swin Transformer의 주요 특징

    1. 윈도우 기반 어텐션(Window-based Attention): 전체 이미지에 대해 어텐션을 계산하는 것이 아니라, 작은 윈도우(서브 이미지)를 나누어 지역적인 어텐션을 적용합니다. 이를 통해 계산 복잡도를 줄입니다.
    2. 윈도우 이동(Shifted Window): 한 번의 윈도우 어텐션 후 윈도우를 약간 이동시키고, 이 새로운 영역에 대해 다시 어텐션을 계산합니다. 이렇게 하면 전체 이미지에 대한 관계를 효율적으로 파악할 수 있습니다.
    3. 다단계 구조(Hierarchical Structure): Swin Transformer는 여러 레이어로 구성되며, 각 레이어에서 이미지의 해상도를 점차 줄여나가며 더 넓은 영역을 한 번에 처리할 수 있도록 만듭니다. 이는 CNN의 풀링(Pooling)과 비슷한 역할을 합니다.

    Swin Transformer와 CIFAR-10

    CIFAR-10과 같은 작은 이미지 데이터셋에 Swin Transformer를 적용하면, Swin의 지역적인 어텐션 메커니즘이 이미지 내의 패턴을 더 잘 잡아내어 성능을 높일 수 있습니다. 특히, Swin Transformer는 대규모 이미지 데이터셋에서 강력한 성능을 발휘하지만, CIFAR-10과 같은 소규모 데이터셋에서도 경쟁력 있는 결과를 얻을 수 있습니다.

     

     

    1. 환경 설치

    1 CUDA 환경에 맞게 torch, torchvision, torchaudio 를 설치한다.

    https://pytorch.org/get-started/previous-versions/

    2 학습 시 필요한 라이브러리를 설치한다.

    pip3 install torch torchvision torchaudio tqdm

    pip install -r requirements.txt (파일 안 tqdm)

    기타 주의사항

    • 학습 epoch는 50 epoch로 고정한다
    • 첫번째 epoch 시작 전 epoch 0으로 모니터링 로그를 추가한다. (데이터 로드 타임 반영)
    • 매번 epoch마다 모니터링 로그를 추가한다.

     

    train_cifar10_swin.py

    import os
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    import torchvision.transforms as transforms
    
    from torchvision import models
    
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    
    # from monitoring import upd_hw_usage as usage
    
    import time
    import timm
    
    # 데이터 전처리: 이미지를 텐서로 변환하고, CIFAR-10 이미지 크기가 작으므로 크기를 조정
    
    transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 데이터 정규화
    ])
    
    model_nm = 'Swin Transformer'
    dataset_nm = 'CIFAR-10'
    data_size = 0.1592 # GB
    
    # data 다운로드 안된 경우 다운로드
    
    savdir = './data'
    if not os.path.exists(savdir) or len(os.listdir(savdir)) == 0:
    torchvision.datasets.CIFAR10(root=savdir, train=True, download=True, transform=transform)
    torchvision.datasets.CIFAR10(root=savdir, train=False, download=True, transform=transform)
    
    load_stt = time.time()
    
    # CIFAR-10 데이터셋 로드
    
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
    
    load_time = time.time() - load_stt
    
    ### 1. 데이터 로드 시간 모니터링 (다운로드 시간 제외)
    
    # usage.add_to_db(0, model_nm, dataset_nm, data_load_time=load_time, data_tot_amount=data_size)
    
    # Swin Transformer 모델 불러오기
    
    model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
    
    # CIFAR-10에 맞게 출력 레이어 수정 (출력 클래스: 10)
    num_ftrs = model.head.in_features
    model.fc = nn.Linear(num_ftrs, 10)
    
    # GPU가 사용 가능하다면 GPU로 전송
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = [model.to](http://model.to/)(device)
    
    # 손실 함수와 최적화 알고리즘 설정
    
    criterion = nn.CrossEntropyLoss()  # 분류 문제에 적합한 손실 함수
    optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer
    
    def train(model, train_loader, criterion, optimizer, device, num_epochs=10):
    model.train()  # 모델을 학습 모드로 설정
    print('Training start')
    for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in tqdm(train_loader):
    inputs, labels = [inputs.to](http://inputs.to/)(device), [labels.to](http://labels.to/)(device)
    
    ```
            # 그라디언트 초기화
            optimizer.zero_grad()
    
            # 순전파 및 손실 계산
            outputs = model(inputs)
            loss = criterion(outputs, labels)
    
            # 역전파 및 최적화
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
    
        # usage.add_to_db(epoch+1, model_nm, dataset_nm)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
    print('Finished Training')
    
    ```
    
    # train 에서 모델을 학습 시키고 난후 evaluate 에 새로운 데이터를 넣어 정확도가 올바른지 확인하는 목적
    
    def evaluate(model, test_loader, device):
    model.eval()  # 평가 모드로 설정 (dropout, batchnorm 등 비활성화)
    correct = 0
    total = 0
    with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
    inputs, labels = [inputs.to](http://inputs.to/)(device), [labels.to](http://labels.to/)(device)
    outputs = model(inputs)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    
    ```
    print(f'Accuracy: {100 * correct / total:.2f}%')
    
    ```
    
    # 모델 학습
    
    train(model, train_loader, criterion, optimizer, device, num_epochs=2)

     

     

    -sh, py 파일 생성 후 

    run_train_cifar10_swin.sh

    -gpu 5번 사용

    CUDA_VISIBLE_DEVICES=5 python "경로/train_cifar10_swin.py"

    -백그라운드로 실행

    nohup sh run_train_cifar10_swin.sh > nohup_train_cifar10_swin.log 2>&1 &

    -로그 확인

    tail -500f nohup_train_cifar10_swin.log

     

    반응형

    댓글

Designed by Tistory.