Skip to content

Commit

Permalink
Fix parametrization fusing
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 17, 2025
1 parent b7209d2 commit 0399f6d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
29 changes: 23 additions & 6 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

from brevitas import config
from brevitas import torch_version
from brevitas.fx import GraphModule
from brevitas.fx import Node
Expand Down Expand Up @@ -1444,17 +1445,33 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module):
return model


def _retrieve_quant_state_dict(module: nn.Module) -> Dict[str, torch.Tensor]:
# Retrieve state dict components related to Brevitas quantizers
config._FULL_STATE_DICT = True
quant_state_dict = {k: v for k, v in module.state_dict().items() if "_quant" in k}
config._FULL_STATE_DICT = False
return quant_state_dict


def fuse_parametrized_rotations(model: nn.Module) -> nn.Module:
# First of all, parameters that have parametrizations need to be untied
model = _untie_parameters_with_parametrizations(model)
# Then, parametrizations can be safely removed
for module in model.modules():
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True)
if parametrize.is_parametrized(module):
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
# Get the quantization-related entries of the module state_dict
quant_state_dict = _retrieve_quant_state_dict(module)
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
parametrize.remove_parametrizations(
module, tensor_name, leave_parametrized=True)
# Restore the state of quantization-related tensors, strict needs to be set to False
# as there will be missing keys
if len(quant_state_dict) > 0:
module.load_state_dict(quant_state_dict, strict=False)
return model


Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def model_export(model, ref_input, args):

def validate(args, extra_args: Optional[List[str]] = None):
if args.rotation != "fused_no_fx_optimize":
assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith("--")]}"
assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith('--')]}"
if args.functional_sdpa_quant:
assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization"
if args.rotation == 'fx':
Expand Down Expand Up @@ -230,7 +230,7 @@ def validate(args, extra_args: Optional[List[str]] = None):


def quantize_llm(args, extra_args=None):
validate(args)
validate(args, extra_args)
set_seed(args.seed)
if args.export_prefix is None:
args.export_prefix = f"{args.model.replace('/', '--')}"
Expand Down

0 comments on commit 0399f6d

Please sign in to comment.