URL
- paper: https://arxiv.org/pdf/2202.06417
- code: https://github.com/yxuansu/SimCTG/tree/main/contrastive_search_explanation
TL;DR
- 本文提出一种叫做
Contrastive search
的解码方法,比top-k
/top-p
等随机解码方法和argmax
这种贪心解码方法更优,模型输出内容语意更连贯,模型更不容易退化 Contrastive search
是在top-k
基础上增加了degeneration penalty
(退化惩罚项),算是random sampling
类算法的升级版- 但需要将候选的
top-k
输出追加到序列末尾作为输入重新输入到模型中输出top-k
对应的hidden state
,因此时间复杂度较高
Algorithm
总体架构
其中 是超参数,用来平衡模型预测结果和退化惩罚项
h
表示hidden state
,用 和序列中每一个token
的hidden state
计算相似度
由于 只能在下一次inference
时产生,所以需要在decoding
阶段再推理一次,这一步比较麻烦
用一段代码来描述一下解码过程
- 首先在
t-1
时刻输入seq_len
长度的序列,输出vocab_size
长度的logits
,选top-k
1 | # compute the logits in this step |
- 将
top-k
预测的每一种可能追加到原始序列,并在batch size
维度concat
成一个shape = [beam_width, seq_len + 1, dim]
的输入
1 | # concatenate each candidate token with prefix |
- 将
shape = [beam_width, seq_len + 1, dim]
输入到模型中,计算得到每个token
的hidden state
1 | # feed these candidates into next round to get their hidden states |
- 计算每一个
next hidden state
和context hidden state
之间的相关性,相关性越高,说明重复性越高,惩罚越重,选择此token
作为next input id
的可能性越低
1 | def ranking(context_hidden, next_hidden, next_top_k_ids, next_top_k_probs, alpha): |
Thought
- 道理上讲得通,但多推理的一次是实实在在的计算代价,总计算量直接翻倍,真的会有大模型选择这种解码方式吗?