0%
URL
TL;DR
FlashMLA
是针对 DeepSeek
提出的 MLA (Multi-head Latent Attention)
模块,在 Nvidia Hopper
架构 GPU
上的加速算法,尤其对边长序列推理做了优化
FlashMLA
和 MLA
的关系大概可以类比到 FlashAttention
和 Attention
的关系
关键特性和能力
- 仅对
Hopper (sm_90)
架构做优化:充分挖掘了硬件 (sm
) 的计算能力和 Memory Hierachy
的 存储 / IO
能力
- 支持可变长度序列:和现实世界中推理场景更贴合
- 分页 KV cache:使用
block size = 64
的分页存储(这里的 64
的含义是:一个 block
存储某一层某个头的连续 64
个 token
对应的 kv cache
)
- 高性能带宽和算力:在
H800 SXM5
设备上,在内存带宽受限配置环境下,内存带宽可达 3000 GB/s
,在算力受限配置下,算力可达 580 TFLOPS
- 支持多种精度:支持
BF16
和 FP16
技术实现
- 基于
Nvidia cutlass
库 + cuda
实现主要计算
- 用干净的
python API
包装,方便集成到 PyTorch-base
框架中
- 用元数据管理的方式支持变长序列
Thoughts
- 做大模型不懂
cuda
是不行了,cutlass
要开始学起来了…
FlashMLA
只在 sm_90
上实现了,其他显卡编译不通,DeepSeek
只打高端局