论文深度解读 · 长上下文 · 稀疏注意力 · Kernel 协同设计

MiniMax Sparse Attention
把超长上下文的注意力成本砍掉 28.4×

MiniMax & 北大 / NVIDIA / 浙大 等 · arXiv 2606.13392 · 2026-06 · 109B MoE · 已开源 kernel + MiniMax-M3

百万 token 上下文正在成为前沿大模型的刚需——agent 工作流、仓库级代码推理、持久记忆,都要求模型在几十万到上百万 token 上联合注意。但 softmax 注意力的 二次方成本 让这件事在部署规模下根本扛不住。MSA 的答案:给标准 GQA 挂一个极轻量的 Index Branch 为每个 GQA group 独立挑选 Top-k 个 KV block,Main Branch 只在选中的 block 上做精确注意力——再配一套 KV-outer 的 GPU kernel,把理论稀疏真正变成 14.2× prefill / 7.6× decode 的墙钟加速。

三句话讲完
28.4×
1M 上下文每 token
注意力 FLOPs 下降
14.2×
H800 上 prefill
墙钟加速
7.6×
decode 阶段
墙钟加速
2,048
每 query 固定预算
k·Bk=16×128

01 问题:二次方成本,在百万上下文前彻底失速

大模型正在从"短的单轮问答"急速转向"长程 agent 工作流"——写代码并部署、在开放网页里导航、编排各种工具、产出结构化文档。这些任务动辄要在几十万甚至上百万 token 的上下文上推理。可一旦上下文拉长,softmax 注意力的二次方成本就成了训练和推理的头号瓶颈,再叠加部署规模下的延迟和吞吐约束,问题被进一步放大。

标准因果注意力对每个 query 位置 t、每个 head h 计算:

逐符号拆开看这里到底贵在哪:

GQA(Grouped-Query Attention)先把 KV head 数从 H_q 压到 H_{kv},把 G=H_q/H_{kv} 个相邻 query head 绑到一个共享的 KV head 上——但它只省了 KV 的存储/读取,主注意力路径仍然吃满全上下文长度,二次方的命门没解。MSA 要动的正是这条主路径。

02 把注意力拆成"先选后算"两阶段

稀疏注意力的统一视角:把因果注意力分解成一个 indexer(选哪些 key 该被关注)+ 一次只在选中 key 上的稀疏注意力。对每个 query 位置 i

MSA 在这个框架上做了两个关键的"粒度"选择,正是它区别于 NSA / MoBA / DSA 的地方:

① GQA-group 粒度共享
同一个 GQA group 里的 G 个 query head 共享同一份 block 选择 {\mathcal{I}}_i^{(r)},但每个 head 保留自己的 query 投影。这让稀疏检索按 group 走、KV 读取保持连续。
② block 粒度选择
选的是整块 KV(block)而非单个 token。块大小 B_k=128,每个 group 选 k=16 块。降低路由开销、让访存更规整,对 GPU 矩阵运算友好。
MSA 架构图
Figure 1 · MSA 架构总览。左侧 Index Branch:Hidden States 经 Linear Projection + Norm&RoPE 得到极轻量的 {\bm{Q}}^{\rm idx}{\bm{K}}^{\rm idx},算出 score 后做 Block Max Pooling(把 token 级分数聚合到 block 级),再 Top-k 选出每个 query、每个 GQA group 的 k 个关键 block(含强制保留的 local block)。右侧 Main Branch:正常生成 Q/K/V,但只在选中的 Top-k KV block 上做精确 softmax 注意力。最右边两个 Query Group 的可视化直观展示了——不同 group 对同一段上下文会选出不同的 block 子集(紫色 vs 棕色),这正是 group-specific 稀疏检索的价值。训练时一条 KL loss 把 index 分布对齐到 Main Branch 的 group 平均分布,且 Index Branch 的梯度与主干 detach。

03 训练:Top-k 不可导,那就用 KL 把"会选"教出来

麻烦在于 \mathrm{TopK} 选择是不可导的——语言模型损失的梯度传不到 index 的 Q/K 投影 {\bm{W}}^{\rm idx}_q,{\bm{W}}^{\rm idx}_k 上,indexer 学不动。MSA 的解法是给 Index Branch 一个直接的对齐信号:让它的打分分布去逼近 Main Branch 在选中 token 上的真实注意力分布。配三个稳定器:Gradient Detach、Indexer Warmup、强制 Local Block。

KL Loss:让 indexer 模仿真实注意力

