Zhangzhe's Blog

The projection of my life.

0%

Deformable DETR: Deformable Transformers for End-to-end Object Detection

URL

TL;DR

  • 提出了 Deformable DETR:这是一种新的目标检测模型,解决了现有 DETR 模型的收敛速度慢和特征空间分辨率有限的问题。
  • 使用可变形的注意力模块:这些模块只关注参考点周围的一小部分关键采样点,从而在更少的训练周期内提高了性能,尤其是对小对象的检测。
  • 结合了可变形卷积的稀疏空间采样和 Transformer 的关系建模能力:这使得模型能够在处理大规模数据时保持高效,同时还能捕捉到复杂的上下文关系。
  • 引入了一种两阶段的变体:在这个变体中,区域提议由 Deformable DETR 生成,然后进行迭代的细化。这使得模型能够更精确地定位和识别目标。

Algorithm

deformable_detr.png

Deformable DETR 整体结构图

Deformabel Attention Block

deformable_attention.png

  • Multi-Head Attention:
    MultiHeadAtten(zq,x)=m=1MWm[kΩkAmqkWmxk]MultiHeadAtten(z_q, x) = \sum_{m=1}^MW_m[\sum_{k\in\Omega_k}A_{mqk}\cdot W_m'x_k]
    • 输入为一个 query 的表征 zqz_q ,以及总特征 x,输出为 query 查询结果向量

    • M 表示 number of head

    • AmqkA_{mqk} 表示 softmax(QKTd)softmax(\frac{QK^T}{\sqrt{d}})

    • WmxkW_m'x_k 实际上就是 self-attention 中的 VV

  • Deformable Attention:
    DeformableAtten(zq,pq,x)=m=1MWm[k=1KAmqkWmx(pq+Δpmqk)]DeformableAtten(z_q,p_q,x) = \sum_{m=1}^MW_m[\sum_{k=1}^KA_{mqk}\cdot W_m'x(p_q + \Delta p_{mqk})]
    • 输入为一个 query 的表征 zqz_q ,总特征 x,以及 query 对应的 预设采样位置,输出为 query 查询结果向量
    • Δpmqk\Delta p_{mqk} 表示由 zqz_q 计算得到的 基于预设查询位置的横纵偏移
    • Amqk=softmax(zqWa)  ,WaRdim×num_points  ,zqRdimA_{mqk} = softmax(z_qW_a)\ \ ,W_a\in\mathbb{R}^{dim\times num\_points}\ \ ,z_q\in\mathbb{R}^{dim} ,即 point position attention 是由 query 线性映射得到的 ,因此 Deformable Attention 没有 Key 的存在,只有 QueryValue
    • K 表示 number of points,即采样点个数
  • Multi-Scale Deformable Attention
    MSDeformableAtten(zq,p^q,{x}l=1L)=m=1MWm[l=1Lk=1KAmlqkWmxl(ϕl(p^q)+Δpmlqk)]MSDeformableAtten(z_q,\hat{p}_q,\{x\}_{l=1}^L) = \sum_{m=1}^MW_m[\sum_{l=1}^L\sum_{k=1}^KA_{mlqk}\cdot W_m'x^l(\phi_l(\hat{p}_q) + \Delta p_{mlqk})]
    • Deformable Attention 不同的是,输入的 x 变成了多尺度特征(例如 backbone 不同深度的特征),更贴近实际视觉工程化应用场景
    • point 采样范围是所有 levelfeature map,即 MSDefromableAttention 有全局 attention 信息
  • Deformable AttentionSelf Attention 对比
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn.functional as F
import math
def attention(query, key, value):
# query, key, value shapes: (batch_size, sequence_length, embedding_dim)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
# scores shape: (batch_size, sequence_length, sequence_length)
probs = F.softmax(scores, dim=-1)
# probs shape: (batch_size, sequence_length, sequence_length)
output = torch.matmul(probs, value)
# output shape: (batch_size, sequence_length, embedding_dim)
return output
def deformable_attention(query, value, reference_points, num_sampling_points, atten_linear, offset_linear):
# query shape: (batch_size, sequence_length, embedding_dim)
# value shape: (batch_size, sequence_length, embedding_dim)
# reference_points shape: (batch_size, sequence_length, 2)
# num_sampling_points: integer, number of points to sample around each reference point

batch_size, seq_len, embed_dim = query.size()

# Calculate offsets
# offset_linear is a linear layer that predicts the offsets
offsets = offset_linear(reference_points).view(batch_size, seq_len, num_sampling_points, 2)
# offsets shape: (batch_size, sequence_length, num_sampling_points, 2)

# Calculate sampling positions based on reference points
sampling_positions = reference_points.unsqueeze(2) + offsets
# sampling_positions shape: (batch_size, sequence_length, num_sampling_points, 2)

# Sample values (this is simplified; you might need interpolation)
# Here, we assume value and reference_points are in the same space for simplicity
sampling_values = value.gather(1, sampling_positions.long())
# sampling_values shape: (batch_size, sequence_length, num_sampling_points, embedding_dim)

# Calculate scores
# atten_linear is a linear layer that transforms the query for calculating attention scores
scores = atten_linear(query).view(batch_size, seq_len, num_sampling_points)
# scores shape: (batch_size, sequence_length, num_sampling_points)

# Softmax to get attention probabilities
probs = F.softmax(scores, dim=-1)
# probs shape: (batch_size, sequence_length, num_sampling_points)

# Calculate output
output = torch.matmul(probs, sampling_values)
# output shape: (batch_size, sequence_length, embedding_dim)

return output

Thought

  • query 线性映射代替 querykey 外积做 attention 数学上可解释性会变差,计算复杂度会降低
  • Deformable Conv 是典型的对 NPU 不友好,Deformable Attention 会更复杂,被代季峰支配的恐惧
  • Multi-scale 做各特征尺度上的信息融合,开创了一个 CNN 做 backbone + Deformable Transformer 做 head 的计算机视觉任务模型新范式,甚至省去了 FPN
  • 总之是用各种便宜的计算来近似复杂的全局 attention,复杂度从 H*W --> K,即 O(n2)>O(K)O(n^2) -> O(K)