본문 바로가기

Test-Time Adaptation

CoTTA: Continual Test-Time Domain Adapation 논문 공부

이 논문 (CoTTA)은 CVPR'22에 publish된 논문이며 2023.10.26 기준 146회의 citation을 기록하고 있다.

논문은 https://arxiv.org/abs/2203.13591에 있고, 코드는 https://qin.ee/cotta.에서 확인할 수 있다.

 

Introduction

이 논문은 continual TTA라는 전례없는 새로운 setting을 다루고 있다. 더 깊이 들어가기 전에 이 setting을 확실히 짚고 넘가자. 이 논문 세팅의 정확한 이름은 online continual test-time adaptation 이다.

online은 test data 를 한번에 쌓아놓고 모델을 training하는 방식이 아니라 inference - model evaluation을 번갈아 진행한다는 뜻이다.

continual이라는건 보통 class 개수가 늘어날때 쓰는 경우가 많은데, 여기서는 딱히 그렇진않고,, 왜 붙었는지 모르겠다. 
(이 논문은 CIFAR10 --> CIFAF10C, CIFAR100 --> CIFAR100C, Cityscape --> ACDC를 해서 class 는 source = target이다)

Test-time adaptation이라는 것은 보통 source와 다른 새로운 unseen target domain에 모델을 test하는 동안 adapt한다는 것이다. 특히, 이 논문에서는 source data를 전혀 사용할 수 없고, pretrained model과 unlabeled target data만 주어진다.

Online continual test-time adaptation

이 논문은 기존에 존재하던 TTA 계열 알고리즘이 다음과 같은 단점이 있다고 말한다.

[1] 기존의 self-training이나 entropy maximinzation 기반의 방식은 continuous distribution shift에 취약하고, error accumulation 현상이 발생하여 성능이 하락한다.

[2] 모델이 계속 새로운 domain에 적응하다보면 source domain의 정보를 까먹는 catestrophic forgetting이 발생한다.

[3] 아무런 off-the-shelf pretrained model에 적용할 수 없다. (흠,, TENT 같은건 바로 가능할 것 같은데..)

 

이 논문 (CoTTA)가 위 문제를 해결하는 방법은 다음과 같다.

[1] weight-averaged teacher의 output을 pseudo-label로 사용하여 self-training 정확도 향상

[2] teacher를 믿을 수 없을때 augmented-averaged prediction을 pseudo-label로 사용하여 안정성 향상

[3] model parameter일부를 source domain에 pretrained 된 모델로 stochastically restore 함으로써 forgetting 방지

* 특히 CoTTA는 아무 off-the-shelf pretrained model에 적용할 수 있는게 큰 장점이라고 한다.

Methodology

Overview of the proposed Continual test-time adaptation (CoTTA)

Weight-Averaged Pseudo-Labels

CoTTA의 작동방식은 아주 간단하다. 큰 맥락에서 보았을때, CoTTA는 Self-training 기반의 TTA 알고리즘인데 pseudo label을 teacher model을 사용해서 구한다.

CoTTA의 objective function

위 equation에서 (') 가 있는 prediction이 teacher model의 prediction인데, student model의 prediction을 teacher model의 prediction으로 supervise하고 있는것을 볼 수 있다. 이 teacher model은 실제로 업데이트가 되고 있는 student model의 moving average으로 이미 다양한 work에서 error accumulation에 강인한 reliable prediction이라는 특성이 있다고 한다.

Teacher model은 moving average이다.

Augmentation-Averaged Pseudo-Labels

CoTTA는 teacher model을 믿을 수 없을 상황을 대비하여 한가지 보험을 준비해놓았다. 만약 teacher model이 아주 낮은 confidence를 가지고 있는 상황을 가정해보자. 저자는 confidence가 낮은 sample은 source domain과 large domain difference gap이 있는 sample이라고 한다. 다시말해, confidence가 낮은 prediction은 supervisory signal로 쓰기에 안좋다는 뜻이다. 그러면 어쩌면 좋을까? 🤔. CoTTA는 이런 예외사항에 input에 대해서 여러 augmentation을 가한 뒤 teacher model에 통과시키고, 그 결과로 얻은 여러 prediction의 평균을 pseudo-label로 만든다. 

아무래도 "augmentation-average를 하면 더 믿을 수 있는 값이 나오겠지?" 라는 intution에 기반한 것 같고, 그렇다고 해서 모든 샘플에 대해서 전부 augmentation-average를 하면 너무 오래걸리기 때문에 이렇게 타협을 본 것 같다.

If needed, pseudo-label is replaced by augmentation-average

Stochastic Restoration

마지막으로 CoTTA는 catestrophic forgetting을 방지하기 위해 stochastic restoration이라는 기법을 제안하였다. 이는 간단하게 모델 paramter의 일부를 확률적으로 pretrained model의 paramter로 되돌려버리는건데, Dropout 처럼 베르누이 분포에 따라 확률적으로 restoration을 수행한다고 한다. 

Stochastic Restoration

TENT 같은 알고리즘에 비해 CoTTA는 모든 model parameter를 update할 수 있는게 장점이라고 한다. (단점 아닌가? 🤔)

Experiments

그렇다면 continual TTA라는 setting은 어떻게 구현이 되었을까? 

기존에 있던 corruption data에 여러 종류의 corruption을 돌아가면서 적용하거나, multi domain dataset의 domain을 돌아가며 바꾸는 방법으로 구현했다고 한다.

예를 들어 CIFAR10-C, CIFAR100-C 는 15개의 corruption type을 가지고 있는데 test-time adaptation을 수행하는 동안 일정 corruption sample을 사용하여 학습을 진행한 뒤 다른 corruption type으로 넘어가는 시나리오를 상상할 수 있겠다.

CoTTA는 segmentation 에서도 실험을 진행하였는데, Cityscape에서 pretrain된 모델을 Adverse Conitions Dataset (ACDC) 의 여러 domain을 바꿔가며 TTA를 수행했다고 한다. 그 결과 아래처럼 TENT, BN, Pseudo-label 보다 좋은 성능을 보였다고 하기는 하는데,,, baseline들이 너무 outdated 하고 naive 한 것들이라 좀 께림칙하기는 했다.. 🤔

Classification error rate (%) for the standard CIFAR10-to-CIFAR10C