Zhangzhe's Blog

The projection of my life.

0%

2025.02 DeepSeek 开源周第一弹 —— FlashMLA

URL

TL;DR

  • FlashMLA 是针对 DeepSeek 提出的 MLA (Multi-head Latent Attention) 模块,在 Nvidia Hopper 架构 GPU 上的加速算法,尤其对边长序列推理做了优化
  • FlashMLAMLA 的关系大概可以类比到 FlashAttentionAttention 的关系

关键特性和能力

  1. 仅对 Hopper (sm_90) 架构做优化:充分挖掘了硬件 (sm) 的计算能力和 Memory Hierachy 的 存储 / IO 能力
  2. 支持可变长度序列:和现实世界中推理场景更贴合
  3. 分页 KV cache:使用 block size = 64 的分页存储(这里的 64 的含义是:一个 block 存储某一层某个头的连续 64token 对应的 kv cache
  4. 高性能带宽和算力:在 H800 SXM5 设备上,在内存带宽受限配置环境下,内存带宽可达 3000 GB/s,在算力受限配置下,算力可达 580 TFLOPS
  5. 支持多种精度:支持 BF16FP16

技术实现

  1. 基于 Nvidia cutlass 库 + cuda 实现主要计算
  2. 用干净的 python API 包装,方便集成到 PyTorch-base 框架中
  3. 用元数据管理的方式支持变长序列

Thoughts

  • 做大模型不懂 cuda 是不行了,cutlass 要开始学起来了…
  • FlashMLA 只在 sm_90 上实现了,其他显卡编译不通,DeepSeek 只打高端局