- 공유 링크 만들기
- X
- 이메일
- 기타 앱
Paper Overview
AAAI'24
Abstract
저자들은 end-to-end generative GZSL framework $D^{3}GZSL$을 소개한다.
이 framework는 seen, 합성된 unseen data를 각각 in-distributino, out-of distribtuion으로 여긴다.
$D^{3}GZSL$는 two core module로 구성되어 있다.
In-distribution dual space distillation ($ID^{2}SD$)
out-of-distribution batch distillation ($O^{2}DBD$)
$ID^{2}SD$은 embedding, label space에 있는 teacher-student 결과를 align하여, 학습 일관성(coherence)를 향상시킨다.
$O^{2}DBD$는 batch 단위의 low-dimensional out-of-distribution representation을 도입하여, seen과 unseen category 사이의 공유된 구조를 감지한다.
Keywords
Generalized Zero-Shot Recognition, Data Distribution Distillation
Introduction
ZSL은 seen class로부터 얻은 knowledge를 transfer하여 모델이 useen class object를 구별하는 능력을 학습하도록 노력한다.
GZSL은 단순히 unseen data만을 구별하는 것이 아닌 한 모델이 seen, unseen을 함께 구별하도록 하는 task다.
unseen class에 대한 sample을 생성함으로써 GZSL은 우리가 익히 아는 supervised laerning problem이 된다.
위 방법을 사용하는 대부분의 논문들은 seen data의 희소성에 영향을 받아, semantic knowledge와 seen data 사이의 correlation에만 기술하는것에 집중하게 된다.
Out-of-distribution (OOD) detection은 data sample을 색인하는데 조점을 맞춘다.
OOD detection의 측면에서 GZSL을 생각하면, deen data는 ID (in-distribution) data이고 반면에, unseen은 OOD data가 된다.
generative approach는 먼저 unseen sample을 만든 후 OOD detector가 이를 감지한다.
그 다음, two expert classifier가 분리되어 학습하여 생성된 unseen과 실제 seen data를 다룬다.
그러나 non-end-to-end training 전략은 seen과 unseen class를 거치는 data distribution 모델링을 방치하게 된다.
Methodology
Problem statement
저자들의 framework는 다음과 같다.
Feature Generation(FG)
저자들은 여러 Generative 모델을 베이스라인 삼아 feature generation 하도록 했다.
$L_{gen}$은 이 generative GZSL method의 loss를 나타낸다.
$G$는 generative model을 뜻하고, $a$는 semantic embedding, $w$는 normal sampling이다.
generated feature를 $x''$라 한다.
In-Distribution Dual-Space Distillation ($ID^{2}SD$)
저자들의 목적은 seen, unseen category 사이를 구별하는 능력을 classifier에 부여하는 것이다.
저자들은 seen data만을 사용하여 신뢰할 수 있는 teacher network를 학습할 수 있다.
저자들은 dual-space distillation의 teacher-student network framework를 적용한다.
distilled knowledge는 feature information뿐만 아니라 data sample의 상호 관계도 포함한다.
저자들은 feature 유사성을 측정하기 위해 batch matrix 내의 sample correlation을 이용한다.
저자들의 teacher network는 embedding function $E_{o}$와 classifier $C_{o}$가 있다.
또 student network는 $E_{s}$와 classifier $C_{s}$가 있는데 $C_{s}$는 seen, unseen을 모두 포함한다.
$\phi$는 teacher network 전체를 나타내고 $\psi$는 student network 전체를 나타낸다.
Batch-Wise ID Embedding Identical Loss
저자들은 sample의 uniformity를 고려한다.
저자들은 batch embedding similarity matrix $A$를 먼저 구성한다.
각 element $a_{ij}$는 cosine similarity로 구한다.
$z$는 embedding feature다.
저자들은 ground-truth는 1 나머지는 0을 타깃으로 한다.
따라서 저자들의 task는 각 elements의 binary classification이 된다.
저자들은 seen category의 sample feature를 배타적으로 이용하기 위해 loss를 만들었다.
저자들의 loss는 다음과 같다.
위 loss식을 보면 class가 겹치는 feature만 유사도가 1이 되고 나머지는 0이 되도록 한다.
cosine similarity가 0이라는 것은 두 feature가 직교관계가 되도록 한다는 것이다.
Instance-Wise ID Logit Identical Loss
이 loss의 목적은 label space에서 studne
t, teacher nwtwork의 output을 align하기 위함이다.
먼저 각 출력 $\tilde{v}, \ddot{v}$을 L2 norm을 건다.
$v_{o} = \frac{\tilde{v}}{\parallel\tilde{v}\parallel_{2}}, v_{s} = \frac{\ddot{v}}{\parallel\ddot{v}\parallel_{2}}$
그 다음 각 network의 확률을 다음과 같이 구한다.
그 다음 다음과 같이 KL divergence를 건다.
Seen에 대한 softmax 확률을 두 모델이 같게 출력하도록 만드는 것이다.
(seen 성능 저하 방지 목적)
Classification Loss
저자들의 student network는 GZSL head기 때문에 unseen도 구분해야 한다.
따라서 GZSL cross-entropy loss는 다음과 같다.
Total Loss of $ID^{2}SD$
Out-of-Distribution Batch Distillation ($O^{2}DBD$)
batch내의 각 sample에 대해 OOD information를 인코딩 하기위해 저자들은 low-dimensional representation을 도입한다.
그다음 저자들은 이 low-dimensional OOD representation사이의 correlatoin을 모델링한다.
이상적으로, input이 unseen data가 입력될때, teacher network는 uniform distribuion을 생성한다.
이 특성을 이용하여 OOD를 검출하고자 한다.
OOD Logits
모델은 각 input sample에 대해 score를 계산하고 이것은 confidence를 나타낸다.
이 방법의 key step은 threshold $\gamma$를 결정하여 ID와 OOD를 구별하는 것이다.
위 방법과 다르게 저자들은 threshold value $\gamma$없이 confidence 추정 접근법을 만들었다.
즉 이 OOD 검출을 backbone과 classifier단에서 어느정도 해내도록 만든다.
저자들은 ID와 OOD 정보를 인코딩하는 low-level representation $\tilde{h}$를 얻는다.
첫번째 element는 ID confidence $c$고 두번째 element는 $\hat{c}$다.
따라서 $\tilde{h}$는 2차원 representation이다.
저자들은 OOD detection score $\tilde{s}$를 다음을 통해 얻는다.
즉 teacher network의 output의 max값을 $\tilde{s}$라고 한다.
저자들은 score $\tilde{s}$를 0~1로 보내기 위해 learnable sigmoid function $\epsilon$를 이용한다.
$\epsilon = \frac{1}{1-e^{-\alpha(x-\beta)}}$
$c = \epsilon(\tilde{s})$
$\hat{c} = 1-c$다.
저자들은 mapping function $H$를 student network에 달아
softmax 확률을 OOD representation embedding space로 매핑한다.
Batch-Wise OOD Logits Identical Loss
저자들은 식(3)을 이용하여 위 loss-dimensional space를 학습한다.
따라서 loss도 다음과 같다.
이전과 다르게 이 loss는 seen, generated unseen이 모두 포함된 data로 학습한다.
본 loss의 의미는 $C_s$자체가 seen, unseen의 전체적인 범주를 구분하도록 돕는다.
Model Optimization
Training
저자들은먼저 teacher model을 real seen sample로 먼저 학습한다.
위 teacher model을 학습하는 것을 OOD detection model라 한다.
학습이후 FG, $ID^{2}SD$, $O^{2}DBD$를 end-to-end로 학습한다.
그다음 teacher network로부터 OOD confidence label을 계산한 뒤, student network의 softmax probability를 OOD representation embedding space로 매핑한다.
Inference
Experiments
Datasets
Evaluation Protocols
Implementation Details
Comparisons with Previous Method
Ablation Study
Training Strategy Analysis
TS는 two-stage framework를 의미하고, IV-TS는 OOD모듈이 완벽하게 seen, unseen을 구분해내는 경우를 의미한다.
Component Analysis
OOD Scoring Strategy Analysis
Conclusion
In this paper, we introduce a generative GZSL framework(D3GZSL) that combines OOD detection and knowledge distillation technologies. Our D3GZSL leverages the OOD detection model to distill the student model, effectively aligning the distribution of generated samples more closely with the distribution of actual samples. By training a unified classifier as the final GZSL classifier, our framework addresses the issue of accumulated error stemming from two-stage classification in previous ZSL methods based on OOD detection. Empirical validation through comprehensive experiments demonstrates that our hybrid D3GZSL framework consistently enhances the performance of existing generative GZSL approaches. This novel approach not only leverages cutting-edge techniques but also addresses the limitations of traditional ZSL methodologies, propelling the field towards more precise and reliable zero-shot learning outcomes.























댓글
댓글 쓰기