Zhangzhe's Blog

The projection of my life.

0%

URL

https://arxiv.org/pdf/2012.12877.pdf

TL;DR

  • 基于 ViT,但解决了 ViT 依赖超大数据集预训练的问题,与 ConvNet 同样在 imageNet 数据集上训练,可以达到 SOTA
  • 提出一种基于 distillation token 的蒸馏

Algorithm

transformer (ViT)

  • Multi-head self attention layer (MSA):
    headi=Attention(QWiQ,KWiK,VWiV)=softmax[QWiQ(KWiK)Tdk]VWiVhead_i=Attention(QW_i^Q,KW_i^K,VW_i^V)=softmax[\frac{QW_i^Q(KW_i^K)^T}{\sqrt{d_k}}]VW_i^V
  • Transformer block: FFN (2 × FC + bottleneck + GeLU + LayerNorm) + MSA + Residual
  • Class token: 在 patch 维度 concat 一维 P x P (P 表示 patch_size),并将这个维度作为输出,其他 patch 维度丢弃
  • Interpolate position embeding: 当输入分辨率变化时,直接对 embeding 插值

distillation

  • 常用 Soft distillation:
    L=(1λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))L=(1-\lambda)L_{CE}(\psi(Z_s),y)+\lambda\tau^2KL(\psi(Z_s/\tau),\psi(Z_t/\tau))
    其中: yy 表示 GT labelZs , ZtZ_s\ ,\ Z_t 表示 logits of student model and teacher modelψ\psi 表示 softmaxτ\tau 表示蒸馏温度, LCE,KLL_{CE}, KL 分表表示交叉熵与 KL 散度
  • Hard distillation:
    L=12LCE(ψ(Zs),y)+12LCE(ψ(Zs),yt),   yt=argmax(Zt)L=\frac{1}{2}L_{CE}(\psi(Z_s),y)+\frac{1}{2}L_{CE}(\psi(Z_s),y_t),\ \ \ y_t = argmax(Z_t)
  • Hard distillation + label smooth 解决 data augmentation 中 crop 导致的图像与 label 不对应的问题
  • Distillation token:类似 Class token,在 patch 维度 concat 一维 P x P,与 Class token 一起输出计算 loss 与 inference
  • Joint classifier: pred=argmax(ψ(Cs)+ψ(Ds))pred=argmax(\psi(C_s)+\psi(D_s)) ,其中 Cs , DsC_s\ ,\ D_s 分别表示 logits of class token and distillation token

other

  • 很多很多的训练技巧

Thoughts

  • 训练技巧很重要。从 源码 角度看,本文能解决 ViT 依赖超大数据集预训练的问题,主要原因是训练技巧强大
  • 本文提出的关于蒸馏方法的理解可以作为蒸馏 Transformer 的指导

URL

https://arxiv.org/pdf/2008.01232.pdf

TL;DR

  • 本文将 bert 模型结构用于多帧动作识别网络的末尾的时间信息融合部分,在 HMDB51 和 UCF101 两个 Action Recognition 数据集上目前仍是 SOTA

Algorithm

一句话总结本文的主要工作:SOTA - TGAP + BERT = NEW SOTA

之前 Action Recognition 常用的网络结构

1. 3D Conv + TGAP

  • 将连续多帧视频一起送入网络,使用 3D Conv 或 C(2 + 1)D 降维时间与空间,升维 Channel
  • 使用 TGAP (temporal global average pooling ) (torch.nn.AdaptiveAvgPool3d) 对时间空间一起全局平均池化到一个 scalar,然后 Channel 维做 FC 分类

2. 3D Conv + GAP + LSTM

  • backbone 部分与 1 相似
  • 对时空 feature map 使用 GAP,保留时间维度的特征,使用 LSTM 等结构处理时间序列,输出 FC 分类

3. 基于 2D Conv + 时序 等

本文网络结构

  • 本文认为 TGAP 会丢失很多时序信息,GAP + LSTM 效果也不好
  • 在末尾使用 GAP + BERT 是一个较好的选择,并只对 Transformer 的 ClassToken 监督

对 Transformer 一个有趣的解释

Transformer 的数学表达式: yi=PFFN(1N(x)jg(xi)f(xi,xj))y_i=PFFN(\frac{1}{N(x)}\sum_{\forall{j}}g(x_i)f(x_i,x_j))
其中:

  • PFFN: Position-wise Feed-forward Networ
  • f(xi,xj)=softmaxj(θ(xi)Tϕ(xj))f(x_i,x_j)=softmax_j(\theta(x_i)^T\phi(x_j)),其中 g,ϕ,θg,\phi,\theta 都是 projection function (FC)
  • 如果 g,ϕ,θg,\phi,\theta 都变成 1 × 1 × 1 Conv,那 Transformer 就变成了 non-local ,所以用 BERT 处理图像序列就非常合理了…
    non-local

baseline

对 R(2 + 1)D 网络的改进

对 SlowFastNet 的改进

  • BERT 的后融合实现: SlowFastNet 的两路序列各自经过 BERT 再 Concat 比 Concat 后再 BERT 效果好…
    bert2

对比实验

  • 作者做了非常完善的对比实验,包括是否使用光流信息,是否在backbone尾部降维,Transformer 用几层几个 head 等,详细见 paper

Thoughts

  • 关于 BERT 与 Non-local 的关系还是挺有趣的