Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (equalize): enable rotation matrix optimization #1155

Merged
merged 26 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@
from abc import abstractmethod
import inspect
from inspect import getcallargs
from typing import Any, Callable, Dict, Optional, Type, Union
from typing import Any, Dict, Type

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import torch.nn.utils.parametrize as parametrize
from torch.overrides import get_testing_overrides

# TODO: Deprecate PyTorch 1.11
try:
from torch.nn.utils.parametrize import is_parametrized
from torch.nn.utils.parametrize import register_parametrization
except ImportError:
from brevitas.utils.torch_utils import is_parametrized
register_parametrization = None

from brevitas.fx import GraphModule
from brevitas.fx import immutable_dict
from brevitas.fx import Node
from brevitas.graph.utils import *
from brevitas.utils.python_utils import islambda
from brevitas.utils.rotation_utils import RotationWeightParametrization

__all__ = [
'Transform',
Expand Down Expand Up @@ -109,6 +113,34 @@ def apply(self, model: Module, *model_args, **model_kwargs):
return model


def _remove_parametrization_entries_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
# Keys for values related to parametrizations
keys_to_remove = []
# Keys/values corresponding to the original tensors before parametrizations
keys_values_to_add = []
# Iterate over state_dict identifying the keys related to parametrizations
for key, value in state_dict.items():
split_key = key.split(".")
# Keys for values before parametrizations have the format "prefix.parametrizations.tensor_name.original"
if len(split_key
) >= 3 and split_key[-3] == "parametrizations" and split_key[-1] == "original":
tensor_name = split_key[-2]
# Name of dictionary entry is "prefix.tensro_name"
keys_values_to_add.append((".".join(split_key[:-3] + [tensor_name]), value))
# Keys corresponding to the parametrizations attached to the model need to be removed
# to make sure the dictionary can be loaded with no missing/unused keys
# NOTE: For safety, an additional check could be added as this logic would not work if a model
# without parametrizations has any key containing "parametrizations"
if "parametrizations" in split_key:
keys_to_remove.append(key)
# Apply changes in-place to the state_dict
for key in keys_to_remove:
del state_dict[key]
for key, value in keys_values_to_add:
state_dict[key] = value
return state_dict


class ModuleToModule(GraphTransform, ABC):

def __init__(self, new_module_class, **kwargs):
Expand Down Expand Up @@ -159,7 +191,25 @@ def _init_new_module(self, old_module: Module, name=None):
def _replace_old_module(self, model, old_module, new_module, load_state_dict=True):
replace_module(model, old_module, new_module)
if load_state_dict:
new_module.load_state_dict(old_module.state_dict())
if not is_parametrized(old_module):
new_module.load_state_dict(old_module.state_dict())
else:
old_module_state_dict = old_module.state_dict()
# If parametrizations are present in old_module, the state_dict needs
# to be processed beforehand
old_module_state_dict = _remove_parametrization_entries_state_dict(
old_module_state_dict)
# Strict can be set to True, since potential parametrizations were
# accounted for
new_module.load_state_dict(old_module_state_dict)
# If the old module is parametrized, these need to be transferred to the new module.
# The method transfer_parametrizations_and_params as it can result in parameter ties
# being broken
# NOTE: unsafe is set to True for efficiency, as the checks should have been done
# when first registering the parametrization to old_module
for tensor_name in old_module.parametrizations:
for param_func in old_module.parametrizations[tensor_name]:
register_parametrization(new_module, tensor_name, param_func, unsafe=True)


class InsertModuleCallAfter(GraphTransform):
Expand Down Expand Up @@ -211,8 +261,7 @@ def apply(self, model: GraphModule) -> GraphModule:
for module in model.modules():
if module is self.module:
# register the parametrization to module
parametrize.register_parametrization(
module, self.tensor_name, self.transform_module)
register_parametrization(module, self.tensor_name, self.transform_module)
break
return model

Expand Down
48 changes: 40 additions & 8 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from brevitas.nn.equalized_layer import INPUT_NAMES
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.nn.quant_scale_bias import ScaleBias
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.utils.python_utils import recurse_getattr
from brevitas.utils.rotation_utils import RotationWeightParametrization
from brevitas.utils.torch_utils import KwargsForwardHook
Expand Down Expand Up @@ -1444,17 +1446,31 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module):
return model


