核心问题:当 Diffusion Transformer 做长视频时,注意力平方级膨胀怎么破?
一句话答案:MoGA 用“可学习的 token 路由器”把相似语义自动分进同一组,组内做全注意力,组外零计算,实现 71% 稀疏度却保持画质,端到端直出 60 秒多镜头 24 fps 视频。
本文欲回答的核心问题
-
为什么长视频生成注定会“爆显存”? -
MoGA 的分组注意力到底改变了什么? -
如何在现有 DiT 代码里无痛换上 MoGA? -
真实指标、速度、显存对比是多少? -
作者踩坑:分组数 M 越大越好吗?平衡损失怎么调?
1 背景:长视频生成的三重诅咒
| 诅咒 | 现象 | 举例(480p@24fps) |
|---|---|---|
| 序列长度 | 1 分钟 ≈ 58 万 token | 单卡 A100 80 G 装不下一次全注意力 |
| 多镜头一致 | 切镜后人物/背景漂移 | 观众一眼出戏 |
| 端到端训练 | 多阶段流水线误差累积 | 关键帧→插帧→超分,三步三步地崩 |
传统稀疏注意力要么“手工划窗”丢失长程,要么“先分块再挑块”带来块大小玄学。MoGA 干脆让网络自己学「谁该跟谁玩」。
2 MoGA 核心思想:把 softmax 的稀疏性变成可学习路由
2.1 一句话原理
用单个线性层当“路由器”,输入 token 语义,输出 M 维 logits → softmax → 选 top-1 组 → 组内做标准 FlashAttention。没有 block,没有 k-means,没有手工阈值。
2.2 形状变化可视化
Full Attention: [N×N] 稠密矩阵
Block Sparse: [N/B × N/B] 先粗后细
MoGA: [G1,G2,…,GM] 每组 ni≈N/M,组内 ni²
理论计算量从 O(N²) 降到 Σ O(ni²)。均匀分组时下界 O(N²/M)。
2.3 路由器代码(PyTorch 风格伪码)
class Router(nn.Module):
def __init__(self, d_model, M):
super().__init__()
self.weight = nn.Linear(d_model, M, bias=False) # 就是一组中心向量
def forward(self, x): # x: [B, N, d]
logits = self.weight(x) # [B, N, M]
scores = logits.softmax(-1)
group_id = scores.argmax(-1) # [B, N]
return group_id, scores
路由器参数量 = d_model × M,当 d=1152、M=5 时仅 5.7 k,可忽略不计。
3 系统架构:DiT + MoGA + 时空局部窗 = 长视频生成模型
3.1 宏观堆叠
-
视觉分支:交替放置「MoGA 全局组注意力」+「STGA 局部时空窗」 -
文本分支:每镜头一句短 prompt,用 Cross-Modal Attention 注入 -
VAE 下采样 (4,8,8),patchify (1,2,2), latent 空间 1 分钟 ≈ 578 k token
3.2 局部窗怎么对齐镜头边界
STGA 把不同镜头拆成独立时间组,但计算时把相邻镜头各补 2 帧 key/value,解决切镜瞬间闪烁——零新增查询,只增 key/value,计算开销 <1%。
4 数据管道:把“长视频”拆成“多镜头+密集字幕”
| 阶段 | 自动工具 | 过滤阈值举例 |
|---|---|---|
| 视频级 | VQA+OCR+黑边检测 | 美学<4.5、清晰度<30 丢弃 |
| 镜头级 | AutoShot + PySceneDetect | 淡入淡出标记,重叠帧剪掉 |
| 字幕级 | Qwen2.5-VL 多模态大模型 | 每 2-4 秒一句,物体+动作+场景 |
最终训练样本:≤65 秒,8-10 镜头,对应 8-10 条短 prompt,token 长度 58 万级。
5 训练细节:从 10 秒→30 秒→60 秒渐进式微调
-
预热:10 秒片段,3 k step,lr=1e-5,α=0.1 -
拉长:30 秒片段,1 k step,序列并行+FlashAttention2 -
极限:60 秒,MMDiT 骨干,M=20,显存仍 <80 G
作者反思:拉长阶段若一次性上到 60 秒,梯度爆炸 100%;渐进式后 loss 曲线丝滑。
6 实验结果:稀疏 71% 但指标反杀全注意力
6.1 5 秒单镜头短视频
| 方法 | 主体一致↑ | 背景一致↑ | 运动平滑↑ | 稀疏度 |
|---|---|---|---|---|
| Wan 全注意力 | 0.9611 | 0.9560 | 0.9936 | 0 % |
| DiTFastAttn | 0.9456 | 0.9394 | 0.9924 | 50 % |
| MoGA | 0.9699 | 0.9542 | 0.9927 | 71 % |
稀疏度更高,反而更稳——因为无关 token 被提前屏蔽,噪声减少。
6.2 30 秒多镜头长视频
| 方法 | Cross-Shot CLIP↑ | 显存占用 | 训练速度 |
|---|---|---|---|
| IC-LoRA+Wan | 0.7169 | 2×A100 80 G | 0.7× 实速 |
| MoGA | 0.8654 | 1×A100 80 G | 1.7× 实速 |
6.3 计算量随 M 增大线性下降
30 秒视频
M= 5 → 2.26 PFLOPs (↓67%)
M=20 → 1.22 PFLOPs (↓82%)
7 消融:分组数 M 与平衡损失 α 怎么选?
| M | Cross-Shot DINO | 显存 | 作者手记 |
|---|---|---|---|
| 2 | 0.8589 | 高 | 稀疏不足,收益小 |
| 8 | 0.8606 | 中 | 最佳折中 |
| 16 | 0.8569 | 低 | 组过碎,长程断 |
平衡损失 α 经验:
-
α=0 时,80% token 涌进同一组,退化成全注意力; -
α=0.1,三步内收敛,分组熵≈M; -
α>0.3,训练初期梯度被拉偏,画质下降 2%。
8 一步落地:把 MoGA 插进你自己的 DiT 仓库
-
安装依赖
pip install flash-attn --no-build-isolation
-
替换 attention 文件
from moga import MoGALayer
# 原:nn.MultiheadAttention
# 新:
self.attn = MoGALayer(d_model=1152, n_head=24, M=5, alpha=0.1)
-
开启序列并行(可选)
export SEQUENCE_PARALLEL_SIZE=4
torchrun --nproc_per_node=4 train.py
-
验证长视频
prompts = ["A woman in a red hat enters the cafe."] * 8 # 8 镜头
video = model.generate(prompts, n_frames=1441, fps=24, seed=42)
9 常见失败模式与排查表
| 现象 | 可能原因 | 快速修复 |
|---|---|---|
| 首帧切镜闪烁 | STGA 没补相邻镜头 key | 确认 augment_kv=2 |
| 路由器全部塌陷到 1 组 | α 太小或 warmup 不足 | 提高 α=0.1→0.15,warmup 500→1000 step |
| 显存仍爆 | M 太小 / 序列并行未开 | M→10,打开 flash_attn + ulysses |
10 结论与实用清单
-
长视频生成的最大敌人是「稠密注意力」,不是参数量。 -
MoGA 用“可学习分组”替代“手工分块”,在 580 k token 上首次实现单卡端到端训练与推理。 -
稀疏度 71% 的情况下,主体一致性反涨 0.8%,训练提速 1.7×。 -
分组数 M 选 5-10,平衡损失 α 选 0.1-0.15,最稳。 -
代码改动仅 3 行,兼容 FlashAttention、序列并行,零额外内存开销。
一页速览(One-page Summary)
-
痛点:DiT 长视频注意力 O(N²) 显存爆炸 -
解法:MoGA → 线性路由器 → 语义分组 → 组内 FlashAttention -
结果:60 s / 1441 帧 / 578 k token / 24 fps / 480p / 单卡 A100 -
指标:稀疏 71%,Cross-Shot CLIP 0.865,反超全注意力 -
落地:pip 装 flash-attn → 替换 MoGALayer → 调 M & α → 开序列并行
FAQ
-
问:M 一定要取 5 或 20 吗?
答:任意整数≤32 均可,越大越稀疏,8 左右最均衡。 -
问:路由器会额外占多少显存?
答:权重仅 d×M 浮点,5 k-20 k 量级,可忽略。 -
问:训练能从零开始吗,还是必须微调?
答:基于 Wan2.1 权重微调 4 k step 即可;从零需更多数据。 -
问:镜头描述必须一句一个?
答:支持任意长度,实验表明 8-10 词最稳。 -
问:生成 90 秒行不行?
答:理论上 2 M token 也能装,需 M≥40 并开 cpu offload,未开放测试。 -
问:与 SVG2 的在线 k-means 区别?
答:SVG2 推理期做聚类,不可导且耗时;MoGA 用可导中心向量,端到端训练。
