Tiny-DeepSpeed:用 500 行代码读懂 DeepSpeed 的核心魔法
刚听说 DeepSpeed 能把 GPT-2 训练显存砍掉一半,却苦于源码像迷宫?
这篇笔记带你用 不到 500 行 的 Tiny-DeepSpeed,亲手复现 Zero-1 / 2 / 3 的显存优化效果,并在单张 2080Ti 上跑通 GPT-2 small。
所有代码、命令、性能数据都源于官方仓库,我们只拆掉了“噪音”,留下“精髓”。
目录
-
为什么你需要 Tiny-DeepSpeed? -
5 分钟带你看懂显存对比表 -
安装:一行命令装好所有依赖 -
运行:从单卡到 2 卡 Zero-3 的完整步骤 -
核心原理拆解 -
5.1 Meta Device:先画蓝图再盖房子 -
5.2 Cache-Rank Map:把参数“切蛋糕” -
5.3 Compute-Communication Overlap:让通信不再拖后腿
-
-
常见疑问 FAQ -
下一步:你可以怎样继续玩?
1. 为什么你需要 Tiny-DeepSpeed?
角色 | 痛点 | Tiny-DeepSpeed 的解法 |
---|---|---|
算法工程师 | 想写自己的 ZeRO 插件,却被 3 万行源码劝退 | 500 行核心代码,读完就能改 |
研究新生 | 导师要求“跑通 Zero-3”,实验室只有 2 张 2080Ti | 官方数据:GPT-2 large 在 Zero-3 下只需 11 GB,2×2080Ti 即可 |
技术面试官 | 想考察候选人对分布式训练的理解 | 让候选人在 1 小时内解释 Tiny-DeepSpeed 的 3 个关键点 |
一句话:Tiny-DeepSpeed = DeepSpeed 的“教科书版本”,只保留能让训练跑得起来的最小集合。
2. 5 分钟带你看懂显存对比表
官方给出的实测数据(单位:GB)
模型规模 | 单卡 | DDP 2 卡 | ZeRO-1 2 卡 | ZeRO-2 2 卡 | ZeRO-3 2 卡 |
---|---|---|---|---|---|
GPT-2 small | 4.65 | 4.75 | 4.08 | 3.79 | 3.69 |
GPT-2 medium | 10.12 | 10.23 | 8.65 | 8.25 | 7.73 |
GPT-2 large | 17.35 | 17.46 | 14.08 | 12.89 | 11.01 |
关键结论
-
ZeRO-3 比单卡节省 37 % 显存(以 GPT-2 large 为例)。 -
DDP 几乎没省:因为每张卡仍要存完整参数、梯度、优化器状态。 -
2 卡就能跑 11 GB 的 GPT-2 large,单卡 24 GB 消费级显卡也能训练 1.5 B 模型。
3. 安装:一行命令装好所有依赖
系统要求
-
Python 3.11 -
PyTorch 2.3.1(CUDA 版) -
Triton 2.3.1(用于高性能 kernel)
步骤
# 1. 克隆仓库
git clone https://github.com/liangyuwang/Tiny-DeepSpeed.git
cd Tiny-DeepSpeed
# 2. 创建虚拟环境(可选)
python -m venv venv
source venv/bin/activate
# 3. 安装依赖
pip install torch==2.3.1+cu118 triton==2.3.1 -f https://download.pytorch.org/whl/torch
4. 运行:从单卡到 2 卡 Zero-3 的完整步骤
4.1 单卡
python example/single_device/train.py
终端会打印 loss 曲线,显存占用 ≈ 4.65 GB(GPT-2 small)。
4.2 DDP(2 卡)
torchrun --nproc_per_node 2 --nnodes 1 example/ddp/train.py
-
两卡各存一份完整参数,显存占用 4.75 GB/卡(与单卡几乎持平)。 -
适合初学多卡通信。
4.3 ZeRO-1(2 卡)
torchrun --nproc_per_node 2 --nnodes 1 example/zero1/train.py
-
优化器状态分片,每卡只存 1/2 优化器状态。 -
显存降至 4.08 GB/卡。
4.4 ZeRO-2(2 卡)
torchrun --nproc_per_node 2 --nnodes 1 example/zero2/train.py
-
在 ZeRO-1 基础上再把 梯度分片。 -
显存进一步降到 3.79 GB/卡。
4.5 ZeRO-3(2 卡)
torchrun --nproc_per_node 2 --nnodes 1 example/zero3/train.py
-
连参数本身也分片:前向时实时通信取参数,用完即弃。 -
显存最低 3.69 GB/卡,几乎只剩激活值。
5. 核心原理拆解
5.1 Meta Device:先画蓝图再盖房子
问题
初始化 1.5 B 参数的模型,先把 6 GB 权重读进内存,再搬到 GPU,很浪费。
Tiny-DeepSpeed 的做法
with torch.device("meta"):
model = GPT2Model(config) # 只建壳,不占显存
-
仅记录 shape
、dtype
,不真的malloc
。 -
真正需要权重时才 materialize
,节省初始化 90 % 显存。
5.2 Cache-Rank Map:把参数“切蛋糕”
问题
ZeRO-3 需要知道“哪块参数归哪张卡”。
解决方案
-
维护一张 param_id → rank_id
的哈希表。 -
初始化顺序固定,因此 无需全局通信,直接本地查表即可。 -
代码片段(简化)
def param_to_rank(param_id, world_size):
return param_id % world_size
5.3 Compute-Communication Overlap:让通信不再拖后腿
朴素实现
-
等所有参数到齐再算 → 通信瓶颈明显。
Tiny-DeepSpeed 的优化
-
把网络分成 若干层,每层前向时: -
异步 pre-fetch 下一层参数; -
当前层计算; -
等待 pre-fetch 完成(几乎不阻塞)。
-
-
实测在 2×A100 上 吞吐提升 23 %。
6. 常见疑问 FAQ
Q1:我只有 1 张 3090,能跑 GPT-2 large 吗?
A:单卡 24 GB,ZeRO-3 仍需 11 GB,完全够。直接跑 example/zero3/train.py
,把 --nproc_per_node 1
即可。
Q2:为什么我的 loss 曲线比官方高?
A:Tiny-DeepSpeed 为了可读性,去掉了 dropout、warmup,训练 1 B token 后差距会收敛。
Q3:如何扩展到多机?
A:仓库已预留 dist.init_process_group
的 init_method="tcp://"
接口,TODO 列表中“Multi nodes”正在开发,敬请期待。
Q4:Windows 可以吗?
A:PyTorch 2.3.1 官方支持 Windows CUDA,但 Triton 在 Windows 需源码编译。建议 WSL2 + Ubuntu 22.04。
Q5:为什么示例没有 AMP?
A:AMP 代码已写好,TODO 列表中标记为 [ ]
,预计下周合并到主分支。
7. 下一步:你可以怎样继续玩?
-
换模型:把 GPT2Model
换成LlamaForCausalLM
,只要保持meta
初始化即可。 -
加插件:在 zero3/trainer.py
里插入梯度裁剪、动态 loss scaling。 -
贡献代码:Fork 仓库,完成 Multi nodes 或 Communication Bucket,提 PR。 -
在线体验:直接打开 Kaggle Notebook,免费 T4 GPU 就能跑通。
致谢
代码与数据均来自 liangyu.wang@kaust.edu.sa 维护的 Tiny-DeepSpeed 仓库。若本文对你有帮助,记得去 GitHub 点个 ⭐,让更多人看到这份“小而美”的教程。