1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| import torch import torch.nn as nn from torchvision.ops import deform_conv2d import math class LFA(nn.Module): def __init__(self, in_channels, out_channels, num_adjacent_keypoints=5): super(LFA, self).__init__() self.offset_conv = nn.Conv2d(in_channels, 2 * num_adjacent_keypoints, kernel_size=1) self.deform_conv_weight = nn.Parameter(torch.Tensor(out_channels, in_channels, 3, 3)) nn.init.kaiming_uniform_(self.deform_conv_weight, a=math.sqrt(5)) def forward(self, x): offsets = self.offset_conv(x) aggregated_features = deform_conv2d(x, offsets, self.deform_conv_weight, stride=1, padding=1) return aggregated_features, offsets
lfa = LFA(in_channels=64, out_channels=128, num_adjacent_keypoints=5)
input_feature_map = torch.randn(16, 64, 128, 128)
aggregated_features, predicted_offsets = lfa(input_feature_map)
|