Zhangzhe's Blog

The projection of my life.

0%

Truncation Sampling as Language Model Desmoothing

URL

TL;DR

  • 常用的 top-p random sampling decoding method 可能会导致输出的长文本质量较差
  • 本文设计了一种基于熵的动态概率阈值优化 top-p 随机采样算法,叫做 ηsampling\eta-sampling

Algorithm

公式表示

Ax<i={xVPθ(xx<i)>η}A_{x \lt i}=\{x \in V | P_{\theta}(x | x \lt i) \gt \eta \}

η=min(ϵ,α(hη,x<i))\eta=min(\epsilon,\alpha(-h_{\eta,x<i}))

其中第一行公式表示随机采样概率值截断到 η\eta

第二行中 ϵ,α\epsilon,\alpha 是超参数,通常 α=ϵ\alpha=\sqrt{\epsilon}h 表示输出的熵,h=(plog(p)).sum()h=-(p*log(p)).sum()p 表示概率

代码表示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class EtaWarper(transformers.LogitsWarper):
"""Our proposed eta sampling warper."""
def __init__(self, epsilon):
self.epsilon = epsilon
self.filter_value = -float("Inf")

def __call__(self, input_ids, scores) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
entropy = - (probabilities * torch.log(probabilities)).sum()
eta = min(self.epsilon, torch.sqrt(torch.tensor(self.epsilon))*torch.exp(-entropy))
indices_to_remove = probabilities < eta
max_word = torch.argmax(scores,dim=-1)
# 防止把候选词全删了,至少留一个
indices_to_remove[...,max_word.squeeze()] = 0
new_scores = scores.masked_fill(indices_to_remove, self.filter_value)
return new_scores

Thought

  • top-p 的小升级,计算量较小,算是一个小 trick