Skip to content

Commit

Permalink
Address final comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 13, 2025
1 parent cd04df6 commit 32c9087
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 49 deletions.
51 changes: 6 additions & 45 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from brevitas.fx import Node
from brevitas.graph.utils import *
from brevitas.utils.python_utils import islambda
from brevitas.utils.rotation_utils import RotationWeightParametrization

__all__ = [
'Transform',
Expand Down Expand Up @@ -196,7 +197,8 @@ def __init__(self, module: Module, tensor_name: str, parametrization_module: Mod
self.tensor_name = tensor_name
self.parametrization_module = parametrization_module

# Property to have a common interface with ModuleInstanceToModuleInstance
# TODO: Unify inferfaces with ModuleInstanceToModuleInstance for
# compatibility with fix_rewriter
@property
def old_module_instance(self):
return self.module
Expand All @@ -210,53 +212,11 @@ def apply(self, model: GraphModule) -> GraphModule:
if module is self.module:
# register the parametrization to module
parametrize.register_parametrization(
module, self.tensor_name, self.parametrization_module, unsafe=True)
module, self.tensor_name, self.parametrization_module)
break
return model


class RotationWeightParametrization(torch.nn.Module):
r"""Rotates a tensor by a specified axis
Args:
rot_mat (Tensor): orthogonal matrix by which to rotate the tensor
rot_func (Callable): function to apply the rotation. The first
argument corresponds to the tensor to be rotated, while the
second specifies the rotation matrix. The third argument (K) is
useful when rotating by an Hadamard matrix and it corresponds
to the dimensionality of the matrix up to a power of two,
i.e. dim=(2**p)*K. See get_hadK for details
axis (int): axis by which to rotate the tensor
K (int, optional): if rot_mat is an Hadamard matrix, K is the highest
divisor of the dimensionality of the matrix, such that K, itself,
is not divisible by 2
"""

def __init__(
self,
rot_mat: Callable[[Tensor, Tensor, Optional[int]], Tensor],
rot_func: Callable,
axis: int,
K: Optional[int] = None,
) -> None:
super().__init__()
self.rot_mat = rot_mat
self.rot_func = rot_func
self.axis = axis
self.K = K

def forward(self, tensor: torch.Tensor) -> torch.Tensor:

if self.axis == 0:
tensor = self.rot_func(tensor.t(), self.rot_mat, self.K).t()
elif self.axis == 1:
tensor = self.rot_func(tensor, self.rot_mat, self.K)
else:
raise RuntimeError("Not supported yet")

return tensor


class ModuleInstanceFuseRotationWeights(Transform):
r"""Transform to rotate in-place a given parameter of a module by a
specified axis
Expand Down Expand Up @@ -294,7 +254,8 @@ def __init__(
self.tensor_name = tensor_name
self.axis = axis

# Property to have a common interface with ModuleInstanceToModuleInstance
# TODO: Unify inferfaces with ModuleInstanceToModuleInstance for
# compatibility with fix_rewriter
@property
def old_module_instance(self):
return self.module
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import RotationWeightParametrization
from brevitas.graph.base import Transform
from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
Expand All @@ -40,6 +39,7 @@
from brevitas.nn.equalized_layer import INPUT_NAMES
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.nn.quant_scale_bias import ScaleBias
from brevitas.utils.rotation_utils import RotationWeightParametrization
from brevitas.utils.torch_utils import KwargsForwardHook

# External optional dependency
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/graph/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def random_hadamard_matrix(size, device):
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
Q = Q * 2 - 1
Q = torch.diag(Q)
return matmul_hadU(Q).to(device)
# Set to float32 for consistency with random_orthogonal_matrix and get_hadK
return matmul_hadU(Q).to(device).float()


def matmul_hadU_cuda(X, hadK, K):
Expand Down
49 changes: 49 additions & 0 deletions src/brevitas/utils/rotation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Callable, Optional

import torch
from torch import Tensor


class RotationWeightParametrization(torch.nn.Module):
r"""Rotates a tensor by a specified axis
Args:
rot_mat (Tensor): orthogonal matrix by which to rotate the tensor
rot_func (Callable): function to apply the rotation. The first
argument corresponds to the tensor to be rotated, while the
second specifies the rotation matrix. The third argument (K) is
useful when rotating by an Hadamard matrix and it corresponds
to the dimensionality of the matrix up to a power of two,
i.e. dim=(2**p)*K. See brevitas.graph.hadamard.get_hadK for details
axis (int): axis by which to rotate the tensor
K (int, optional): if rot_mat is an Hadamard matrix, K is the highest
divisor of the dimensionality of the matrix, such that K, itself,
is not divisible by 2
"""

def __init__(
self,
rot_mat: Callable[[Tensor, Tensor, Optional[int]], Tensor],
rot_func: Callable,
axis: int,
K: Optional[int] = None,
) -> None:
super().__init__()
self.rot_mat = rot_mat
self.rot_func = rot_func
self.axis = axis
self.K = K

def forward(self, tensor: torch.Tensor) -> torch.Tensor:

if self.axis == 0:
tensor = self.rot_func(tensor.t(), self.rot_mat, self.K).t()
elif self.axis == 1:
tensor = self.rot_func(tensor, self.rot_mat, self.K)
else:
raise RuntimeError("Not supported yet")

return tensor
2 changes: 1 addition & 1 deletion tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from brevitas import torch_version
from brevitas.fx import symbolic_trace
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import RotationWeightParametrization
from brevitas.graph.equalize import _apply_had_device
from brevitas.graph.equalize import _apply_ort_device
from brevitas.graph.equalize import _apply_rotate
Expand All @@ -37,6 +36,7 @@
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.graph.utils import get_module
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.utils.rotation_utils import RotationWeightParametrization
from tests.marker import requires_pt_ge

from .equalization_fixtures import *
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas/graph/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import ModuleToModuleByInstance
from brevitas.graph.base import RotationWeightParametrization
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.utils.rotation_utils import RotationWeightParametrization

SEED = 123456
INPUT_SIZE = (1, 3, 224, 224)
Expand Down

0 comments on commit 32c9087

Please sign in to comment.