Skip to content

Commit

Permalink
Feat (equalize): enable rotation matrix optimization (#1155)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Giuseppe Franco <[email protected]>
  • Loading branch information
pablomlago and Giuseppe5 authored Jan 28, 2025
1 parent e5d03a8 commit 024d4a7
Show file tree
Hide file tree
Showing 13 changed files with 872 additions and 67 deletions.
65 changes: 57 additions & 8 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@
from abc import abstractmethod
import inspect
from inspect import getcallargs
from typing import Any, Callable, Dict, Optional, Type, Union
from typing import Any, Dict, Type

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import torch.nn.utils.parametrize as parametrize
from torch.overrides import get_testing_overrides

# TODO: Deprecate PyTorch 1.11
try:
from torch.nn.utils.parametrize import is_parametrized
from torch.nn.utils.parametrize import register_parametrization
except ImportError:
from brevitas.utils.torch_utils import is_parametrized
register_parametrization = None

from brevitas.fx import GraphModule
from brevitas.fx import immutable_dict
from brevitas.fx import Node
from brevitas.graph.utils import *
from brevitas.utils.python_utils import islambda
from brevitas.utils.rotation_utils import RotationWeightParametrization

__all__ = [
'Transform',
Expand Down Expand Up @@ -109,6 +113,34 @@ def apply(self, model: Module, *model_args, **model_kwargs):
return model


def _remove_parametrization_entries_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
# Keys for values related to parametrizations
keys_to_remove = []
# Keys/values corresponding to the original tensors before parametrizations
keys_values_to_add = []
# Iterate over state_dict identifying the keys related to parametrizations
for key, value in state_dict.items():
split_key = key.split(".")
# Keys for values before parametrizations have the format "prefix.parametrizations.tensor_name.original"
if len(split_key
) >= 3 and split_key[-3] == "parametrizations" and split_key[-1] == "original":
tensor_name = split_key[-2]
# Name of dictionary entry is "prefix.tensro_name"
keys_values_to_add.append((".".join(split_key[:-3] + [tensor_name]), value))
# Keys corresponding to the parametrizations attached to the model need to be removed
# to make sure the dictionary can be loaded with no missing/unused keys
# NOTE: For safety, an additional check could be added as this logic would not work if a model
# without parametrizations has any key containing "parametrizations"
if "parametrizations" in split_key:
keys_to_remove.append(key)
# Apply changes in-place to the state_dict
for key in keys_to_remove:
del state_dict[key]
for key, value in keys_values_to_add:
state_dict[key] = value
return state_dict


class ModuleToModule(GraphTransform, ABC):

def __init__(self, new_module_class, **kwargs):
Expand Down Expand Up @@ -159,7 +191,25 @@ def _init_new_module(self, old_module: Module, name=None):
def _replace_old_module(self, model, old_module, new_module, load_state_dict=True):
replace_module(model, old_module, new_module)
if load_state_dict:
new_module.load_state_dict(old_module.state_dict())
if not is_parametrized(old_module):
new_module.load_state_dict(old_module.state_dict())
else:
old_module_state_dict = old_module.state_dict()
# If parametrizations are present in old_module, the state_dict needs
# to be processed beforehand
old_module_state_dict = _remove_parametrization_entries_state_dict(
old_module_state_dict)
# Strict can be set to True, since potential parametrizations were
# accounted for
new_module.load_state_dict(old_module_state_dict)
# If the old module is parametrized, these need to be transferred to the new module.
# The method transfer_parametrizations_and_params as it can result in parameter ties
# being broken
# NOTE: unsafe is set to True for efficiency, as the checks should have been done
# when first registering the parametrization to old_module
for tensor_name in old_module.parametrizations:
for param_func in old_module.parametrizations[tensor_name]:
register_parametrization(new_module, tensor_name, param_func, unsafe=True)


class InsertModuleCallAfter(GraphTransform):
Expand Down Expand Up @@ -211,8 +261,7 @@ def apply(self, model: GraphModule) -> GraphModule:
for module in model.modules():
if module is self.module:
# register the parametrization to module
parametrize.register_parametrization(
module, self.tensor_name, self.transform_module)
register_parametrization(module, self.tensor_name, self.transform_module)
break
return model

Expand Down
48 changes: 40 additions & 8 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from brevitas.nn.equalized_layer import INPUT_NAMES
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.nn.quant_scale_bias import ScaleBias
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.utils.python_utils import recurse_getattr
from brevitas.utils.rotation_utils import RotationWeightParametrization
from brevitas.utils.torch_utils import KwargsForwardHook
Expand Down Expand Up @@ -1444,17 +1446,31 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module):
return model


def _fuse_rotations(model: nn.Module) -> nn.Module:
def fuse_parametrized_rotations(model: nn.Module) -> nn.Module:
# First of all, parameters that have parametrizations need to be untied
model = _untie_parameters_with_parametrizations(model)
# Then, parametrizations can be safely removed
for module in model.modules():
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True)
if parametrize.is_parametrized(module):
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
# Check if the module has any quantization-related children
state_dict = None
for submodule in module.modules():
if isinstance(submodule,
(WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)):
state_dict = submodule.state_dict()
break
# The rotated tensor is saved by setting leave_parametrized=True
parametrize.remove_parametrizations(
module, tensor_name, leave_parametrized=True)
# Restore the state of the quantization modules, as these might have been reset
# when registering the parametrized parameter
if state_dict is not None:
submodule.load_state_dict(state_dict)
return model


