把 Transformer 做成“终身学习者”:TTT-E2E 如何让大模型边用边学?

关键词:长上下文、Test-Time Training、TTT-E2E、滑动窗口注意力、元学习、推理加速


1. 为什么又聊“长上下文”?

“模型读完一本 10 万字小说,却记不住第一章主角的名字”——这是目前主流大模型(包括 GPT 系列)在长上下文场景下的典型痛点。
原因不复杂:

  • 全注意力(Full Attention) 的算力随长度线性增长,128 k token 时延迟已经让人皱眉;
  • RNN / 线性注意力 虽然恒定速度,却越读越“迷糊”,长度超过 32 k 后效果反而掉点;
  • 滑动窗口 只能看“最近邻居”,再远就真·失忆。

有没有一种办法,既保持恒定推理速度,又让模型像人类一样“边读边总结”?
TTT-E2E 这篇 2025 年 12 月的论文给出了完整答案:把测试阶段变成一次额外的训练阶段,让模型自己压缩历史,而不是把历史全部塞进显存。


2. TTT-E2E 是什么?一句话先说明白

TTT-E2E(End-to-End Test-Time Training)=
“在推理时,用下一个 token 预测当训练信号,只更新少量 MLP 权重,把看过的上下文压缩进模型自身参数。”

特点 一句话解释
恒定推理延迟 类似 RNN,每 token 计算量与长度无关
不额外占显存 不把 key/value 全存下来,只保留“压缩后的知识”
端到端训练 训练时就考虑“推理时会再学一次”,用元学习初始化

3. 人类是怎么读书的?

把人类当参照系,更容易看懂 TTT 的设计动机:

人类读书 传统 Transformer TTT-E2E
逐页读,边读边忘掉细节,只保留主线 每页都记成 key/value,显存爆炸 每读一段,用梯度把主线写进 MLP
二刷时只记得“这本书讲成长”,不记得第 3 页第 2 句 二刷仍要重新扫全部 key/value 二刷直接继承已压缩的 MLP,越读越精炼

4. 原理解压:四张图就能懂

4.1 训练阶段(Outer Loop)

  1. 把每条序列假想成“测试序列”。
  2. 在序列上跑内循环:像推理时一样,每 b 个 token 用下一个 token 预测损失更新一次 MLP。
  3. 内循环结束后,得到“更新过很多次”的模型,计算最终损失。
  4. 用最终损失回传梯度到初始权重(梯度之梯度,即元学习)。

图1:Outer-Inner Loop 示意
图1:Outer-Inner Loop 示意,来源论文 Figure 2

4.2 推理阶段(Inner Loop 复用)

  1. 拿到新文章,直接用训练好的初始权重起步。
  2. 每 b=1 k token 做一次小批量梯度下降,只改最后 1/4 层的 MLP。
  3. 读完立即可以续写,key/value cache 只占 8 k 窗口,显存恒定。

4.3 微观结构

组件 是否冻结 原因
Embedding 层 冻结 避免 token 表示漂移
滑动窗口注意力 冻结 负责局部 8 k 内的精准关联
前 3/4 层 MLP 冻结 保留预训练通用知识
后 1/4 层 MLP 更新 充当“快速权重”,存储长程主线

4.4 伪代码(PyTorch 风格)

# 伪代码,仅展示核心逻辑
for chunk in text.chunks(batch_size=b):
    logits = model(chunk, window_size=8_000)
    loss   = cross_entropy(logits, chunk.targets)
    loss.backward()
    optimizer.step_only_last_quarter_mlp()
    optimizer.zero_grad()

5. 实验结果:数字说话

5.1 上下文扩展能力

  • 3 B 参数、164 B 训练 token,Books 数据集上结果:

    • 128 k 上下文,TTT-E2E 比 Full Attention 低 0.02 的测试损失,同时延迟仅 37 %
    • 32 k 后,SWA、Mamba2、Gated DeltaNet 损失开始反涨,TTT-E2E 仍单调下降。

图2:损失 vs 长度,左为绝对值,右为相对 Full Attention 的差值
图2:损失与长度曲线,来源论文 Figure 1 & 9

5.2 训练算力缩放

| 模型规模 | 125 M → 3 B |
| 训练 token | 16 B → 80 B |
结论:当模型 > 760 Mtoken > 48 B 后,TTT-E2E 与 Full Attention 的缩放斜率几乎重合,意味着大预算训练下不会掉队。

5.3 推理延迟实测(H100,128 k 输入)

