우선 상속관계는 다음과 같다. (보라색이 custom model), 순서대로 분석해보자.
DARTHQDTrack > KDQDTrack > TeacherQDTrack > QDTrack > BaseMultiObjectTracker > BaseMultiObjectTracker > BaseModule > nn.Module...
이중 DARTHQDTrack, KDQTrack은 darth/models/adapters/base 폴더 안에 들어있다.
0. DARTH 복습하기
DARTH는 여러 view 사이에 contrastive learning을 수행한다. 👉 좀더 자세한 설명은 이전 포스팅 참고 (링크)
구조적으로는 Faster-RCNN 기반의 QDTrack을 base model로 사용하고 3개의 view (teacher, student, contrastive)가 존재한다. Loss는 (1) Patch Contrastive Loss (PCL) 과 (2) Detection Consistency (DC) loss가 존재한다.
(위 그림만 본다면..) PCL loss는 student vs contrastive 사이에 구해지는 듯 하고,,, DC loss는 teacher vs student 사이에 구해지는 듯 하다.. PCL loss는 infoNCE 형태의 loss를 가지고 있고, DC loss는 L1 norm 형태로 정의된다.
DC loss는 RPN feature (중간 feature) 와 RoI feature (최종 output) 두가지에 대해서 계산된다.
1. DARTHQDTrack 코드 분석하기
생성자 (__init__) 함수 : 생성자에서 teacher, student, contrastive view에 대한 data augmentation pipeline을 인자로 받는다. 뭔지 모르겠지만 bbox filtering을 위한 confidence threshold (0.7) 도 인자로 받는다.
forward_train 함수 : 모델 학습을 위한 학습. 1차적으로 constrasitve learning을 위해서 input 이미지를 여러 view로 transform한다. 아래와 같이 teacher / student / contrastive view가 있으며 뒤로 갈수록 transform의 강도가 강해진다. Unsupervised learning이기 때문에 teacher view의 예측을 pseudo label로 사용하는데, 아래 그림을 보면 정확하지 않은 pred도 있다.
* 메모리 절약을 위해 pseudo label 에는 pytorch graph가 연결되지 않는다.
* 아래와 같이 Teacher view와 Student view는 색깔만 좀 다르지 geometrically 하게는 다르지 않다!
(pixel-wise consistency loss를 줘도 아무 문제가 없다는 뜻..!)
◾ Base Tracker QDTrack은 backbone으로 ResNet50을 사용하고 5개의 multi-scale 중간 feature 를 가진다.
◾ Teacher model과 student model의 backbone이 (당연히) 다르며, 각각을 통과한 features를 t_x, s_x 라고 부른다.
◽ Teacher model은 Teacher view를 input으로, Student model은 Student view를 input으로 받는다.
◽ Contrastive view는 Student model로 들어가서 feature c_x 가 된다.
◾ t_x, s_x를 각 Teacher model, Student model (둘다 FasterRCNN) 의 RPN에 통과시켜 proposal (t_proposal_list, s_proposal_list) 를 얻는다. 각 proposal_list는 이미지당 1000개의 proposal 정보를 가지고 있는데 bbox의 위치 정보 (4개의 값) 와 confidence score (0~1) 을 가지고 있다.
◾ c_x도 RPN을 통과하여 c_proposal_list가 계산된다. 이때 img_meta 정보도 함께 들어가는데,, 이 정보를 사용하여 transformation에 의해 geometric distortion이 일어났던것을 realign 할 수 있는 것으로 보인다.
◽ Contrasitve view에 대한 prediction (c_x, c_proposal_list) 과 Student view에 대한 prediction (s_x, s_proposal_list) 가 self.track_head (QuasiDenseTrackHead) 의 forward_train 함수를 통과하여 PCL loss가 계산된다.
◽ config 파일을 보면 QDTrack 모델을 만들때 track_head 에 QuasiDenseTrackHead가 지정되어 있다.
◾ DC loss를 계산하기 위해서 t_x, s_x 각각을 get_rpn_displacement 함수에 통과시킨다. 이 함수는 TeacherQDTrack 모델에 정의되어 있는데 별게 아니고 그냥 RPN output을 뱉는 함수다.
백문이불여일견... rpn_outs의 cls / reg output의 크기가 각각 3/12 인 이유는 (원래 1/4 여야함) 오른쪽 위에 나와있듯이 Anchohead에 3개 모양의 anchors가 정의되어 있기 때문이다.
◾ RPN_DC loss와 ROI_DC loss는 각각 RPNDistillationLoss와 ROIDistillationLoss를 이용하여 계산된다. MMlab을 바탕으로하다보니 Loss가 클래스 형태로 정의되어 있고, 코드들은 darth/models/losses 안에 저장되어 있다.
(각 loss에 대한 자세한 분석은 다른 포스팅에서 다루도록 하자..)
◾ RPN_DC loss, ROI_DC loss, PCL loss 는 losses 라는 dict안에 "이름:값"의 형태로 저장되어 리턴된다.
◽ 이건 mmcv에서 공통적으로 losses를 처리하는 방식이다.
2. Epoch_based_runner 에서의 loss 처리
tools/run/train.py는 darth/apis/train.py의 train_model 함수를 호출하고, 이는 또 runner (epoch_based_runner)의 run함수를 호출한다. train_model 함수는 (1) dataloade의 선언, (2) Multi-GPU 학습을 위한 모델 wrapping, (3) optimizer 선언, (4) 여러 hook 등록, (5) runner의 선언 및 runner.run의 호출을 통한 모델의 학습 시작.... 을 담당한다.
위 그림을 보면 첫째로 mode를 고르고 다음으로 모델을 가동한다.
◾ 배치별로 반복문이 돌아간다.
◾ "before_train_iter" 에서 loss를 계산하는데,, 이때 모든 loss들은 dictionary의 형태로 Runner의 멤버 변수 self.ouput에 저장된다 👉 이 값은 "after_train_iter"에서 모델 (이것도 멤버번수임) 을 업데이트하는데 사용된다.
흠.. 이제 대충 코드가 어떻게 돌아가는지는 알겠고,,, 저번에 실패했던 성능 reproduce를 다시 해보자..!
DARTH는 여러 TTA 상황에서 human category만 고려하는 방법을 사용했다고 하는데,, 이것도 눈여겨 봐야할 듯 하다..
이상 끝! 💪😁🤜
'Object Tracking 연구' 카테고리의 다른 글
DARTH 코드 분석하기 (3) Test 코드 분석 및 DARTH 성능 reproduce (0) | 2024.01.11 |
---|---|
DARTH 코드 분석하기 (1) Train/test.py, config 분석 (0) | 2024.01.07 |
[디버깅] DARTH : CostumOptimizerHook is not in the hook registry (0) | 2024.01.07 |
MMCV에서 MODEL, DATASET을 어떻게 build 하는가? (0) | 2024.01.07 |
MMCV에서 Registry를 어떻게 만드는가? (0) | 2023.12.29 |