TTA(Test time Augmentation)

2024. 11. 28. 14:11·CV(컴퓨터비전)

1. TTA란?

TTA란 Train 과정이 아닌 Test(Inference) 과정에서 Augmentation을 적용하여 나온 결과들에 대해 대표값 (대체로는 평균값)을 최종 예측값으로 활용한다. 이렇게 하면 보다 모델이 일관되고 강력한 예측을 할 수 있게된다. 

 

https://stepup.ai/test_time_data_augmentation/

 

해당 방법이 효과적인 이유는 무작위로 변형된 이미지에 대한 예측을평균 내면서 오류도 평균화 하기 때문이다. 단일 벡터에서는 오류가 커질 수 있지만, 이를 평균내면 올바른 예측을 할 수 있도록 유도할 수 있다. 이 때문에 TTA는 모델이 확신하지 못하는 테스트 이미지에 특히 유용하다. 

 

2. 코드실습(Pytorch) 

https://github.com/qubvel/ttach

 

GitHub - qubvel/ttach: Image Test Time Augmentation with PyTorch!

Image Test Time Augmentation with PyTorch! Contribute to qubvel/ttach development by creating an account on GitHub.

github.com

https://github.com/andrewekhalel/edafa

 

GitHub - andrewekhalel/edafa: Test Time Augmentation (TTA) wrapper for computer vision tasks: segmentation, classification, supe

Test Time Augmentation (TTA) wrapper for computer vision tasks: segmentation, classification, super-resolution, ... etc. - andrewekhalel/edafa

github.com

 

 

여러 블로그 글들을 찾아보니 해당 라이브러리를 많이 사용하는 듯 했다.

간단하게 사용할 수 있는 라이브러리는 tta 이고 ,  직접 custom을 원하는 경우 edafa를 주로 사용하시는 거 같다. 

 

이번실습에서는 총 8가지로 증강하고자 한다.

→ (0도, 90도 , 180도, 270도 회전) X ( 그대로, 뒤집어서) 

img_path = './dacon/data/test_input/TEST_000.png'

def load_image(image_path):
    image = Image.open(image_path).convert('RGB') #컬러채널 추가 
    tensor_image = transforms.ToTensor()(image) #PyTorch Tensor로 변환 
    tensor_image = tensor_image.unsqueeze(0) #배치 차원 추가 
    return tensor_image 

tensor_image = load_image(img_path)
tensor_image

# define 2*4 = 8 augmentations 
transforms_tta = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles = [0,90,180,270])
    ]
)

augmented_image = [transformer.augment_image(tensor_image) for transformer in transforms_tta] 

fig , axes = plt.subplots(2,4,figsize = (10,5)) 
for ax , img in zip(axes.flatten(),augmented_image): 
    ax.imshow(img.squeeze(0).numpy().transpose(1,2,0)) # 배치 차원 제거 및 HWC로 변경 
    ax.axis('off')
plt.show()
  • axes 는 plt.subplots로 생성된 2X4 형태의 numpy 배열이다. 
  • .flatten()을 활용해 1차원 배열로 변환 (for 문을 순회하기 위해서는 1차원 배열이 되어야 함) 
axes = [[ax1, ax2, ax3, ax4], [ax5, ax6, ax7, ax8]]
axes.flatten() = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8]
  • image.squeeze(0) : 이미지가 배치(batch) 차원을 포함하고 있다고 가정할 때, 배치 차원 제거 
  • transpose(1,2,0) : Tensor의 차원 순서 변경
    • Pytorch는  (C, H, W) 형태를 사용, matplotlib는 (H, W, C) 형태를 요구 

결과


3. 참고자료 

https://modulabs.co.kr/blog/test-time-augmentation-computer-vision/ 

 

테스트 타임 증강 (Test Time Augmentation, TTA)

테스트 타임 증강(Test Time Augmentation, TTA)은 데이터 증강 (Data Augmentation)과 유사하게 테스트 이미지에 무작위 변형을 적용하는 방법입니다. 데이터 증강은 성능을 향상시키기 위해 훈련 데이터를

modulabs.co.kr

https://stepup.ai/test_time_data_augmentation/

 

How to Correctly Use Test-Time Data Augmentation to Improve Predictions

Learn how to boost a model's accuracy with test-time data augmentation (TTA). Experiment for yourself with the Colab notebook.

stepup.ai

 

'CV(컴퓨터비전)' 카테고리의 다른 글

🦙 LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions  (0) 2024.12.02
'CV(컴퓨터비전)' 카테고리의 다른 글
  • 🦙 LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions
zzoming_00
zzoming_00
꾸준함🍀
  • zzoming_00
    ZZOMING'S TECH BLOG
    zzoming_00
  • 전체
    오늘
    어제
    • ALL (56)
      • Docker (0)
      • 경진대회 (2)
      • 사이드 프로젝트 (5)
      • NLP(자연어처리) (4)
      • CV(컴퓨터비전) (2)
      • ML&DL (9)
      • Git (4)
      • Python (10)
      • Algorithm (19)
  • 블로그 메뉴

    • 홈
    • 글쓰기
    • 태그
    • 방명록
  • 링크

    • 글쓰기
  • 공지사항

  • 인기 글

  • 태그

    내장함수
    LLM
    챗봇
    참조방식
    고차함수
    알고리즘
    클래스
    파이썬
    object
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
zzoming_00
TTA(Test time Augmentation)
상단으로

티스토리툴바