对 query 位置 i、GQA group r,在选中 block 诱导出的可见 token 集合 {\mathcal{I}}_{i,\rm tok}^{(r)} 上,定义 Index 分布 P^{\rm idx} 和 Main Branch 老师分布 P:

然后 indexer 去最小化 KL 散度(对所有位置和 group 取平均):

\mathrm{stopgrad}(P) 是关键:老师分布不接受梯度,所以这个 KL 只往学生(index 投影)方向流。它教会 indexer:"Main Branch 真正重视哪些 token,你打分时也要把它们排到前面"——这样后续的 block 选择才语义有意义。

💡 我的看法:这套"对齐"本质是注意力的自蒸馏

KL loss 让 indexer 去蒸馏 Main Branch 自己的注意力图——不需要额外标签、不需要外部 teacher,信号完全自生。对我们 WELM 这种线上长上下文场景,最值钱的是它原生可训:不像 H2O / SnapKV / Quest 那些 inference-time 稀疏方法,要在一个全注意力 backbone 上事后裁剪,至少留一个阶段贴着全注意力速度跑。MSA 的 indexer 是训练期就长出来的,prefill 和 decode 两头都能吃到稀疏红利。

三个稳定器:别让 KL 污染主干、别让早期乱选

Gradient Detach
对 Index Branch 的输入做 stop-gradient:{\bm{Q}}^{\rm idx}=\mathrm{stopgrad}({\bm{X}}){\bm{W}}^{\rm idx}_q。加上老师已 detach,\mathcal{L}_{\rm KL}只更新 index 的两个投影,既不碰 Main Branch 投影、也不经由 X 反传到主干。一个干净的旁路。
Indexer Warmup
两阶段调度:头 40B token 两个分支都跑全注意力,只用 KL 把新加的 index 投影练热;之后才切到稀疏。把预训练好的 dense checkpoint 转稀疏(CPT)时同样用这招,先对齐 indexer 再让它接管路由。
强制 Local Block
每个 query 的本地块(含自己的那块)永远入选,占掉一个 slot,剩下交给 indexer 挑。防止退化选择漏掉 query 紧邻的上下文——局部信息是注意力的命根子。

04 复杂度:为什么省的钱随上下文越拉越多

同样的 H_q,H_{kv},d_h,N 下,GQA 与 MSA 的因果注意力 FLOPs:

结论:当 k B_k\ll NH_{kv}d_{\rm idx}\ll H_q d_h 时,两者的 FLOPs 差距随 N 一路拉大。这就是为什么 1M 上下文能省到 28.4× 而短上下文省得少。每个 query、每个 group 只看 k B_k=2048 个 token——这个数字恒定,无论上下文是 128K 还是 1M。

05 Kernel:把理论稀疏榨成真实墙钟加速

理论 FLOPs 降 28× 不等于跑得快 28×——稀疏注意力多了 index 构建、Top-k、反向索引、query gather、负载均衡这些开销,访存还更不规整。MSA 把算法和 GPU 执行路径协同设计,这才是它能落地的真正功夫。

① Exp-free Top-k:绕开 softmax 直接排序

因为 softmax 保序(s_i\le s_j\iff\mathrm{softmax}(s)_i\le\mathrm{softmax}(s)_j),Top-k 的结果和原始分数排序完全一致。所以前向直接跳过 max/exp/sum,把原始分数喂给选择。再加 per-thread 寄存器级 Top-k:warp 的 32 个 lane 各扫 1/32 的行,各维护一个 k 元素小顶堆,堆顶缓存在寄存器、延迟写入,最后一轮 shuffle merge——专门为 B_k=128,k=16 这个"小 k"区间定制。

Table 1 · Top-k 延迟对比(μs,H800,fp32,50 次中位数)
序列长 NBlocks BktorchTileLangMSA(本文)vs torchvs TileLang
128K102416397028647795.1×3.7×
128K2048325378363019912.7×1.8×
512K409616338101777978804.3×2.3×
512K8192325765926100213262.7×1.2×
部署设定 k=16 时优势最大(对 torch 4–5×)。三方产出的 index 集合完全一致。

② KV-outer 稀疏注意力:让 KV block 反向 gather query

稀疏 prefill 该让谁在外层循环?MSA 算了两套账(都按 bf16 计 IO):

逐符号看:Q-outer(query 在外层)的算术强度只有 G(GQA ratio,实验里 16);KV-outer(KV block 在外层、反向 gather 选中它的 query)是 \tfrac{2}{3}B_k(≈85)。后者大得多,所以选 KV-outer 来填满 tensor-core 的 MMA。关键技巧:

06 实验:109B MoE、3T token、原生多模态