Expand Down Expand Up @@ -1514,6 +1530,7 @@ def __init__(
orphan_sink: bool = False,
sdpa_regions: bool = False,
rotate_matmul: bool = False,
use_parametrized_rotations: bool = False,
full_rotation_method: str = 'had',
return_rewriters: bool = False) -> None:
super(GraphRotationEqualization, self).__init__()
Expand All @@ -1531,6 +1548,17 @@ def __init__(
self.full_rotation_method = full_rotation_method
self.return_rewriters = return_rewriters
self.sdpa_regions = sdpa_regions
if use_parametrized_rotations:
# NOTE: When use_parametrized_rotations=False, parametrized rotations are applied. This changes the attribute __class__
# of the parametrized module, e.g. to"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>".
# Therefore, algorithms that do type checking might need to use type_before_parametrizations(module),
# instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications
# of the weights should not operate on parametrized modules. In this situation, parametrizations
# need to be removed beforehand by invoking fuse_parametrized_rotations
warnings.warn(
"Using parametrized results might break type-checking, which could lead to unexpected behaviour."
)
self.use_parametrized_rotations = use_parametrized_rotations

def rotate_matmuls(self, graph_module):
matmul_nodes = list(graph_module.graph.nodes)
Expand Down Expand Up @@ -1620,7 +1648,11 @@ def apply(self,
if self.rotate_matmul:
self.rotate_matmuls(graph_model)
if len(regions) > 0:
rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method)
rewriters = _apply_rotate(
graph_model,
regions,
self.full_rotation_method,
fuse_rotations=not self.use_parametrized_rotations)
if self.return_rewriters:
return graph_model, rewriters
else:
Expand Down
21 changes: 13 additions & 8 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import torch
import torch.nn as nn

import brevitas
# TODO: Deprecate PyTorch 1.11
try:
from torch.nn.utils.parametrize import type_before_parametrizations
except ImportError:
from brevitas.utils.torch_utils import type_before_parametrizations

from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleToModuleByInstance
Expand Down Expand Up @@ -403,8 +408,8 @@ def act_handler(model, layer_map):
if node.op == 'call_module':
module = get_module(model, node.target)
if isinstance(module, tuple(layer_map.keys())):
if layer_map[type(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type(module)]
if layer_map[type_before_parametrizations(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)]
quant_module = quant_module_class(**quant_module_kwargs)
# Check for activation equalization mul nodes
if len(node.users) == 1:
Expand Down Expand Up @@ -465,8 +470,8 @@ def layer_handler(
quant_identity_map=quant_identity_map,
quant_act_map=quant_act_map,
unsigned_act_tuple=unsigned_act_tuple)
if layer_map[type(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type(module)]
if layer_map[type_before_parametrizations(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)]
# Quantize the input if is not quantized, input_quant is not specified,
# and the quant_identity_map is provided.
if not are_inputs_quantized_and_aligned(
Expand Down Expand Up @@ -511,7 +516,7 @@ def find_module(
Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its
Linear submodules.
"""
if _module_class_name(type(model)) in layer_map.keys():
if _module_class_name(type_before_parametrizations(model)) in layer_map.keys():
module_to_replace.append(model)
else:
for name, module in model.named_children():
Expand All @@ -532,8 +537,8 @@ def layerwise_layer_handler(
find_module(model, layer_map, module_to_replace, name_blacklist)
rewriters = []
for module in module_to_replace:
if layer_map[_module_class_name(type(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))]
if layer_map[_module_class_name(type_before_parametrizations(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type_before_parametrizations(module))]
rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs)
rewriters.append(rewriter)
for rewriter in rewriters:
Expand Down
16 changes: 15 additions & 1 deletion src/brevitas/utils/rotation_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Callable, Optional
from typing import Callable, List, Optional

import torch
from torch import nn
from torch import Tensor


Expand Down Expand Up @@ -47,3 +48,16 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
raise RuntimeError("Not supported yet")

return tensor


def extract_trainable_rotation_matrices(model: nn.Module) -> List[nn.Parameter]:
trainable_rotations = []
# IDs of the rotation matrices are tracked, as several modules can share
# the same parametrized rotation
ids_rot = set()
for module in model.modules():
if isinstance(module, RotationWeightParametrization):
if id(module.rot_mat) not in ids_rot:
ids_rot.add(id(module.rot_mat))
trainable_rotations.append(module.rot_mat)
return trainable_rotations
36 changes: 36 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,39 @@ def wrapped_fn(*args, **kwargs):
return wrapped_fn

return decorator


# TODO: Remove after deprecating PyTorch 1.11
def is_parametrized(module: torch.nn.Module, tensor_name: Optional[str] = None) -> bool:
r"""Determine if a module has a parametrization.
Args:
module (nn.Module): module to query
tensor_name (str, optional): name of the parameter in the module
Default: ``None``
Returns:
``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`,
or if it has any parametrization when :attr:`tensor_name` is ``None``;
otherwise ``False``
"""
parametrizations = getattr(module, "parametrizations", None)
if parametrizations is None or not isinstance(parametrizations, torch.nn.ModuleDict):
return False
if tensor_name is None:
# Check that there is at least one parametrized buffer or Parameter
return len(parametrizations) > 0
else:
return tensor_name in parametrizations


# TODO: Remove after deprecating PyTorch 1.11
def type_before_parametrizations(module: torch.nn.Module) -> type:
r"""Return the module type before parametrizations were applied and if not, then it returns the module type.
Args:
module (nn.Module): module to get type of
"""
if is_parametrized(module):
return module.__class__.__bases__[0]
else:
return type(module)
6 changes: 4 additions & 2 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points.

When using `--optimize-rotations`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`.

```bash
usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval]
Expand Down Expand Up @@ -49,8 +51,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa]
[--functional-sdpa-quant] [--replace-mha]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--rotation-sdpa-regions]
[--rotation-mode {had,ort}] [--optimize-rotations]
[--rotation-orphan-sink] [--rotation-sdpa-regions]
[--act-equalization {None,layerwise,fx}]
[--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
[--load-awq LOAD_AWQ]
Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ model: facebook/opt-125m
no_float16: false
no_quantize: false
nsamples: 128
optimize_rotations: false
quant_sdpa: false
quantize_input_zero_point: false
quantize_last_layer: false
Expand Down
Loading

0 comments on commit 024d4a7

Please sign in to comment.