Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] non-functional SAC loss #2393

Open
wants to merge 2 commits into
base: gh/vmoens/19/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 125 additions & 15 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
from torchrl.objectives.redq import REDQLoss
from torchrl.objectives.reinforce import ReinforceLoss
from torchrl.objectives.utils import (
_vmap_func,
_maybe_vmap_maybe_func,
HardUpdate,
hold_out_net,
SoftUpdate,
Expand Down Expand Up @@ -254,11 +254,11 @@ def __init__(self):
layers.append(nn.Linear(4, 4))
net = nn.Sequential(*layers).to(device)
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
self.convert_to_functional(model, "model", expand_dim=4)
self.maybe_convert_to_functional(model, "model", expand_dim=4)
self._make_vmap()

def _make_vmap(self):
self.vmap_model = _vmap_func(
self.vmap_model = _maybe_vmap_maybe_func(
self.model,
(None, 0),
randomness=(
Expand Down Expand Up @@ -3876,6 +3876,116 @@ def test_sac_vmap_equiv(

assert_allclose_td(loss_vmap, loss_novmap)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("as_list", [True, False])
@pytest.mark.parametrize("provide_target", [True, False])
@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
def test_sac_nofunc(
self,
device,
version,
as_list,
provide_target,
delay_value,
delay_actor,
delay_qvalue,
num_qvalue=4,
td_est=None,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")

torch.manual_seed(self.seed)
td = self._create_mock_data_sac(device=device)

kwargs = {}

actor = self._create_mock_actor(device=device)
if delay_actor:
kwargs["delay_actor"] = True
if provide_target:
kwargs["target_actor_network"] = self._create_mock_actor(device=device)
kwargs["target_actor_network"].load_state_dict(actor.state_dict())
if as_list:
qvalue = [
self._create_mock_qvalue(device=device) for _ in range(num_qvalue)
]
else:
qvalue = self._create_mock_qvalue(device=device)

if delay_qvalue:
kwargs["delay_qvalue"] = True
if provide_target:
if as_list:
kwargs["target_qvalue_network"] = [
self._create_mock_qvalue(device=device) for _ in qvalue
]
for qval_t, qval in zip(kwargs["target_qvalue_network"], qvalue):
qval_t.load_state_dict(qval.state_dict())

else:
kwargs["target_qvalue_network"] = self._create_mock_qvalue(
device=device
)
kwargs["target_qvalue_network"].load_state_dict(qvalue.state_dict())

if version == 1:
value = self._create_mock_value(device=device)
else:
value = None
if delay_value:
kwargs["delay_value"] = True
if provide_target and version == 1:
kwargs["target_value_network"] = self._create_mock_value(device=device)
kwargs["target_value_network"].load_state_dict(value.state_dict())

rng_state = torch.random.get_rng_state()
with pytest.warns(
UserWarning, match="The target network is ignored as the"
) if delay_qvalue and not as_list and provide_target else contextlib.nullcontext():
loss_fn_nofunc = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
use_vmap=False,
functional=False,
**kwargs,
)
torch.random.set_rng_state(rng_state)
loss_fn_func = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
use_vmap=False,
functional=True,
**kwargs,
)
assert_allclose_td(
torch.stack(
list(
TensorDict.from_module(loss_fn_nofunc.qvalue_network)[
"module"
].values()
)
),
loss_fn_func.qvalue_network_params.data,
)
with torch.no_grad(), _check_td_steady(td), pytest.warns(
UserWarning, match="No target network updater"
):
rng_state = torch.random.get_rng_state()
loss_func = loss_fn_nofunc(td.clone())
torch.random.set_rng_state(rng_state)
loss_nofunc = loss_fn_func(td.clone())

assert_allclose_td(loss_func, loss_nofunc)

@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
Expand Down Expand Up @@ -12383,7 +12493,7 @@ class MyLoss(LossModule):

def __init__(self, actor_network):
super().__init__()
self.convert_to_functional(
self.maybe_convert_to_functional(
actor_network,
"actor_network",
create_target_params=create_target_params,
Expand Down Expand Up @@ -12532,7 +12642,7 @@ class custom_module(LossModule):
def __init__(self, delay_module=True):
super().__init__()
module1 = torch.nn.BatchNorm2d(10).eval()
self.convert_to_functional(
self.maybe_convert_to_functional(
module1, "module1", create_target_params=delay_module
)

Expand Down Expand Up @@ -14296,12 +14406,12 @@ class MyLoss(LossModule):

def __init__(self, actor_network, qvalue_network):
super().__init__()
self.convert_to_functional(
self.maybe_convert_to_functional(
actor_network,
"actor_network",
create_target_params=True,
)
self.convert_to_functional(
self.maybe_convert_to_functional(
qvalue_network,
"qvalue_network",
3,
Expand Down Expand Up @@ -14869,8 +14979,8 @@ def __init__(self, compare_against, expand_dim):
module_b = TensorDictModule(
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
)
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
self.maybe_convert_to_functional(module_a, "module_a")
self.maybe_convert_to_functional(
module_b,
"module_b",
compare_against=module_a.parameters() if compare_against else [],
Expand Down Expand Up @@ -14918,8 +15028,8 @@ def __init__(self, expand_dim=2):
module_b = TensorDictModule(
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
)
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
self.maybe_convert_to_functional(module_a, "module_a")
self.maybe_convert_to_functional(
module_b,
"module_b",
compare_against=module_a.parameters(),
Expand Down Expand Up @@ -14967,8 +15077,8 @@ class MyLoss(LossModule):

def __init__(self, module_a, module_b0, module_b1, expand_dim=2):
super().__init__()
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
self.maybe_convert_to_functional(module_a, "module_a")
self.maybe_convert_to_functional(
[module_b0, module_b1],
"module_b",
# This will be ignored
Expand Down Expand Up @@ -15337,14 +15447,14 @@ def __init__(self):
TensorDictModule(value, in_keys=["hidden"], out_keys=["value"]),
)
super().__init__()
self.convert_to_functional(
self.maybe_convert_to_functional(
actor,
"actor",
expand_dim=None,
create_target_params=False,
compare_against=None,
)
self.convert_to_functional(
self.maybe_convert_to_functional(
value,
"value",
expand_dim=2,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
)

if functional:
self.convert_to_functional(
self.maybe_convert_to_functional(
actor_network,
"actor_network",
)
Expand All @@ -292,7 +292,7 @@ def __init__(
else:
policy_params = None
if functional:
self.convert_to_functional(
self.maybe_convert_to_functional(
critic_network, "critic_network", compare_against=policy_params
)
else:
Expand Down
Loading
Loading