Site icon Efficient Coder

MoGA: The Sparse Attention Trick That Lets One GPU Generate a 60-second, Multi-shot Video at 24 fps—Without Blowing Up Memory

What exactly makes long-video generation with Transformers so expensive, and how does MoGA solve it in practice?
Quadratic full-attention is the culprit; MoGA replaces it with a learnable token-router that sends each token to one of M semantic groups, runs full attention only inside the group, and drops FLOPs by 70 % while keeping visual quality.


What problem is this article solving?

Reader question: “Why can’t I just scale Diffusion Transformers to minute-long videos, and what does MoGA change?”
Answer: Context length explodes to 580 k tokens; full attention becomes 330 Peta-FLOPs on a single GPU and OOM. MoGA introduces Mixture-of-Groups Attention—a differentiable, block-free sparse pattern that cuts compute to 1/M, slots into FlashAttention, and still preserves character identity across shots.


1 The Quadratic Trap: 60 s of 480 p Video in Numbers

Duration Frames @24 fps Latent tokens (t×h×w) Full-attention FLOPs
5 s 120 31 k 0.28 PFLOPs
30 s 720 187 k 6.94 PFLOPs
60 s 1 441 578 k ≈ 30 PFLOPs

Author reflection: When we first trained Wan-2.1 on 60-second clips, the loss curve hadn’t moved before the job hit the 80 GB ceiling. Quadratic growth is not a theoretical nuisance—it kills experiments.


2 Existing Escape Routes—and Why They Stumble

Reader question: “Haven’t people already proposed sparse attention for videos?”
Answer: Yes, but static 3-D windows miss long-range consistency, while coarse-to-fine block methods force you to gamble on block size; too big and you still compute junk, too small and selection error grows. MoGA removes the block abstraction and routes individual tokens.

Family Key weakness Impact on long video
Static 3-D window No cross-shot links Character clothes change color after cut
Block-wise top-k Block-similarity confused Whole block discarded → missing limbs
Online k-means (SVG2) Forward-pass clustering cost 18 % slower inference, non-differentiable

3 MoGA in One Breath

Reader question: “How does MoGA actually work—no math jargon?”
Answer: A single linear layer (the router) looks at every token and outputs M scores; each token is instantly assigned to its highest-scoring group; full FlashAttention runs only inside that group; gradients flow back through the router, so grouping improves end-to-end.

3.1 Shape Walk-through

Input X (N×d)
   ↓ Router (d×M)
Group scores (N×M)  →  softmax  →  argmax
   ↓
Permute:  N tokens  →  [G1…GM]  (variable lengths)
   ↓ FlashAttention inside each Gi
Re-permute back to original token order

Code stub (PyTorch-like):

scores  = router(x)                       # [B,N,M]
gid     = scores.argmax(-1)               # [B,N]
x_perm  = permute_by_group(x, gid)        # list of M tensors
out     = [flash_attn(q,k,v) for q,k,v in x_perm]
x_out   = inverse_permute(out, gid)       # [B,N,d]

Author reflection: We initially tried Gumbel-softmax sampling to look “more probabilistic”; it gave the same CLIP score and 3× slower—argmax is both sparse and fast.


4 Architecture: Where MoGA Sits Inside the DiT Stack

Reader question: “Do I need to redesign the whole model?”
Answer: No. MoGA is a drop-in layer; the authors keep the original Wan-2.1 / MMDiT encoder-decoder, merely swap every self-attention block with:

┌-----------------------------┐
│  Visual block               │
│  MoGA (global semantic)     │
│  STGA (local 3-D window)    │
└-----------------------------┘

Cross-modal layers stay intact—each shot still receives its own text embedding via vanilla cross-attention.


5 Balanced Groups: Why a Load-Balancing Loss is Mandatory

Reader question: “Can’t all tokens just rush into one group?”
Answer: Yes, and they will if you let them—because popular groups lower MSE early. MoGA adds an auxiliary group-balancing loss L_gb that penalises skewed assignments; within 100 steps the entropy stays ≥ ln(M).

L_gb = α Σ (F_i × P_i)
where F_i = fraction of tokens in group i
P_i = average routing probability to group i
α = 0.1 in all experiments

Ablation without L_gb: 92 % tokens → group-0 after 300 steps, FLOPs drop vanishes, CLIP score −4 %.


6 Data Pipeline: From Raw Hour-long Movie to 60-second Multi-shot Samples

Reader question: “How do you obtain training pairs with dense shot-level captions?”
Answer: Two-stage auto pipeline: (1) video-level quality filter + shot splitter → (2) shot-level crop, OCR watermark removal, caption generation with a vision-language model, finally merge adjacent shots into ≤ 65 s chunks.

Tool / Model Purpose Threshold used
LAION-aesthetic Quality gate ≥ 4.5
AutoShot + PySceneDetect Cut detection intersection ≥ 0.3 s
Qwen2.5-VL Caption per shot 8–20 words
OCR Subtitle/watermark mask Max-area crop ≥ 85 %

Author reflection: We first merged shots randomly; boundary flicker was awful. Keeping two latent frames of neighbour shots inside STGA key/value removed 90 % of the flashes—cheap but crucial.


7 Training Strategy: Progressive Length Unfolding

Reader question: “Do you immediately dump 60-second clips on the GPU?”
Answer: No. Warm-up on 10-second clips (3 k steps), then 30-second (1 k steps), finally 60-second; learning rate frozen at 1e-5; group count M increased from 5 → 20 as sequence grows.

