diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index d222488d0..d16c0fb65 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1443,7 +1443,39 @@ def _apply_rotate( return rewriters +from brevitas.utils.python_utils import recurse_getattr + + +def _untie_parameters_with_parametrizations(model: torch.nn.Module): + # get ALL model parameters and their names + all_named_parameters = { + name: param for name, param in model.named_parameters(remove_duplicate=False)} + + # get ONLY unique named parameters, + # if parameter is tied and have multiple names, it will be included only once + no_duplicate_named_parameters = { + name: param for name, param in model.named_parameters(remove_duplicate=True)} + + # the difference of the two sets will give us the tied parameters + tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) + + for tied_param_name in tied_param_names: + tied_param_name_split = tied_param_name.split(".") + # Check if the tied parameter is the original parameter in the module + if len(tied_param_name_split) >= 3 and tied_param_name_split[ + -3] == "parametrizations" and tied_param_name_split[-1] == "original": + # If that is the case, retrieve the parent module + parent_module = recurse_getattr(model, ".".join(tied_param_name_split[:-1])) + # And set to a new parameter, thus breaking the tie + setattr(parent_module, "original", nn.Parameter(all_named_parameters[tied_param_name])) + + return model + + def _fuse_rotations(model: nn.Module) -> nn.Module: + # First of all, parameters that have parametrizations need to be untied + model = _untie_parameters_with_parametrizations(model) + # Then, parametrizations can be safely removed for module in model.modules(): # Names of the tensors that can potentially be parametrized tensor_names = ["weight", "bias"] diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 7e7321b62..014637414 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import functools + from packaging import version import pytest import pytest_cases @@ -540,34 +542,49 @@ def forward(self, x): "srcs": [], "sinks": ["block2_linear2"]},] -@pytest_cases.fixture -def block_residual_model(): +class BlockResidualModel(nn.Module): - class BlockResidualModel(nn.Module): + def __init__(self, is_tied: bool = False) -> None: + super().__init__() + self.embedding = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) - def __init__(self) -> None: - super().__init__() - self.embedding = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + self.block1_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True) + self.block1_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) - self.block1_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True) - self.block1_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + self.block2_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + self.act = nn.SiLU() + self.block2_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True) - self.block2_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) - self.act = nn.SiLU() - self.block2_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True) + self.head = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + if is_tied: + self.head.weight = self.embedding.weight - self.head = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + def forward(self, x): + x = self.embedding(x) + r = x + x = self.block1_linear1(x) + x = self.block1_linear2(x) + r + r = x + x = self.block2_linear1(x) + x = self.act(x) + x = self.block2_linear2(x) + r + x = self.head(x) + return x - def forward(self, x): - x = self.embedding(x) - r = x - x = self.block1_linear1(x) - x = self.block1_linear2(x) + r - r = x - x = self.block2_linear1(x) - x = self.act(x) - x = self.block2_linear2(x) + r - x = self.head(x) - return x - return BlockResidualModel +@pytest_cases.fixture +def block_residual_model(): + return functools.partial(BlockResidualModel, is_tied=False) + + +@pytest_cases.fixture +def block_residual_model_tied(): + return functools.partial(BlockResidualModel, is_tied=True) + + +list_of_rotation_fixtures = [ + "block_residual_model", + "block_residual_model_tied",] + +rotation_model = fixture_union( + 'rotation_model', list_of_rotation_fixtures, ids=list_of_rotation_fixtures) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 8cacad219..21fe6ac79 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -413,12 +413,11 @@ def compare_model_weights(model_fused, model_unfused, classes_to_compare=(nn.Lin ids=lambda mask: "-".join([rot for mask_el, rot in zip(mask, ["R1", "R2", "R3"]) if mask_el])) @pytest_cases.parametrize('full_rotation_method', ['ort', 'had']) @pytest_cases.parametrize('device', ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']) -@pytest_cases.parametrize('fuse_rotations', [False, True], ids=["fused", "unfused"]) +@pytest_cases.parametrize('fuse_rotations', [False, True], ids=["unfused", "fused"]) @pytest_cases.parametrize('use_fx', [True, False], ids=["fx", "no-fx"]) -def test_apply_rotate( - block_residual_model, mask, full_rotation_method, device, fuse_rotations, use_fx): +def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_rotations, use_fx): # Instantiate a residual model for which a collection of regions is available - model = block_residual_model() + model = rotation_model() device = torch.device("cuda") if device == 'cuda' else torch.device("cpu") model.to(device) # Sample input to pass through the models @@ -433,11 +432,16 @@ def test_apply_rotate( # The module names in the original model need to be mapped to the ones # in graph_model map_model_graph = {} + assigned_graph_modules = set() for graph_module_name, graph_module in graph_model.named_modules(): if hasattr(graph_module, "weight"): for name, module in model.named_modules(): - if hasattr(module, "weight") and graph_module.weight is module.weight: + # The check name not in map_model_graph prevents the assignment to the same module + # when tied parameters are present + if name not in map_model_graph and graph_module_name not in assigned_graph_modules and hasattr( + module, "weight") and graph_module.weight is module.weight: map_model_graph[name] = graph_module_name + assigned_graph_modules.add(graph_module_name) # Replace the names of the modules in sources/sinks by the names of the modules in the FX model regions_dicts = [{ k: list(map(lambda x: map_model_graph[x], v))