Simple end-to-end RLHF (Reinforcement Learning from Human Feedback) for diffusion models (DDPO) on personal hardware.
在个人硬件上对扩散模型进行端到端RLHF(基于人类反馈的强化学习)训练(DDPO)。
-
Download this repo.
下载本仓库
-
Download a diffusers-format checkpoint from Hugging Face. Or convert a single-file-format checkpoint to diffusers-format. Put the model under
/models/checkpoint(default). Currently only SD1.5 models are tested. Also be sure to use the DDIM scheduler as it's supported by DDPO.从Hugging Face下载diffusers格式的模型检查点,或将单文件格式检查点转换为diffusers格式。将模型放置在
/models/checkpoint目录(默认路径)。目前仅测试过SD1.5系列模型,请确保使用DDIM调度器(DDPO支持该调度器)。 -
Set your parameters in
config.py. The default parameters are tested on a 16GB GPU, but you should tweak them to fit your needs.在
config.py中设置参数。默认参数在16GB显存的GPU上测试通过,请根据硬件配置调整参数。
-
python generate.pyThis will generate sample images under/samples.该命令会在
/samples目录下生成样本图片。
-
You can use the provided WebApp to score generated images. A detailed use guide can be found in this repo: AlmostPerfect_frontend. Once finished, you should get a
scores.jsonfile in your/samplesdirectory.使用配套WebApp对生成图片进行评分。详细使用指南请参考前端仓库:AlmostPerfect_frontend。完成评分后,
/samples目录下应该有scores.json文件。
-
python train_reward.pyThis trains a simple latent CNN model and saves it under/models. Pay special attention to not overfit the reward model. You are free to try more complex reward models, which may require modifications to this process.该命令将训练一个简单的潜空间CNN模型并保存在
/models目录。特别注意避免奖励模型过拟合,可尝试更复杂的奖励模型架构(需要自行修改代码)。
-
python train_policy.pyThis will periodically save trained LoRAs under/outputs, you can then useui_lora_epoch_xxx.safetensorsin WebUI or ComfyUI for inference.该命令会定期保存训练好的LoRA权重到
/outputs目录,生成的ui_lora_epoch_xxx.safetensors文件可在WebUI或ComfyUI中加载使用。
-
Go back to Step 1 with LoRA loaded if you want to continue.
加载训练好的LoRA后,可返回第1步继续迭代优化。
The scripts in utils are adopted from the original ddpo-pytorch repo. The train_policy.py script is also based on this repo, with lots of simplifications, optimizations and compatibility fixes for the latest libraries.
utils 中的脚本采用自 ddpo-pytorch 原始仓库。train_policy.py 脚本同样基于该仓库,但进行了大量简化、优化,并针对最新的库进行了兼容性修复。

