URL
TL;DR
- 这是由恺明和杨立昆提出的一篇关于
transformer
算子优化的论文,主要观点是去掉transformer
结构中的normalization
层,改成tanh
层 - 改用
tanh
算子的transformer
模型,在大多数任务上可达到使用归一化层的模型相同的性能,甚至更好
Algorithm
- 简单来说,这篇论文的核心思想是将
transformer
中的normalization
层(可以是LayerNorm
或RMSNorm
)替换成dynamic tanh
层(简称DyT
) normalization
计算公式:
其中 和 分别是
mean
和std
, 和 是scale
和shift
参数
DyT
计算公式:
其中 是个可学习参数, 和 是
scale
和shift
参数(和normalization
一样)
DyT
实现伪代码:
1 | # input x has the shape of [B, T, C] |
默认初始化值为
0.5
Results
- 作者在多个领域的知名模型上都对比了修改前后训练精度,
DyT
的性能和normalization
的性能基本一致,打的有来有回
- 作者还对比了
DyT
和normalization
的训练/推理速度,DyT
的训练/推理速度要快很多
- 作者同时做
tanh
和 做了消融实验,发现tanh
和 都是必要的
Thoughts
- 属于是恺明和立昆的梦幻联动了…,这种对最火的结构的优化,非大佬不能为也,想象下如果这篇论文是大学实验室发表的,大家第一反应恐怕是:Who think you are? 😂
- 之前算是稍微接触过硬件,
DyT
这种element-wise op
比normalization
这种reduce op
一定快多了,想怎么tiling
都行…
v1.5.2