From 17f0b44d97202ad8a8bcdf079b49997fec0cc4b9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 12 Aug 2024 07:46:26 -0400 Subject: [PATCH] [Algorithm] TD3 fast ghstack-source-id: 16037b74edf8c66efec9c5aad1d6713ff2635762 Pull Request resolved: https://github.com/pytorch/rl/pull/2389 --- sota-implementations/td3/config-fast.yaml | 55 +++++ sota-implementations/td3/td3-fast.py | 246 ++++++++++++++++++++++ sota-implementations/td3/utils.py | 109 +++++++--- torchrl/collectors/collectors.py | 4 +- 4 files changed, 381 insertions(+), 33 deletions(-) create mode 100644 sota-implementations/td3/config-fast.yaml create mode 100644 sota-implementations/td3/td3-fast.py diff --git a/sota-implementations/td3/config-fast.yaml b/sota-implementations/td3/config-fast.yaml new file mode 100644 index 00000000000..27e894bfc90 --- /dev/null +++ b/sota-implementations/td3/config-fast.yaml @@ -0,0 +1,55 @@ +# task and env +env: + name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency + task: "" + library: gymnasium + seed: 42 + max_episode_steps: 1000 + +# collector +collector: + total_frames: 1000000 + init_random_frames: 25_000 + init_env_steps: 1000 + frames_per_batch: 1000 + reset_at_each_iter: False + device: cpu + env_per_collector: 1 + num_workers: 8 + +# replay buffer +replay_buffer: + prb: 0 # use prioritized experience replay + size: 1000000 + scratch_dir: null + device: null + +# optim +optim: + utd_ratio: 1.0 + gamma: 0.99 + loss_function: l2 + lr: 3.0e-4 + weight_decay: 0.0 + adam_eps: 1e-4 + batch_size: 256 + target_update_polyak: 0.995 + policy_update_delay: 2 + policy_noise: 0.2 + noise_clip: 0.5 + +# network +network: + hidden_sizes: [256, 256] + activation: relu + device: null + +# logging +logger: + backend: wandb + project_name: torchrl_example_td3 + group_name: null + exp_name: ${env.name}_TD3 + mode: online + eval_iter: 25000 + video: False diff --git a/sota-implementations/td3/td3-fast.py b/sota-implementations/td3/td3-fast.py new file mode 100644 index 00000000000..a433b238acf --- /dev/null +++ b/sota-implementations/td3/td3-fast.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""TD3 Example. + +This is a simple self-contained example of a TD3 training script. + +It supports state environments like MuJoCo. + +The helper functions are coded in the utils.py associated with this script. +""" +import time + +import hydra +import numpy as np +import torch +import torch.cuda +import tqdm +from torchrl._utils import logger as torchrl_logger +from torchrl.data.utils import CloudpickleWrapper + +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from torchrl.record.loggers import generate_exp_name, get_logger +from utils import ( + log_metrics, + make_async_collector, + make_environment, + make_loss_module, + make_optimizer, + make_replay_buffer, + make_simple_environment, + make_td3_agent, +) + + +@hydra.main(version_base="1.1", config_path="", config_name="config-fast") +def main(cfg: "DictConfig"): # noqa: F821 + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + device = torch.device(device) + + # Create logger + exp_name = generate_exp_name("TD3", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="td3_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create environments + train_env, eval_env = make_environment(cfg, logger=logger) + + # Create agent + model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device) + + # Create TD3 loss + loss_module, target_net_updater = make_loss_module(cfg, model) + + # Create replay buffer + replay_buffer = make_replay_buffer( + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, + scratch_dir=cfg.replay_buffer.scratch_dir, + device=cfg.replay_buffer.device if cfg.replay_buffer.device else device, + prefetch=0, + mmap=False, + ) + reshape = CloudpickleWrapper(lambda td: td.reshape(-1)) + replay_buffer.append_transform(reshape, invert=True) + + # Create off-policy collector + envname = cfg.env.name + task = cfg.env.task + library = cfg.env.library + seed = cfg.env.seed + max_episode_steps = cfg.env.max_episode_steps + collector = make_async_collector( + cfg, + lambda: make_simple_environment( + envname, task, library, seed, max_episode_steps + ), + exploration_policy, + replay_buffer, + ) + + # Create optimizers + optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) + + # Main loop + start_time = time.time() + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + max(1, cfg.collector.env_per_collector) + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + delayed_updates = cfg.optim.policy_update_delay + prb = cfg.replay_buffer.prb + update_counter = 0 + + sampling_start = time.time() + current_frames = cfg.collector.frames_per_batch + update_actor = False + + test_env = make_simple_environment(envname, task, library, seed, max_episode_steps) + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + reward = test_env.rollout(10_000, exploration_policy)["next", "reward"].mean() + print(f"reward before training: {reward: 4.4f}") + + # loss_module.value_loss = torch.compile( + # loss_module.value_loss, mode="reduce-overhead" + # ) + # loss_module.actor_loss = torch.compile( + # loss_module.actor_loss, mode="reduce-overhead" + # ) + + def train_update(sampled_tensordict): + # Compute loss + q_loss, *_ = loss_module.value_loss(sampled_tensordict) + + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + q_losses.append(q_loss.item()) + + # Update actor + if update_actor: + actor_loss, *_ = loss_module.actor_loss(sampled_tensordict) + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + actor_losses.append(actor_loss.item()) + + # Update target params + target_net_updater.step() + + train_update_cuda = None + g = torch.cuda.CUDAGraph() + + for _ in collector: + sampling_time = time.time() - sampling_start + exploration_policy[1].step(current_frames) + + # Update weights of the inference policy + collector.update_policy_weights_() + + pbar.update(current_frames) + + # Add to replay buffer + collected_frames += current_frames + + # Optimization steps + training_start = time.time() + + if collected_frames >= init_random_frames: + ( + actor_losses, + q_losses, + ) = ([], []) + for _ in range(num_updates): + + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to( + device, non_blocking=True + ) + else: + sampled_tensordict = sampled_tensordict.clone() + + if train_update_cuda is None: + static_sample = sampled_tensordict + with torch.cuda.graph(g): + train_update(static_sample) + + def train_update_cuda(x): + static_sample.copy_(x) + g.replay() + else: + train_update_cuda(sampled_tensordict) + + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + + training_time = time.time() - training_start + + # Logging + metrics_to_log = {} + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = np.mean(q_losses) + if update_actor: + metrics_to_log["train/a_loss"] = np.mean(actor_losses) + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() + + collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() + + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + reward = test_env.rollout(10_000, exploration_policy)["next", "reward"].mean() + print(f"reward before training: {reward: 4.4f}") + test_env.close() + + end_time = time.time() + execution_time = end_time - start_time + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 60a4d046355..906685b4479 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -10,9 +10,9 @@ from tensordict.nn import TensorDictSequential from torch import nn, optim -from torchrl.collectors import SyncDataCollector +from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage from torchrl.envs import ( CatTensors, Compose, @@ -46,20 +46,18 @@ # ----------------- -def env_maker(cfg, device="cpu", from_pixels=False): - lib = cfg.env.library +def env_maker(envname, task, library, device="cpu", from_pixels=False): + lib = library # cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, + envname, # cfg.env.name device=device, from_pixels=from_pixels, pixels_only=False, ) elif lib == "dm_control": - env = DMControlEnv( - cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False - ) + env = DMControlEnv(envname, task, from_pixels=from_pixels, pixels_only=False) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -82,33 +80,56 @@ def apply_env_transforms(env, max_episode_steps): def make_environment(cfg, logger=None): """Make environments for training and evaluation.""" - partial = functools.partial(env_maker, cfg=cfg) - parallel_env = ParallelEnv( - cfg.collector.env_per_collector, - EnvCreator(partial), - serial_for_single=True, + partial = functools.partial( + env_maker, envname=cfg.env.name, task=cfg.env.task, library=cfg.env.library ) + if cfg.collector.env_per_collector == 0: + parallel_env = partial() + else: + parallel_env = ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(partial), + serial_for_single=True, + ) parallel_env.set_seed(cfg.env.seed) train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) - partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) + partial = functools.partial( + env_maker, + envname=cfg.env.name, + task=cfg.env.task, + library=cfg.env.library, + from_pixels=cfg.logger.video, + ) trsf_clone = train_env.transform.clone() if cfg.logger.video: trsf_clone.insert( 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) ) - eval_env = TransformedEnv( - ParallelEnv( - cfg.collector.env_per_collector, - EnvCreator(partial), - serial_for_single=True, - ), - trsf_clone, - ) + if cfg.collector.env_per_collector == 0: + eval_env = TransformedEnv(partial(), trsf_clone) + else: + eval_env = TransformedEnv( + ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(partial), + serial_for_single=True, + ), + trsf_clone, + ) return train_env, eval_env +def make_simple_environment( + envname, task, library, seed, max_episode_steps, logger=None +): + """Make environments for training and evaluation.""" + env = env_maker(envname=envname, task=task, library=library) + env.set_seed(seed) + return apply_env_transforms(env, max_episode_steps) + + # ==================================================================== # Collector and replay buffer # --------------------------- @@ -129,6 +150,23 @@ def make_collector(cfg, train_env, actor_model_explore): return collector +def make_async_collector(cfg, train_env, actor_model_explore, rb): + """Make fast collector.""" + collector = MultiaSyncDataCollector( + [EnvCreator(train_env)] * cfg.collector.num_workers, + actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + reset_at_each_iter=cfg.collector.reset_at_each_iter, + device=cfg.collector.device, + replay_buffer=rb, + replay_buffer_chunk=False, + ) + collector.set_seed(cfg.env.seed) + return collector + + def make_replay_buffer( batch_size, prb=False, @@ -136,34 +174,39 @@ def make_replay_buffer( scratch_dir=None, device="cpu", prefetch=3, + mmap=True, ): with ( tempfile.TemporaryDirectory() if scratch_dir is None else nullcontext(scratch_dir) ) as scratch_dir: + if mmap: + storage = LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ) + else: + storage = LazyTensorStorage( + buffer_size, + device=device, + ) + if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( alpha=0.7, beta=0.5, pin_memory=False, prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=scratch_dir, - device=device, - ), + storage=storage, batch_size=batch_size, ) else: replay_buffer = TensorDictReplayBuffer( pin_memory=False, prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=scratch_dir, - device=device, - ), + storage=storage, batch_size=batch_size, ) return replay_buffer @@ -282,12 +325,14 @@ def make_optimizer(cfg, loss_module): lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.adam_eps, + foreach=True, ) optimizer_critic = optim.Adam( critic_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.adam_eps, + foreach=True, ) return optimizer_actor, optimizer_critic diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index e7f2c94b1c2..ae5ddc91866 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1673,7 +1673,9 @@ def _check_replay_buffer_init(self): fake_td = self.create_env_fn[0]( **self.create_env_kwargs[0] ).fake_tensordict() - fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long) + fake_td["collector", "traj_ids"] = torch.zeros( + fake_td.shape, dtype=torch.long + ) self.replay_buffer._storage._init(fake_td) except AttributeError: