URL
TL;DR
- 常用的
top-p random sampling decoding method
可能会导致输出的长文本质量较差
- 本文设计了一种基于熵的动态概率阈值优化
top-p
随机采样算法,叫做 η−sampling
Algorithm
公式表示
Ax<i={x∈V∣Pθ(x∣x<i)>η}
η=min(ϵ,α(−hη,x<i))
其中第一行公式表示随机采样概率值截断到 η
第二行中 ϵ,α 是超参数,通常 α=ϵ,h
表示输出的熵,h=−(p∗log(p)).sum(),p
表示概率
代码表示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| class EtaWarper(transformers.LogitsWarper): """Our proposed eta sampling warper.""" def __init__(self, epsilon): self.epsilon = epsilon self.filter_value = -float("Inf") def __call__(self, input_ids, scores) -> torch.FloatTensor: probabilities = scores.softmax(dim=-1) entropy = - (probabilities * torch.log(probabilities)).sum() eta = min(self.epsilon, torch.sqrt(torch.tensor(self.epsilon))*torch.exp(-entropy)) indices_to_remove = probabilities < eta max_word = torch.argmax(scores,dim=-1) indices_to_remove[...,max_word.squeeze()] = 0 new_scores = scores.masked_fill(indices_to_remove, self.filter_value) return new_scores
|
Thought
- 是
top-p
的小升级,计算量较小,算是一个小 trick