diff --git a/test/test_cost.py b/test/test_cost.py index 469bff3fe81..1084e452bf5 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -113,7 +113,7 @@ from torchrl.objectives.redq import REDQLoss from torchrl.objectives.reinforce import ReinforceLoss from torchrl.objectives.utils import ( - _vmap_func, + _maybe_vmap_maybe_func, HardUpdate, hold_out_net, SoftUpdate, @@ -249,11 +249,11 @@ def __init__(self): layers.append(nn.Linear(4, 4)) net = nn.Sequential(*layers).to(device) model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"]) - self.convert_to_functional(model, "model", expand_dim=4) + self.maybe_convert_to_functional(model, "model", expand_dim=4) self._make_vmap() def _make_vmap(self): - self.vmap_model = _vmap_func( + self.vmap_model = _maybe_vmap_maybe_func( self.model, (None, 0), randomness=( @@ -3871,6 +3871,116 @@ def test_sac_vmap_equiv( assert_allclose_td(loss_vmap, loss_novmap) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("as_list", [True, False]) + @pytest.mark.parametrize("provide_target", [True, False]) + @pytest.mark.parametrize("delay_value", (True, False)) + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + def test_sac_nofunc( + self, + device, + version, + as_list, + provide_target, + delay_value, + delay_actor, + delay_qvalue, + num_qvalue=4, + td_est=None, + ): + if (delay_actor or delay_qvalue) and not delay_value: + pytest.skip("incompatible config") + + torch.manual_seed(self.seed) + td = self._create_mock_data_sac(device=device) + + kwargs = {} + + actor = self._create_mock_actor(device=device) + if delay_actor: + kwargs["delay_actor"] = True + if provide_target: + kwargs["target_actor_network"] = self._create_mock_actor(device=device) + kwargs["target_actor_network"].load_state_dict(actor.state_dict()) + if as_list: + qvalue = [ + self._create_mock_qvalue(device=device) for _ in range(num_qvalue) + ] + else: + qvalue = self._create_mock_qvalue(device=device) + + if delay_qvalue: + kwargs["delay_qvalue"] = True + if provide_target: + if as_list: + kwargs["target_qvalue_network"] = [ + self._create_mock_qvalue(device=device) for _ in qvalue + ] + for qval_t, qval in zip(kwargs["target_qvalue_network"], qvalue): + qval_t.load_state_dict(qval.state_dict()) + + else: + kwargs["target_qvalue_network"] = self._create_mock_qvalue( + device=device + ) + kwargs["target_qvalue_network"].load_state_dict(qvalue.state_dict()) + + if version == 1: + value = self._create_mock_value(device=device) + else: + value = None + if delay_value: + kwargs["delay_value"] = True + if provide_target and version == 1: + kwargs["target_value_network"] = self._create_mock_value(device=device) + kwargs["target_value_network"].load_state_dict(value.state_dict()) + + rng_state = torch.random.get_rng_state() + with pytest.warns( + UserWarning, match="The target network is ignored as the" + ) if delay_qvalue and not as_list and provide_target else contextlib.nullcontext(): + loss_fn_nofunc = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + loss_function="l2", + use_vmap=False, + functional=False, + **kwargs, + ) + torch.random.set_rng_state(rng_state) + loss_fn_func = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + loss_function="l2", + use_vmap=False, + functional=True, + **kwargs, + ) + assert_allclose_td( + torch.stack( + list( + TensorDict.from_module(loss_fn_nofunc.qvalue_network)[ + "module" + ].values() + ) + ), + loss_fn_func.qvalue_network_params.data, + ) + with torch.no_grad(), _check_td_steady(td), pytest.warns( + UserWarning, match="No target network updater" + ): + rng_state = torch.random.get_rng_state() + loss_func = loss_fn_nofunc(td.clone()) + torch.random.set_rng_state(rng_state) + loss_nofunc = loss_fn_func(td.clone()) + + assert_allclose_td(loss_func, loss_nofunc) + @pytest.mark.parametrize("delay_value", (True, False)) @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @@ -12378,7 +12488,7 @@ class MyLoss(LossModule): def __init__(self, actor_network): super().__init__() - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=create_target_params, @@ -12527,7 +12637,7 @@ class custom_module(LossModule): def __init__(self, delay_module=True): super().__init__() module1 = torch.nn.BatchNorm2d(10).eval() - self.convert_to_functional( + self.maybe_convert_to_functional( module1, "module1", create_target_params=delay_module ) @@ -14291,12 +14401,12 @@ class MyLoss(LossModule): def __init__(self, actor_network, qvalue_network): super().__init__() - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=True, ) - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", 3, @@ -14864,8 +14974,8 @@ def __init__(self, compare_against, expand_dim): module_b = TensorDictModule( nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"] ) - self.convert_to_functional(module_a, "module_a") - self.convert_to_functional( + self.maybe_convert_to_functional(module_a, "module_a") + self.maybe_convert_to_functional( module_b, "module_b", compare_against=module_a.parameters() if compare_against else [], @@ -14913,8 +15023,8 @@ def __init__(self, expand_dim=2): module_b = TensorDictModule( nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"] ) - self.convert_to_functional(module_a, "module_a") - self.convert_to_functional( + self.maybe_convert_to_functional(module_a, "module_a") + self.maybe_convert_to_functional( module_b, "module_b", compare_against=module_a.parameters(), @@ -14962,8 +15072,8 @@ class MyLoss(LossModule): def __init__(self, module_a, module_b0, module_b1, expand_dim=2): super().__init__() - self.convert_to_functional(module_a, "module_a") - self.convert_to_functional( + self.maybe_convert_to_functional(module_a, "module_a") + self.maybe_convert_to_functional( [module_b0, module_b1], "module_b", # This will be ignored @@ -15332,14 +15442,14 @@ def __init__(self): TensorDictModule(value, in_keys=["hidden"], out_keys=["value"]), ) super().__init__() - self.convert_to_functional( + self.maybe_convert_to_functional( actor, "actor", expand_dim=None, create_target_params=False, compare_against=None, ) - self.convert_to_functional( + self.maybe_convert_to_functional( value, "value", expand_dim=2, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index d3b2b4d2ac2..5b3b5064fbd 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -278,7 +278,7 @@ def __init__( ) if functional: - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", ) @@ -292,7 +292,7 @@ def __init__( else: policy_params = None if functional: - self.convert_to_functional( + self.maybe_convert_to_functional( critic_network, "critic_network", compare_against=policy_params ) else: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 1cf3837a638..c11d193fe3f 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -14,7 +14,13 @@ from tensordict import is_tensor_collection, TensorDict, TensorDictBase -from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams +from tensordict.nn import ( + TensorDictModule, + TensorDictModuleBase, + TensorDictParams, + TensorDictSequential, +) +from tensordict.utils import Buffer from torch import nn from torch.nn import Parameter from torchrl._utils import RL_WARNINGS @@ -252,13 +258,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """ raise NotImplementedError - def convert_to_functional( + def maybe_convert_to_functional( self, module: TensorDictModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, compare_against: Optional[List[Parameter]] = None, + target_network: TensorDictModule | None = None, **kwargs, ) -> None: """Converts a module to functional to be used in the loss. @@ -298,6 +305,9 @@ def convert_to_functional( the resulting parameters will be a detached version of the original parameters. If ``None``, the resulting parameters will carry gradients as expected. + target_network (TensorDictModule, optional): if the loss module is + not functional, the optional target network associated with + the input network. """ for name in ( @@ -310,6 +320,11 @@ def convert_to_functional( f"The name {name} wasn't part of the annotations ({self.__class__.__annotations__.keys()}). Make sure it is present in the definition class." ) + name_target_network = "target_" + module_name + + def _custom_resample(param_dst): + return param_dst.uniform_(param_dst.min().item(), param_dst.max().item()) + if kwargs: raise TypeError(f"Unrecognised keyword arguments {list(kwargs.keys())}") # To make it robust to device casting, we must register list of @@ -323,11 +338,113 @@ def convert_to_functional( "The ``expand_dim`` value must match the length of the module list/tuple " "if a single module isn't provided." ) - params = TensorDict.from_modules( - *module, as_module=True, expand_identical=True - ) + if self.functional: + params = TensorDict.from_modules( + *module, + as_module=True, + expand_identical=True, + lazy_stack=not self.functional, + ) + module = module[0] + else: + setattr(self, module_name, TensorDictSequential(*module)) + module = getattr(self, module_name) + params = TensorDict.from_module(module, as_module=True) + # We don't want the parameters to appear twice + setattr(self, module_name + "_params", None) + if target_network is not None: + if not isinstance(target_network, (list, tuple)): + raise RuntimeError( + "The target network must be a list or a tuple of modules of the same length as " + f"the module for {module}." + ) + target_network = TensorDictSequential( + *target_network + ).requires_grad_(False) + # convert all params to buffers + target_params = TensorDict.from_module(target_network) + target_params = target_params.apply(Buffer) + target_params.to_module(target_network) + setattr(self, name_target_network, target_network) + elif create_target_params: + target_params = TensorDict.from_module( + module, as_module=True + ).apply(lambda x: Buffer(x.data.clone())) + with target_params.to("meta").to_module(module): + # The only way to get this working is to deepcopy the module + setattr(self, name_target_network, deepcopy(module)) + target_params.to_module(getattr(self, name_target_network)) + else: + self.__name__[name_target_network] = module + + setattr(self, name_target_network + "_params", None) + self._has_update_associated[module_name] = not create_target_params + return else: params = TensorDict.from_module(module, as_module=True) + if not self.functional: + if expand_dim: + + new_params = params.data.expand(expand_dim, *params.shape).clone() + new_params = new_params.apply(_custom_resample) + new_params = new_params.unbind(0) + + def deepcopy_module(module, new_params): + with new_params.to("meta").to_module(module): + module = deepcopy(module) + new_params = new_params.data.apply( + lambda x, y: nn.Parameter(x, requires_grad=True) + if isinstance(y, nn.Parameter) + else Buffer(x), + params, + ) + new_params.to_module(module) + return module + + module = TensorDictSequential( + *( + deepcopy_module(module, new_params_) + for new_params_ in new_params + ) + ) + params = TensorDict.from_module(module, as_module=True) + setattr(self, module_name, module) + setattr(self, module_name + "_params", None) + if target_network is not None and not expand_dim: + target_network.requires_grad_(False) + # convert all params to buffers + target_params = TensorDict.from_module(target_network) + target_params = target_params.apply(Buffer) + target_params.to_module(target_network) + + # Check shape of the target params + def assert_shape(p1, p2): + if not p1.shape == p2.shape: + raise ValueError( + "The shape of the target parameters must match the original ones." + ) + + target_params.apply(assert_shape, params, filter_empty=True) + setattr(self, name_target_network, target_network) + elif create_target_params: + if target_network is not None: + warnings.warn( + "The target network is ignored as the network must be deepcopied anyway. " + f"If you want to use a precise target network for {module_name}, please provide " + f"a list of targets instead." + ) + target_params = TensorDict.from_module( + module, as_module=True + ).apply(lambda x: Buffer(x.data.clone())) + with target_params.to("meta").to_module(module): + # The only way to get this working is to deepcopy the module + setattr(self, name_target_network, deepcopy(module)) + target_params.to_module(getattr(self, name_target_network)) + else: + self.__dict__[name_target_network] = module + setattr(self, name_target_network + "_params", None) + self._has_update_associated[module_name] = not create_target_params + return for key in params.keys(True): if sep in key: @@ -362,11 +479,7 @@ def _compare_and_expand(param): return expanded_param else: p_out = param.expand(expand_dim, *param.shape).clone() - p_out = nn.Parameter( - p_out.uniform_( - p_out.min().item(), p_out.max().item() - ).requires_grad_() - ) + p_out = nn.Parameter(_custom_resample(p_out).requires_grad_()) return p_out params = TensorDictParams( @@ -396,7 +509,6 @@ def _compare_and_expand(param): # A deepcopy with meta device could be used but that assumes that the model is copyable! self.__dict__[module_name] = module - name_params_target = "target_" + module_name if create_target_params: # if create_target_params: # we create a TensorDictParams to keep the target params as Buffer instances @@ -406,9 +518,13 @@ def _compare_and_expand(param): ), no_convert=True, ) - setattr(self, name_params_target + "_params", target_params) + setattr(self, name_target_network + "_params", target_params) + self.__dict__[name_target_network] = module self._has_update_associated[module_name] = not create_target_params + # legacy + convert_to_functional = maybe_convert_to_functional + def __getattr__(self, item): if item.startswith("target_") and item.endswith("_params"): params = self._modules.get(item, None) @@ -582,6 +698,18 @@ def set_vmap_randomness(self, value): self._vmap_randomness = value self._make_vmap() + def _maybe_func_call( + self, *args, module: nn.Module, module_params: TensorDictBase, func=None + ): + if func is None: + func = "forward" + module_func = getattr(module, func) + if self.functional: + with module_params.to_module(module): + return module_func(*args) + else: + return module_func(*args) + @staticmethod def _make_meta_params(param): is_param = isinstance(param, nn.Parameter) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 6a6cf8548e4..8390fa0c8bf 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -29,8 +29,8 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_vmap_maybe_func, _reduce, - _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, @@ -301,7 +301,7 @@ def __init__( # Actor self.delay_actor = delay_actor - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, @@ -311,7 +311,7 @@ def __init__( self.delay_qvalue = delay_qvalue self.num_qvalue_nets = 2 - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", self.num_qvalue_nets, @@ -377,10 +377,10 @@ def __init__( self.reduction = reduction def _make_vmap(self): - self._vmap_qvalue_networkN0 = _vmap_func( + self._vmap_qvalue_networkN0 = _maybe_vmap_maybe_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self._vmap_qvalue_network00 = _vmap_func( + self._vmap_qvalue_network00 = _maybe_vmap_maybe_func( self.qvalue_network, randomness=self.vmap_randomness ) @@ -1067,7 +1067,7 @@ def __init__( action_space=action_space, ) - self.convert_to_functional( + self.maybe_convert_to_functional( value_network, "value_network", create_target_params=self.delay_value, diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index d86442fca12..b5628a8fc34 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -22,8 +22,8 @@ from torchrl.objectives.utils import ( _cache_values, + _maybe_vmap_maybe_func, _reduce, - _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, @@ -277,7 +277,7 @@ def __init__( self._set_deprecated_ctor_keys(priority_key=priority_key) # Actor - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=False, @@ -294,7 +294,7 @@ def __init__( self.num_qvalue_nets = num_qvalue_nets q_value_policy_params = policy_params - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", num_qvalue_nets, @@ -342,7 +342,7 @@ def __init__( self.reduction = reduction def _make_vmap(self): - self._vmap_qnetworkN0 = _vmap_func( + self._vmap_qnetworkN0 = _maybe_vmap_maybe_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 7dc6b23212a..1ed842c6980 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -218,7 +218,7 @@ def __init__( with params_meta.to_module(actor_critic): self.__dict__["actor_critic"] = deepcopy(actor_critic) - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, @@ -229,7 +229,7 @@ def __init__( policy_params = list(actor_network.parameters()) else: policy_params = None - self.convert_to_functional( + self.maybe_convert_to_functional( value_network, "value_network", create_target_params=self.delay_value, diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index eb34b021484..01c3c4f4fd6 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -95,7 +95,7 @@ def __init__( super().__init__() # Actor Network - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=False, @@ -300,7 +300,7 @@ def __init__( super().__init__() # Actor Network - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=False, diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 4f805c1b411..d6e864bd2d3 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -24,8 +24,8 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_vmap_maybe_func, _reduce, - _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -166,7 +166,7 @@ def __init__( super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, @@ -181,7 +181,7 @@ def __init__( self.actor_network.return_log_prob = True self.delay_qvalue = delay_qvalue - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", expand_dim=num_qvalue_nets, @@ -228,7 +228,9 @@ def __init__( raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _make_vmap(self): - self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) + self._vmap_qvalue_networkN0 = _maybe_vmap_maybe_func( + self.qvalue_network, (None, 0) + ) @property def target_entropy(self): diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 1f3ec714f53..f9b96b4968d 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -199,7 +199,7 @@ def __init__( action_space=action_space, ) - self.convert_to_functional( + self.maybe_convert_to_functional( value_network, "value_network", create_target_params=self.delay_value, @@ -462,7 +462,7 @@ def __init__( module=value_network, wrapper_type=DistributionalQValueActor ) - self.convert_to_functional( + self.maybe_convert_to_functional( value_network, "value_network", create_target_params=self.delay_value, diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py index 3c0050fca84..5ba0da70164 100644 --- a/torchrl/objectives/gail.py +++ b/torchrl/objectives/gail.py @@ -86,7 +86,7 @@ def __init__( super().__init__() # Discriminator Network - self.convert_to_functional( + self.maybe_convert_to_functional( discriminator_network, "discriminator_network", create_target_params=False, diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 04d7e020551..895c5533f03 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -20,8 +20,8 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_vmap_maybe_func, _reduce, - _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, @@ -279,7 +279,7 @@ def __init__( self.expectile = expectile # Actor Network - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=False, @@ -291,7 +291,7 @@ def __init__( else: policy_params = None # Value Function Network - self.convert_to_functional( + self.maybe_convert_to_functional( value_network, "value_network", create_target_params=False, @@ -307,7 +307,7 @@ def __init__( ) else: qvalue_policy_params = None - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", num_qvalue_nets, @@ -322,7 +322,7 @@ def __init__( self.reduction = reduction def _make_vmap(self): - self._vmap_qvalue_networkN0 = _vmap_func( + self._vmap_qvalue_networkN0 = _maybe_vmap_maybe_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index ce4cc8ddbb8..e1a7d363ca3 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -225,12 +225,12 @@ def __init__( ).to_module(global_value_network): self.__dict__["global_value_network"] = deepcopy(global_value_network) - self.convert_to_functional( + self.maybe_convert_to_functional( local_value_network, "local_value_network", create_target_params=self.delay_value, ) - self.convert_to_functional( + self.maybe_convert_to_functional( mixer_network, "mixer_network", create_target_params=self.delay_value, diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index d79f0b2ea84..f6854751815 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -328,7 +328,7 @@ def __init__( self._out_keys = None super().__init__() if functional: - self.convert_to_functional(actor_network, "actor_network") + self.maybe_convert_to_functional(actor_network, "actor_network") else: self.actor_network = actor_network self.actor_network_params = None @@ -341,7 +341,7 @@ def __init__( else: policy_params = None if functional: - self.convert_to_functional( + self.maybe_convert_to_functional( critic_network, "critic_network", compare_against=policy_params ) else: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index cda2c62894e..75f3fdc1acb 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -23,8 +23,8 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_vmap_maybe_func, _reduce, - _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, @@ -280,7 +280,7 @@ def __init__( self._in_keys = None self._set_deprecated_ctor_keys(priority_key=priority_key) - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, @@ -295,7 +295,7 @@ def __init__( else: policy_params = None self.delay_qvalue = delay_qvalue - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", num_qvalue_nets, @@ -339,10 +339,10 @@ def __init__( self._make_vmap() def _make_vmap(self): - self._vmap_qvalue_network00 = _vmap_func( + self._vmap_qvalue_network00 = _maybe_vmap_maybe_func( self.qvalue_network, randomness=self.vmap_randomness ) - self._vmap_getdist = _vmap_func( + self._vmap_getdist = _maybe_vmap_maybe_func( self.actor_network, func="get_dist_params", randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 08ff896610c..8c5c93a3994 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -275,7 +275,7 @@ def __init__( # Actor if self.functional: - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=False, @@ -292,7 +292,7 @@ def __init__( # Value if critic_network is not None: if self.functional: - self.convert_to_functional( + self.maybe_convert_to_functional( critic_network, "critic_network", create_target_params=self.delay_value, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index f5c69dd68cd..c104b855d07 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import contextlib import math import warnings from dataclasses import dataclass @@ -28,11 +29,11 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, - _LoopVmapModule, + _maybe_vmap_maybe_func, _reduce, - _vmap_func, default_value_kwargs, distance_loss, + hold_out_net, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -54,16 +55,24 @@ class SACLoss(LossModule): Reinforcement Learning with a Stochastic Actor" https://arxiv.org/abs/1801.01290 and "Soft Actor-Critic Algorithms and Applications" https://arxiv.org/abs/1812.05905 + SACLoss has four different losses that are executed in the following order in :meth:`.forward`: + + - :meth:`~.qvalue_loss`; + - :meth:`~.value_loss`, which can be ignored for SAC-v2; + - :meth:`~.actor_loss`; + - :meth:`~.alpha_loss`. + + Args: actor_network (ProbabilisticActor): stochastic actor - qvalue_network (TensorDictModule): Q(s, a) parametric model. + qvalue_network (TensorDictModule or list of modules): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` times. If a list of modules is passed, their parameters will be stacked unless they share the same identity (in which case the original parameter will be expanded). - .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + .. warning:: When a list of modules if passed, it will not be compared against the policy parameters and all the parameters will be considered as untied. value_network (TensorDictModule, optional): V(s) parametric model. @@ -75,6 +84,9 @@ class SACLoss(LossModule): Keyword Args: num_qvalue_nets (integer, optional): number of Q-Value networks used. + If a list of :attr:`qvalue_network` is passed and ``num_qvalue_nets`` is not specified, + the vlaue of ``num_qvalue_nets`` is determined accordingly. If the two are passed, the length of the + module list must match ``num_qvalue_nets``. Defaults to ``2``. loss_function (str, optional): loss function to be used with the value function loss. Default is `"smooth_l1"`. @@ -296,7 +308,7 @@ def __init__( qvalue_network: TensorDictModule | List[TensorDictModule], value_network: Optional[TensorDictModule] = None, *, - num_qvalue_nets: int = 2, + num_qvalue_nets: int | None = None, loss_function: str = "smooth_l1", alpha_init: float = 1.0, min_alpha: float = None, @@ -312,8 +324,12 @@ def __init__( separate_losses: bool = False, reduction: str = None, use_vmap: bool = True, + functional: bool = True, + target_actor_network: ProbabilisticActor | None = None, + target_qvalue_network: ProbabilisticActor | None = None, + target_value_network: ProbabilisticActor | None = None, ) -> None: - self.use_vmap = use_vmap + self._in_keys = None self._out_keys = None if reduction is None: @@ -321,12 +337,22 @@ def __init__( super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) + if num_qvalue_nets is None: + if isinstance(qvalue_network, (list, tuple)): + num_qvalue_nets = len(qvalue_network) + else: + num_qvalue_nets = 2 + + self.use_vmap = use_vmap + self.functional = functional + # Actor self.delay_actor = delay_actor - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, + target_network=target_actor_network, ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -339,11 +365,12 @@ def __init__( if value_network is not None: self._version = 1 self.delay_value = delay_value - self.convert_to_functional( + self.maybe_convert_to_functional( value_network, "value_network", create_target_params=self.delay_value, compare_against=policy_params, + target_network=target_value_network, ) else: self._version = 2 @@ -359,12 +386,13 @@ def __init__( q_value_policy_params = policy_params else: q_value_policy_params = policy_params - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", num_qvalue_nets, create_target_params=self.delay_qvalue, compare_against=q_value_policy_params, + target_network=target_qvalue_network, ) self.loss_function = loss_function @@ -412,23 +440,33 @@ def __init__( self._make_vmap() self.reduction = reduction + @property + def functional(self): + return self._functional + + @functional.setter + def functional(self, value): + self._functional = value + def _make_vmap(self): - if self.use_vmap: - self._vmap_qnetworkN0 = _vmap_func( - self.qvalue_network, (None, 0), randomness=self.vmap_randomness + if self.use_vmap and not self.functional: + raise RuntimeError( + "functional=False required use_vmap to be set to False too." ) - if self._version == 1: - self._vmap_qnetwork00 = _vmap_func( - self.qvalue_network, randomness=self.vmap_randomness - ) - else: - self._vmap_qnetworkN0 = _LoopVmapModule( - self.qvalue_network, (None, 0), functional=True + self._vmap_qnetworkN0 = _maybe_vmap_maybe_func( + self.qvalue_network, + (None, 0), + randomness=self.vmap_randomness, + functional=self.functional, + use_vmap=self.use_vmap, + ) + if self._version == 1: + self._vmap_qnetwork00 = _maybe_vmap_maybe_func( + self.qvalue_network, + randomness=self.vmap_randomness, + functional=self.functional, + use_vmap=self.use_vmap, ) - if self._version == 1: - self._vmap_qnetwork00 = _LoopVmapModule( - self.qvalue_network, functional=True - ) @property def target_entropy_buffer(self): @@ -588,14 +626,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_reshape = tensordict - if self._version == 1: - loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict_reshape) - loss_value, _ = self._value_loss(tensordict_reshape) - else: - loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape) - loss_value = None + loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape) + loss_value, _ = self.value_loss(tensordict_reshape) loss_actor, metadata_actor = self.actor_loss(tensordict_reshape) - loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) + loss_alpha, _ = self.alpha_loss(log_prob=metadata_actor["log_prob"]) + tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape @@ -624,16 +659,58 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) return td_out + def qvalue_loss(self, tensordict): + """The loss for the QValue network. + + In :meth:`~.forward`, this method is the first to be executed. + + Args: + tensordict (TensorDictBase): the input data. See :attr:`~.in_keys` for more details + on the required fields. + + Returns: a tensor containing the qvalue loss along with a dictionary of metadata. + + """ + if self._version == 1: + loss_qvalue, qvalue_metadata = self._qvalue_v1_loss(tensordict) + else: + loss_qvalue, qvalue_metadata = self._qvalue_v2_loss(tensordict) + return loss_qvalue, qvalue_metadata + + def value_loss(self, tensordict): + """The loss for the QValue network. + + In :meth:`~.forward`, this method is the second to be executed. + It's a no-op for SAC-v2. + + Args: + tensordict (TensorDictBase): the input data. See :attr:`~.in_keys` for more details + on the required fields. + + Returns: a tensor containing the qvalue loss along with a dictionary of metadata. + + """ + if self._version == 1: + loss_value, metadata = self._value_loss(tensordict) + else: + loss_value = None + metadata = {} + return loss_value, metadata + @property @_cache_values def _cached_detached_qvalue_params(self): - return self.qvalue_network_params.detach() + if self.functional: + return self.qvalue_network_params.detach() + return None def actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: """The loss for the actor. + In :meth:`~.forward`, this method is the third to be executed. + Args: tensordict (TensorDictBase): the input data. See :attr:`~.in_keys` for more details on the required fields. @@ -641,20 +718,25 @@ def actor_loss( Returns: a tensor containing the actor loss along with a dictionary of metadata. """ - with set_exploration_type( - ExplorationType.RANDOM - ), self.actor_network_params.to_module(self.actor_network): - dist = self.actor_network.get_dist(tensordict) - a_reparm = dist.rsample() + with set_exploration_type(ExplorationType.RANDOM): + dist = self._maybe_func_call( + tensordict, + func="get_dist", + module=self.actor_network, + module_params=self.actor_network_params, + ) + a_reparm = dist.rsample() log_prob = dist.log_prob(a_reparm) td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q.set(self.tensor_keys.action, a_reparm) - - td_q = self._vmap_qnetworkN0( - td_q, - self._cached_detached_qvalue_params, # should we clone? - ) + with hold_out_net( + self.qvalue_network + ) if not self.functional else contextlib.nullcontext(): + td_q = self._vmap_qnetworkN0( + td_q, + self._cached_detached_qvalue_params, # should we clone? + ) min_q_logprob = ( td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) @@ -667,28 +749,52 @@ def actor_loss( return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} + def alpha_loss(self, log_prob: Tensor) -> Tensor: + """The loss for the entropy factor. + + In :meth:`~.forward`, this method is the fourth to be executed. + + Args: + log_prob (Tensor): the log probability of the action taken in the actor loss. + + Returns: a tensor containing the entropy loss for the `alpha` parameter and a dictionary of metadata. + """ + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_prob) + return alpha_loss, {} + @property @_cache_values def _cached_target_params_actor_value(self): - return TensorDict._new_unsafe( - { - "module": { - "0": self.target_actor_network_params, - "1": self.target_value_network_params, - } - }, - torch.Size([]), - ) + if self.functional: + return TensorDict._new_unsafe( + { + "module": { + "0": self.target_actor_network_params, + "1": self.target_value_network_params, + } + }, + torch.Size([]), + ) + return None def _qvalue_v1_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: - target_params = self._cached_target_params_actor_value with set_exploration_type(self.deterministic_sampling_mode): - target_value = self.value_estimator.value_estimate( - tensordict, target_params=target_params - ).squeeze(-1) - + if self.functional: + target_params = self._cached_target_params_actor_value + target_value = self.value_estimator.value_estimate( + tensordict, target_params=target_params + ).squeeze(-1) + else: + target_value = self.value_estimator.value_estimate(tensordict).squeeze( + -1 + ) # Q-nets must be trained independently: as such, we split the data in 2 # if required and train each q-net on one half of the data. shape = tensordict.shape @@ -732,14 +838,18 @@ def _compute_target_v2(self, tensordict) -> Tensor: tensordict = tensordict.clone(False) # get actions and log-probs with torch.no_grad(): - with set_exploration_type( - ExplorationType.RANDOM - ), self.actor_network_params.to_module(self.actor_network): - next_tensordict = tensordict.get("next").clone(False) - next_dist = self.actor_network.get_dist(next_tensordict) - next_action = next_dist.rsample() - next_tensordict.set(self.tensor_keys.action, next_action) - next_sample_log_prob = next_dist.log_prob(next_action) + next_tensordict = tensordict.get("next").clone(False) + # We need to set the mode if the module is explorative (?) + with set_exploration_type(ExplorationType.RANDOM): + next_dist = self._maybe_func_call( + next_tensordict, + module=self.actor_network, + func="get_dist", + module_params=self.actor_network_params, + ) + next_action = next_dist.rsample() + next_tensordict.set(self.tensor_keys.action, next_action) + next_sample_log_prob = next_dist.log_prob(next_action) # get q-values next_tensordict_expand = self._vmap_qnetworkN0( @@ -788,11 +898,19 @@ def _value_loss( ) -> Tuple[Tensor, Dict[str, Tensor]]: # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() - with self.value_network_params.to_module(self.value_network): - self.value_network(td_copy) + + td_copy = self._maybe_func_call( + td_copy, module=self.value_network, module_params=self.value_network_params + ) + pred_val = td_copy.get(self.tensor_keys.value).squeeze(-1) - with self.target_actor_network_params.to_module(self.actor_network): - action_dist = self.actor_network.get_dist(td_copy) # resample an action + + action_dist = self._maybe_func_call( + td_copy, + module=self.target_actor_network, + module_params=self.target_actor_network_params, + func="get_dist", + ) action = action_dist.rsample() td_copy.set(self.tensor_keys.action, action, inplace=False) @@ -818,15 +936,6 @@ def _value_loss( ) return loss_value, {} - def _alpha_loss(self, log_prob: Tensor) -> Tensor: - if self.target_entropy is not None: - # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) - else: - # placeholder - alpha_loss = torch.zeros_like(log_prob) - return alpha_loss - @property def _alpha(self): if self.min_log_alpha is not None: @@ -1055,7 +1164,7 @@ def __init__( super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, @@ -1067,7 +1176,7 @@ def __init__( else: policy_params = None self.delay_qvalue = delay_qvalue - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", num_qvalue_nets, @@ -1132,7 +1241,7 @@ def __init__( self.reduction = reduction def _make_vmap(self): - self._vmap_qnetworkN0 = _vmap_func( + self._vmap_qnetworkN0 = _maybe_vmap_maybe_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 922d6df7a74..673c79496e4 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -20,8 +20,8 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_vmap_maybe_func, _reduce, - _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, @@ -249,7 +249,7 @@ def __init__( self.delay_actor = delay_actor self.delay_qvalue = delay_qvalue - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, @@ -260,7 +260,7 @@ def __init__( policy_params = list(actor_network.parameters()) else: policy_params = None - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", num_qvalue_nets, @@ -321,10 +321,10 @@ def __init__( self.reduction = reduction def _make_vmap(self): - self._vmap_qvalue_network00 = _vmap_func( + self._vmap_qvalue_network00 = _maybe_vmap_maybe_func( self.qvalue_network, randomness=self.vmap_randomness ) - self._vmap_actor_network00 = _vmap_func( + self._vmap_actor_network00 = _maybe_vmap_maybe_func( self.actor_network, randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index cd40ac1e029..ee088473534 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -19,8 +19,8 @@ from torchrl.objectives.utils import ( _cache_values, + _maybe_vmap_maybe_func, _reduce, - _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, @@ -264,7 +264,7 @@ def __init__( self.delay_actor = delay_actor self.delay_qvalue = delay_qvalue - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, @@ -275,7 +275,7 @@ def __init__( policy_params = list(actor_network.parameters()) else: policy_params = None - self.convert_to_functional( + self.maybe_convert_to_functional( qvalue_network, "qvalue_network", num_qvalue_nets, @@ -335,10 +335,10 @@ def __init__( self.reduction = reduction def _make_vmap(self): - self._vmap_qvalue_network00 = _vmap_func( + self._vmap_qvalue_network00 = _maybe_vmap_maybe_func( self.qvalue_network, randomness=self.vmap_randomness ) - self._vmap_actor_network00 = _vmap_func( + self._vmap_actor_network00 = _maybe_vmap_maybe_func( self.actor_network, randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index affc39d7ff4..a9a47f2fe6d 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -482,27 +482,53 @@ def new_fun(self, netname=None): return new_fun -def _vmap_func(module, *args, func=None, call_vmap: bool = True, **kwargs): - try: +def _capture_vmap_error(func): + @functools.wraps(func) + def new_func(*args, **kwargs): + try: + return func(*args, **kwargs) + except RuntimeError as err: + if re.match( + r"vmap: called random operation while in randomness error mode", + str(err), + ): + raise RuntimeError( + "Please use .set_vmap_randomness('different') to handle random operations during vmap." + ) from err + + return new_func + + +def _maybe_vmap_maybe_func( + module, *args, func=None, use_vmap: bool = True, functional: bool = True, **kwargs +): + randomness = kwargs.pop("randomness", "error") + if functional: + if func is None: + func = "forward" + func = getattr(module, func) def decorated_module(*module_args_params): params = module_args_params[-1] module_args = module_args_params[:-1] with params.to_module(module): - if func is None: - return module(*module_args) - else: - return getattr(module, func)(*module_args) - - return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 + return func(*module_args) - except RuntimeError as err: - if re.match( - r"vmap: called random operation while in randomness error mode", str(err) - ): - raise RuntimeError( - "Please use .set_vmap_randomness('different') to handle random operations during vmap." - ) from err + if use_vmap: + out = vmap( # noqa: TOR101 + decorated_module, *args, randomness=randomness, **kwargs + ) + return _capture_vmap_error(out) + else: + return _LoopVmapModule(module, *args, func=func, **kwargs, functional=True) + else: + if use_vmap: + # This should be rarely reached - we don't allow vmap + non functional in losses + out = vmap(func, *args, randomness=randomness, **kwargs) # noqa: TOR101 + return _capture_vmap_error(out) + else: + # we still need to iterate over the module + return _LoopVmapModule(module, *args, func=func, **kwargs, functional=False) class _LoopVmapModule(nn.Module): @@ -511,6 +537,7 @@ def __init__( module: nn.Module, in_dims: Tuple[int | None] = None, out_dims: Tuple[int | None] = None, + func=None, register_module: bool = False, functional: bool = False, ): @@ -520,6 +547,9 @@ def __init__( self.__dict__["module"] = module else: self.module = module + if func is None: + func = "forward" + self.func = func self.in_dims = in_dims if out_dims is not None: raise NotImplementedError("out_dims not implemented yet.") @@ -531,8 +561,17 @@ def forward(self, *args): to_rep = [] if self.in_dims is None: self.in_dims = [0] * len(args) + + if not self.functional: + args = args[:-1] + in_dims = self.in_dims[:-1] + else: + in_dims = self.in_dims + args = list(args) - for i, (arg, in_dim) in enumerate(_zip_strict(args, self.in_dims)): + for i, (arg, in_dim) in enumerate(_zip_strict(args, in_dims)): + if not self.functional and n is None: + n = len(self.module) if in_dim is not None: arg = arg.unbind(in_dim) if n is None: @@ -550,12 +589,13 @@ def forward(self, *args): ] out = [] n_out = None - for _args in zip(*args): + for i, _args in enumerate(zip(*args)): if self.functional: with _args[-1].to_module(self.module): out.append(self.module(*_args[:-1])) else: - out.append(self.module(*_args)) + # Ignore the last param, which must be a TD containing params + out.append(self.module[i](*_args)) if n_out is None: n_out = len(out[-1]) if isinstance(out[-1], tuple) else 1 if n_out > 1: diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index b7db2e8242e..015afa76f0c 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -27,7 +27,11 @@ from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST +from torchrl.objectives.utils import ( + _maybe_vmap_maybe_func, + hold_out_net, + RANDOM_MODULE_LIST, +) from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -142,9 +146,9 @@ def _call_value_nets( ) elif params is not None: params_stack = torch.stack([params, next_params], 0).contiguous() - data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( - data_in, params_stack - ) + data_out = _maybe_vmap_maybe_func( + value_net, (0, 0), randomness=vmap_randomness + )(data_in, params_stack) else: data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) value_est = data_out.get(value_key) @@ -217,6 +221,7 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() value_network: Union[TensorDictModule, Callable] + target_value_network: Union[TensorDictModule, Callable] _vmap_randomness = None @property @@ -290,6 +295,7 @@ def __init__( self, *, value_network: TensorDictModule, + target_value_network: TensorDictModule | None = None, shifted: bool = False, differentiable: bool = False, skip_existing: bool | None = None, @@ -302,6 +308,9 @@ def __init__( self.differentiable = differentiable self.skip_existing = skip_existing self.__dict__["value_network"] = value_network + self.__dict__["target_value_network"] = ( + value_network if target_value_network is None else target_value_network + ) self.dep_keys = {} self.shifted = shifted @@ -429,8 +438,10 @@ def _next_value(self, tensordict, target_params, kwargs): step_td = step_mdp(tensordict, keep_other=False) if self.value_network is not None: with hold_out_net( - self.value_network - ) if target_params is None else target_params.to_module(self.value_network): + self.target_value_network + ) if target_params is None else target_params.to_module( + self.target_value_network + ): self.value_network(step_td) next_value = step_td.get(self.tensor_keys.value) return next_value @@ -519,6 +530,7 @@ def __init__( *, gamma: float | torch.Tensor, value_network: TensorDictModule, + target_value_network: TensorDictModule | None = None, shifted: bool = False, average_rewards: bool = False, differentiable: bool = False, @@ -530,6 +542,7 @@ def __init__( ): super().__init__( value_network=value_network, + target_value_network=target_value_network, differentiable=differentiable, shifted=shifted, advantage_key=advantage_key, @@ -732,6 +745,7 @@ def __init__( *, gamma: float | torch.Tensor, value_network: TensorDictModule, + target_value_network: TensorDictModule | None = None, average_rewards: bool = False, differentiable: bool = False, skip_existing: bool | None = None, @@ -744,6 +758,7 @@ def __init__( ): super().__init__( value_network=value_network, + target_value_network=target_value_network, differentiable=differentiable, advantage_key=advantage_key, value_target_key=value_target_key, @@ -954,6 +969,7 @@ def __init__( gamma: float | torch.Tensor, lmbda: float | torch.Tensor, value_network: TensorDictModule, + target_value_network: TensorDictModule | None = None, average_rewards: bool = False, differentiable: bool = False, vectorized: bool = True, @@ -967,6 +983,7 @@ def __init__( ): super().__init__( value_network=value_network, + target_value_network=target_value_network, differentiable=differentiable, advantage_key=advantage_key, value_target_key=value_target_key, @@ -1209,6 +1226,7 @@ def __init__( gamma: float | torch.Tensor, lmbda: float | torch.Tensor, value_network: TensorDictModule, + target_value_network: TensorDictModule | None = None, average_gae: bool = False, differentiable: bool = False, vectorized: bool = True, @@ -1223,6 +1241,7 @@ def __init__( super().__init__( shifted=shifted, value_network=value_network, + target_value_network=target_value_network, differentiable=differentiable, advantage_key=advantage_key, value_target_key=value_target_key, @@ -1517,6 +1536,7 @@ def __init__( gamma: float | torch.Tensor, actor_network: TensorDictModule, value_network: TensorDictModule, + target_value_network: TensorDictModule | None = None, rho_thresh: float | torch.Tensor = 1.0, c_thresh: float | torch.Tensor = 1.0, average_adv: bool = False, @@ -1532,6 +1552,7 @@ def __init__( super().__init__( shifted=shifted, value_network=value_network, + target_value_network=target_value_network, differentiable=differentiable, advantage_key=advantage_key, value_target_key=value_target_key, diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 869f0f980b3..cde9d23edf6 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -198,12 +198,12 @@ def _init( ) -> None: super(type(self), self).__init__() - self.convert_to_functional( + self.maybe_convert_to_functional( actor_network, "actor_network", create_target_params=True, ) - self.convert_to_functional( + self.maybe_convert_to_functional( value_network, "value_network", create_target_params=True,