Zhangzhe's Blog

The projection of my life.

0%

Transformers without Normalization

URL

TL;DR

  • 这是由恺明和杨立昆提出的一篇关于 transformer 算子优化的论文,主要观点是去掉 transformer 结构中的 normalization 层,改成 tanh
  • 改用 tanh 算子的 transformer 模型,在大多数任务上可达到使用归一化层的模型相同的性能,甚至更好

Algorithm

dyt.png

  • 简单来说,这篇论文的核心思想是将 transformer 中的 normalization 层(可以是 LayerNormRMSNorm)替换成 dynamic tanh 层(简称 DyT
  • normalization 计算公式:

normalization(x)=γ×xμσ2+ϵ+β\text{normalization}(x) = \gamma \times \frac{x - \mu}{\sqrt{\sigma^2+\epsilon}} + \beta

其中 μ\muσ\sigma 分别是 meanstdγ\gammaβ\betascaleshift 参数

  • DyT 计算公式:

DyT(x)=γ×tanh(αx)+β\text{DyT}(x) = \gamma \times \tanh(\alpha x) + \beta

其中 α\alpha 是个可学习参数,γ\gammaβ\betascaleshift 参数(和 normalization 一样)

  • DyT 实现伪代码:
1
2
3
4
5
6
7
8
9
10
11
# input x has the shape of [B, T, C]
# B: batch size, T: tokens, C: dimension
class DyT(Module):
def __init__(self, C, init_α):
super().__init__()
self.α = Parameter(ones(1) * init_α)
self.γ = Parameter(ones(C))
self.β = Parameter(zeros(C))
def forward(self, x):
x = tanh(self.alpha * x)
return self.γ * x + self.β

α\alpha 默认初始化值为 0.5

Results

  • 作者在多个领域的知名模型上都对比了修改前后训练精度,DyT 的性能和 normalization 的性能基本一致,打的有来有回
    dyt2.png
    dyt3.png
    dyt4.png
    dyt5.png
    dyt6.png
    dyt1.png
  • 作者还对比了 DyTnormalization 的训练/推理速度,DyT 的训练/推理速度要快很多
    dyt7.png
  • 作者同时做 tanhα\alpha 做了消融实验,发现 tanhα\alpha 都是必要的
    dyt8.png
    dyt9.png

Thoughts

  • 属于是恺明和立昆的梦幻联动了…,这种对最火的结构的优化,非大佬不能为也,想象下如果这篇论文是大学实验室发表的,大家第一反应恐怕是:Who think you are? 😂
  • 之前算是稍微接触过硬件,DyT 这种 element-wise opnormalization 这种 reduce op 一定快多了,想怎么 tiling 都行…
Powered By Valine
v1.5.2