Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Puffer speed up #321

Merged
merged 8 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 113 additions & 69 deletions baselines/ippo/ippo_pufferlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import wandb
from box import Box
import time
import random

from integrations.rl.puffer import ppo
from integrations.rl.puffer.puffer_env import env_creator
Expand All @@ -32,6 +34,73 @@

app = Typer()

def log_normal(mean, scale, clip):
'''Samples normally spaced points on a log 10 scale.
mean: Your center sample point
scale: standard deviation in base 10 orders of magnitude
clip: maximum standard deviations

Example: mean=0.001, scale=1, clip=2 will produce data from
0.1 to 0.00001 with most of it between 0.01 and 0.0001
'''
return 10**np.clip(
np.random.normal(
np.log10(mean),
scale,
),
a_min = np.log10(mean) - clip,
a_max = np.log10(mean) + clip,
)

def logit_normal(mean, scale, clip):
'''log normal but for logit data like gamma and gae_lambda'''
return 1 - log_normal(1 - mean, scale, clip)

def uniform_pow2(min, max):
'''Uniform distribution over powers of 2 between min and max inclusive'''
min_base = np.log2(min)
max_base = np.log2(max)
return 2**np.random.randint(min_base, max_base+1)

def uniform(min, max):
'''Uniform distribution between min and max inclusive'''
return np.random.uniform(min, max)

def int_uniform(min, max):
'''Uniform distribution between min and max inclusive'''
return np.random.randint(min, max+1)

def sample_hyperparameters(sweep_config):
samples = {}
for name, param in sweep_config.items():
if name in ('method', 'name', 'metric'):
continue

assert isinstance(param, dict)
if any(isinstance(param[k], dict) for k in param):
samples[name] = sample_hyperparameters(param)
elif 'values' in param:
assert 'distribution' not in param
samples[name] = random.choice(param['values'])
elif 'distribution' in param:
if param['distribution'] == 'uniform':
samples[name] = uniform(param['min'], param['max'])
elif param['distribution'] == 'int_uniform':
samples[name] = int_uniform(param['min'], param['max'])
elif param['distribution'] == 'uniform_pow2':
samples[name] = uniform_pow2(param['min'], param['max'])
elif param['distribution'] == 'log_normal':
samples[name] = log_normal(
param['mean'], param['scale'], param['clip'])
elif param['distribution'] == 'logit_normal':
samples[name] = logit_normal(
param['mean'], param['scale'], param['clip'])
else:
raise ValueError(f'Invalid distribution: {param["distribution"]}')
else:
raise ValueError('Must specify either values or distribution')

return samples

def get_model_parameters(policy):
"""Helper function to count the number of trainable parameters."""
Expand All @@ -56,42 +125,20 @@ def make_policy(env, config):
dropout=config.train.network.dropout,
).to(config.train.device)


def train(args, make_env):
def train(args, vecenv):
"""Main training loop for the PPO agent."""

backend_mapping = {
# Note: Only native backend is currently supported with GPUDrive
"native": pufferlib.vector.Native,
"serial": pufferlib.vector.Serial,
"multiprocessing": pufferlib.vector.Multiprocessing,
"ray": pufferlib.vector.Ray,
}

backend = backend_mapping.get(args.vec.backend)
if not backend:
raise ValueError("Invalid --vec.backend.")

vecenv = pufferlib.vector.make(
make_env,
num_envs=1, # GPUDrive is already batched
num_workers=args.vec.num_workers,
batch_size=args.vec.env_batch_size,
zero_copy=args.vec.zero_copy,
backend=backend,
)

policy = make_policy(env=vecenv.driver_env, config=args).to(
args.train.device
)

args.train.network.num_parameters = get_model_parameters(policy)
args.train.env = args.environment.name

args.wandb = init_wandb(args, args.train.exp_id, id=args.train.exp_id)
args.train.__dict__.update(dict(args.wandb.config.train))
wandb_run = init_wandb(args, args.train.exp_id, id=args.train.exp_id)
args.train.update(dict(wandb_run.config.train))

