站点图标 高效码农

把 1 分钟 480p 视频塞进 58 万 token:MoGA 如何用「分组注意力」让长视频生成不再爆显存

核心问题:当 Diffusion Transformer 做长视频时,注意力平方级膨胀怎么破?
一句话答案:MoGA 用“可学习的 token 路由器”把相似语义自动分进同一组,组内做全注意力,组外零计算,实现 71% 稀疏度却保持画质,端到端直出 60 秒多镜头 24 fps 视频。


本文欲回答的核心问题

  1. 为什么长视频生成注定会“爆显存”?
  2. MoGA 的分组注意力到底改变了什么?
  3. 如何在现有 DiT 代码里无痛换上 MoGA?
  4. 真实指标、速度、显存对比是多少?
  5. 作者踩坑:分组数 M 越大越好吗?平衡损失怎么调?

1 背景:长视频生成的三重诅咒

诅咒 现象 举例(480p@24fps)
序列长度 1 分钟 ≈ 58 万 token 单卡 A100 80 G 装不下一次全注意力
多镜头一致 切镜后人物/背景漂移 观众一眼出戏
端到端训练 多阶段流水线误差累积 关键帧→插帧→超分,三步三步地崩

传统稀疏注意力要么“手工划窗”丢失长程,要么“先分块再挑块”带来块大小玄学。MoGA 干脆让网络自己学「谁该跟谁玩」。


2 MoGA 核心思想:把 softmax 的稀疏性变成可学习路由

2.1 一句话原理

用单个线性层当“路由器”,输入 token 语义,输出 M 维 logits → softmax → 选 top-1 组 → 组内做标准 FlashAttention。没有 block,没有 k-means,没有手工阈值。

2.2 形状变化可视化

Full Attention:     [N×N] 稠密矩阵
Block Sparse:       [N/B × N/B] 先粗后细
MoGA:               [G1,G2,…,GM] 每组 ni≈N/M,组内 ni²

理论计算量从 O(N²) 降到 Σ O(ni²)。均匀分组时下界 O(N²/M)。

2.3 路由器代码(PyTorch 风格伪码)

class Router(nn.Module):
    def __init__(self, d_model, M):
        super().__init__()
        self.weight = nn.Linear(d_model, M, bias=False)  # 就是一组中心向量
    
    def forward(self, x):           # x: [B, N, d]
        logits = self.weight(x)     # [B, N, M]
        scores = logits.softmax(-1)
        group_id = scores.argmax(-1)  # [B, N]
        return group_id, scores

路由器参数量 = d_model × M,当 d=1152、M=5 时仅 5.7 k,可忽略不计。


3 系统架构:DiT + MoGA + 时空局部窗 = 长视频生成模型

3.1 宏观堆叠

  • 视觉分支:交替放置「MoGA 全局组注意力」+「STGA 局部时空窗」
  • 文本分支:每镜头一句短 prompt,用 Cross-Modal Attention 注入
  • VAE 下采样 (4,8,8),patchify (1,2,2), latent 空间 1 分钟 ≈ 578 k token

3.2 局部窗怎么对齐镜头边界

STGA 把不同镜头拆成独立时间组,但计算时把相邻镜头各补 2 帧 key/value,解决切镜瞬间闪烁——零新增查询,只增 key/value,计算开销 <1%。


4 数据管道:把“长视频”拆成“多镜头+密集字幕”

阶段 自动工具 过滤阈值举例
视频级 VQA+OCR+黑边检测 美学<4.5、清晰度<30 丢弃
镜头级 AutoShot + PySceneDetect 淡入淡出标记,重叠帧剪掉
字幕级 Qwen2.5-VL 多模态大模型 每 2-4 秒一句,物体+动作+场景

最终训练样本:≤65 秒,8-10 镜头,对应 8-10 条短 prompt,token 长度 58 万级。


5 训练细节:从 10 秒→30 秒→60 秒渐进式微调

  1. 预热:10 秒片段,3 k step,lr=1e-5,α=0.1
  2. 拉长:30 秒片段,1 k step,序列并行+FlashAttention2
  3. 极限:60 秒,MMDiT 骨干,M=20,显存仍 <80 G
    作者反思:拉长阶段若一次性上到 60 秒,梯度爆炸 100%;渐进式后 loss 曲线丝滑。

