Zhangzhe's Blog

The projection of my life.

0%

Training data-efficient image transformers & distillation through attention

URL

https://arxiv.org/pdf/2012.12877.pdf

TL;DR

  • 基于 ViT,但解决了 ViT 依赖超大数据集预训练的问题,与 ConvNet 同样在 imageNet 数据集上训练,可以达到 SOTA
  • 提出一种基于 distillation token 的蒸馏

Algorithm

transformer (ViT)

  • Multi-head self attention layer (MSA):
    headi=Attention(QWiQ,KWiK,VWiV)=softmax[QWiQ(KWiK)Tdk]VWiVhead_i=Attention(QW_i^Q,KW_i^K,VW_i^V)=softmax[\frac{QW_i^Q(KW_i^K)^T}{\sqrt{d_k}}]VW_i^V
  • Transformer block: FFN (2 × FC + bottleneck + GeLU + LayerNorm) + MSA + Residual
  • Class token: 在 patch 维度 concat 一维 P x P (P 表示 patch_size),并将这个维度作为输出,其他 patch 维度丢弃
  • Interpolate position embeding: 当输入分辨率变化时,直接对 embeding 插值

distillation

  • 常用 Soft distillation:
    L=(1λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))L=(1-\lambda)L_{CE}(\psi(Z_s),y)+\lambda\tau^2KL(\psi(Z_s/\tau),\psi(Z_t/\tau))
    其中: yy 表示 GT labelZs , ZtZ_s\ ,\ Z_t 表示 logits of student model and teacher modelψ\psi 表示 softmaxτ\tau 表示蒸馏温度, LCE,KLL_{CE}, KL 分表表示交叉熵与 KL 散度

  • Hard distillation:
    L=12LCE(ψ(Zs),y)+12LCE(ψ(Zs),yt),   yt=argmax(Zt)L=\frac{1}{2}L_{CE}(\psi(Z_s),y)+\frac{1}{2}L_{CE}(\psi(Z_s),y_t),\ \ \ y_t = argmax(Z_t)

  • Hard distillation + label smooth 解决 data augmentation 中 crop 导致的图像与 label 不对应的问题

  • Distillation token:类似 Class token,在 patch 维度 concat 一维 P x P,与 Class token 一起输出计算 loss 与 inference

  • Joint classifier: pred=argmax(ψ(Cs)+ψ(Ds))pred=argmax(\psi(C_s)+\psi(D_s)) ,其中 Cs , DsC_s\ ,\ D_s 分别表示 logits of class token and distillation token

other

  • 很多很多的训练技巧

Thoughts

  • 训练技巧很重要。从 源码 角度看,本文能解决 ViT 依赖超大数据集预训练的问题,主要原因是训练技巧强大
  • 本文提出的关于蒸馏方法的理解可以作为蒸馏 Transformer 的指导