data = ppo.create(args.train, vecenv, policy, wandb=args.wandb)
data = ppo.create(args.train, vecenv, policy, wandb=wandb_run)
while data.global_step < args.train.total_timesteps:
try:
ppo.evaluate(data) # Rollout
Expand All @@ -107,8 +154,24 @@ def train(args, make_env):
ppo.evaluate(data)
ppo.close(data)

def set_experiment_metadata(config):
datetime_ = datetime.now().strftime("%m_%d_%H_%M_%S_%f")[:-3]
if config["train"]["resample_scenes"]:
if config["train"]["resample_scenes"]:
dataset_size = config["train"]["resample_dataset_size"]
config["train"][
"exp_id"
] = f'PPO_R_{dataset_size}__{datetime_}'
else:
dataset_size = str(config["environment"]["k_unique_scenes"])
config["train"][
"exp_id"
] = f'PPO_S_{dataset_size}__{datetime_}'

config["environment"]["dataset_size"] = dataset_size


def init_wandb(args, name, id=None, resume=True):
def init_wandb(args, name, id=None, resume=True, tag=None):
wandb.init(
id=id or wandb.util.generate_id(),
project=args.wandb.project,
Expand All @@ -128,29 +191,6 @@ def init_wandb(args, name, id=None, resume=True):

return wandb


def sweep(args, project="PPO", sweep_name="my_sweep"):
"""Initialize a WandB sweep with hyperparameters."""
sweep_id = wandb.sweep(
sweep=dict(
method="random",
name=sweep_name,
metric={"goal": "maximize", "name": "environment/episode_return"},
parameters={
"learning_rate": {
"distribution": "log_uniform_values",
"min": 1e-4,
"max": 1e-1,
},
"batch_size": {"values": [512, 1024, 2048]},
"minibatch_size": {"values": [128, 256, 512]},
},
),
project=project,
)
wandb.agent(sweep_id, lambda: train(args), count=100)


@app.command()
def run(
config_path: Annotated[
Expand Down Expand Up @@ -185,6 +225,7 @@ def run(
project: Annotated[Optional[str], typer.Option(help="WandB project name")] = None,
entity: Annotated[Optional[str], typer.Option(help="WandB entity name")] = None,
group: Annotated[Optional[str], typer.Option(help="WandB group name")] = None,
max_runs: Annotated[Optional[int], typer.Option(help="Maximum number of sweep runs")] = 100,
render: Annotated[Optional[int], typer.Option(help="Whether to render the environment; 0 or 1")] = None,
):
"""Run PPO training with the given configuration."""
Expand Down Expand Up @@ -240,21 +281,6 @@ def run(
{k: v for k, v in wandb_config.items() if v is not None}
)

datetime_ = datetime.now().strftime("%m_%d_%H_%M_%S_%f")[:-3]

if config["train"]["resample_scenes"]:
if config["train"]["resample_scenes"]:
dataset_size = config["train"]["resample_dataset_size"]
config["train"][
"exp_id"
] = f'{config["train"]["exp_id"]}__R_{dataset_size}__{datetime_}'
else:
dataset_size = str(config["environment"]["k_unique_scenes"])
config["train"][
"exp_id"
] = f'{config["train"]["exp_id"]}__S_{dataset_size}__{datetime_}'

config["environment"]["dataset_size"] = dataset_size
config["train"]["device"] = config["train"].get(
"device", "cpu"
) # Default to 'cpu' if not set
Expand All @@ -279,11 +305,29 @@ def run(
train_config=config.train,
device=config.train.device,
)
vecenv = pufferlib.vector.make(
make_env,
num_envs=1, # GPUDrive is already batched
num_workers=config.vec.num_workers,
batch_size=config.vec.env_batch_size,
zero_copy=config.vec.zero_copy,
backend=pufferlib.vector.Native,
)

if config.mode == "train":
train(config, make_env)

set_experiment_metadata(config)
train(config, vecenv)
elif config.mode == "sweep":
for i in range(max_runs):
np.random.seed(int(time.time()))
random.seed(int(time.time()))
set_experiment_metadata(config)
hypers = sample_hyperparameters(config.sweep)
config.train.update(hypers['train'])
config.environment.update(hypers['environment'])
train(config, vecenv)

if __name__ == "__main__":

app()
import cProfile
cProfile.run('app()', 'profiled')
#app()
68 changes: 55 additions & 13 deletions examples/experiments/ippo_ff_p1_self_play.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
mode: "train"
mode: "sweep"
use_rnn: false
eval_model_path: null
baseline: false
data_dir: "data/processed/training"
data_dir: "data/processed/examples"

environment: # Overrides default environment configs (see pygpudrive/env/config.py)
name: "gpudrive"
num_worlds: 100 # Number of parallel environments
k_unique_scenes: 100 # Number of unique scenes to sample from
max_controlled_agents: 128 # Maximum number of agents controlled by the model. Make sure this aligns with the variable kMaxAgentCount in src/consts.hpp
num_worlds: 300 # Number of parallel environments
k_unique_scenes: 3 # Number of unique scenes to sample from
max_controlled_agents: 64 # Maximum number of agents controlled by the model. Make sure this aligns with the variable kMaxAgentCount in src/consts.hpp
ego_state: true
road_map_obs: true
partner_obs: true
Expand All @@ -26,11 +26,11 @@ environment: # Overrides default environment configs (see pygpudrive/env/config.
sampling_seed: 42 # If given, the set of scenes to sample from will be deterministic, if None, the set of scenes will be random
obs_radius: 50.0 # Visibility radius of the agents
wandb:
entity: ""
entity: "jsuarez"
project: "paper_1_self_play"
group: "verify"
mode: "online" # Options: online, offline, disabled
tags: ["ppo", "ff"]
tags: ["ppo", "ff", "basic-sweep"]

## NOTES
## Good batch size: 128 * number of controlled agents (e.g. 2**18)
Expand All @@ -46,24 +46,24 @@ train:
compile_mode: "reduce-overhead"

# # # Data sampling # # #
resample_scenes: true
resample_scenes: False
resample_criterion: "global_step"
resample_dataset_size: 500 # Number of unique scenes to sample from
resample_interval: 1_000_000
resample_interval: 20_000_000
resample_limit: 1000 # Resample until the limit is reached; set to a large number to continue resampling indefinitely
sample_with_replacement: true
shuffle_dataset: false

# # # PPO # # #
torch_deterministic: false
total_timesteps: 1_000_000_000
total_timesteps: 50_000_000
batch_size: 131_072
minibatch_size: 8192
minibatch_size: 16384
learning_rate: 3e-4
anneal_lr: false
gamma: 0.99
gae_lambda: 0.95
update_epochs: 5
update_epochs: 1
norm_adv: true
clip_coef: 0.2
clip_vloss: false
Expand All @@ -76,7 +76,7 @@ train:
# # # Network # # #
network:
input_dim: 64 # Embedding of the input features
hidden_dim: 128 # Latent dimension
hidden_dim: 192 # Latent dimension
pred_heads_arch: [64] # Arch of the prediction heads (actor and critic)
num_transformer_layers: 0 # Number of transformer layers
dropout: 0.01
Expand All @@ -96,6 +96,48 @@ train:
render_format: "mp4" # Options: gif, mp4
render_fps: 15 # Frames per second

sweep:
train:
learning_rate:
distribution: "log_normal"
mean: 0.005
scale: 1.0
clip: 2.0

ent_coef:
distribution: "log_normal"
mean: 0.005
scale: 0.5
clip: 1.0

gamma:
distribution: "logit_normal"
mean: 0.98
scale: 0.5
clip: 1.0

gae_lambda:
distribution: "logit_normal"
mean: 0.95
scale: 0.5
clip: 1.0

environment:
collision_weight:
distribution: "uniform"
min: -1.0
max: 0.0

off_road_weight:
distribution: "uniform"
min: -1.0
max: 0.0

goal_achieved_weight:
distribution: "uniform"
min: 0.0
max: 1.0

vec:
backend: "native" # Only native is currently supported
num_workers: 1
Expand Down
2 changes: 1 addition & 1 deletion integrations/rl/puffer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def train(data):


def close(data):
data.vecenv.close()
#data.vecenv.close()
data.utilization.stop()
config = data.config
if data.wandb is not None:
Expand Down
Loading