6 实验结果:稀疏 71% 但指标反杀全注意力

6.1 5 秒单镜头短视频

方法 主体一致↑ 背景一致↑ 运动平滑↑ 稀疏度
Wan 全注意力 0.9611 0.9560 0.9936 0 %
DiTFastAttn 0.9456 0.9394 0.9924 50 %
MoGA 0.9699 0.9542 0.9927 71 %

稀疏度更高,反而更稳——因为无关 token 被提前屏蔽,噪声减少。

6.2 30 秒多镜头长视频

方法 Cross-Shot CLIP↑ 显存占用 训练速度
IC-LoRA+Wan 0.7169 2×A100 80 G 0.7× 实速
MoGA 0.8654 1×A100 80 G 1.7× 实速

6.3 计算量随 M 增大线性下降

30 秒视频
M= 5  → 2.26 PFLOPs (↓67%)
M=20  → 1.22 PFLOPs (↓82%)

7 消融:分组数 M 与平衡损失 α 怎么选?

M Cross-Shot DINO 显存 作者手记
2 0.8589 稀疏不足,收益小
8 0.8606 最佳折中
16 0.8569 组过碎,长程断

平衡损失 α 经验

  • α=0 时,80% token 涌进同一组,退化成全注意力;
  • α=0.1,三步内收敛,分组熵≈M;
  • α>0.3,训练初期梯度被拉偏,画质下降 2%。

8 一步落地:把 MoGA 插进你自己的 DiT 仓库

  1. 安装依赖
pip install flash-attn --no-build-isolation
  1. 替换 attention 文件
from moga import MoGALayer
# 原:nn.MultiheadAttention
# 新:
self.attn = MoGALayer(d_model=1152, n_head=24, M=5, alpha=0.1)
  1. 开启序列并行(可选)
export SEQUENCE_PARALLEL_SIZE=4
torchrun --nproc_per_node=4 train.py
  1. 验证长视频
prompts = ["A woman in a red hat enters the cafe."] * 8  # 8 镜头
video = model.generate(prompts, n_frames=1441, fps=24, seed=42)

9 常见失败模式与排查表

现象 可能原因 快速修复
首帧切镜闪烁 STGA 没补相邻镜头 key 确认 augment_kv=2
路由器全部塌陷到 1 组 α 太小或 warmup 不足 提高 α=0.1→0.15,warmup 500→1000 step
显存仍爆 M 太小 / 序列并行未开 M→10,打开 flash_attn + ulysses

10 结论与实用清单

  1. 长视频生成的最大敌人是「稠密注意力」,不是参数量。
  2. MoGA 用“可学习分组”替代“手工分块”,在 580 k token 上首次实现单卡端到端训练与推理。
  3. 稀疏度 71% 的情况下,主体一致性反涨 0.8%,训练提速 1.7×。
  4. 分组数 M 选 5-10,平衡损失 α 选 0.1-0.15,最稳。
  5. 代码改动仅 3 行,兼容 FlashAttention、序列并行,零额外内存开销。

一页速览(One-page Summary)

  • 痛点:DiT 长视频注意力 O(N²) 显存爆炸
  • 解法:MoGA → 线性路由器 → 语义分组 → 组内 FlashAttention
  • 结果:60 s / 1441 帧 / 578 k token / 24 fps / 480p / 单卡 A100
  • 指标:稀疏 71%,Cross-Shot CLIP 0.865,反超全注意力
  • 落地:pip 装 flash-attn → 替换 MoGALayer → 调 M & α → 开序列并行

FAQ

  1. :M 一定要取 5 或 20 吗?
    :任意整数≤32 均可,越大越稀疏,8 左右最均衡。

  2. :路由器会额外占多少显存?
    :权重仅 d×M 浮点,5 k-20 k 量级,可忽略。

  3. :训练能从零开始吗,还是必须微调?
    :基于 Wan2.1 权重微调 4 k step 即可;从零需更多数据。

  4. :镜头描述必须一句一个?
    :支持任意长度,实验表明 8-10 词最稳。

  5. :生成 90 秒行不行?
    :理论上 2 M token 也能装,需 M≥40 并开 cpu offload,未开放测试。

  6. :与 SVG2 的在线 k-means 区别?
    :SVG2 推理期做聚类,不可导且耗时;MoGA 用可导中心向量,端到端训练。

退出移动版