본문 바로가기

Test-Time Adaptation

SHOT: Do We Really Need to Access the Source Data? Source Hypothesis Transfer forUnsupervised Domain Adaptation 논문 공부

이 논문은 ICML'20에 publish된 논문으로 2023.11.27 기준 858회의 citation을 보유하고 있다. (와우! 😮)

논문은 https://dl.acm.org/doi/abs/10.5555/3524938.3525498 에서 찾아볼 수 있고, 여러모로 TTA의 시작을 연 기념비적인 논문인 듯 하다.

 

Introduction

이 논문은 unsupervised domain adapatation (UDA)를 푸는 논문인데, 구체적으로 source data를 전혀 사용하지 않고 오직 source model만 사용하는 test-time UDA 를 푼다. 기존의 UDA 논문 중 대부분은 source data를 사용하여 MMD 를 줄이는 등의 distribution alignment가 많이 이루어졌는데, 이 방법론들은 source data를 활용하는 과정에서 privacy를 침해하거나 (특히 FL 세팅에서) 높은 communication cost가 발생한다고 한다.

 

이 논문은 Neural network이 (1) feature encoding module, (2) hypothesis (classifier) 두가지 part로 구성되어 있다고 한다. 이 논문은 classifier는 freeze하고 target data의 feature distribution이 source datad의 경우와 비슷해지도록 오직 feature encoding module을 optimize 하는 방법을 제시한다. 구체적으로, 두가지 방법는데 (1) Information Maximization (IM) loss를 사용하고, (2) self-supervised pseudo labeling 을 사용해서 feature encoding module을 학습시킨다. 이와 별개로 label smoothing, weight normalization, batch normalization을 orthogonal하게 적용하였고, 성능 향상을 확인하다고 한다. Partial-set, open-set 세팅에서도 잘 작동한다고 한다.

 

Method (1) -  Source Model Generation

Target domain에 adapt시킬 source model을 훈련하는 과정이다. 이 논문은 classification을 다루기 때문에 단순한 CE 를 사용해서 model을 train한다. 여기에 label smoothing (LS) 까지 적용한 CE를 사용했다고 하는데, LS를 적용할 경우 CE의 target이 binary one-hot vector에서 👉 one-hot + uniform 을 적절히 섞은 soft label로 바뀐다. 자세한건 아래 이미지 참고.

Source model을 학습하기 위한 Loss function

Method (2) -  Source Hypothesis Transfer with Infomration Maximization (SHOT-IM)

이제 여기서부터는 pre-trained source model을 unseen target domain에 adapt (in an unspervised way) 하는 과정을 다룬다. 별로 복잡한건 없고, classifier를 고정시킨 뒤 feature extractor를 다음 두가지 loss에 따라 학습시킨다고 한다. method가 간단하니 논리의 흐름을 이해하는 것이 중요하다 (따라 읽으며 생각해보자) :

🤔 어떻게 해야 target domain에 잘 adapt할 수 있을까?

👉 target feature가 source feature 사이의 domain gap을 줄이면 되겠다

🤔 어떻게 해야 domain gap을 줄일 수 있을까?

👉 target feature가 source feature 가 비슷한 특성을 갖도록 하면 되겠다.

🤔 source feature는 individually certain (각각에 대해 확신) + globally diverse (전체적으로 not biased) 한 특성을 가진다.

👉 target feature도 이렇게 되도록 loss를 주면 되겠다.

이렇게 해서 나온 Loss fuctions은 아래와 같다. (individually certain + globally diverse) 

 

(1) Entropty minimization $L_{ent}(f_t;X_t)$: entropy가 uniform일때 maximum이란걸 고려했을때, target data에 대한 prediction이 one-hot vector와 비슷하도록 만들어 주는 loss

(2) Feature diversification $L_{div}(f_t;X_t)$: target domain에 대한 mean prediction이 어느 한 class에 치우치지 않고 전체적으로 분포하도록 만들어 주는 loss. $L_{ent}(f_t;X_t)$ 만 사용할 경우 한 class에 치우쳐지기 쉽다고 한다. (그런데 이건 I.I.D setting이 아니면 쓸모 없는 loss가 아닌가? 🤔)

Adaptation을 위한 Information Maximization (IM) losses

Method (3) -  Source Hypothesis Transfer Augmented with Self-supervised Pseudo-labeling

이름을 보면 알겠지만, pseudo-label을 만드는 방법에 대한 내용이고 그냥 Kmeans 로 매우 단순하게 구한다. (2020년에는 이런 알고리즘으로도 ICML을 갈 수 있었나보다) Motivation을 위해 아래 Figure를 보자.  SHOT-IM을 적용해서 source와 target의 distribution을 어느정도 align하는데에는 성공했지만, 일부 false positive들이 관찰된다. 👉 이를 해소하기 위해서 거리 기반 pseudo-label을 만들어서 사용하겠다는 뜻이다.

Pseudo-labeling을 하는 이유

Pseudo-label을 구하는 방법 : target data에 대해서 Kmeans clustering을 여러번 돌린다. 이때 최초의 centroid는 random point로 구하는게 아니라 model prediction 값을 사용한 weighted average 값으로 구한다. (아래 equation 참고)

First iter.에 centroid 구하는 법

그 다음부터는 익히 잘 알고 있는 Kmeans clustering 알고리즘을 통해 centroid와 각 샘플의 probability를 update한다.

Pseudo-label 구하는 방법

원래 Kmeans는 여러번의 iteration을 거치는게 보통이지만, 논문 저자 왈, 한번만 iteration 해도 잘 작동한다고 한다. 위와 같이 계산된 pseudo-label $\hat{y_t}$는 아래와 같이 IM loss와 결합되어 source model을 adapt하는데 사용된다.

SHOT의 최종 loss

Method (4) -  Network Architecture of Source Model

자잘하지만 중요한 학습 디테일에 관한 내용이다. 

(1) Weight normalization (WN) : FC layer의 weight의 weight를 normalize 시켜서 feature의 크기가 아닌 방향에 집중한다.

(2) Batch normalization (BN) : Source feature과 Target feature를 통일 시켜서 domain shift를 완화한다고 한다.

 

이상 끝! 😁