def _fuse_rotations(model: nn.Module) -> nn.Module:
def fuse_parametrized_rotations(model: nn.Module) -> nn.Module:
# First of all, parameters that have parametrizations need to be untied
model = _untie_parameters_with_parametrizations(model)
# Then, parametrizations can be safely removed
for module in model.modules():
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True)
if parametrize.is_parametrized(module):
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
# Check if the module has any quantization-related children
state_dict = None
for submodule in module.modules():
if isinstance(submodule,
(WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)):
state_dict = submodule.state_dict()
break
# The rotated tensor is saved by setting leave_parametrized=True
parametrize.remove_parametrizations(
module, tensor_name, leave_parametrized=True)
# Restore the state of the quantization modules, as these might have been reset
# when registering the parametrized parameter
if state_dict is not None:
submodule.load_state_dict(state_dict)
return model


Expand Down Expand Up @@ -1514,6 +1530,7 @@ def __init__(
orphan_sink: bool = False,
sdpa_regions: bool = False,
rotate_matmul: bool = False,
use_parametrized_rotations: bool = False,
full_rotation_method: str = 'had',
return_rewriters: bool = False) -> None:
super(GraphRotationEqualization, self).__init__()
Expand All @@ -1531,6 +1548,17 @@ def __init__(
self.full_rotation_method = full_rotation_method
self.return_rewriters = return_rewriters
self.sdpa_regions = sdpa_regions
if use_parametrized_rotations:
# NOTE: When use_parametrized_rotations=False, parametrized rotations are applied. This changes the attribute __class__
# of the parametrized module, e.g. to"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>".
# Therefore, algorithms that do type checking might need to use type_before_parametrizations(module),
# instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications
# of the weights should not operate on parametrized modules. In this situation, parametrizations
# need to be removed beforehand by invoking fuse_parametrized_rotations
warnings.warn(
"Using parametrized results might break type-checking, which could lead to unexpected behaviour."
)
self.use_parametrized_rotations = use_parametrized_rotations

def rotate_matmuls(self, graph_module):
matmul_nodes = list(graph_module.graph.nodes)
Expand Down Expand Up @@ -1620,7 +1648,11 @@ def apply(self,
if self.rotate_matmul:
self.rotate_matmuls(graph_model)
if len(regions) > 0:
rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method)
rewriters = _apply_rotate(
graph_model,
regions,
self.full_rotation_method,
fuse_rotations=not self.use_parametrized_rotations)
if self.return_rewriters:
return graph_model, rewriters
else:
Expand Down
21 changes: 13 additions & 8 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import torch
import torch.nn as nn

import brevitas
# TODO: Deprecate PyTorch 1.11
try:
from torch.nn.utils.parametrize import type_before_parametrizations
except ImportError:
from brevitas.utils.torch_utils import type_before_parametrizations

from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleToModuleByInstance
Expand Down Expand Up @@ -403,8 +408,8 @@ def act_handler(model, layer_map):
if node.op == 'call_module':
module = get_module(model, node.target)
if isinstance(module, tuple(layer_map.keys())):
if layer_map[type(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type(module)]
if layer_map[type_before_parametrizations(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)]
quant_module = quant_module_class(**quant_module_kwargs)
# Check for activation equalization mul nodes
if len(node.users) == 1:
Expand Down Expand Up @@ -465,8 +470,8 @@ def layer_handler(
quant_identity_map=quant_identity_map,
quant_act_map=quant_act_map,
unsigned_act_tuple=unsigned_act_tuple)
if layer_map[type(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type(module)]
if layer_map[type_before_parametrizations(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)]
# Quantize the input if is not quantized, input_quant is not specified,
# and the quant_identity_map is provided.
if not are_inputs_quantized_and_aligned(
Expand Down Expand Up @@ -511,7 +516,7 @@ def find_module(
Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its
Linear submodules.
"""
if _module_class_name(type(model)) in layer_map.keys():
if _module_class_name(type_before_parametrizations(model)) in layer_map.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we check for typing elsewhere in this file to apply quantization. Would you mind double checking that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a similar check in layer_handler which seems to be only called for graph models. If I'm not missing something, do we want to have it there too? We currently don't have any use-case, as far as I'm aware, in which we have a graph model with parametrizations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be cases with graph model + parametrizations in the future, so let's change it there as well so that it works.

Once we have a QuantLinear with a parametrization registered, do we still need to use this new function or can we fall back to type?

Copy link
Collaborator Author

@pablomlago pablomlago Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you register the parametrization, the attribute class changes from QuantLinear to ParametrizedQuantLinear, so I guess it's still needed, as type is going to return a class prefixed by Parametrized.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be a problem, since I assume it might break stuff like GPTQ. Would you mind checking?

Copy link
Collaborator Author

@pablomlago pablomlago Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But GPTQ modifies weights in-place right? I don't think that fits well with parametrised rotations, since the weights are no longer a static tensor in memory, but instead are dinamically computed by running the forward of the parametrization modules, passing as input the original weight tensor. My understanding is that we would need to fuse the rotations before applying GPTQ, and therefore, there won't be any problem as layers will be again QuantLinear.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, there could be a world where you first do GPTQ and then optimize the rotations afterwards.
Let's add a comment/warning when adding parametrized rotation that type checking might break

module_to_replace.append(model)
else:
for name, module in model.named_children():
Expand All @@ -532,8 +537,8 @@ def layerwise_layer_handler(
find_module(model, layer_map, module_to_replace, name_blacklist)
rewriters = []
for module in module_to_replace:
if layer_map[_module_class_name(type(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))]
if layer_map[_module_class_name(type_before_parametrizations(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type_before_parametrizations(module))]
rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs)
rewriters.append(rewriter)
for rewriter in rewriters:
Expand Down
16 changes: 15 additions & 1 deletion src/brevitas/utils/rotation_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Callable, Optional
from typing import Callable, List, Optional

import torch
from torch import nn
from torch import Tensor


Expand Down Expand Up @@ -47,3 +48,16 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
raise RuntimeError("Not supported yet")

return tensor


def extract_trainable_rotation_matrices(model: nn.Module) -> List[nn.Parameter]:
trainable_rotations = []
# IDs of the rotation matrices are tracked, as several modules can share
# the same parametrized rotation
ids_rot = set()
for module in model.modules():
if isinstance(module, RotationWeightParametrization):
if id(module.rot_mat) not in ids_rot:
ids_rot.add(id(module.rot_mat))
trainable_rotations.append(module.rot_mat)
return trainable_rotations
36 changes: 36 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,39 @@ def wrapped_fn(*args, **kwargs):
return wrapped_fn

return decorator


# TODO: Remove after deprecating PyTorch 1.11
def is_parametrized(module: torch.nn.Module, tensor_name: Optional[str] = None) -> bool:
r"""Determine if a module has a parametrization.

Args:
module (nn.Module): module to query
tensor_name (str, optional): name of the parameter in the module
Default: ``None``
Returns:
``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`,
or if it has any parametrization when :attr:`tensor_name` is ``None``;
otherwise ``False``
"""
parametrizations = getattr(module, "parametrizations", None)
if parametrizations is None or not isinstance(parametrizations, torch.nn.ModuleDict):
return False
if tensor_name is None:
# Check that there is at least one parametrized buffer or Parameter
return len(parametrizations) > 0
else:
return tensor_name in parametrizations


# TODO: Remove after deprecating PyTorch 1.11
def type_before_parametrizations(module: torch.nn.Module) -> type:
r"""Return the module type before parametrizations were applied and if not, then it returns the module type.

Args:
module (nn.Module): module to get type of
"""
if is_parametrized(module):
return module.__class__.__bases__[0]
else:
return type(module)
6 changes: 4 additions & 2 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points.

When using `--optimize-rotations`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`.

```bash
usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval]
Expand Down Expand Up @@ -49,8 +51,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa]
[--functional-sdpa-quant] [--replace-mha]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--rotation-sdpa-regions]
[--rotation-mode {had,ort}] [--optimize-rotations]
[--rotation-orphan-sink] [--rotation-sdpa-regions]
[--act-equalization {None,layerwise,fx}]
[--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
[--load-awq LOAD_AWQ]
Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ model: facebook/opt-125m
no_float16: false
no_quantize: false
nsamples: 128
optimize_rotations: false
quant_sdpa: false
quantize_input_zero_point: false
quantize_last_layer: false
Expand Down
Loading
Loading