0%
URL
TL;DR
FlashMLA 是针对 DeepSeek 提出的 MLA (Multi-head Latent Attention) 模块,在 Nvidia Hopper 架构 GPU 上的加速算法,尤其对边长序列推理做了优化
FlashMLA 和 MLA 的关系大概可以类比到 FlashAttention 和 Attention 的关系
关键特性和能力
- 仅对
Hopper (sm_90) 架构做优化:充分挖掘了硬件 (sm) 的计算能力和 Memory Hierachy 的 存储 / IO 能力
- 支持可变长度序列:和现实世界中推理场景更贴合
- 分页 KV cache:使用
block size = 64 的分页存储(这里的 64 的含义是:一个 block 存储某一层某个头的连续 64 个 token 对应的 kv cache)
- 高性能带宽和算力:在
H800 SXM5 设备上,在内存带宽受限配置环境下,内存带宽可达 3000 GB/s,在算力受限配置下,算力可达 580 TFLOPS
- 支持多种精度:支持
BF16 和 FP16
技术实现
- 基于
Nvidia cutlass 库 + cuda 实现主要计算
- 用干净的
python API 包装,方便集成到 PyTorch-base 框架中
- 用元数据管理的方式支持变长序列
Thoughts
- 做大模型不懂
cuda 是不行了,cutlass 要开始学起来了…
FlashMLA 只在 sm_90 上实现了,其他显卡编译不通,DeepSeek 只打高端局