Zhangzhe's Blog

The projection of my life.

0%

MOTR: End-to-End Multiple-Object Tracking with Transformer

URL

TL;DR

  • 提出了一个完全端到端的多目标跟踪框架

  • 将多目标跟踪问题形式化为一组序列预测问题

  • 引入了跟踪感知的标签分配

  • 提出了用于时间建模的集体平均损失和时间聚合网络方法

Algorithm

MOTR 整体流程

MOTR.png

  1. 特征提取:用 CNN backbone 提取连续帧中每一帧的特征(上图中的 Enc

  2. 查询生成:用 Deformable Transformer 对第一步提取的特征进行查询(上图中的 Dec

    • 对于视频第一帧,只解码 object detection query (上图中的 qdq_d )得到 hidden state

    • 对于非第一帧,将 object detection query (上图中的 qdq_d )和上一帧的 tracking query (上图中的 qtrq_{tr} )先 concat 再进行解码得到 hidden state

  3. 预测结果生成:用一个简单的结构将上一步得到的 hidden state 映射到任务空间,预测结果包含 object detection resultstracking results

  4. 得到下一帧的 tracking query:用 QIM (Query Interaction Module, 查询交互模块) 将上一步得到的预测结果映射为下一帧的 tracking query

  5. 计算损失 / 输出预测结果:对于训练,计算集体平均损失(CAL, Collective Average Loss);对于预测,直接输出第 3 步得到的结果

  • 描述 MOTR 过程的伪代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def process_frame(frame, detect_queries, track_queries=None, ground_truths=None):
# 使用CNN提取帧特征
# frame shape: (height, width, channels)
frame_features = extract_frame_features(frame) # Shape: (height, width, channels)

if track_queries is None:
# 使用Deformable DETR解码器生成隐藏状态
# detect_queries shape: (num_queries, query_dim)
# frame_features shape: (height, width, channels)
hidden_states = deformable_detr_decoder(detect_queries, frame_features) # Shape: (num_queries, hidden_dim)
else:
queries = concatenate(track_queries, detect_queries) # Shape: (num_queries + num_tracks, query_dim)
hidden_states = deformable_detr_decoder(queries, frame_features) # Shape: (num_queries + num_tracks, hidden_dim)

# 生成预测
# hidden_states shape: (num_queries, hidden_dim)
predictions = predict(hidden_states) # Shape: (num_queries + num_tracks, num_classes + 4)

# 使用Query Interaction Module (QIM)生成下一帧的跟踪查询
# hidden_states shape: (num_queries, hidden_dim)
track_queries = qim(hidden_states) # Shape: (num_tracks, query_dim)

if ground_truths is not None:
# 使用Collective Average Loss (CAL)进行训练
# predictions shape: (num_queries, num_classes + 4)
# ground_truths shape: (num_objects, num_classes + 4)
loss = cal(predictions, ground_truths)
backpropagate(loss)

return predictions, track_queries # Shape: (num_queries + num_tracks, num_classes + 4), (num_tracks, query_dim)

def process_video(video, ground_truths=None):
# 初始化检测查询
# 返回形状:(num_queries, query_dim)
detect_queries = initialize_detect_queries()
track_queries = None # Shape: (num_tracks, query_dim)

for frame in video:
predictions, track_queries = process_frame(frame, detect_queries, track_queries, ground_truths)
if ground_truths is None:
yield predictions

查询交互模块

  • 查询交互模块 Query Interaction Module (QIM)MOTR 中的一个关键组件,它负责处理物体的进入和退出,以及增强长期的时间关系建模

  • QIM 的输入是当前帧预测的 detection resulttracking result,输出是下一帧的 tacking query

  • 通俗来说,QIM 是根据当前帧预测的结果,给出下一帧的 “提问”

  • QIM 过程的伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def query_interaction_module(hidden_states, scores, tau_en, tau_ex, M):
# hidden_states shape: (num_queries, hidden_dim)
# scores shape: (num_queries, num_classes)
# tau_en, tau_ex: entrance and exit thresholds
# M: number of consecutive frames for exit threshold

# Object Entrance
entrance_mask = scores.max(dim=1) > tau_en # Shape: (num_queries,)
hidden_states = hidden_states[entrance_mask] # Shape: (num_entrance_queries, hidden_dim)

# Temporal Aggregation Network (TAN),主要目的是融合时序信息,本文是用了一个 Multi-Head Self-Attention 实现
hidden_states = temporal_aggregation_network(hidden_states) # Shape: (num_entrance_queries, hidden_dim)

# Object Exit
exit_mask = scores.max(dim=1) < tau_ex # Shape: (num_entrance_queries,)
exit_mask = exit_mask.rolling(window=M).sum() > 0 # Shape: (num_entrance_queries,)
hidden_states = hidden_states[~exit_mask] # Shape: (num_track_queries, hidden_dim)

return hidden_states # Shape: (num_track_queries, hidden_dim)

集体平均损失

  • 集体平均损失(Collective Average Loss,CAL)是 MOTR 算法中用于训练的损失函数。不同于传统的逐帧计算损失,CAL 收集整个视频剪辑的所有预测,然后基于整个视频剪辑计算总体损失

  • 集体平均损失的代码描述

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def collective_average_loss(predictions, ground_truths, matching_results):
total_loss = 0
total_objects = 0

for i in range(len(predictions)):
pred_tracked = predictions[i]['tracked']
pred_detected = predictions[i]['detected']
gt_tracked = ground_truths[i]['tracked']
gt_detected = ground_truths[i]['detected']
match_tracked = matching_results[i]['tracked']
match_detected = matching_results[i]['detected']

total_loss += single_frame_loss(pred_tracked, match_tracked, gt_tracked)
total_loss += single_frame_loss(pred_detected, match_detected, gt_detected)

total_objects += len(gt_tracked) + len(gt_detected)

return total_loss / total_objects

Thought

  • 以一种非常优雅的方式解决了端到端多目标追踪的任务,打破了之前 NN detection + Hard logic code trackingtracking 范式

  • 这种非黑盒的(显式监督 detecion bbox)复杂任务端到端训练,启发了后续的许多更复杂的端到端任务,例如 UniAD