方法 预填充延迟 (s) 速度倍数
Full Attention 0.94 1.0 ×
TTT-E2E 0.35 2.7 ×
Mamba2 0.33 2.8 ×
SWA 0.32 2.9 ×

TTT-E2E 把延迟压到与 RNN 同级,却不掉精度


6. 如何复现?官方代码与关键超参

GitHub:https://github.com/test-time-training/e2e

6.1 最小可跑配置(760 M 模型)

超参 备注
窗口大小 k 8 192 必须 ≥ 内循环 batch
内循环 batch b 1 024 再小训练不稳
更新层数 最后 6 层(共 24 层) 比例≈1/4
学习率 η 5 e-3 内循环 SGD,无动量
外循环优化器 AdamW 与常规预训练一致

6.2 两阶段训练流程

  1. 预训练

    • 数据:DCLM-Baseline,只保留 ≥ 8 k 文档
    • 长度:8 k token
    • 目标:让模型学会“元学习初始化”
  2. 长上下文微调

    • 数据:Books
    • 长度:32 k / 64 k / 128 k 各跑一份 checkpoint
    • 技巧:RoPE θ 按 1 M→10 M 对数线性放大

7. 失败场景与局限

场景 表现 原因
Needle-in-Haystack(128 k 找 UUID) 准确率 3 % 压缩机制把“看似无关”的 needle 当成噪音
训练阶段短文本 延迟劣势 3.4 × 梯度之梯度开销大,短序列摊销不足
多轮对话角色扮演 未评估 需要持续外循环更新,当前仅单序列

作者直言:如果任务极度依赖逐字召回,全注意力仍是不可替代;TTT-E2E 更适合**“读长文、抓主旨”**的场景。


8. 常见疑问 FAQ

Q1:推理时做反向传播,显存不会爆吗?
A:只更新最后 1/4 层 MLP,且用梯度检查点,激活值不存历史,显存≈同规模 SWA。

Q2:和动态评估(Dynamic Evaluation)有何不同?
A:动态评估用训练阶段相同的目标微调;TTT-E2E 用元学习专门准备“好初始化”,内外循环目标一致,更稳定。

Q3:能否直接加载 HuggingFace 上的 Llama3?
A:目前代码基于 JAX,且需要重新跑两阶段训练。作者透露未来会支持 PyTorch 并放出热启动脚本

Q4:窗口大小 k 与内循环 batch b 必须 8 k/1 k 吗?
A:不是。图 4 显示 k ≥ b 即可,k=2 k~32 k 都能 work,只是 8 k 在速度与精度间折中最好。


9. 快速上手:Colab 迷你 Demo(CPU 可跑)

  1. 打开官方提供的 [notebooks/ttt_toy.ipynb]
  2. 安装依赖

    pip install -q jax transformers datasets
    
  3. 下载 125 M 预训练权重

    from huggingface_hub import hf_hub_download
    ckpt = hf_hub_download(repo_id="ttt-e2e/125m", filename="ckpt.msgpack")
    
  4. 跑内循环

    from ttt import TTTModel
    model = TTTModel.from_checkpoint(ckpt, window=1024, inner_batch=128)
    prompt = "Alice was beginning to get very tired"
    out = model.generate(prompt, max_len=200, temp=0.7)
    print(out)
    
  5. 观察 loss 曲线是否随 token 下降——下降即代表 TTT 正在“学进去”

10 未来展望:作者列的 To-Do List

  • FlashAttention 适配:cuDNN 目前不支持梯度之梯度,自定义 kernel 可把训练延迟再降 2×。
  • 热启动微调:先正常预训练 Llama3,再外挂 TTT 微调 5 % token,可把成本压到“额外 10 %” 以内。
  • 自生成 TTT:让模型自己总结前文,把总结当标签做自监督,有望缓解 needle 问题。
  • 多模态:视频、音频序列同样符合“长上下文+可压缩”假设,已有并行工作 One-Minute Video Generation with TTT 在 CVPR 2025 接收。

11 结语:把“读书”还给模型,也还给我们

TTT-E2E 的核心启示是——大模型不必记住每一页,只需像人类一样“边读边忘、保留精髓”
它用最朴素的下一个 token 预测,把“终身学习”能力塞进标准 Transformer,不新增复杂算子,不依赖定制芯片
如果你正在做:

  • 长文档问答
  • 代码库级理解
  • 低成本长文本推理服务

TTT-E2E 值得跑一次 baseline。
把长上下文的显存账单砍到恒定,把效果曲线继续拉长——这一次,模型真的可以“读完”一整本书,并且告诉你它在哪一页学会了成长。