Skip to content

Minimal proximal policy gradient (clipped version) in PyTorch.

License

Notifications You must be signed in to change notification settings

zhihanyang2022/pytorch-ppo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

91 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

I'm considering running more seeds and reproducing SB3's PPO's performance on some hard domains, so stay tuned!

PPO-Clip 🤗

In this repo, I minimally but very carefully implemented PPO-Clip in PyTorch, without parallelism. It has been very time-consuming since there are many details to get right, even though I had stable-baselines3 as a reference.

Now, it's working very well and converges very stably within 2-3 minutes on CartPole-v0 and Pendulum-v0 (100k steps). It reaches good performance on HalfCheetah-v3 (1M steps) using 2 hours.

CartPole-v0 Pendulum-v0 HalfCheetah-v3
ezgif com-gif-maker ezgif com-gif-maker (1) (see below for video)
image image image
openaigym.video.0.60021.video000000.mp4

Requirements

pip install wandb==0.12.2 numpy==1.19.5 torch==1.9.1 scipy==1.6.2 gym==0.18.3

Don't use a gym version that's newer because Pendulum-v0 or CartPole-v0 might be unavailable (i.e., only newer version is available).

Scripts

Setup wandb.

Train:

python launch.py --expdir experiments/CartPole-v0_ppo_sb3 --run_id 1

Visualize policy (require you to download the trained models from wandb):

python launch.py --expdir experiments/CartPole-v0_ppo_sb3 --run_id 1 --enjoy

Plots

You can view plots on wandb. Here's an example screenshot:

image

FAQs

Why doesn't parallel data collection bring you as much speedup as you think?

Sure, parallel data collection brings some speedup, but PPO is fast primarily because it's taking far fewer gradient steps than, e.g., DDPG, given the same number of environment timesteps. See DLR-RM/stable-baselines3#643 for a thorough discussion.

How is this implementation different from SB3's PPO?

They are supposed to be exactly the same (in terms of what the code does; code design choices are rather different), except that this repo doesn't have certain arguments, which are default to None or False in SB3 anyways. The hyper-parameters in config files were copied from rl-baselines3-zoo. These hyper-parameter values are good because they are tuned.

What hyper-parameter values are you using?

Checkout the config files in the two subdirs of experiments, one for the simplest discrete control domain and the other for the simplest continuous control domain.

How to extend this codebase for research purposes?

You can create a new file in algorithms, add it to the algo_name2class dictionary inside launch.py. Then, you can simply specify that you want to run that algorithm in config files, which means you need to create a folder in experiments to contain that config file.

Thanks

I'd like to thank the maintainers of SB3 for answering my questions during this process.

About

Minimal proximal policy gradient (clipped version) in PyTorch.

Topics

Resources

License

Stars

Watchers

Forks

Languages