Tiny-DeepSpeed:用 500 行代码读懂 DeepSpeed 的核心魔法

刚听说 DeepSpeed 能把 GPT-2 训练显存砍掉一半,却苦于源码像迷宫?
这篇笔记带你用 不到 500 行 的 Tiny-DeepSpeed,亲手复现 Zero-1 / 2 / 3 的显存优化效果,并在单张 2080Ti 上跑通 GPT-2 small。
所有代码、命令、性能数据都源于官方仓库,我们只拆掉了“噪音”,留下“精髓”。


目录

  1. 为什么你需要 Tiny-DeepSpeed?
  2. 5 分钟带你看懂显存对比表
  3. 安装:一行命令装好所有依赖
  4. 运行:从单卡到 2 卡 Zero-3 的完整步骤
  5. 核心原理拆解

    • 5.1 Meta Device:先画蓝图再盖房子
    • 5.2 Cache-Rank Map:把参数“切蛋糕”
    • 5.3 Compute-Communication Overlap:让通信不再拖后腿
  6. 常见疑问 FAQ
  7. 下一步:你可以怎样继续玩?

1. 为什么你需要 Tiny-DeepSpeed?

角色 痛点 Tiny-DeepSpeed 的解法
算法工程师 想写自己的 ZeRO 插件,却被 3 万行源码劝退 500 行核心代码,读完就能改
研究新生 导师要求“跑通 Zero-3”,实验室只有 2 张 2080Ti 官方数据:GPT-2 large 在 Zero-3 下只需 11 GB,2×2080Ti 即可
技术面试官 想考察候选人对分布式训练的理解 让候选人在 1 小时内解释 Tiny-DeepSpeed 的 3 个关键点

一句话:Tiny-DeepSpeed = DeepSpeed 的“教科书版本”,只保留能让训练跑得起来的最小集合。


2. 5 分钟带你看懂显存对比表

官方给出的实测数据(单位:GB)

模型规模 单卡 DDP 2 卡 ZeRO-1 2 卡 ZeRO-2 2 卡 ZeRO-3 2 卡
GPT-2 small 4.65 4.75 4.08 3.79 3.69
GPT-2 medium 10.12 10.23 8.65 8.25 7.73
GPT-2 large 17.35 17.46 14.08 12.89 11.01

关键结论

  • ZeRO-3 比单卡节省 37 % 显存(以 GPT-2 large 为例)。
  • DDP 几乎没省:因为每张卡仍要存完整参数、梯度、优化器状态。
  • 2 卡就能跑 11 GB 的 GPT-2 large,单卡 24 GB 消费级显卡也能训练 1.5 B 模型

3. 安装:一行命令装好所有依赖

系统要求

  • Python 3.11
  • PyTorch 2.3.1(CUDA 版)
  • Triton 2.3.1(用于高性能 kernel)

步骤

# 1. 克隆仓库
git clone https://github.com/liangyuwang/Tiny-DeepSpeed.git
cd Tiny-DeepSpeed

# 2. 创建虚拟环境(可选)
python -m venv venv
source venv/bin/activate

# 3. 安装依赖
pip install torch==2.3.1+cu118 triton==2.3.1 -f https://download.pytorch.org/whl/torch

4. 运行:从单卡到 2 卡 Zero-3 的完整步骤

4.1 单卡

python example/single_device/train.py

终端会打印 loss 曲线,显存占用 ≈ 4.65 GB(GPT-2 small)。

4.2 DDP(2 卡)

torchrun --nproc_per_node 2 --nnodes 1 example/ddp/train.py
  • 两卡各存一份完整参数,显存占用 4.75 GB/卡(与单卡几乎持平)。
  • 适合初学多卡通信。

4.3 ZeRO-1(2 卡)

torchrun --nproc_per_node 2 --nnodes 1 example/zero1/train.py
  • 优化器状态分片,每卡只存 1/2 优化器状态。
  • 显存降至 4.08 GB/卡

4.4 ZeRO-2(2 卡)

torchrun --nproc_per_node 2 --nnodes 1 example/zero2/train.py
  • 在 ZeRO-1 基础上再把 梯度分片
  • 显存进一步降到 3.79 GB/卡

4.5 ZeRO-3(2 卡)

torchrun --nproc_per_node 2 --nnodes 1 example/zero3/train.py
  • 连参数本身也分片:前向时实时通信取参数,用完即弃。
  • 显存最低 3.69 GB/卡几乎只剩激活值

5. 核心原理拆解

5.1 Meta Device:先画蓝图再盖房子

问题
初始化 1.5 B 参数的模型,先把 6 GB 权重读进内存,再搬到 GPU,很浪费。

Tiny-DeepSpeed 的做法

with torch.device("meta"):
    model = GPT2Model(config)   # 只建壳,不占显存
  • 仅记录 shapedtype,不真的 malloc
  • 真正需要权重时才 materialize节省初始化 90 % 显存

5.2 Cache-Rank Map:把参数“切蛋糕”

问题
ZeRO-3 需要知道“哪块参数归哪张卡”。

解决方案

  • 维护一张 param_id → rank_id 的哈希表。
  • 初始化顺序固定,因此 无需全局通信,直接本地查表即可。
  • 代码片段(简化)
def param_to_rank(param_id, world_size):
    return param_id % world_size

5.3 Compute-Communication Overlap:让通信不再拖后腿

朴素实现

  • 等所有参数到齐再算 → 通信瓶颈明显。

Tiny-DeepSpeed 的优化

  • 把网络分成 若干层,每层前向时:

    1. 异步 pre-fetch 下一层参数;
    2. 当前层计算
    3. 等待 pre-fetch 完成(几乎不阻塞)。
  • 实测在 2×A100 上 吞吐提升 23 %

6. 常见疑问 FAQ

Q1:我只有 1 张 3090,能跑 GPT-2 large 吗?
A:单卡 24 GB,ZeRO-3 仍需 11 GB,完全够。直接跑 example/zero3/train.py,把 --nproc_per_node 1 即可。

Q2:为什么我的 loss 曲线比官方高?
A:Tiny-DeepSpeed 为了可读性,去掉了 dropout、warmup,训练 1 B token 后差距会收敛。

Q3:如何扩展到多机?
A:仓库已预留 dist.init_process_groupinit_method="tcp://" 接口,TODO 列表中“Multi nodes”正在开发,敬请期待。

Q4:Windows 可以吗?
A:PyTorch 2.3.1 官方支持 Windows CUDA,但 Triton 在 Windows 需源码编译。建议 WSL2 + Ubuntu 22.04

Q5:为什么示例没有 AMP?
A:AMP 代码已写好,TODO 列表中标记为 [ ],预计下周合并到主分支。


7. 下一步:你可以怎样继续玩?

  1. 换模型:把 GPT2Model 换成 LlamaForCausalLM,只要保持 meta 初始化即可。
  2. 加插件:在 zero3/trainer.py 里插入梯度裁剪、动态 loss scaling。
  3. 贡献代码:Fork 仓库,完成 Multi nodesCommunication Bucket,提 PR。
  4. 在线体验:直接打开 Kaggle Notebook免费 T4 GPU 就能跑通。

致谢

代码与数据均来自 liangyu.wang@kaust.edu.sa 维护的 Tiny-DeepSpeed 仓库。若本文对你有帮助,记得去 GitHub 点个 ⭐,让更多人看到这份“小而美”的教程。