验证规模够狠:41 层 MoE backbone、约 109B 总参 / 6B 激活、200K 词表、d_{\rm model}=3072,64 query head / 4 KV head / head dim 128,RoPE 维 64。128 routed experts + 1 shared,top-4 路由。两条训练路线:

LM loss 曲线
Figure 2(a) · 预训练 LM loss。蓝线 Full Attention vs 青线 MSA-PT,跨 3T token 几乎完全重合(右上角放大了最后 50B token 窗口,两条线粘在 1.21–1.22)。gradient norm 也始终在同一区间——大规模下训练稀疏注意力和训全注意力一样稳
Table 2 · 代表性评测(3T token 预算;PPL 越低越好,其余越高越好)
类别BenchmarkFullMSA-PTMSA-CPT
GeneralMMLU67.067.266.8
BBH67.766.666.1
WinoGrande58.360.962.0
MathGSM8K76.277.773.7
MathVista43.846.844.5
CodeHumanEval61.064.057.9
EvalPlus59.461.860.0
RetrievalRULER-8K79.884.277.2
RULER-32K75.077.575.7
ImageVisualWebBench55.668.459.4
CharXiv37.5541.5537.15
VideoEgoSchema29.637.625.8
VideoMME41.1145.4839.65
两条稀疏路线都跟 Full 打平。有意思的是 MSA-PT 在大量 math/image/video/RULER 上反而超过 Full——原生稀疏预训练似乎让表征适配了稀疏模式。MSA-CPT 更保守、贴近原 checkpoint,是已有 dense 模型时的实用转换路径。
Table 3 · 长上下文扩展(MSA-CPT 再续 140B token,Δ = 相对 Full)
Benchmark子项FullMSA-CPTΔ
HELMET-128KOverall46.5345.93-0.60
ICL70.4072.80+2.40
RULER-128KOverall72.0072.12+0.12
MK/MQ/MV96.6398.87+2.24
每个 query/group 只看 kB_k=2048 个 token,128K 上下文下仍贴着 Full——在极紧的注意力预算下守住了长上下文能力。
效率对比
Figure 4 · 效率。左:理论每 token 注意力 FLOPs,1M 上下文降 28.4×,且随上下文越长降得越多。中/右:实测 prefill / decode 加速。墙钟加速(14.2× / 7.6×)小于理论 FLOPs 降幅——因为多了 index 构建、Top-k、反向索引、gather、负载均衡这些不规整开销;但同样随上下文增长而扩大,因为 dense baseline 一直随全序列涨,MSA 主预算恒定。

07 它和 NSA / MoBA / DSA 差在哪

NSA
针对 MQA/MHA,三条并行分支(压缩注意力 + 选择注意力 + 滑窗)。MSA 砍到只留"选择",更简。
MoBA
也基于 GQA,但用很大的 KV block、块平均 key 打分,且 indexer 只靠 LM 梯度训练。MSA 用小块 + KL 显式对齐。
DSA
坐在 MLA 的 MQA 模式上,ReLU lightning indexer 逐 token 打分、所有 head 共享一份 Top-k、token 级选择。MSA 是 group 级 + block 级。

MSA 的差异点凝练成两条同时采用的轴:per-GQA-group 的 Top-k 共享 + block 级选择——既拿到多 group 的 block 粒度检索,又让 KV 读取保持连续。

💡 我的看法:对 WELM 线上长上下文最实在的一点

我们最头疼的是多轮请求漂移导致 KV-Cache 跨机命中不了、被迫重算 prefill、TTFT 飙升。MSA 这套的价值在于:它把"每个 query 实际要算的 KV 预算钉死在 2048 个 token"——无论上下文 128K 还是 1M。这意味着 prefill 的算力和访存都有了可预测的上界,对前缀缓存/路由调度是极友好的信号:缓存的有效 KV 集合是 block 对齐、group 共享的,比逐 token 漂移的 cache 好复用得多。而且 indexer 原生可训、kernel 已开源(MiniMax-AI/MSA),GQA backbone 几乎不用改就能接——它在 outlook 里也明确点了 RL 后训练和 agentic 部署是下一步,正好是长上下文成本最吃紧的地方。值得拿我们的 serving 栈实测一把 prefill 命中率和 TTFT。

一句话收尾:MSA 用"奥卡姆剃刀"砍掉了稀疏注意力里所有花哨的分支,只留下 group 级 block 选择 + KL 对齐的 indexer,再用 KV-outer kernel 把理论稀疏榨成真实加速——在 109B、3T token、原生多模态这个足够认真的规模上,证明了"稀疏可以和全注意力打平,而且越长越赚"。