본문 바로가기

Test-Time Adaptation

SAR : TOWARDS STABLE TEST-TIME ADAPTATION INDYNAMIC WILD WORLD 논문

이 논문은 ICLR'23에 publish된 논문으로 2023.11.29 기준 50회의 citation을 보유하고 있다. (아니 어떻게 벌써? 😮)

논문은 https://openreview.net/forum?id=g2YraF75Tj 에서 찾아볼 수 있고, 코드는 https://github.com/mr-eggplant/SAR 에서 찾아볼 수 있다.

Introduction

이 논문은 기존의 TTA (test-time adaptation) method와 TTT (test-time training) method가 realistic wild setting 에서는 쉽게 fail하는 것을 보이고 이를 막을 수 있는 방법들을 제안한 논문이다. 여기서 저자는 3가지의 wild setup을 고려했는데 👉 (1) Multiple distribution shift, (2) Small batch size, (3) Imbalanced label distribution 아래 figure와 같이 성능하락이 컸다:

😉 그리고 이 논문은 online TTA를 다루는 것도 기억해둘 것!

세가지 wild scenario에서 기존 TTA 성능이 크게 하락하는 것을 확인할 수 있다.

논문의 저자는 크게 두가지 이유로 인해 TTA in wild setup이 실패한다고 했다.

(1) batch norm (BN) layer : 기존의 TTA는 test domain의 batch statistics를 estimate하여 기존의 statistics을 replace하는 방법을 주로 사용하였다. 하지만 wild setup에서는 batch 상태가 이상해서 올바른 batch statistics를 계산할 수 없다.

(2) model collapse : 기존의 entropy minimization 기반의 method는 model collapse가 발생하여 단하나의 class에 대해서만 prediction을 하는 bias가 발생하 수 있다. 논문의 저자는 이 현상이 large gradient norm을 가지는 일부 toxic sample 때문이라고 하고, 이 toxic samples에 잘 대처할 수 있는 여러 methods를 제안했다. 

 

이 논문의 contribution 을 정리하자면 다음과 같다.

(1) 기존의 BN 기반 TTA/TTT 을 BN/GN/LN (Batch/Group/Layer) 으로 확장하여 wild setup에서 성능 변화를 분석하였다.

(2) SAR이라는 알고리즘을 제안 👉 model collapse 에 강인하여 wild-setup 에도 성능 감소가 크지 않음.

Preliminaries

간단하게 딱 TTT랑 TTA만 보고 넘어가자

TTT : training/inference stage 전부 개조해야하는 방식 😣. Training 때 main-task 말고 self-supervised task (e.g., rotation prediction) 을 함께 수행함. 👉 Inference 때 self-supervised task로 모델을 추가적으로 adapt한뒤 inference 수행

TTA : training stage는 건들지 않음. 이 논문에서는 주로 TENT를 고려. TENT는 entropy minimization을 objective function으로 두고 (batch) normalization layer의 parameter를 opmitze 👉 이 블로그에서 설명한게 있으니 (참고!!)

Method - Intuition

본 method로 들어가기 이전에 intutition만 가볍게 얻고 넘어가자. 아래 figure (a)는 model collapse가 발생해서 단하나의 class만 predict 하는 현상이 발생한다. (collapse가 발생하지 않으면,, (b) 처럼 uniform 하게 나와야한다..

🤔 Model collapse가 언제 일어나는걸까? (figure (c))

👉 갑자기 gradient norm의 솓구치다가 바닥으로 내리꽂히는 것을 확인했다.

😮 Gradient norm이 큰 샘플들을 조심해야겠구나!

🤔 그런데 gradient norm이 큰 샘플을 어떻게 찾아내지? (figure(d))

👉 Entropy 가 높은 샘플들이 gradient도 같이 높으니까 filter out하면 되겠다! (근데 그냥 gradient 측정하면 안되나? 🤔)

👉 아래 figure (d) 에 있는 Area1, Area2 의 샘플을 exclude 해보자!

Wild-seupt 에서 기존 TA methods의 failure case analyses

Method - SAR (Sharpness-Aware and Reliable test-time entropy minimization)

Large Entropy Filtering. 말그대로 entropy가 높은 샘플들을 filter out한다.

Large Entropy Filtering

Sharpness-aware Entropy Minimization. 말위의 filtering으로 figure (d) 의 Area1 샘플들은 exclude할 수 있다. 하지만 Area4에 해당하는 샘플들은 여전히 학습에 사용되야한다. 어짜피 filtering 할수가 없으니 이 논문은 large gradient sample로부터 robust하도록 model을 training하는 방법을 대안으로 제시한다. 

Loss (entropy) surface (오른쪽이 좋은거임)

구체적으로 이논문은 loss (entropy) surface를 편평하게 만들 수 있는 loss를 제시하는데, 이렇게 loss surface가 flat해지면 large gradient samples에 대해 강인해진다고 한다. 이를 위해, 이 논문은 Shapness of entropy라는 개념을 정의한다. 

minimization of the sharpness of the entropy

저기 epsilon에 대해 max를 취한 값이 sharpness인데, $\theta$ 기준 어떤 euclidean ball 안에서 가장 큰 entropy 값이다. 결국 min-max 꼴의 bi-level optimization 형태가 되었는데, inner optimization은 Taylor expansion으로 1차 근사하여 문제를 해결한다. (이게 SAM이라는 방법이라 한다. 🤔)

convex를 잘 몰라서 어떻게 저런 solution이 나왔는지는 모르겠으나,,,,, 어찌저찌 dual solution을 찾고 대입해서 최종적으로 equation (5) 와 같은 형태의 objective function을 어게 된다. 

Overall Optimization. Final learning objective는 아래와 같다. Shaprness-aware entropy minimization + entropy-based filtering이 합쳐진 꼴이다. 추가적으로 stability를 위해서 Model Recovery Scheme 이라는 기술을 적용했다고 하는데,  이는 가끔씩 loss (entropy) 값이 moving average 값으로 갑자기 작아졌을때 model parameter를 update 이전으로 되돌리는 기술이다. 위의 Figure (c)에서 알 수 있듯이, 모델이 망가지기 시작하면 entropy가 갑자기 낮아지기 때문이라 한다. 😉

Final training objective

Experimental Results

Empirical Studies of Normalization Layer Effects in TTA. 말기존의 TTA method를 여러 종류의 normalization (BN,GN,LN)  wild setting에서 성능이 어떻게 변화하는지를 관측하였다.

공통적으로 layer norm을 쓰는게 좋고,, batch norm은 batch size에 많이 민감할 뿐만 아니라, BS 가 어느 정도 높아지더라도 성능이 낮다.
avg.adapt는 각각 domain에 따로 adapt시킨뒤 얻은 성능을 평균한 것이고, mix.adapt는 여러 domain을 섞어서 한번에 adapt한 것이다.BN이 mixed 랑 avg 랑 차이가 많이 난다는 것은... domain shift에 더 sensitive하다는 것.. mixed를 못할 수도 있으니까 avg성능이 높은게 좋다.
layer norm이 imbalance dataset에 강인 하고, 특히 BN은 빠르게 골로 간다.
Comparisons with state-of-the-art methods on ImageNet-C (severity level 5) under ONLINE IMBALANCED LABEL SHIFTS
Comparisons with state-of-the-art methods on ImageNet-C (severity level 5) with BATCH SIZE=1
Effects of components in SAR

 

이상 끝!