1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| def ranking(context_hidden, next_hidden, next_top_k_ids, next_top_k_probs, alpha): ''' context_hidden: beam_width x context_len x embed_dim next_hidden: beam_width x 1 x embed_dim next_top_k_ids: beam_width x 1 ''' beam_width, context_len, embed_dim = context_hidden.size() assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim]) norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) assert cosine_matrix.size() == torch.Size([beam_width, context_len]) scores, _ = torch.max(cosine_matrix, dim = -1) assert scores.size() == torch.Size([beam_width]) next_top_k_probs = next_top_k_probs.view(-1) scores = (1.0 - alpha) * next_top_k_probs - alpha * scores _, selected_idx = torch.topk(scores, k = 1) assert selected_idx.size() == torch.Size([1]) selected_idx = selected_idx.unsqueeze(0) assert selected_idx.size() == torch.Size([1,1]) next_id = torch.gather(next_top_k_ids, dim = 0, index=selected_idx) assert next_id.size() == torch.Size([1,1]) return next_id
|