MixGRPO:用“混合采样+滑动窗口”让 AI 绘图模型训练快 71%
一句话总结
在 FLUX.1-dev 之上,MixGRPO 用“ODE+SDE 混合采样”只优化最关键的 4 步,训练时间比 DanceGRPO 再省一半,图像质量还能再高一点。文章末尾有完整上手步骤,可在 32 张 A100 上复现。
为什么训练“人味”图模型这么慢?
过去半年,研究者发现:
在扩散/流匹配模型里加 RLHF(人类反馈强化学习),能显著提升文本-图像一致性和美学分数,但代价是训练时间爆炸。
典型方法 | 每步采样 | 需优化步数 | 训练瓶颈 |
---|---|---|---|
Flow-GRPO | SDE 全程随机 | 25 步全算 | NFE 25×2=50 |
DanceGRPO | SDE 随机子集 | 14 步 | NFE 25+14=39 |
MixGRPO(本文) | ODE+SDE 混合 | 仅 4 步 | NFE 25+4=29 |
NFE(Number of Function Evaluations)≈ GPU 时间。数字越小越快。
MixGRPO 的“三步提速法”
1. 混合采样:只在“窗口”内用随机 SDE
-
ODE 段(前 & 后):确定性采样,省算力。 -
SDE 段(中间 4 步):保留随机性,让 RL 继续探索。
这样做带来的好处:
-
优化长度从 25 步缩到 4 步,GPU 内存和显式时间都下降。 -
由于 ODE 段不参与优化,可以用更高阶的 ODE solver(DPM-Solver++)再提速。
2. 滑动窗口:让优化从“高噪声”滑到“低噪声”
把 4 步窗口像卷积核一样从 t=0 移到 t=T,直觉上:
-
早期噪声大,搜索空间大,奖励折扣高,先优化。 -
后期细节微调,留给 ODE 快速推进。
实验结果:
固定窗口(frozen)也能工作,但“指数衰减位移”策略在 PickScore 上再涨 0.002(表 3)。
3. Flash 版本:用二阶 ODE 再砍 71% 时间
方法 | NFEπθold | 单图训练时间 | ImageReward↑ |
---|---|---|---|
DanceGRPO | 25 | 9.30 s | 1.335 |
MixGRPO | 25 | 7.34 s | 1.564 |
MixGRPO-Flash* | 12 | 4.86 s | 1.588 |
(Flash 使用冻结窗口,l=0,ODE 后段直接二阶采样。)
实际效果:看得见的提升
客观指标
指标 | FLUX | DanceGRPO | MixGRPO | MixGRPO-Flash |
---|---|---|---|---|
HPS-v2.1 | 0.313 | 0.356 | 0.367 | 0.358 |
PickScore | 0.227 | 0.233 | 0.237 | 0.236 |
ImageReward | 1.088 | 1.436 | 1.629 | 1.528 |
UnifiedReward | 3.370 | 3.397 | 3.418 | 3.407 |
所有分数都是“越高越好”,MixGRPO 在 4 个公开奖励模型上均排名第一。
主观对比
图 3 与图 6 展示了同一 prompt 下的输出差异。
可以看出:
-
FLUX:构图正确,细节略显平淡。 -
DanceGRPO:颜色与纹理有提升,但偶尔过饱和。 -
MixGRPO:主体清晰、光影自然、文字一致性更好。
上手实践:30 分钟跑通 32 卡训练
环境安装(CentOS 为例)
# 1. 新建环境
conda create -n MixGRPO python=3.12
conda activate MixGRPO
# 2. 系统依赖
sudo yum install -y pdsh pssh mesa-libGL
bash env_setup.sh # 自动装好 torch、diffusers、transformers
模型下载
模型/权重 | 路径 | 一键命令 |
---|---|---|
FLUX.1-dev | ./data/flux |
huggingface-cli download black-forest-labs/FLUX.1-dev --local-dir ./data/flux |
HPS-v2.1 | ./hps_ckpt |
见 README.md 2.2 节 |
ImageReward | ./image_reward_ckpt |
同上 |
MixGRPO 已训练权重 | ./mix_grpo_ckpt |
huggingface-cli download tulvgengenr/MixGRPO diffusion_pytorch_model.safetensors --local-dir ./mix_grpo_ckpt |
数据预处理
# 把 HPDv2 的 prompt 转成 embedding
bash scripts/preprocess/preprocess_flux_rl_embeddings.sh
多机训练(4 节点 × 8 GPU)
-
编辑 data/hosts/hostfile
写入 4 台机器 IP。 -
每台机器执行 bash scripts/preprocess/set_env_multinode.sh
自动设置
INDEX_CUSTOM=0/1/2/3
。 -
把 WandB key 写入 scripts/finetune/finetune_flux_grpo_FastGRPO.sh
。 -
启动训练 bash scripts/finetune/finetune_flux_grpo_FastGRPO.sh
单卡推理 & 评分
# 1. 生成
bash scripts/inference/inference_flux.sh
# 2. 评分
bash scripts/evaluate/eval_reward.sh
FAQ:你可能关心的 7 个问题
Q1:为什么只优化 4 步不会掉效果?
早期 4 步决定整体布局,后期 21 步只影响细节;用 ODE 快速走完后段,不影响奖励计算。
Q2:能接自己的奖励模型吗?
可以。在reward_models
目录新建类,按 ImageReward 的接口返回score
即可。
Q3:显存不够怎么办?
-
把窗口 w
调到 2; -
梯度累积步数从 3 提到 6; -
用 bf16 混合精度(已默认)。
Q4:单机 8 卡能跑吗?
把 hostfile 只写本机 IP,INDEX_CUSTOM=0
即可。代码会自动退化到单机多卡。
Q5:与 LoRA/Distillation 冲突吗?
不冲突。LoRA 只改参数形式,MixGRPO 改采样流程,二者正交。
Q6:训练会过拟合奖励模型吗?
使用混合推理(附录 A)pmix=80%
,即前 80% 步用 MixGRPO,后 20% 用原模型,可缓解 reward hacking。
Q7:代码 License?
见仓库根目录License.txt
,遵循腾讯 Hunyuan 团队自定义协议,科研可自由使用,商业需二次确认。
小结:下一步可以做什么?
-
继续压缩步数:把 Flash* 的二阶 ODE 换成三阶,或尝试蒸馏到 4 步以内。 -
多模态奖励:将 UnifiedReward 替换为同时考虑 OCR、人脸、风格的多头奖励。 -
视频/3D 扩展:滑动窗口思想天然适合帧序列,可迁移到视频生成 RLHF。
引用
如果本文对你的工作有帮助,请使用以下格式引用:
@misc{li2025mixgrpounlockingflowbasedgrpo,
title={MixGRPO: Unlocking Flow-based GRPO Efficiency with Mixed ODE-SDE},
author={Junzhe Li and Yutao Cui and Tao Huang and Yinping Ma and Chun Fan and Miles Yang and Zhao Zhong},
year={2025},
eprint={2507.21802},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2507.21802},
}