Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 14, 2025
1 parent 0575f0e commit 18fff54
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 90 deletions.
45 changes: 14 additions & 31 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,14 @@ class ModuleInstanceRegisterParametrization(Transform):
parametrization
tensor_name: (str): name of the :class:`torch.nn.Parameter` of
module which is to be parametrized
parametrization_module (nn.Module): the parametrization to
transform_module (nn.Module): the parametrization to
register
"""

def __init__(self, module: Module, tensor_name: str, parametrization_module: Module) -> None:
def __init__(self, module: Module, tensor_name: str, transform_module: Module) -> None:
self.module = module
self.tensor_name = tensor_name
self.parametrization_module = parametrization_module
self.transform_module = transform_module

# TODO: Unify inferfaces with ModuleInstanceToModuleInstance for
# compatibility with fix_rewriter
Expand All @@ -212,47 +212,31 @@ def apply(self, model: GraphModule) -> GraphModule:
if module is self.module:
# register the parametrization to module
parametrize.register_parametrization(
module, self.tensor_name, self.parametrization_module)
module, self.tensor_name, self.transform_module)
break
return model


class ModuleInstanceFuseRotationWeights(Transform):
r"""Transform to rotate in-place a given parameter of a module by a
specified axis
class ModuleInstanceTransformTensor(Transform):
r"""Transform to transform in-place a given parameter of a module
Args:
module (nn.Module): parent module of the parameter to be rotated
rot_mat (Tensor): orthogonal matrix by which to rotate the tensor
rot_func (Callable): function to apply the rotation. The first
argument corresponds to the tensor to be rotated, while the
second specifies the rotation matrix. The third argument (K) is
useful when rotating by an Hadamard matrix and it corresponds
to the dimensionality of the matrix up to a power of two,
i.e. dim=(2**p)*K. See get_hadK for details
K (int, optional): if rot_mat is an Hadamard matrix, K is the highest
divisor of the dimensionality of the matrix, such that K, itself,
is not divisible by 2
axis (int): axis by which to rotate the tensor
tensor_name: (str): name of the :class:`torch.nn.Parameter` of
module which is to be rotated
module (nn.Module): parent module of the parameter to be transformed
tensor_name (str): name of the :class:`torch.nn.Parameter` of
module which is to be transformed
transform_module (nn.Module): module defining the transformation to apply
to the tensor
"""

def __init__(
self,
module: Module,
rot_mat: Tensor,
rot_func: Callable[[Tensor, Tensor, Optional[int]], Tensor],
K: Optional[int],
tensor_name: str,
axis: int,
transform_module: Module,
):
self.module = module
self.rot_mat = rot_mat
self.rot_func = rot_func
self.K = K
self.tensor_name = tensor_name
self.axis = axis
self.transform_module = transform_module

# TODO: Unify inferfaces with ModuleInstanceToModuleInstance for
# compatibility with fix_rewriter
Expand All @@ -273,8 +257,7 @@ def apply(self, model: GraphModule) -> GraphModule:
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
tensor = getattr(module, self.tensor_name).data
tensor = RotationWeightParametrization(
self.rot_mat, self.rot_func, self.axis, self.K)(tensor)
tensor = self.transform_module(tensor)
# Modify the weights in-place
setattr(module, self.tensor_name, torch.nn.Parameter(tensor))

Expand Down
77 changes: 21 additions & 56 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from brevitas.graph import ModuleToModuleByInstance
from brevitas.graph.base import GraphTransform
from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceFuseRotationWeights
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleInstanceTransformTensor
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import Transform
from brevitas.graph.hadamard import get_hadK
Expand Down Expand Up @@ -1361,78 +1361,43 @@ def _apply_rotate(

for name, indexes in region.srcs.items():
module = region.get_module_from_name(name)
axis = _get_output_axis(module)

if fuse_rotations:
rewriter = ModuleInstanceFuseRotationWeights(
# Rotate "bias" if present
tensor_names_axis = [("weight", _get_output_axis(module))] + ([
("bias", 1)] if getattr(module, 'bias', None) is not None else [])
# If rotations are fused, transform is applied directly onto the tensor
rewriter_class = ModuleInstanceTransformTensor if fuse_rotations else ModuleInstanceRegisterParametrization
# Obtain rewriters for applying the rotations
for tensor_name, axis in tensor_names_axis:
rewriter = rewriter_class(
module=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="weight",
axis=axis,
)
rewriters.append(rewriter)

if getattr(module, 'bias', None) is not None:
rewriter = ModuleInstanceFuseRotationWeights(
module=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="bias",
axis=1,
)
rewriters.append(rewriter)
else:
rewriter = ModuleInstanceRegisterParametrization(
module=module,
tensor_name="weight",
parametrization_module=RotationWeightParametrization(
tensor_name=tensor_name,
transform_module=RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
axis=axis,
K=K,
))
rewriters.append(rewriter)
if getattr(module, 'bias', None) is not None:
rewriter = ModuleInstanceRegisterParametrization(
module=module,
tensor_name="bias",
parametrization_module=RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
axis=1,
K=K,
))
rewriters.append(rewriter)

for name, indexes in region.sinks.items():
module = region.get_module_from_name(name)
axis = _get_input_axis(module)

if not insert_rotation_module and not fuse_rotations:
rewriter = ModuleInstanceRegisterParametrization(
# Only "weight" is rotated
tensor_names_axis = [("weight", _get_input_axis(module))]
# If rotations are fused or if the module is an orphan sink, transform is applied directly onto the tensor
rewriter_class = ModuleInstanceRegisterParametrization if insert_rotation_module or fuse_rotations else ModuleInstanceRegisterParametrization
# Obtain rewriters for applying the rotations
for tensor_name, axis in tensor_names_axis:
rewriter = rewriter_class(
module=module,
tensor_name="weight",
parametrization_module=RotationWeightParametrization(
tensor_name=tensor_name,
transform_module=RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
axis=axis,
K=K,
))
rewriters.append(rewriter)
else:
rewriter = ModuleInstanceFuseRotationWeights(
module=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="weight",
axis=axis,
)
rewriters.append(rewriter)

# Replace by RotatedModule in orphan sink
if insert_rotation_module and len(region.srcs) == 0:
rewriter = ModuleInstanceWrapModule(
module, RotatedModule, "layer", {
Expand Down
6 changes: 3 additions & 3 deletions tests/brevitas/graph/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from brevitas.graph import MeanMethodToAdaptiveAvgPool2d
from brevitas.graph import MergeBatchNorm
from brevitas.graph import MethodToModule
from brevitas.graph.base import ModuleInstanceFuseRotationWeights
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceTransformTensor
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import ModuleToModuleByInstance
from brevitas.nn import QuantConv1d
Expand Down Expand Up @@ -358,8 +358,8 @@ def forward(self, x):
model_unfused = TestModel()
model_unfused.linear.weight.data = model_fused.linear.weight.data

model_fused = ModuleInstanceFuseRotationWeights(
model_fused.linear, rot_mat, rot_func, None, "weight", axis).apply(model_fused)
model_fused = ModuleInstanceTransformTensor(
model_fused.linear, "weight", rot_mat, rot_func, None, axis).apply(model_fused)
model_unfused = ModuleInstanceRegisterParametrization(
model_unfused.linear,
"weight",
Expand Down

0 comments on commit 18fff54

Please sign in to comment.