Skip to content

Commit

Permalink
[Feature] non-functional SAC loss
Browse files Browse the repository at this point in the history
ghstack-source-id: fd766d1a4f0868435920c8fe8caffb75425c2a05
Pull Request resolved: #2393
  • Loading branch information
vmoens committed Aug 13, 2024
1 parent 63a1457 commit 0dbb1a5
Show file tree
Hide file tree
Showing 21 changed files with 591 additions and 181 deletions.
140 changes: 125 additions & 15 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
from torchrl.objectives.redq import REDQLoss
from torchrl.objectives.reinforce import ReinforceLoss
from torchrl.objectives.utils import (
_vmap_func,
_maybe_vmap_maybe_func,
HardUpdate,
hold_out_net,
SoftUpdate,
Expand Down Expand Up @@ -254,11 +254,11 @@ def __init__(self):
layers.append(nn.Linear(4, 4))
net = nn.Sequential(*layers).to(device)
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
self.convert_to_functional(model, "model", expand_dim=4)
self.maybe_convert_to_functional(model, "model", expand_dim=4)
self._make_vmap()

def _make_vmap(self):
self.vmap_model = _vmap_func(
self.vmap_model = _maybe_vmap_maybe_func(
self.model,
(None, 0),
randomness=(
Expand Down Expand Up @@ -3876,6 +3876,116 @@ def test_sac_vmap_equiv(

assert_allclose_td(loss_vmap, loss_novmap)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("as_list", [True, False])
@pytest.mark.parametrize("provide_target", [True, False])
@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
def test_sac_nofunc(
self,
device,
version,
as_list,
provide_target,
delay_value,
delay_actor,
delay_qvalue,
num_qvalue=4,
td_est=None,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")

torch.manual_seed(self.seed)
td = self._create_mock_data_sac(device=device)

kwargs = {}

actor = self._create_mock_actor(device=device)
if delay_actor:
kwargs["delay_actor"] = True
if provide_target:
kwargs["target_actor_network"] = self._create_mock_actor(device=device)
kwargs["target_actor_network"].load_state_dict(actor.state_dict())
if as_list:
qvalue = [
self._create_mock_qvalue(device=device) for _ in range(num_qvalue)
]
else:
qvalue = self._create_mock_qvalue(device=device)

if delay_qvalue:
kwargs["delay_qvalue"] = True
if provide_target:
if as_list:
kwargs["target_qvalue_network"] = [
self._create_mock_qvalue(device=device) for _ in qvalue
]
for qval_t, qval in zip(kwargs["target_qvalue_network"], qvalue):
qval_t.load_state_dict(qval.state_dict())

else:
kwargs["target_qvalue_network"] = self._create_mock_qvalue(
device=device
)
kwargs["target_qvalue_network"].load_state_dict(qvalue.state_dict())

if version == 1:
value = self._create_mock_value(device=device)
else:
value = None
if delay_value:
kwargs["delay_value"] = True
if provide_target and version == 1:
kwargs["target_value_network"] = self._create_mock_value(device=device)
kwargs["target_value_network"].load_state_dict(value.state_dict())

rng_state = torch.random.get_rng_state()
with pytest.warns(
UserWarning, match="The target network is ignored as the"
) if delay_qvalue and not as_list and provide_target else contextlib.nullcontext():
loss_fn_nofunc = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
use_vmap=False,
functional=False,
**kwargs,
)
torch.random.set_rng_state(rng_state)
loss_fn_func = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
use_vmap=False,
functional=True,
**kwargs,
)
assert_allclose_td(
torch.stack(
list(
TensorDict.from_module(loss_fn_nofunc.qvalue_network)[
"module"
].values()
)
),
loss_fn_func.qvalue_network_params.data,
)
with torch.no_grad(), _check_td_steady(td), pytest.warns(
UserWarning, match="No target network updater"
):
rng_state = torch.random.get_rng_state()
loss_func = loss_fn_nofunc(td.clone())
torch.random.set_rng_state(rng_state)
loss_nofunc = loss_fn_func(td.clone())

assert_allclose_td(loss_func, loss_nofunc)

@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
Expand Down Expand Up @@ -12383,7 +12493,7 @@ class MyLoss(LossModule):

def __init__(self, actor_network):
super().__init__()
self.convert_to_functional(
self.maybe_convert_to_functional(
actor_network,
"actor_network",
create_target_params=create_target_params,
Expand Down Expand Up @@ -12532,7 +12642,7 @@ class custom_module(LossModule):
def __init__(self, delay_module=True):
super().__init__()
module1 = torch.nn.BatchNorm2d(10).eval()
self.convert_to_functional(
self.maybe_convert_to_functional(
module1, "module1", create_target_params=delay_module
)

Expand Down Expand Up @@ -14296,12 +14406,12 @@ class MyLoss(LossModule):

def __init__(self, actor_network, qvalue_network):
super().__init__()
self.convert_to_functional(
self.maybe_convert_to_functional(
actor_network,
"actor_network",
create_target_params=True,
)
self.convert_to_functional(
self.maybe_convert_to_functional(
qvalue_network,
"qvalue_network",
3,
Expand Down Expand Up @@ -14869,8 +14979,8 @@ def __init__(self, compare_against, expand_dim):
module_b = TensorDictModule(
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
)
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
self.maybe_convert_to_functional(module_a, "module_a")
self.maybe_convert_to_functional(
module_b,
"module_b",
compare_against=module_a.parameters() if compare_against else [],
Expand Down Expand Up @@ -14918,8 +15028,8 @@ def __init__(self, expand_dim=2):
module_b = TensorDictModule(
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
)
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
self.maybe_convert_to_functional(module_a, "module_a")
self.maybe_convert_to_functional(
module_b,
"module_b",
compare_against=module_a.parameters(),
Expand Down Expand Up @@ -14967,8 +15077,8 @@ class MyLoss(LossModule):

def __init__(self, module_a, module_b0, module_b1, expand_dim=2):
super().__init__()
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
self.maybe_convert_to_functional(module_a, "module_a")
self.maybe_convert_to_functional(
[module_b0, module_b1],
"module_b",
# This will be ignored
Expand Down Expand Up @@ -15337,14 +15447,14 @@ def __init__(self):
TensorDictModule(value, in_keys=["hidden"], out_keys=["value"]),
)
super().__init__()
self.convert_to_functional(
self.maybe_convert_to_functional(
actor,
"actor",
expand_dim=None,
create_target_params=False,
compare_against=None,
)
self.convert_to_functional(
self.maybe_convert_to_functional(
value,
"value",
expand_dim=2,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
)

if functional:
self.convert_to_functional(
self.maybe_convert_to_functional(
actor_network,
"actor_network",
)
Expand All @@ -292,7 +292,7 @@ def __init__(
else:
policy_params = None
if functional:
self.convert_to_functional(
self.maybe_convert_to_functional(
critic_network, "critic_network", compare_against=policy_params
)
else:
Expand Down
Loading

0 comments on commit 0dbb1a5

Please sign in to comment.