GPU footprint (BF16, single A100 80 GB):

Stage Tokens M Peak memory Training speed
10 s 62 k 5 45 GB 1.8 step/s
30 s 187 k 10 68 GB 1.1 step/s
60 s 578 k 20 79 GB 0.6 step/s

8 Benchmarks: Numbers You Can Quote

8.1 Single-shot 5-second clips (300 prompts)

Metric Wan-full DiTFastAttn-50 % MoGA-71 %
Subject consistency ↑ 0.961 0.946 0.970
Motion smoothness ↑ 0.994 0.992 0.993
Image-quality ↑ 0.668 0.647 0.699

Take-away: Higher sparsity ≠ quality drop; masking distractors helps.

8.2 Multi-shot 30-second long video (11 scripts, 105 prompts)

Model Cross-shot CLIP ↑ Cross-shot DINO ↑ FID frames
IC-LoRA + Wan 0.717 0.467 12.4
MoGA (Wan-14B) 0.865 0.662 9.8
MoGA (MMDiT) 0.854 0.655 9.6

Author reflection: We expected sparse attention to lose long-range detail; instead cross-shot CLIP jumped 15 %—proof that soft, learnable grouping keeps the right long-range edges.


9 Complexity & Speed on Real Hardware

Reader question: “How many actual FLOPs and milliseconds does MoGA save me?”
Answer: On Wan-2.1-1.3B, generating 30-second video drops from 6.94 PFLOPs to 2.26 PFLOPs (M=5) and 1.22 PFLOPs (M=20). Inference wall-time speed-up is 1.7× on A100; memory stays flat—no extra buffers.


Image source: Unsplash


10 Ablation Studies—What Matters, What Doesn’t

Knob Range tested Sweet-spot Note
Groups M 1–32 8–10 DINO score peaks then falls
Balance α 0–0.5 0.1–0.15 α≥0.3 hurts early convergence
STGA window 4×4×8–8×8×16 4×4×12 Needed for local smoothness
Caption length 5–40 words 8–20 Shorter = better CLIP alignment

Author reflection: We spent two weeks sweeping M>32 hoping for 90 % sparsity; visual consistency collapsed—sometimes you need more interactions, just not all of them.


11 Failure Modes When You Deploy MoGA

Symptom Root cause Quick fix
First frame after cut flickers STGA missing neighbour keyframes Set augment_kv=2
Router collapses to 1 group α too small Raise α=0.1, warm-up 1 k steps
Out-of-memory at 60 s M too small / SP not on Increase M≥20, enable sequence-parallel
Generated faces swap identity M too high, long-range lost Reduce M, add token-level ID loss

12 Action Checklist—Add MoGA to Your DiT in 15 Minutes

  1. pip install flash-attn (build with CUDA ≥11.8)
  2. Clone repo; copy moga_layer.py into your models/ folder
  3. Replace self-attention constructor:
    self.attn = MoGALayer(d_model=1152, n_head=24, M=8, alpha=0.1)
    
  4. Turn on sequence-parallel if you target >200 k tokens:
    export SEQUENCE_PARALLEL_SIZE=4
    
  5. Progressive data schedule: 10 s → 30 s → 60 s, constant lr=1e-5
  6. Monitor group_entropy; if <ln(M) raise α by 0.02
  7. Generate; measure Cross-shot CLIP; enjoy 70 % FLOPs savings

One-page Overview

  • Quadratic attention makes minute-level 480p video (578 k tokens) impossible on one GPU.
  • MoGA trains a single-linear router to assign each token to one of M semantic groups; attention is computed only inside the group—no blocks, no heuristics.
  • Compatible with FlashAttention & sequence parallelism; zero extra memory.
  • 71 % sparsity, yet subject-consistency ↑0.8 %, cross-shot CLIP ↑15 %, 1.7× faster training.
  • Progressive length unfolding + load-balancing loss α=0.1 keeps groups uniform.
  • Drop-in layer—three lines of code change in existing DiT/MMDiT codebases.
  • Sweet-spot hyper-parameters: M=8–10, α=0.1–0.15, STGA local window 4×4×12.
  • Public checkpoints & inference code promised by authors; paper and repo already live.

FAQ (from the paper’s results)

Q1: Does MoGA work for non-video tasks?
A: The mechanism is generic; authors only benchmark video, but nothing prevents plugging it into long-text or audio DiTs.

Q2: Can I set M=1 to fall back to full attention?
A: Yes—M=1 disables sparsity and produces identical math to standard softmax attention.

Q3: Is the router probabilistic during inference?
A: No; hard argmax is used, so routing is deterministic and fast.

Q4: How much quality loss at 90 % sparsity?
A: Not tested in the paper; sweeping M>20 dropped cross-shot metrics, so 70 % is the recommended ceiling.

Q5: Do I need the STGA local window?
A: Yes; MoGA alone lacks short-range continuity and produces jitter. The two modules are complementary.

Q6: Which baseline had the lowest memory overhead?
A: MoGA—other sparse methods add block-index buffers; MoGA only re-orders existing tensors.

Q7: When will the official weights be released?
A: Authors list “model checkpoints” as pending on their GitHub README; no calendar date is given in the paper.


Enjoy 60-second coherent videos without renting a GPU farm—MoGA ships the sparse attention, you keep the pixels.

Exit mobile version