MixGRPO:用“混合采样+滑动窗口”让 AI 绘图模型训练快 71%

一句话总结
在 FLUX.1-dev 之上,MixGRPO 用“ODE+SDE 混合采样”只优化最关键的 4 步,训练时间比 DanceGRPO 再省一半,图像质量还能再高一点。文章末尾有完整上手步骤,可在 32 张 A100 上复现。


为什么训练“人味”图模型这么慢?

过去半年,研究者发现:
在扩散/流匹配模型里加 RLHF(人类反馈强化学习),能显著提升文本-图像一致性美学分数,但代价是训练时间爆炸

典型方法 每步采样 需优化步数 训练瓶颈
Flow-GRPO SDE 全程随机 25 步全算 NFE 25×2=50
DanceGRPO SDE 随机子集 14 步 NFE 25+14=39
MixGRPO(本文) ODE+SDE 混合 仅 4 步 NFE 25+4=29

NFE(Number of Function Evaluations)≈ GPU 时间。数字越小越快。


MixGRPO 的“三步提速法”

1. 混合采样:只在“窗口”内用随机 SDE

  • ODE 段(前 & 后):确定性采样,省算力。
  • SDE 段(中间 4 步):保留随机性,让 RL 继续探索。

这样做带来的好处:

  • 优化长度从 25 步缩到 4 步,GPU 内存和显式时间都下降
  • 由于 ODE 段不参与优化,可以用更高阶的 ODE solver(DPM-Solver++)再提速。

2. 滑动窗口:让优化从“高噪声”滑到“低噪声”

把 4 步窗口像卷积核一样从 t=0 移到 t=T,直觉上:

  • 早期噪声大,搜索空间大,奖励折扣高,先优化。
  • 后期细节微调,留给 ODE 快速推进。

实验结果:
固定窗口(frozen)也能工作,但“指数衰减位移”策略在 PickScore 上再涨 0.002(表 3)。

3. Flash 版本:用二阶 ODE 再砍 71% 时间

方法 NFEπθold 单图训练时间 ImageReward↑
DanceGRPO 25 9.30 s 1.335
MixGRPO 25 7.34 s 1.564
MixGRPO-Flash* 12 4.86 s 1.588

Flash 使用冻结窗口,l=0,ODE 后段直接二阶采样。)


实际效果:看得见的提升

客观指标

指标 FLUX DanceGRPO MixGRPO MixGRPO-Flash
HPS-v2.1 0.313 0.356 0.367 0.358
PickScore 0.227 0.233 0.237 0.236
ImageReward 1.088 1.436 1.629 1.528
UnifiedReward 3.370 3.397 3.418 3.407

所有分数都是“越高越好”,MixGRPO 在 4 个公开奖励模型上均排名第一。

主观对比

图 3 与图 6 展示了同一 prompt 下的输出差异。
可以看出:

  • FLUX:构图正确,细节略显平淡。
  • DanceGRPO:颜色与纹理有提升,但偶尔过饱和。
  • MixGRPO:主体清晰、光影自然、文字一致性更好。

上手实践:30 分钟跑通 32 卡训练

环境安装(CentOS 为例)

# 1. 新建环境
conda create -n MixGRPO python=3.12
conda activate MixGRPO

# 2. 系统依赖
sudo yum install -y pdsh pssh mesa-libGL
bash env_setup.sh         # 自动装好 torch、diffusers、transformers

模型下载

模型/权重 路径 一键命令
FLUX.1-dev ./data/flux huggingface-cli download black-forest-labs/FLUX.1-dev --local-dir ./data/flux
HPS-v2.1 ./hps_ckpt 见 README.md 2.2 节
ImageReward ./image_reward_ckpt 同上
MixGRPO 已训练权重 ./mix_grpo_ckpt huggingface-cli download tulvgengenr/MixGRPO diffusion_pytorch_model.safetensors --local-dir ./mix_grpo_ckpt

数据预处理

# 把 HPDv2 的 prompt 转成 embedding
bash scripts/preprocess/preprocess_flux_rl_embeddings.sh

多机训练(4 节点 × 8 GPU)

  1. 编辑 data/hosts/hostfile 写入 4 台机器 IP。
  2. 每台机器执行

    bash scripts/preprocess/set_env_multinode.sh
    

    自动设置 INDEX_CUSTOM=0/1/2/3

  3. 把 WandB key 写入 scripts/finetune/finetune_flux_grpo_FastGRPO.sh
  4. 启动训练

    bash scripts/finetune/finetune_flux_grpo_FastGRPO.sh
    

单卡推理 & 评分

# 1. 生成
bash scripts/inference/inference_flux.sh
# 2. 评分
bash scripts/evaluate/eval_reward.sh

FAQ:你可能关心的 7 个问题

Q1:为什么只优化 4 步不会掉效果?
早期 4 步决定整体布局,后期 21 步只影响细节;用 ODE 快速走完后段,不影响奖励计算。

Q2:能接自己的奖励模型吗?
可以。在 reward_models 目录新建类,按 ImageReward 的接口返回 score 即可。

Q3:显存不够怎么办?

  • 把窗口 w 调到 2;
  • 梯度累积步数从 3 提到 6;
  • 用 bf16 混合精度(已默认)。

Q4:单机 8 卡能跑吗?
把 hostfile 只写本机 IP,INDEX_CUSTOM=0 即可。代码会自动退化到单机多卡。

Q5:与 LoRA/Distillation 冲突吗?
不冲突。LoRA 只改参数形式,MixGRPO 改采样流程,二者正交。

Q6:训练会过拟合奖励模型吗?
使用混合推理(附录 A)pmix=80%,即前 80% 步用 MixGRPO,后 20% 用原模型,可缓解 reward hacking。

Q7:代码 License?
见仓库根目录 License.txt,遵循腾讯 Hunyuan 团队自定义协议,科研可自由使用,商业需二次确认。


小结:下一步可以做什么?

  • 继续压缩步数:把 Flash* 的二阶 ODE 换成三阶,或尝试蒸馏到 4 步以内。
  • 多模态奖励:将 UnifiedReward 替换为同时考虑 OCR、人脸、风格的多头奖励。
  • 视频/3D 扩展:滑动窗口思想天然适合帧序列,可迁移到视频生成 RLHF。

引用

如果本文对你的工作有帮助,请使用以下格式引用:

@misc{li2025mixgrpounlockingflowbasedgrpo,
  title={MixGRPO: Unlocking Flow-based GRPO Efficiency with Mixed ODE-SDE}, 
  author={Junzhe Li and Yutao Cui and Tao Huang and Yinping Ma and Chun Fan and Miles Yang and Zhao Zhong},
  year={2025},
  eprint={2507.21802},
  archivePrefix={arXiv},
  primaryClass={cs.AI},
  url={https://arxiv.org/abs/2507.21802}, 
}