1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
| import torch import torch.nn.functional as F import math def attention(query, key, value): scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) probs = F.softmax(scores, dim=-1) output = torch.matmul(probs, value) return output def deformable_attention(query, value, reference_points, num_sampling_points, atten_linear, offset_linear): batch_size, seq_len, embed_dim = query.size() offsets = offset_linear(reference_points).view(batch_size, seq_len, num_sampling_points, 2) sampling_positions = reference_points.unsqueeze(2) + offsets sampling_values = value.gather(1, sampling_positions.long()) scores = atten_linear(query).view(batch_size, seq_len, num_sampling_points) probs = F.softmax(scores, dim=-1) output = torch.matmul(probs, sampling_values) return output
|