1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| import torch from torch import nn from torchvision.models import resnet50 class DETR(nn.Module): def __init__( self, num_classes, hidden_dim, nheads, num_encoder_layers, num_decoder_layers ): super().__init__() self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2]) self.conv = nn.Conv2d(2048, hidden_dim, 1) self.transformer = nn.Transformer( hidden_dim, nheads, num_encoder_layers, num_decoder_layers ) self.linear_class = nn.Linear(hidden_dim, num_classes + 1) self.linear_bbox = nn.Linear(hidden_dim, 4) self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) def forward(self, inputs): x = self.backbone(inputs) h = self.conv(x) H, W = h.shape[-2:] pos = ( torch.cat( [ self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), ], dim=-1, ) .flatten(0, 1) .unsqueeze(1) ) h = self.transformer( pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1) ) return self.linear_class(h), self.linear_bbox(h).sigmoid() detr = DETR( num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6 ) detr.eval() inputs = torch.randn(1, 3, 800, 1200) logits, bboxes = detr(inputs) print(logits.shape) print(bboxes.shape)
|