Zhangzhe's Blog

The projection of my life.

0%

DETR: End-to-End Object Detection with Transformers

URL

https://arxiv.org/pdf/2005.12872.pdf

Algorithm

Architecture

detr.png

DETR inference

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(
self, num_classes, hidden_dim, nheads, num_encoder_layers, num_decoder_layers
):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(
hidden_dim, nheads, num_encoder_layers, num_decoder_layers
)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = (
torch.cat(
[
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
],
dim=-1,
)
.flatten(0, 1)
.unsqueeze(1)
)
h = self.transformer(
pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1)
)
return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(
num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6
)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
print(logits.shape) # [100, 1, 92]: [num_query, batch, classes]
print(bboxes.shape) # [100, 1, 4]: [num_query, batch, box]
  • 模型拓扑图
    detr.jpg