论文深度解读 · 长上下文 · 稀疏注意力 · Kernel 协同设计
MiniMax Sparse Attention
把超长上下文的注意力成本砍掉 28.4×
MiniMax & 北大 / NVIDIA / 浙大 等 · arXiv 2606.13392 · 2026-06 · 109B MoE · 已开源 kernel + MiniMax-M3
百万 token 上下文正在成为前沿大模型的刚需——agent 工作流、仓库级代码推理、持久记忆,都要求模型在几十万到上百万 token 上联合注意。但 softmax 注意力的 二次方成本 让这件事在部署规模下根本扛不住。MSA 的答案:给标准 GQA 挂一个极轻量的 Index Branch 为每个 GQA group 独立挑选 Top-k 个 KV block,Main Branch 只在选中的 block 上做精确注意力——再配一套 KV-outer 的 GPU kernel,把理论稀疏真正变成 14.2× prefill / 7.6× decode 的墙钟加速。
三句话讲完
- 架构极简:只给 GQA 加两个投影矩阵(index 的 Q/K),在 block 粒度、每个 GQA group 独立地选 Top-k 个 KV block;Main Branch 就是普通 softmax 注意力,只不过把可见 token 限制在选中的 block 里。没有花哨的三分支、没有压缩注意力——奥卡姆剃刀式的设计。
- 训练靠 KL 对齐:Top-k 选择不可导,所以用一个 KL loss 让 Index Branch 的打分分布去逼近 Main Branch 的真实注意力分布,并用 stop-gradient 把这个辅助损失牢牢关在 index 投影里,不污染主干。配 indexer warmup + 强制 local block 两个稳定器。
- 效率靠 kernel 协同:exp-free 的 Top-k 选择 + KV-outer 的稀疏注意力(让 KV block 反向 gather 它被哪些 query 选中,把 query 拼起来填满 tensor-core)。在 109B MoE、3T token、原生多模态训练下,MSA 与 GQA 打平,1M 上下文每 token 注意力算力降 28.4×。
28.4×
1M 上下文每 token
注意力 FLOPs 下降
2,048
每 query 固定预算
k·Bk=16×128
01 问题:二次方成本,在百万上下文前彻底失速
大模型正在从"短的单轮问答"急速转向"长程 agent 工作流"——写代码并部署、在开放网页里导航、编排各种工具、产出结构化文档。这些任务动辄要在几十万甚至上百万 token 的上下文上推理。可一旦上下文拉长,softmax 注意力的二次方成本就成了训练和推理的头号瓶颈,再叠加部署规模下的延迟和吞吐约束,问题被进一步放大。
标准因果注意力对每个 query 位置 t、每个 head h 计算:
逐符号拆开看这里到底贵在哪:
- \sum_{i\le t}:对每个 query t,要遍历它前面所有 i\le t 个 key——这就是 O(N) 的来源。
- \alpha_{t,i}^{(h)}:注意力权重,分母 \sum_{j\le t}\exp(\cdot) 又是一次对全历史的求和(softmax 归一化)。
- 对 N 个 query 各做一遍 O(N) 的求和 → 总计 \Theta(2H_q N^2 d_h) FLOPs,随序列长度 N 二次方增长。
GQA(Grouped-Query Attention)先把 KV head 数从 H_q 压到 H_{kv},把 G=H_q/H_{kv} 个相邻 query head 绑到一个共享的 KV head 上——但它只省了 KV 的存储/读取,主注意力路径仍然吃满全上下文长度,二次方的命门没解。MSA 要动的正是这条主路径。
02 把注意力拆成"先选后算"两阶段
稀疏注意力的统一视角:把因果注意力分解成一个 indexer(选哪些 key 该被关注)+ 一次只在选中 key 上的稀疏注意力。对每个 query 位置 i:
- \mathrm{Index}_\phi:第一阶段,称 Index Branch。输入 query 和全部可见 key,吐出一个被选中的下标集合 {\mathcal{I}}_i\subseteq\{1,\dots,i\}。\phi 是它的参数(固定规则的 indexer 没有参数,可学习的有)。
- \mathrm{Attn}:第二阶段,称 Main Branch。就是普通的 scaled dot-product softmax 注意力,但只在 {\mathcal{I}}_i 这个子集上算。
MSA 在这个框架上做了两个关键的"粒度"选择,正是它区别于 NSA / MoBA / DSA 的地方:
① GQA-group 粒度共享
同一个 GQA group 里的 G 个 query head 共享同一份 block 选择 {\mathcal{I}}_i^{(r)},但每个 head 保留自己的 query 投影。这让稀疏检索按 group 走、KV 读取保持连续。
② block 粒度选择
选的是整块 KV(block)而非单个 token。块大小 B_k=128,每个 group 选 k=16 块。降低路由开销、让访存更规整,对 GPU 矩阵运算友好。
Figure 1 · MSA 架构总览。左侧 Index Branch:Hidden States 经 Linear Projection + Norm&RoPE 得到极轻量的 {\bm{Q}}^{\rm idx}、{\bm{K}}^{\rm idx},算出 score 后做 Block Max Pooling(把 token 级分数聚合到 block 级),再 Top-k 选出每个 query、每个 GQA group 的 k 个关键 block(含强制保留的 local block)。右侧 Main Branch:正常生成 Q/K/V,但只在选中的 Top-k KV block 上做精确 softmax 注意力。最右边两个 Query Group 的可视化直观展示了——不同 group 对同一段上下文会选出不同的 block 子集(紫色 vs 棕色),这正是 group-specific 稀疏检索的价值。训练时一条 KL loss 把 index 分布对齐到 Main Branch 的 group 平均分布,且 Index Branch 的梯度与主干 detach。
03 训练:Top-k 不可导,那就用 KL 把"会选"教出来
麻烦在于 \mathrm{TopK} 选择是不可导的——语言模型损失的梯度传不到 index 的 Q/K 投影 {\bm{W}}^{\rm idx}_q,{\bm{W}}^{\rm idx}_k 上,indexer 学不动。MSA 的解法是给 Index Branch 一个直接的对齐信号:让它的打分分布去逼近 Main Branch 在选中 token 上的真实注意力分布。配三个稳定器:Gradient Detach、Indexer Warmup、强制 Local Block。
KL Loss:让 indexer 模仿真实注意力
对 query 位置 i、GQA group r,在选中 block 诱导出的可见 token 集合 {\mathcal{I}}_{i,\rm tok}^{(r)} 上,定义 Index 分布 P^{\rm idx} 和 Main Branch 老师分布 P:
- P^{\rm idx}:Index Branch 自己那套打分 S^{\rm idx} 过 softmax——这是"学生"。
- P^{(r)}:老师。注意它是 \frac{1}{G}\sum_{\ell\in\mathcal{H}_r}——把 group 里 G 个 query head 各自的真实注意力分布,在概率层面平均。因为一个 group 共享一份 block 选择,所以老师必须是这 G 个 head 的"共识"。
然后 indexer 去最小化 KL 散度(对所有位置和 group 取平均):
\mathrm{stopgrad}(P) 是关键:老师分布不接受梯度,所以这个 KL 只往学生(index 投影)方向流。它教会 indexer:"Main Branch 真正重视哪些 token,你打分时也要把它们排到前面"——这样后续的 block 选择才语义有意义。
💡 我的看法:这套"对齐"本质是注意力的自蒸馏
KL loss 让 indexer 去蒸馏 Main Branch 自己的注意力图——不需要额外标签、不需要外部 teacher,信号完全自生。对我们 WELM 这种线上长上下文场景,最值钱的是它原生可训:不像 H2O / SnapKV / Quest 那些 inference-time 稀疏方法,要在一个全注意力 backbone 上事后裁剪,至少留一个阶段贴着全注意力速度跑。MSA 的 indexer 是训练期就长出来的,prefill 和 decode 两头都能吃到稀疏红利。
三个稳定器:别让 KL 污染主干、别让早期乱选
Gradient Detach
对 Index Branch 的输入做 stop-gradient:{\bm{Q}}^{\rm idx}=\mathrm{stopgrad}({\bm{X}}){\bm{W}}^{\rm idx}_q。加上老师已 detach,\mathcal{L}_{\rm KL} 就只更新 index 的两个投影,既不碰 Main Branch 投影、也不经由 X 反传到主干。一个干净的旁路。
Indexer Warmup
两阶段调度:头 40B token 两个分支都跑全注意力,只用 KL 把新加的 index 投影练热;之后才切到稀疏。把预训练好的 dense checkpoint 转稀疏(CPT)时同样用这招,先对齐 indexer 再让它接管路由。
强制 Local Block
每个 query 的本地块(含自己的那块)永远入选,占掉一个 slot,剩下交给 indexer 挑。防止退化选择漏掉 query 紧邻的上下文——局部信息是注意力的命根子。
04 复杂度:为什么省的钱随上下文越拉越多
同样的 H_q,H_{kv},d_h,N 下,GQA 与 MSA 的因果注意力 FLOPs:
- GQA 的 N^2 直接乘在满头 H_q 和满 head 维 d_h 上——最贵的一项随 N 二次方涨。
- MSA 的 Main Branch 是 4H_q d_h N k B_k:其中 k B_k 是固定预算(16×128=2048),不随 N 变,所以这项只是线性增长。
- MSA 还有个 Index Branch 的 H_{kv}d_{\rm idx}N^2,看着也是二次方——但 H_{kv}\ll H_q(实验里 4 vs 64)、d_{\rm idx} 也小,所以这一项的系数极小。
结论:当 k B_k\ll N 且 H_{kv}d_{\rm idx}\ll H_q d_h 时,两者的 FLOPs 差距随 N 一路拉大。这就是为什么 1M 上下文能省到 28.4× 而短上下文省得少。每个 query、每个 group 只看 k B_k=2048 个 token——这个数字恒定,无论上下文是 128K 还是 1M。
05 Kernel:把理论稀疏榨成真实墙钟加速
理论 FLOPs 降 28× 不等于跑得快 28×——稀疏注意力多了 index 构建、Top-k、反向索引、query gather、负载均衡这些开销,访存还更不规整。MSA 把算法和 GPU 执行路径协同设计,这才是它能落地的真正功夫。
① Exp-free Top-k:绕开 softmax 直接排序
因为 softmax 保序(s_i\le s_j\iff\mathrm{softmax}(s)_i\le\mathrm{softmax}(s)_j),Top-k 的结果和原始分数排序完全一致。所以前向直接跳过 max/exp/sum,把原始分数喂给选择。再加 per-thread 寄存器级 Top-k:warp 的 32 个 lane 各扫 1/32 的行,各维护一个 k 元素小顶堆,堆顶缓存在寄存器、延迟写入,最后一轮 shuffle merge——专门为 B_k=128,k=16 这个"小 k"区间定制。
Table 1 · Top-k 延迟对比(μs,H800,fp32,50 次中位数)
| 序列长 N | Blocks B | k | torch | TileLang | MSA(本文) | vs torch | vs TileLang |
| 128K | 1024 | 16 | 3970 | 2864 | 779 | 5.1× | 3.7× |
| 128K | 2048 | 32 | 5378 | 3630 | 1991 | 2.7× | 1.8× |
| 512K | 4096 | 16 | 33810 | 17779 | 7880 | 4.3× | 2.3× |
| 512K | 8192 | 32 | 57659 | 26100 | 21326 | 2.7× | 1.2× |
部署设定 k=16 时优势最大(对 torch 4–5×)。三方产出的 index 集合完全一致。
② KV-outer 稀疏注意力:让 KV block 反向 gather query
稀疏 prefill 该让谁在外层循环?MSA 算了两套账(都按 bf16 计 IO):
逐符号看:Q-outer(query 在外层)的算术强度只有 G(GQA ratio,实验里 16);KV-outer(KV block 在外层、反向 gather 选中它的 query)是 \tfrac{2}{3}B_k(≈85)。后者大得多,所以选 KV-outer 来填满 tensor-core 的 MMA。关键技巧:
- Query 拼接:单个 query 位置只贡献 G 个 head(MMA 的 M 维只有 16,严重填不满)。KV-outer 下同一个 KV tile 的所有 gather query 共享 KV 操作数,于是把 \lceil 128/G\rceil 个 query 位置拼进一个 128×128 的 score MMA。
- Pre-scheduled chunking:"sink 行"问题——某个早期 KV block 几乎被每个 query 选中,变成热点。调度器把热 tile 沿 query 维切成 \sim 2kB_k 大小的 chunk,扇出到多个共享同一份 K/V 的 CTA,无需 atomic。
- 两阶段 forward:每个 query 的 k 份 partial 由 k 个不同 CTA 产出,没法 inline 归一化。拆成 attention kernel(各写 partial 到预分配 slot)+ combine kernel(读回 LSE 做 split-K 合并)。
- Sparse KL Loss 的 LSE fusion:KL 只影响反向,所以前向时直接把 \mathrm{LSE} 顺手吐到显存,整个 KL 前向 pass 被省掉;反向 kernel 直接读这些标量。
06 实验:109B MoE、3T token、原生多模态
验证规模够狠:41 层 MoE backbone、约 109B 总参 / 6B 激活、200K 词表、d_{\rm model}=3072,64 query head / 4 KV head / head dim 128,RoPE 维 64。128 routed experts + 1 shared,top-4 路由。两条训练路线:
- MSA-PT:从零原生稀疏预训练,40B token indexer warmup 后全程稀疏。
- MSA-CPT:从一个 2.6T token 的 GQA 全注意力 checkpoint 出发,换成 MSA 再续训 400B token(前 40B warmup)。务实的"转换"路线。
Figure 2(a) · 预训练 LM loss。蓝线 Full Attention vs 青线 MSA-PT,跨 3T token 几乎完全重合(右上角放大了最后 50B token 窗口,两条线粘在 1.21–1.22)。gradient norm 也始终在同一区间——大规模下训练稀疏注意力和训全注意力一样稳。
Table 2 · 代表性评测(3T token 预算;PPL 越低越好,其余越高越好)
| 类别 | Benchmark | Full | MSA-PT | MSA-CPT |
| General | MMLU | 67.0 | 67.2 | 66.8 |
| BBH | 67.7 | 66.6 | 66.1 |
| WinoGrande | 58.3 | 60.9 | 62.0 |
| Math | GSM8K | 76.2 | 77.7 | 73.7 |
| MathVista | 43.8 | 46.8 | 44.5 |
| Code | HumanEval | 61.0 | 64.0 | 57.9 |
| EvalPlus | 59.4 | 61.8 | 60.0 |
| Retrieval | RULER-8K | 79.8 | 84.2 | 77.2 |
| RULER-32K | 75.0 | 77.5 | 75.7 |
| Image | VisualWebBench | 55.6 | 68.4 | 59.4 |
| CharXiv | 37.55 | 41.55 | 37.15 |
| Video | EgoSchema | 29.6 | 37.6 | 25.8 |
| VideoMME | 41.11 | 45.48 | 39.65 |
两条稀疏路线都跟 Full 打平。有意思的是 MSA-PT 在大量 math/image/video/RULER 上反而超过 Full——原生稀疏预训练似乎让表征适配了稀疏模式。MSA-CPT 更保守、贴近原 checkpoint,是已有 dense 模型时的实用转换路径。
Table 3 · 长上下文扩展(MSA-CPT 再续 140B token,Δ = 相对 Full)
| Benchmark | 子项 | Full | MSA-CPT | Δ |
| HELMET-128K | Overall | 46.53 | 45.93 | -0.60 |
| ICL | 70.40 | 72.80 | +2.40 |
| RULER-128K | Overall | 72.00 | 72.12 | +0.12 |
| MK/MQ/MV | 96.63 | 98.87 | +2.24 |
每个 query/group 只看 kB_k=2048 个 token,128K 上下文下仍贴着 Full——在极紧的注意力预算下守住了长上下文能力。
Figure 4 · 效率。左:理论每 token 注意力 FLOPs,1M 上下文降 28.4×,且随上下文越长降得越多。中/右:实测 prefill / decode 加速。墙钟加速(14.2× / 7.6×)小于理论 FLOPs 降幅——因为多了 index 构建、Top-k、反向索引、gather、负载均衡这些不规整开销;但同样随上下文增长而扩大,因为 dense baseline 一直随全序列涨,MSA 主预算恒定。
07 它和 NSA / MoBA / DSA 差在哪
NSA
针对 MQA/MHA,三条并行分支(压缩注意力 + 选择注意力 + 滑窗)。MSA 砍到只留"选择",更简。
MoBA
也基于 GQA,但用很大的 KV block、块平均 key 打分,且 indexer 只靠 LM 梯度训练。MSA 用小块 + KL 显式对齐。
DSA
坐在 MLA 的 MQA 模式上,ReLU lightning indexer 逐 token 打分、所有 head 共享一份 Top-k、token 级选择。MSA 是 group 级 + block 级。
MSA 的差异点凝练成两条同时采用的轴:per-GQA-group 的 Top-k 共享 + block 级选择——既拿到多 group 的 block 粒度检索,又让 KV 读取保持连续。
💡 我的看法:对 WELM 线上长上下文最实在的一点
我们最头疼的是多轮请求漂移导致 KV-Cache 跨机命中不了、被迫重算 prefill、TTFT 飙升。MSA 这套的价值在于:它把"每个 query 实际要算的 KV 预算钉死在 2048 个 token"——无论上下文 128K 还是 1M。这意味着 prefill 的算力和访存都有了可预测的上界,对前缀缓存/路由调度是极友好的信号:缓存的有效 KV 集合是 block 对齐、group 共享的,比逐 token 漂移的 cache 好复用得多。而且 indexer 原生可训、kernel 已开源(MiniMax-AI/MSA),GQA backbone 几乎不用改就能接——它在 outlook 里也明确点了 RL 后训练和 agentic 部署是下一步,正好是长上下文成本最吃紧的地方。值得拿我们的 serving 栈实测一把 prefill 命中率和 TTFT。
一句话收尾:MSA 用"奥卡姆剃刀"砍掉了稀疏注意力里所有花哨的分支,只留下 group 级 block 选择 + KL 对齐的 indexer,再用 KV-outer kernel 把理论稀疏榨成真实加速——在 109B、3T token、原生多模态这个足够认真的规模上,证明了"稀疏可以和全注意力打平,而且越长越赚"。