RL-06-04-Stable-Baselines3与生态

← 上级:RL-06.评估环境与工具链


一、Stable-Baselines3(SB3)

1
pip install stable-baselines3[extra]
1
2
3
4
5
6
7
8
from stable_baselines3 import PPO, DQN, SAC
from stable_baselines3.common.evaluation import evaluate_policy

env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1, seed=42)
model.learn(total_timesteps=50_000)
mean, std = evaluate_policy(model, env, n_eval_episodes=20)
model.save("ppo_cartpole")
算法 SB3 类 典型环境
DQN DQN 离散、低维/Atari(CnnPolicy)
PPO PPO 通用
A2C A2C 轻量 On-Policy
SAC SAC 连续 MuJoCo
TD3 TD3 连续

二、Callback 与日志

1
2
3
4
5
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback

eval_env = gym.make("CartPole-v1")
eval_cb = EvalCallback(eval_env, best_model_save_path="./best", eval_freq=5000)
model.learn(100_000, callback=eval_cb)

tensorboard_log="./tb" 启用 TensorBoard。


三、CleanRL

单文件 PyTorch 实现,适合读源码对照 RL-04 自研实现

GitHub: vwxyzjn/cleanrl


四、RLlib

Ray 生态,分布式采样与大规模超参搜索;入门成本高于 SB3。


五、选用建议

目标 工具
快速 baseline SB3
学实现 CleanRL → 本系列 RL-04
生产原型 SB3 + 自研 env
大规模 RLlib

六、小结

-------------本文结束感谢您的阅读-------------