From 18fff54c1088a40860904eaecb0147ecba195076 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 14 Jan 2025 11:09:30 +0000 Subject: [PATCH] Minor refactor --- src/brevitas/graph/base.py | 45 +++++---------- src/brevitas/graph/equalize.py | 77 +++++++------------------ tests/brevitas/graph/test_transforms.py | 6 +- 3 files changed, 38 insertions(+), 90 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index f30d4ca3a..d1631f34e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -188,14 +188,14 @@ class ModuleInstanceRegisterParametrization(Transform): parametrization tensor_name: (str): name of the :class:`torch.nn.Parameter` of module which is to be parametrized - parametrization_module (nn.Module): the parametrization to + transform_module (nn.Module): the parametrization to register """ - def __init__(self, module: Module, tensor_name: str, parametrization_module: Module) -> None: + def __init__(self, module: Module, tensor_name: str, transform_module: Module) -> None: self.module = module self.tensor_name = tensor_name - self.parametrization_module = parametrization_module + self.transform_module = transform_module # TODO: Unify inferfaces with ModuleInstanceToModuleInstance for # compatibility with fix_rewriter @@ -212,47 +212,31 @@ def apply(self, model: GraphModule) -> GraphModule: if module is self.module: # register the parametrization to module parametrize.register_parametrization( - module, self.tensor_name, self.parametrization_module) + module, self.tensor_name, self.transform_module) break return model -class ModuleInstanceFuseRotationWeights(Transform): - r"""Transform to rotate in-place a given parameter of a module by a - specified axis +class ModuleInstanceTransformTensor(Transform): + r"""Transform to transform in-place a given parameter of a module Args: - module (nn.Module): parent module of the parameter to be rotated - rot_mat (Tensor): orthogonal matrix by which to rotate the tensor - rot_func (Callable): function to apply the rotation. The first - argument corresponds to the tensor to be rotated, while the - second specifies the rotation matrix. The third argument (K) is - useful when rotating by an Hadamard matrix and it corresponds - to the dimensionality of the matrix up to a power of two, - i.e. dim=(2**p)*K. See get_hadK for details - K (int, optional): if rot_mat is an Hadamard matrix, K is the highest - divisor of the dimensionality of the matrix, such that K, itself, - is not divisible by 2 - axis (int): axis by which to rotate the tensor - tensor_name: (str): name of the :class:`torch.nn.Parameter` of - module which is to be rotated + module (nn.Module): parent module of the parameter to be transformed + tensor_name (str): name of the :class:`torch.nn.Parameter` of + module which is to be transformed + transform_module (nn.Module): module defining the transformation to apply + to the tensor """ def __init__( self, module: Module, - rot_mat: Tensor, - rot_func: Callable[[Tensor, Tensor, Optional[int]], Tensor], - K: Optional[int], tensor_name: str, - axis: int, + transform_module: Module, ): self.module = module - self.rot_mat = rot_mat - self.rot_func = rot_func - self.K = K self.tensor_name = tensor_name - self.axis = axis + self.transform_module = transform_module # TODO: Unify inferfaces with ModuleInstanceToModuleInstance for # compatibility with fix_rewriter @@ -273,8 +257,7 @@ def apply(self, model: GraphModule) -> GraphModule: if hasattr(module, 'allocate_params'): module.allocate_params(module) tensor = getattr(module, self.tensor_name).data - tensor = RotationWeightParametrization( - self.rot_mat, self.rot_func, self.axis, self.K)(tensor) + tensor = self.transform_module(tensor) # Modify the weights in-place setattr(module, self.tensor_name, torch.nn.Parameter(tensor)) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 0861cd45f..3478ce2a8 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -23,9 +23,9 @@ from brevitas.graph import ModuleToModuleByInstance from brevitas.graph.base import GraphTransform from brevitas.graph.base import InsertModuleCallAfter -from brevitas.graph.base import ModuleInstanceFuseRotationWeights from brevitas.graph.base import ModuleInstanceRegisterParametrization from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import ModuleInstanceTransformTensor from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.base import Transform from brevitas.graph.hadamard import get_hadK @@ -1361,78 +1361,43 @@ def _apply_rotate( for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) - axis = _get_output_axis(module) - - if fuse_rotations: - rewriter = ModuleInstanceFuseRotationWeights( + # Rotate "bias" if present + tensor_names_axis = [("weight", _get_output_axis(module))] + ([ + ("bias", 1)] if getattr(module, 'bias', None) is not None else []) + # If rotations are fused, transform is applied directly onto the tensor + rewriter_class = ModuleInstanceTransformTensor if fuse_rotations else ModuleInstanceRegisterParametrization + # Obtain rewriters for applying the rotations + for tensor_name, axis in tensor_names_axis: + rewriter = rewriter_class( module=module, - rot_mat=rot_mat, - rot_func=rot_func, - K=K, - tensor_name="weight", - axis=axis, - ) - rewriters.append(rewriter) - - if getattr(module, 'bias', None) is not None: - rewriter = ModuleInstanceFuseRotationWeights( - module=module, - rot_mat=rot_mat, - rot_func=rot_func, - K=K, - tensor_name="bias", - axis=1, - ) - rewriters.append(rewriter) - else: - rewriter = ModuleInstanceRegisterParametrization( - module=module, - tensor_name="weight", - parametrization_module=RotationWeightParametrization( + tensor_name=tensor_name, + transform_module=RotationWeightParametrization( rot_mat=rot_mat, rot_func=rot_func, axis=axis, K=K, )) rewriters.append(rewriter) - if getattr(module, 'bias', None) is not None: - rewriter = ModuleInstanceRegisterParametrization( - module=module, - tensor_name="bias", - parametrization_module=RotationWeightParametrization( - rot_mat=rot_mat, - rot_func=rot_func, - axis=1, - K=K, - )) - rewriters.append(rewriter) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) - axis = _get_input_axis(module) - - if not insert_rotation_module and not fuse_rotations: - rewriter = ModuleInstanceRegisterParametrization( + # Only "weight" is rotated + tensor_names_axis = [("weight", _get_input_axis(module))] + # If rotations are fused or if the module is an orphan sink, transform is applied directly onto the tensor + rewriter_class = ModuleInstanceRegisterParametrization if insert_rotation_module or fuse_rotations else ModuleInstanceRegisterParametrization + # Obtain rewriters for applying the rotations + for tensor_name, axis in tensor_names_axis: + rewriter = rewriter_class( module=module, - tensor_name="weight", - parametrization_module=RotationWeightParametrization( + tensor_name=tensor_name, + transform_module=RotationWeightParametrization( rot_mat=rot_mat, rot_func=rot_func, axis=axis, K=K, )) rewriters.append(rewriter) - else: - rewriter = ModuleInstanceFuseRotationWeights( - module=module, - rot_mat=rot_mat, - rot_func=rot_func, - K=K, - tensor_name="weight", - axis=axis, - ) - rewriters.append(rewriter) - + # Replace by RotatedModule in orphan sink if insert_rotation_module and len(region.srcs) == 0: rewriter = ModuleInstanceWrapModule( module, RotatedModule, "layer", { diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index 9e4e63e56..c56294d9a 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -17,8 +17,8 @@ from brevitas.graph import MeanMethodToAdaptiveAvgPool2d from brevitas.graph import MergeBatchNorm from brevitas.graph import MethodToModule -from brevitas.graph.base import ModuleInstanceFuseRotationWeights from brevitas.graph.base import ModuleInstanceRegisterParametrization +from brevitas.graph.base import ModuleInstanceTransformTensor from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.base import ModuleToModuleByInstance from brevitas.nn import QuantConv1d @@ -358,8 +358,8 @@ def forward(self, x): model_unfused = TestModel() model_unfused.linear.weight.data = model_fused.linear.weight.data - model_fused = ModuleInstanceFuseRotationWeights( - model_fused.linear, rot_mat, rot_func, None, "weight", axis).apply(model_fused) + model_fused = ModuleInstanceTransformTensor( + model_fused.linear, "weight", rot_mat, rot_func, None, axis).apply(model_fused) model_unfused = ModuleInstanceRegisterParametrization( model_unfused.linear, "weight",