Zhangzhe's Blog

The projection of my life.

0%

Locally Typical Sampling

URL

TL;DR

  • 本文提出来一种 Typical sampling 的解码策略,是对 top-p sampling 的改进,启发了后续的 ηsampling\eta-sampling 解码策略,旨在生成更自然、更符合人类语言使用习惯的文本。

  • 论文使用信息论的观点,局部典型性(local typicality)采样的依据是条件熵和信息量的接近程度

    • 条件熵:(log(p)p).sum()(-log(p)*p).sum()
    • 信息量:log(p)-log(p)

Algorithm

公式角度理解

Lϵ(T)={y=y0,...,yT1tT,logp(yty<t)+H(YtY<t=y<t)<ϵ}L_{\epsilon}^{(T)}=\{y=y_0,...,y_T| \forall 1\le t \le T, | \log p(y_t|y_{\lt t}) + H(Y_t|Y_{\lt t=y \lt t}) | \lt \epsilon \}

  • 其中
    • logp(yty<t)+H(YtY<t=y<t)<ϵ| \log p(y_t|y_{\lt t}) + H(Y_t|Y_{\lt t=y \lt t}) | \lt \epsilon 表示信息量和交叉熵最大相差不大于 ϵ\epsilon,写成 H(YtY<t=y<t)(logp(yty<t))<ϵ|H(Y_t|Y_{\lt t=y \lt t}) - (-\log p(y_t|y_{\lt t}) ) | \lt \epsilon 更容易理解
    • ϵ\epsilon 是超参数,需要在 (0, 1) 之间

代码角度理解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class TypicalLogitsWarper(LogitsWarper):

def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
mass = float(mass)
if not (mass > 0 and mass < 1):
raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")

self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)

# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind.clamp_(max=sorted_scores.shape[-1] - 1)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
return scores_processed

Thought

  • 单看计算过程实际上比较符合直觉,用信息论的角度理解反而会云里雾里…