From 7a3101510bf84753aacc5960d232358abb682e14 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 14 Jan 2025 14:52:40 +0000 Subject: [PATCH 01/26] Enable quantization of parametrized layers --- src/brevitas/graph/base.py | 48 ++++++++++++- src/brevitas/graph/quantize_impl.py | 8 ++- tests/brevitas/graph/test_quantize.py | 100 ++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index d1631f34e..3931640a7 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -109,6 +109,34 @@ def apply(self, model: Module, *model_args, **model_kwargs): return model +def _remove_parametrization_entries_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + # Keys for values related to parametrizations + keys_to_remove = [] + # Keys/values corresponding to the original tensors before parametrizations + keys_values_to_add = [] + # Iterate over state_dict identifying the keys related to parametrizations + for key, value in state_dict.items(): + split_key = key.split(".") + # Keys for values before parametrizations have the format "prefix.parametrizations.tensor_name.original" + if len(split_key + ) >= 3 and split_key[-3] == "parametrizations" and split_key[-1] == "original": + tensor_name = split_key[-2] + # Name of dictionary entry is "prefix.tensro_name" + keys_values_to_add.append((".".join(split_key[:-3] + [tensor_name]), value)) + # Keys corresponding to the parametrizations attached to the model need to be removed + # to make sure the dictionary can be loaded with no missing/unused keys + # NOTE: For safety, an additional check could be added as this logic would not work if a model + # without parametrizations has any key containing "parametrizations" + if "parametrizations" in split_key: + keys_to_remove.append(key) + # Apply changes in-place to the state_dict + for key in keys_to_remove: + del state_dict[key] + for key, value in keys_values_to_add: + state_dict[key] = value + return state_dict + + class ModuleToModule(GraphTransform, ABC): def __init__(self, new_module_class, **kwargs): @@ -159,7 +187,25 @@ def _init_new_module(self, old_module: Module, name=None): def _replace_old_module(self, model, old_module, new_module, load_state_dict=True): replace_module(model, old_module, new_module) if load_state_dict: - new_module.load_state_dict(old_module.state_dict()) + old_module_state_dict = old_module.state_dict() + # If parametrizations are present in old_module, the state_dict needs + # to be processed beforehand + if parametrize.is_parametrized(old_module): + old_module_state_dict = _remove_parametrization_entries_state_dict( + old_module_state_dict) + # Strict can be set to True, since potential parametrizations were + # accounted for + new_module.load_state_dict(old_module_state_dict) + # If the old module is parametrized, these need to be transferred to the new module. + # The method transfer_parametrizations_and_params as it can result in parameter ties + # being broken + # NOTE: unsafe is set to True for efficiency, as the checks should have been done + # when first registering the parametrization to old_module + if parametrize.is_parametrized(old_module): + for tensor_name in old_module.parametrizations: + for param_func in old_module.parametrizations[tensor_name]: + parametrize.register_parametrization( + new_module, tensor_name, param_func, unsafe=True) class InsertModuleCallAfter(GraphTransform): diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 535f9a8f9..a4d348ab5 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +import torch.nn.utils.parametrize as parametrize import brevitas from brevitas.graph.base import InsertModuleCallAfter @@ -511,7 +512,7 @@ def find_module( Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its Linear submodules. """ - if _module_class_name(type(model)) in layer_map.keys(): + if _module_class_name(parametrize.type_before_parametrizations(model)) in layer_map.keys(): module_to_replace.append(model) else: for name, module in model.named_children(): @@ -532,8 +533,9 @@ def layerwise_layer_handler( find_module(model, layer_map, module_to_replace, name_blacklist) rewriters = [] for module in module_to_replace: - if layer_map[_module_class_name(type(module))] is not None: - quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))] + if layer_map[_module_class_name( + parametrize.type_before_parametrizations(module))] is not None: + quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) for rewriter in rewriters: diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py index 45167b673..ec15da6a2 100644 --- a/tests/brevitas/graph/test_quantize.py +++ b/tests/brevitas/graph/test_quantize.py @@ -1,7 +1,14 @@ +import copy + import pytest_cases +import torch import torch.nn as nn +import torch.nn.utils.parametrize as parametrize +from brevitas.graph.base import _remove_parametrization_entries_state_dict from brevitas.graph.quantize import layerwise_quantize +from brevitas.utils.python_utils import recurse_getattr +from brevitas.utils.rotation_utils import RotationWeightParametrization @pytest_cases.parametrize( @@ -42,3 +49,96 @@ def test_layerwise_quantize_blacklist(kwargs): assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" checked = True assert checked, f"Layer named {key} not found. Layer names are: {found_names}" + + +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected': ""},]) +def test_layerwise_quantize_parametrized_modules(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + del kwargs['key'] + del kwargs['expected'] + del kwargs['rot_mat'] + del kwargs['rot_func'] + + model = kwargs["model"] + module = recurse_getattr(model, key) + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + qmodel = layerwise_quantize(**kwargs) + checked = False + found_names = [] + for n, m in qmodel.named_modules(): + found_names.append(n) + if n == key: + mt = str(type(m)) + assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" + checked = True + assert checked, f"Layer named {key} not found. Layer names are: {found_names}" + + +@pytest_cases.parametrize( + 'kwargs', + [{ + 'model': nn.Sequential(nn.Linear(2, 3)), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected_state_dict_keys': ['0.weight', '0.bias'],}]) +def test_remove_parametrization_entries_state_dict(kwargs): + key = kwargs['key'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + expected_state_dict_keys = kwargs['expected_state_dict_keys'] + del kwargs['key'] + del kwargs['rot_mat'] + del kwargs['rot_func'] + del kwargs['expected_state_dict_keys'] + + model = kwargs["model"] + module = recurse_getattr(model, key) + old_state_dict = copy.deepcopy(model.state_dict()) + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + # Retrieve state dict after parametrization + state_dict = model.state_dict() + # Remove parametrization entries from state dict + state_dict = _remove_parametrization_entries_state_dict(state_dict) + # Verify that all the expected keys in expected_state_dict_keys + # are present in state_dict + assert len(set(expected_state_dict_keys) - set(state_dict.keys())) == 0 + # Verify that keys match + for key, value in state_dict.items(): + # Verify that key is in the expected keys + assert key in expected_state_dict_keys, f"Unexpected key {key} in state_dict" + # Compare tensor values + assert torch.allclose(value, old_state_dict[key], rtol=0.0, atol=0.0), f"Value of tensor {value} does not match with that in the original state_dict" From 7d2b9557446757859abca6f69ba046b3e6b82cd0 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 14 Jan 2025 19:19:12 +0000 Subject: [PATCH 02/26] Enable rotation optimization --- src/brevitas/graph/equalize.py | 11 +- src/brevitas/utils/rotation_utils.py | 16 ++- .../llm/llm_quant/rotation_optimization.py | 106 ++++++++++++++++++ src/brevitas_examples/llm/main.py | 39 +++++-- tests/brevitas/graph/test_equalization.py | 4 +- 5 files changed, 159 insertions(+), 17 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/rotation_optimization.py diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 31b2d4f72..d7e28dee5 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1444,7 +1444,7 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module): return model -def _fuse_rotations(model: nn.Module) -> nn.Module: +def fuse_parametrized_rotations(model: nn.Module) -> nn.Module: # First of all, parameters that have parametrizations need to be untied model = _untie_parameters_with_parametrizations(model) # Then, parametrizations can be safely removed @@ -1591,8 +1591,10 @@ def find_sink(node): m.pre_process_k = functional_rotate_input return regions - def apply(self, - graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply( + self, + graph_model: GraphModule, + fuse_rotations: bool = True) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1620,7 +1622,8 @@ def apply(self, if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + rewriters = _apply_rotate( + graph_model, regions, self.full_rotation_method, fuse_rotations=fuse_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas/utils/rotation_utils.py b/src/brevitas/utils/rotation_utils.py index 6a79d1cc3..e5f92c50b 100644 --- a/src/brevitas/utils/rotation_utils.py +++ b/src/brevitas/utils/rotation_utils.py @@ -1,9 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Callable, Optional +from typing import Callable, List, Optional import torch +from torch import nn from torch import Tensor @@ -47,3 +48,16 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: raise RuntimeError("Not supported yet") return tensor + + +def extract_trainable_rotation_matrices(model: nn.Module) -> List[nn.Parameter]: + trainable_rotations = [] + # IDs of the rotation matrices are tracked, as several modules can share + # the same parametrized rotation + ids_rot = set() + for module in model.modules(): + if isinstance(module, RotationWeightParametrization): + if id(module.rot_mat) not in ids_rot: + ids_rot.add(id(module.rot_mat)) + trainable_rotations.append(module.rot_mat) + return trainable_rotations diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py new file mode 100644 index 000000000..f98b5cb02 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +from dataclasses import field +import os +from typing import List, Optional + +from datasets import Dataset +import torch +import transformers +from transformers import Trainer +from transformers.tokenization_utils import PreTrainedTokenizerBase + +from brevitas.optim.cailey_sgd import CaileySGD +from brevitas.utils.rotation_utils import extract_trainable_rotation_matrices +from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.llm.llm_quant.data_utils import DatasetToDevice + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + # By default, arguments are saved in the current working directory + output_dir: Optional[str] = field(default=os.getcwd()) + + +def parse_optimization_rotation_args(unknown_args=None) -> None: + parser = transformers.HfArgumentParser(TrainingArguments) + training_args = parser.parse_args_into_dataclasses(args=unknown_args) + return training_args[0] + + +# Function to create a batch +def collate_fn(kwargs_list, return_tensors="pt"): + kwargs = {} + for curr_dict in kwargs_list: + for key, value in curr_dict.items(): + if isinstance(value, torch.Tensor): + if key not in kwargs: + kwargs[key] = [] + kwargs[key].append(value) + else: + if key not in kwargs: + kwargs[key] = value + for key, value in kwargs.items(): + if isinstance(value, list) and len(value) > 0: + kwargs[key] = torch.cat(kwargs[key], dim=0) + return kwargs + + +def _prepare_train_dataset(train_dataset: DatasetToDevice) -> Dataset: + return DatasetToDevice( + data=[{ + "input_ids": train_datapoint["input_ids"], "labels": train_datapoint["input_ids"]} + for train_datapoint in train_dataset.data], + device=None) + + +def _prepare_model(model: torch.nn.Module) -> torch.nn.Module: + # For a PretrainedModel, the Trainer in accelerate calls save_pretrained after + # finishing the optimization. However, this method no longer works after + # registering parametrizations/quantizing, so this method is mocked to prevent + # a crash. + def mock_save_pretrained_fn(*args, **kwargs): + pass + + model.save_pretrained = mock_save_pretrained_fn + # Cache needs to be disabled for training + model.config.use_cache = False + # Loss for training + model.config.loss_type = "ForCausalLM" + + return model + + +def apply_rotation_optimization( + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + train_dataset: DatasetToDevice, + unknown_args: List[str] = None, +) -> None: + + # Prepare dataset and model for training + train_dataset = _prepare_train_dataset(train_dataset) + model = _prepare_model(model) + # Remove hooks and empty cache before starting optimization + remove_hooks(model) + torch.cuda.empty_cache() + # Get training arguments + training_args = parse_optimization_rotation_args(unknown_args) + # Set to False the model parameters + for param in model.parameters(): + param.requires_grad = False + # Collect trainable matrices + trainable_rotations = extract_trainable_rotation_matrices(model) + for rot_mat in trainable_rotations: + rot_mat.requires_grad = True + optimizer = CaileySGD(trainable_rotations, lr=training_args.learning_rate, stiefel=True) + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + eval_dataset=None, + data_collator=collate_fn, + optimizers=(optimizer, None)) + trainer.train() + # After finishing training, set eval mode again + model.eval() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 7567f1727..267c17946 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -22,7 +22,8 @@ from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph import load_quant_model_mode -from brevitas.graph.base import ModuleInstanceWrapModule +from brevitas.graph.base import ModuleInstanceTransformTensor +from brevitas.graph.equalize import fuse_parametrized_rotations from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import functional_quantization_mode @@ -54,6 +55,7 @@ from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.prepare_for_quantize import \ replace_sdpa_with_quantizable_layers +from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -81,7 +83,7 @@ def set_seed(seed): torch.random.manual_seed(seed) -def fused_rotation_no_fx(model, calibration_loader, args): +def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = True): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)): @@ -89,6 +91,7 @@ def fused_rotation_no_fx(model, calibration_loader, args): new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add) apply_layernorm_affine_merge(new_model) + # NOTE: This call breaks ties between the the lm_head and the embedding layer new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -100,13 +103,13 @@ def fused_rotation_no_fx(model, calibration_loader, args): full_rotation_method=args.rotation_mode, return_rewriters=True, sdpa_regions=args.rotation_sdpa_regions) - new_model, rewriters = eq.apply(new_model) + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: # The weights between model and new_model are tied, so this check prevents # rotating the weights twice - if isinstance(r, ModuleInstanceWrapModule): - r.apply(model) + if not isinstance(r, ModuleInstanceTransformTensor): + model = r.apply(model) remove_hooks(new_model) @@ -209,7 +212,7 @@ def validate(args): "or decreasing the sequence length (seqlen)") -def quantize_llm(args): +def quantize_llm(args, unknown_args=None): validate(args) set_seed(args.seed) if args.export_prefix is None: @@ -337,6 +340,8 @@ def quantize_llm(args): model = eq.apply(model) elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) + elif args.rotation == 'fused_no_fx_optimize': + fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations=False) if args.weight_equalization: print("Apply weight equalization...") @@ -462,6 +467,20 @@ def quantize_llm(args): for k, v in dict_hooks.items(): k._hf_hook.post_forward = v + if args.rotation in ['fused_no_fx_optimize']: + apply_rotation_optimization( + model=model, + tokenizer=tokenizer, + train_dataset=calibration_loader, + unknown_args=unknown_args, + ) + # Remove hooks from optimization + remove_hooks(model) + # Offload model before fusing the rotations + model = offload_model(model) + # Fuse rotations with weights + model = fuse_parametrized_rotations(model) + if args.act_calibration and not args.load_checkpoint: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -814,7 +833,7 @@ def parse_args(args, override_defaults={}): '--rotation', type=str, default=None, - choices=['fx', 'layerwise', 'fused_no_fx'], + choices=['fx', 'layerwise', 'fused_no_fx', 'fused_no_fx_optimize'], help='Apply graph rotation equalization') parser.add_argument( '--rotation-mode', @@ -912,13 +931,13 @@ def parse_args(args, override_defaults={}): help='A list of tasks for zero_shot evaluation. Default: %(default)s') parser.set_defaults(**override_defaults) - return parser.parse_args(args) + return parser.parse_known_args(args) def main(): overrides = override_defaults(sys.argv[1:]) - args = parse_args(sys.argv[1:], override_defaults=overrides) - quantize_llm(args) + args, unknown_args = parse_args(sys.argv[1:], override_defaults=overrides) + quantize_llm(args, unknown_args) if __name__ == '__main__': diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 41edf0752..b491609d1 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -20,13 +20,13 @@ from brevitas.graph.equalize import _apply_rotate from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions -from brevitas.graph.equalize import _fuse_rotations from brevitas.graph.equalize import _get_input_axis from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module from brevitas.graph.equalize import _supported_layers from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizationIndexes +from brevitas.graph.equalize import fuse_parametrized_rotations from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.equalize import random_orthogonal_matrix @@ -517,7 +517,7 @@ def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_r isinstance(module, RotatedModule) for module in rotated_model.modules()]) # Optionally fuse the rotations if fuse_rotations: - rotated_model_unfused = _fuse_rotations(rotated_model_unfused) + rotated_model_unfused = fuse_parametrized_rotations(rotated_model_unfused) # Verify that no parametrizations remain after fusing for module in rotated_model_unfused.modules(): assert not parametrize.is_parametrized(module) From ec106489b07ba1c933192be12e089dec0d045548 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 15 Jan 2025 17:12:29 +0000 Subject: [PATCH 03/26] Enable specifying custom number of samples for rotation optimization --- .../llm/llm_quant/rotation_optimization.py | 7 +++++-- src/brevitas_examples/llm/main.py | 21 ++++++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index f98b5cb02..3865e740b 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -80,11 +80,14 @@ def apply_rotation_optimization( # Prepare dataset and model for training train_dataset = _prepare_train_dataset(train_dataset) model = _prepare_model(model) + # Get training arguments + training_args = parse_optimization_rotation_args(unknown_args) + # Enable skipping optimization + if training_args.max_steps <= 0: + return # Remove hooks and empty cache before starting optimization remove_hooks(model) torch.cuda.empty_cache() - # Get training arguments - training_args = parse_optimization_rotation_args(unknown_args) # Set to False the model parameters for param in model.parameters(): param.requires_grad = False diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 267c17946..e2afb8db5 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -274,6 +274,20 @@ def quantize_llm(args, unknown_args=None): device=None, fuse_sequences=args.fuse_sequences) + if args.rotation in ["fused_no_fx_optimize"]: + # Load the data for rotation optimization + rot_calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples_rot_calibration, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=require_fx and args.export_target is not None, + device=None, + fuse_sequences=args.fuse_sequences) + device = next(iter(model.parameters())).device print("Data loaded.") @@ -471,7 +485,7 @@ def quantize_llm(args, unknown_args=None): apply_rotation_optimization( model=model, tokenizer=tokenizer, - train_dataset=calibration_loader, + train_dataset=rot_calibration_loader, unknown_args=unknown_args, ) # Remove hooks from optimization @@ -629,6 +643,11 @@ def parse_args(args, override_defaults={}): type=int, default=128, help='Number of calibration data samples. Default: 128.') + parser.add_argument( + '--nsamples-rot-calibration', + type=int, + default=800, + help='Number of calibration data samples for rotation. Default: %(default)d.') parser.add_argument('--seqlen', type=int, default=2048, help='Sequence length. Default: 2048.') parser.add_argument('--eval', action='store_true', help='Eval model PPL on the chosen Dataset.') parser.add_argument( From 267643c11b4a200ee5a24053e0a6e074f9779f9a Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 11:00:35 +0000 Subject: [PATCH 04/26] Add rotation optimization tests --- tests/brevitas_examples/test_llm.py | 135 +++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 4 deletions(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index f6db73924..199a3d29c 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -47,14 +47,14 @@ def transformers_version_ge(required_version: str): # Check that all args in args are used def validate_args(args): a = vars(args) - da = vars(parse_args([])) + da = vars(parse_args([])[0]) for k in a.keys(): assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `quantize_llm`" -def validate_args_and_run_main(args): +def validate_args_and_run_main(args, unknown_args=None): validate_args(args) - float_ppl, quant_ppl, model = quantize_llm(args) + float_ppl, quant_ppl, model = quantize_llm(args, unknown_args=unknown_args) return float_ppl, quant_ppl, model @@ -131,7 +131,7 @@ def small_models_with_ppl(request): @pytest_cases.fixture() def default_run_args(request): - args = UpdatableNamespace(**vars(parse_args([]))) + args = UpdatableNamespace(**vars(parse_args([])[0])) args.nsamples = 2 args.seqlen = 2 args.model = "hf-internal-testing/tiny-random-MistralForCausalLM" @@ -156,6 +156,11 @@ def run_test_models_run_args(args, model_with_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) +@pytest.fixture(scope="session", autouse=True) +def set_env(): + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + # yapf: disable @pytest_cases.fixture( ids=[ @@ -825,3 +830,125 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): quant_ppl = quant_ppl.detach().cpu().numpy() assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@pytest_cases.fixture( + ids=[ + "llama_rotation_optimization_ort", + "llama_rotation_optimization_ort_no_orphan", + "llama_rotation_optimization_had", + "llama_rotation_optimization_had_no_orphan",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx_optimize", + "rotation_orphan_sink": True, + "rotation_mode": "ort", + "nsamples_rot_calibration": 2, + "no_float16": True, + "unknown_args": [ + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33232.65234375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx_optimize", + "rotation_orphan_sink": False, + "rotation_mode": "ort", + "nsamples_rot_calibration": 2, + "max_steps": 2, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "no_float16": True, + "unknown_args": [ + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33420.65234375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx_optimize", + "rotation_orphan_sink": True, + "rotation_mode": "had", + "nsamples_rot_calibration": 2, + "max_steps": 2, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "no_float16": True, + "unknown_args": [ + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33290.48046875}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx_optimize", + "rotation_orphan_sink": False, + "rotation_mode": "had", + "nsamples_rot_calibration": 2, + "max_steps": 2, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "no_float16": True, + "unknown_args": [ + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33204.80859375},]) +def rotation_optimization_args_and_ppl(default_run_args, request): + args = default_run_args + run_dict = request.param + unknown_args = run_dict["unknown_args"] + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + del run_dict["unknown_args"] + args.update(**run_dict) + yield args, unknown_args, float_ppl, quant_ppl + + +@requires_pt_ge('2.4') +def test_small_models_rotation_optimization_ppl(caplog, rotation_optimization_args_and_ppl): + if platform.system() == "Windows": + pytest.skip("Skipping dynamo + windows") + caplog.set_level(logging.INFO) + args, unknown_args, exp_float_ppl, exp_quant_ppl = rotation_optimization_args_and_ppl + float_ppl, quant_ppl, model = validate_args_and_run_main(args, unknown_args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" From 1935e844d2023c8c81997927d53d7cf397403f1e Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 11:38:23 +0000 Subject: [PATCH 05/26] Fix compatibility with PT 1.11 --- src/brevitas/graph/base.py | 24 ++++++++++--------- src/brevitas/graph/quantize_impl.py | 8 +++++-- src/brevitas/utils/torch_utils.py | 36 +++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 13 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 3931640a7..1660efdba 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -5,21 +5,25 @@ from abc import abstractmethod import inspect from inspect import getcallargs -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Dict, Type import torch -from torch import Tensor from torch.nn import Module -from torch.nn import Parameter -import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides +# TODO: Deprecate PyTorch 1.1 +try: + from torch.nn.utils.parametrize import is_parametrized + from torch.nn.utils.parametrize import register_parametrization +except ImportError: + from brevitas.utils.torch_utils import is_parametrized + register_parametrization = None + from brevitas.fx import GraphModule from brevitas.fx import immutable_dict from brevitas.fx import Node from brevitas.graph.utils import * from brevitas.utils.python_utils import islambda -from brevitas.utils.rotation_utils import RotationWeightParametrization __all__ = [ 'Transform', @@ -190,7 +194,7 @@ def _replace_old_module(self, model, old_module, new_module, load_state_dict=Tru old_module_state_dict = old_module.state_dict() # If parametrizations are present in old_module, the state_dict needs # to be processed beforehand - if parametrize.is_parametrized(old_module): + if is_parametrized(old_module): old_module_state_dict = _remove_parametrization_entries_state_dict( old_module_state_dict) # Strict can be set to True, since potential parametrizations were @@ -201,11 +205,10 @@ def _replace_old_module(self, model, old_module, new_module, load_state_dict=Tru # being broken # NOTE: unsafe is set to True for efficiency, as the checks should have been done # when first registering the parametrization to old_module - if parametrize.is_parametrized(old_module): + if is_parametrized(old_module): for tensor_name in old_module.parametrizations: for param_func in old_module.parametrizations[tensor_name]: - parametrize.register_parametrization( - new_module, tensor_name, param_func, unsafe=True) + register_parametrization(new_module, tensor_name, param_func, unsafe=True) class InsertModuleCallAfter(GraphTransform): @@ -257,8 +260,7 @@ def apply(self, model: GraphModule) -> GraphModule: for module in model.modules(): if module is self.module: # register the parametrization to module - parametrize.register_parametrization( - module, self.tensor_name, self.transform_module) + register_parametrization(module, self.tensor_name, self.transform_module) break return model diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index a4d348ab5..54753cabf 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -6,9 +6,13 @@ import torch import torch.nn as nn -import torch.nn.utils.parametrize as parametrize -import brevitas +# TODO: Deprecate PyTorch 1.1 +try: + from torch.nn.utils.parametrize import type_before_parametrizations +except ImportError: + from brevitas.utils.torch_utils import type_before_parametrizations + from brevitas.graph.base import InsertModuleCallAfter from brevitas.graph.base import ModuleInstanceToModuleInstance from brevitas.graph.base import ModuleToModuleByInstance diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 93e4435de..684ce439d 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -146,3 +146,39 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn return decorator + + +# TODO: Remove after deprecating PyTorch 1.11 +def is_parametrized(module: torch.nn.Module, tensor_name: Optional[str] = None) -> bool: + r"""Determine if a module has a parametrization. + + Args: + module (nn.Module): module to query + tensor_name (str, optional): name of the parameter in the module + Default: ``None`` + Returns: + ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`, + or if it has any parametrization when :attr:`tensor_name` is ``None``; + otherwise ``False`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, torch.nn.ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 + else: + return tensor_name in parametrizations + + +# TODO: Remove after deprecating PyTorch 1.11 +def type_before_parametrizations(module: torch.nn.Module) -> type: + r"""Return the module type before parametrizations were applied and if not, then it returns the module type. + + Args: + module (nn.Module): module to get type of + """ + if is_parametrized(module): + return module.__class__.__bases__[0] + else: + return type(module) From 31f8a557cf858623a6438a3fc54408ed256eee10 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 11:51:12 +0000 Subject: [PATCH 06/26] Fix test --- src/brevitas/graph/quantize_impl.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 54753cabf..392aa0079 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -516,7 +516,7 @@ def find_module( Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its Linear submodules. """ - if _module_class_name(parametrize.type_before_parametrizations(model)) in layer_map.keys(): + if _module_class_name(type_before_parametrizations(model)) in layer_map.keys(): module_to_replace.append(model) else: for name, module in model.named_children(): @@ -537,9 +537,8 @@ def layerwise_layer_handler( find_module(model, layer_map, module_to_replace, name_blacklist) rewriters = [] for module in module_to_replace: - if layer_map[_module_class_name( - parametrize.type_before_parametrizations(module))] is not None: - quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] + if layer_map[_module_class_name(type_before_parametrizations(module))] is not None: + quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) for rewriter in rewriters: From 7fb2fdb519cff977dee30a134bf2274a04a17229 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 13:43:54 +0000 Subject: [PATCH 07/26] Fix arguments --- tests/brevitas_examples/test_llm.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 199a3d29c..69f3106e5 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -156,11 +156,6 @@ def run_test_models_run_args(args, model_with_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) -@pytest.fixture(scope="session", autouse=True) -def set_env(): - os.environ["CUDA_VISIBLE_DEVICES"] = "1" - - # yapf: disable @pytest_cases.fixture( ids=[ @@ -851,6 +846,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "nsamples_rot_calibration": 2, "no_float16": True, "unknown_args": [ + "--learning_rate", + "1.5", "--max_steps", "2", "--per_device_train_batch_size", @@ -869,11 +866,10 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "ort", "nsamples_rot_calibration": 2, - "max_steps": 2, - "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 1, "no_float16": True, "unknown_args": [ + "--learning_rate", + "1.5", "--max_steps", "2", "--per_device_train_batch_size", @@ -892,11 +888,10 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "had", "nsamples_rot_calibration": 2, - "max_steps": 2, - "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 1, "no_float16": True, "unknown_args": [ + "--learning_rate", + "1.5", "--max_steps", "2", "--per_device_train_batch_size", @@ -915,11 +910,10 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "had", "nsamples_rot_calibration": 2, - "max_steps": 2, - "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 1, "no_float16": True, "unknown_args": [ + "--learning_rate", + "1.5", "--max_steps", "2", "--per_device_train_batch_size", From 272451cabad17b04896d23fd0186b83753ab408c Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 13:56:58 +0000 Subject: [PATCH 08/26] Minor refactoring --- src/brevitas/graph/base.py | 29 +++++++++++++++-------------- src/brevitas/graph/quantize_impl.py | 2 +- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 1660efdba..8fd041a5e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -11,7 +11,7 @@ from torch.nn import Module from torch.overrides import get_testing_overrides -# TODO: Deprecate PyTorch 1.1 +# TODO: Deprecate PyTorch 1.11 try: from torch.nn.utils.parametrize import is_parametrized from torch.nn.utils.parametrize import register_parametrization @@ -191,21 +191,22 @@ def _init_new_module(self, old_module: Module, name=None): def _replace_old_module(self, model, old_module, new_module, load_state_dict=True): replace_module(model, old_module, new_module) if load_state_dict: - old_module_state_dict = old_module.state_dict() - # If parametrizations are present in old_module, the state_dict needs - # to be processed beforehand - if is_parametrized(old_module): + if not is_parametrized(old_module): + new_module.load_state_dict(old_module.state_dict()) + else: + old_module_state_dict = old_module.state_dict() + # If parametrizations are present in old_module, the state_dict needs + # to be processed beforehand old_module_state_dict = _remove_parametrization_entries_state_dict( old_module_state_dict) - # Strict can be set to True, since potential parametrizations were - # accounted for - new_module.load_state_dict(old_module_state_dict) - # If the old module is parametrized, these need to be transferred to the new module. - # The method transfer_parametrizations_and_params as it can result in parameter ties - # being broken - # NOTE: unsafe is set to True for efficiency, as the checks should have been done - # when first registering the parametrization to old_module - if is_parametrized(old_module): + # Strict can be set to True, since potential parametrizations were + # accounted for + new_module.load_state_dict(old_module_state_dict) + # If the old module is parametrized, these need to be transferred to the new module. + # The method transfer_parametrizations_and_params as it can result in parameter ties + # being broken + # NOTE: unsafe is set to True for efficiency, as the checks should have been done + # when first registering the parametrization to old_module for tensor_name in old_module.parametrizations: for param_func in old_module.parametrizations[tensor_name]: register_parametrization(new_module, tensor_name, param_func, unsafe=True) diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 392aa0079..c826481bf 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -# TODO: Deprecate PyTorch 1.1 +# TODO: Deprecate PyTorch 1.11 try: from torch.nn.utils.parametrize import type_before_parametrizations except ImportError: From 758e70d805d56bf9455ea76e2406d7613beb64d3 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 14:44:58 +0000 Subject: [PATCH 09/26] Prevent saves in tests --- .../llm/llm_quant/rotation_optimization.py | 9 +++++++++ tests/brevitas_examples/test_llm.py | 18 +++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 3865e740b..d84a50318 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -3,6 +3,7 @@ import os from typing import List, Optional +from accelerate.utils import DistributedType from datasets import Dataset import torch import transformers @@ -24,6 +25,14 @@ class TrainingArguments(transformers.TrainingArguments): def parse_optimization_rotation_args(unknown_args=None) -> None: parser = transformers.HfArgumentParser(TrainingArguments) training_args = parser.parse_args_into_dataclasses(args=unknown_args) + # If a single-process is running, only one GPU should be available + # for Trainer, to prevent using DataParallel, which was causing an + # error due to tensors in different devices being operated. + # Therefore, DistributedDataParallel should be used to run in + # multiple GPUs + if training_args[0].distributed_state.distributed_type == DistributedType.NO and training_args[ + 0]._n_gpu > 1: + training_args[0]._n_gpu = 1 return training_args[0] diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 69f3106e5..261022280 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -853,7 +853,9 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--per_device_train_batch_size", "1", "--gradient_accumulation_steps", - "1"], + "1", + "--save_strategy", + "no"], "float_ppl": 33238.8984375, "quant_ppl": 33232.65234375}, { @@ -875,7 +877,9 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--per_device_train_batch_size", "1", "--gradient_accumulation_steps", - "1"], + "1", + "--save_strategy", + "no"], "float_ppl": 33238.8984375, "quant_ppl": 33420.65234375}, { @@ -897,7 +901,9 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--per_device_train_batch_size", "1", "--gradient_accumulation_steps", - "1"], + "1", + "--save_strategy", + "no"], "float_ppl": 33238.8984375, "quant_ppl": 33290.48046875}, { @@ -918,8 +924,10 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "2", "--per_device_train_batch_size", "1", - "--gradient_accumulation_steps", - "1"], + "--,gradient_accumulation_steps", + "1", + "--save_strategy", + "no"], "float_ppl": 33238.8984375, "quant_ppl": 33204.80859375},]) def rotation_optimization_args_and_ppl(default_run_args, request): From e77824d93ec8b5ca1f7dce77c799f25593c2c94f Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 15:05:25 +0000 Subject: [PATCH 10/26] Fix typo --- tests/brevitas_examples/test_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 261022280..d4af0861e 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -924,7 +924,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "2", "--per_device_train_batch_size", "1", - "--,gradient_accumulation_steps", + "--gradient_accumulation_steps", "1", "--save_strategy", "no"], From ee5ceb2da5facd8fcfbd08560b416bfeea083071 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 19:20:55 +0000 Subject: [PATCH 11/26] Address comments and new tests --- src/brevitas/graph/quantize_impl.py | 8 +- src/brevitas_examples/llm/main.py | 13 ++++ tests/brevitas/graph/test_quantize.py | 49 ++++++++++++ tests/brevitas_examples/test_llm.py | 104 +++++++++++++++++--------- 4 files changed, 136 insertions(+), 38 deletions(-) diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index c826481bf..d0a0f4be8 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -408,8 +408,8 @@ def act_handler(model, layer_map): if node.op == 'call_module': module = get_module(model, node.target) if isinstance(module, tuple(layer_map.keys())): - if layer_map[type(module)] is not None: - quant_module_class, quant_module_kwargs = layer_map[type(module)] + if layer_map[type_before_parametrizations(module)] is not None: + quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)] quant_module = quant_module_class(**quant_module_kwargs) # Check for activation equalization mul nodes if len(node.users) == 1: @@ -470,8 +470,8 @@ def layer_handler( quant_identity_map=quant_identity_map, quant_act_map=quant_act_map, unsigned_act_tuple=unsigned_act_tuple) - if layer_map[type(module)] is not None: - quant_module_class, quant_module_kwargs = layer_map[type(module)] + if layer_map[type_before_parametrizations(module)] is not None: + quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)] # Quantize the input if is not quantized, input_quant is not specified, # and the quant_identity_map is provided. if not are_inputs_quantized_and_aligned( diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e2afb8db5..a2728fc1e 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -103,6 +103,19 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = full_rotation_method=args.rotation_mode, return_rewriters=True, sdpa_regions=args.rotation_sdpa_regions) + # NOTE: When fuse_rotations=False, parametrized rotations are applied, i.e. the weights of + # selected modules stop being attributes but, instead, properties, and their value is + # computed by passing the original value of the tensor through the forward passes of the + # parametrization modules. Parametrizations are registered using + # torch.nn.utils.parametrize.register_parametrization, which modifies the __class__ + # attribute of the parametrized module, e.g. "" + # changes to "". Therefore, + # algorithms that do type checking might need to use type_before_parametrizations(module), + # instead of only type(module) (see layerwise_layer_handler). Moreover, if, for instance, + # the "weight" attribute is parametrized, it will be removed from the attributes + # of the class. Consequently, quantization algorithms that rely on in-place modifications + # of the weights should not operate on parametrized modules. In this situation, parametrizations + # need to be removed beforehand by invoking fuse_parametrized_rotations new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py index ec15da6a2..62ce405c9 100644 --- a/tests/brevitas/graph/test_quantize.py +++ b/tests/brevitas/graph/test_quantize.py @@ -7,8 +7,10 @@ from brevitas.graph.base import _remove_parametrization_entries_state_dict from brevitas.graph.quantize import layerwise_quantize +from brevitas.graph.quantize import quantize from brevitas.utils.python_utils import recurse_getattr from brevitas.utils.rotation_utils import RotationWeightParametrization +from tests.marker import requires_pt_ge @pytest_cases.parametrize( @@ -142,3 +144,50 @@ def test_remove_parametrization_entries_state_dict(kwargs): assert key in expected_state_dict_keys, f"Unexpected key {key} in state_dict" # Compare tensor values assert torch.allclose(value, old_state_dict[key], rtol=0.0, atol=0.0), f"Value of tensor {value} does not match with that in the original state_dict" + + +@requires_pt_ge('2.3.1') +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'sample_input': torch.tensor([[0.8, -0.6]]), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected': ""},]) +def test_quantize_parametrized_modules(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + sample_input = kwargs['sample_input'] + model = kwargs["model"] + + graph_model, _ = torch._dynamo.export(model)(sample_input) + orig_module = recurse_getattr(model, key) + # Use tied weights to identify equivalent model + key, module = [(key, module) for key, module in graph_model.named_modules() if hasattr(module, "weight") and module.weight is orig_module.weight][0] + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + qmodel = quantize(graph_model) + checked = False + found_names = [] + for n, m in qmodel.named_modules(): + found_names.append(n) + if n == key: + mt = str(type(m)) + assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" + checked = True + assert checked, f"Layer named {key} not found. Layer names are: {found_names}" diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index d4af0861e..33d46a179 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -7,6 +7,7 @@ import os import platform import shutil +from unittest.mock import patch import numpy as np import onnx @@ -23,21 +24,16 @@ from tests.marker import jit_disabled_for_export from tests.marker import requires_pt_ge +ATOL_PPL = 2e+02 +RTOL_PPL = 1e-04 + def ptid2pathname(string): return string.replace("/", "-").replace(":", "-") -def allclose(x, y): - return np.allclose(x, y, rtol=1e-03, atol=1e+01, equal_nan=False) - - -def allveryclose(x, y): - return np.allclose(x, y, rtol=1e-04, atol=2e+02, equal_nan=False) - - -def allexact(x, y): - return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False) +def allclose(x, y, rtol=RTOL_PPL, atol=ATOL_PPL): + return np.allclose(x, y, rtol=rtol, atol=atol, equal_nan=False) def transformers_version_ge(required_version: str): @@ -252,8 +248,8 @@ def test_small_models_acc(caplog, acc_args_and_acc): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -294,8 +290,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -738,8 +734,8 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -760,7 +756,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "ort", "float_ppl": 33238.8984375, - "quant_ppl": 33232.65234375}, + "quant_ppl": 33232.65234375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -771,7 +767,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "ort", "float_ppl": 33238.8984375, - "quant_ppl": 33420.65234375}, + "quant_ppl": 33420.65234375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -782,7 +778,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "had", "float_ppl": 33238.8984375, - "quant_ppl": 33290.48046875}, + "quant_ppl": 33290.48046875,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -793,7 +789,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "had", "float_ppl": 33238.8984375, - "quant_ppl": 33204.80859375}, + "quant_ppl": 33204.80859375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -802,7 +798,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "replace_rmsnorm": True, "rotation": "layerwise", "float_ppl": 33238.8984375, - "quant_ppl": 33446.734375},]) + "quant_ppl": 33446.734375,},]) def rotation_ppl_args_and_ppl(default_run_args, request): args = default_run_args run_dict = request.param @@ -823,8 +819,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -857,7 +853,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--save_strategy", "no"], "float_ppl": 33238.8984375, - "quant_ppl": 33232.65234375}, + "quant_ppl": 33278.98828125, + "exp_layer_types_count": { + "": 4, + "": 1, + "": 1, + "": 14,}}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -881,7 +882,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--save_strategy", "no"], "float_ppl": 33238.8984375, - "quant_ppl": 33420.65234375}, + "quant_ppl": 33424.73046875, + "exp_layer_types_count": { + "": 0, + "": 1, + "": 1, + "": 14,}}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -905,7 +911,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--save_strategy", "no"], "float_ppl": 33238.8984375, - "quant_ppl": 33290.48046875}, + "quant_ppl": 33339.21875, + "exp_layer_types_count": { + "": 4, + "": 1, + "": 1, + "": 14,}}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -929,28 +940,53 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--save_strategy", "no"], "float_ppl": 33238.8984375, - "quant_ppl": 33204.80859375},]) -def rotation_optimization_args_and_ppl(default_run_args, request): + "quant_ppl": 33219.08984375, + "exp_layer_types_count": { + "": 0, + "": 1, + "": 1, + "": 14,}},]) +def rotation_optimization_args_layer_count_and_ppl(default_run_args, request): args = default_run_args run_dict = request.param unknown_args = run_dict["unknown_args"] float_ppl = run_dict["float_ppl"] quant_ppl = run_dict["quant_ppl"] + exp_layer_types_count = run_dict["exp_layer_types_count"] del run_dict["float_ppl"] del run_dict["quant_ppl"] del run_dict["unknown_args"] + del run_dict["exp_layer_types_count"] args.update(**run_dict) - yield args, unknown_args, float_ppl, quant_ppl + yield args, unknown_args, float_ppl, quant_ppl, exp_layer_types_count @requires_pt_ge('2.4') -def test_small_models_rotation_optimization_ppl(caplog, rotation_optimization_args_and_ppl): +def test_small_models_rotation_optimization_ppl( + caplog, rotation_optimization_args_layer_count_and_ppl): if platform.system() == "Windows": pytest.skip("Skipping dynamo + windows") + # Tolerances are stricter for this test, to ensure that it does not pass + # with non-optimized quantized perplexities + RTOL_ROT, ATOL_ROT = 1e-05, 2. caplog.set_level(logging.INFO) - args, unknown_args, exp_float_ppl, exp_quant_ppl = rotation_optimization_args_and_ppl - float_ppl, quant_ppl, model = validate_args_and_run_main(args, unknown_args) + args, unknown_args, exp_float_ppl, exp_quant_ppl, _ = rotation_optimization_args_layer_count_and_ppl + float_ppl, quant_ppl, _ = validate_args_and_run_main(args, unknown_args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl, rtol=RTOL_ROT, atol=ATOL_ROT), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@requires_pt_ge('2.4') +def test_small_models_rotation_optimization_layer_count( + caplog, rotation_optimization_args_layer_count_and_ppl): + if platform.system() == "Windows": + pytest.skip("Skipping dynamo + windows") + # Tolerances are stricter for this test, to ensure that it does not pass + # with non-optimized quantized perplexities + caplog.set_level(logging.INFO) + args, unknown_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl + with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model): + _, _, model = validate_args_and_run_main(args, unknown_args) + assert_layer_types_count(model, exp_layer_types_count) From 94bd3807e870a5b9442f1c7f3375d23132c572f0 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 17 Jan 2025 10:46:10 +0000 Subject: [PATCH 12/26] Rename HF args --- .../llm/llm_quant/rotation_optimization.py | 11 ++--- src/brevitas_examples/llm/main.py | 16 ++++--- tests/brevitas_examples/test_llm.py | 42 ++++++++----------- 3 files changed, 34 insertions(+), 35 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index d84a50318..84446ad80 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -20,11 +20,14 @@ class TrainingArguments(transformers.TrainingArguments): # By default, arguments are saved in the current working directory output_dir: Optional[str] = field(default=os.getcwd()) + # NOTE: Currently, there is no infrastructure to resume training + # from a checkpoint, so related files are not save by default + save_strategy: Optional[str] = field(default="no") -def parse_optimization_rotation_args(unknown_args=None) -> None: +def parse_rotation_optimization_args(extra_args: Optional[List[str]] = None) -> TrainingArguments: parser = transformers.HfArgumentParser(TrainingArguments) - training_args = parser.parse_args_into_dataclasses(args=unknown_args) + training_args = parser.parse_args_into_dataclasses(args=extra_args) # If a single-process is running, only one GPU should be available # for Trainer, to prevent using DataParallel, which was causing an # error due to tensors in different devices being operated. @@ -83,14 +86,12 @@ def apply_rotation_optimization( model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, train_dataset: DatasetToDevice, - unknown_args: List[str] = None, + training_args: TrainingArguments, ) -> None: # Prepare dataset and model for training train_dataset = _prepare_train_dataset(train_dataset) model = _prepare_model(model) - # Get training arguments - training_args = parse_optimization_rotation_args(unknown_args) # Enable skipping optimization if training_args.max_steps <= 0: return diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index a2728fc1e..0d1e2e288 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -6,6 +6,7 @@ from copy import deepcopy import functools import sys +from typing import List, Optional from warnings import warn from lm_eval import evaluator @@ -56,6 +57,7 @@ from brevitas_examples.llm.llm_quant.prepare_for_quantize import \ replace_sdpa_with_quantizable_layers from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization +from brevitas_examples.llm.llm_quant.rotation_optimization import parse_rotation_optimization_args from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -158,7 +160,9 @@ def model_export(model, ref_input, args): export_torch_qcdq(model, ref_input['input_ids'], export_path=f"{args.export_prefix}.pt") -def validate(args): +def validate(args, extra_args: Optional[List[str]] = None): + if args.rotation != "fused_no_fx_optimize": + assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith("--")]}" if args.functional_sdpa_quant: assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization" if args.rotation == 'fx': @@ -225,7 +229,7 @@ def validate(args): "or decreasing the sequence length (seqlen)") -def quantize_llm(args, unknown_args=None): +def quantize_llm(args, extra_args=None): validate(args) set_seed(args.seed) if args.export_prefix is None: @@ -288,6 +292,8 @@ def quantize_llm(args, unknown_args=None): fuse_sequences=args.fuse_sequences) if args.rotation in ["fused_no_fx_optimize"]: + # Extra arguments should be used as training arguments for rotation optimization + rot_optimization_args = parse_rotation_optimization_args(extra_args=extra_args) # Load the data for rotation optimization rot_calibration_loader = get_dataset_for_model( args.model, @@ -499,7 +505,7 @@ def quantize_llm(args, unknown_args=None): model=model, tokenizer=tokenizer, train_dataset=rot_calibration_loader, - unknown_args=unknown_args, + training_args=rot_optimization_args, ) # Remove hooks from optimization remove_hooks(model) @@ -968,8 +974,8 @@ def parse_args(args, override_defaults={}): def main(): overrides = override_defaults(sys.argv[1:]) - args, unknown_args = parse_args(sys.argv[1:], override_defaults=overrides) - quantize_llm(args, unknown_args) + args, extra_args = parse_args(sys.argv[1:], override_defaults=overrides) + quantize_llm(args, extra_args) if __name__ == '__main__': diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 33d46a179..a3aab2b24 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -48,9 +48,9 @@ def validate_args(args): assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `quantize_llm`" -def validate_args_and_run_main(args, unknown_args=None): +def validate_args_and_run_main(args, extra_args=None): validate_args(args) - float_ppl, quant_ppl, model = quantize_llm(args, unknown_args=unknown_args) + float_ppl, quant_ppl, model = quantize_llm(args, extra_args=extra_args) return float_ppl, quant_ppl, model @@ -841,7 +841,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_mode": "ort", "nsamples_rot_calibration": 2, "no_float16": True, - "unknown_args": [ + "extra_args": [ "--learning_rate", "1.5", "--max_steps", @@ -849,9 +849,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--per_device_train_batch_size", "1", "--gradient_accumulation_steps", - "1", - "--save_strategy", - "no"], + "1"], "float_ppl": 33238.8984375, "quant_ppl": 33278.98828125, "exp_layer_types_count": { @@ -870,7 +868,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_mode": "ort", "nsamples_rot_calibration": 2, "no_float16": True, - "unknown_args": [ + "extra_args": [ "--learning_rate", "1.5", "--max_steps", @@ -878,9 +876,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--per_device_train_batch_size", "1", "--gradient_accumulation_steps", - "1", - "--save_strategy", - "no"], + "1"], "float_ppl": 33238.8984375, "quant_ppl": 33424.73046875, "exp_layer_types_count": { @@ -899,7 +895,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_mode": "had", "nsamples_rot_calibration": 2, "no_float16": True, - "unknown_args": [ + "extra_args": [ "--learning_rate", "1.5", "--max_steps", @@ -907,9 +903,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--per_device_train_batch_size", "1", "--gradient_accumulation_steps", - "1", - "--save_strategy", - "no"], + "1"], "float_ppl": 33238.8984375, "quant_ppl": 33339.21875, "exp_layer_types_count": { @@ -928,7 +922,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_mode": "had", "nsamples_rot_calibration": 2, "no_float16": True, - "unknown_args": [ + "extra_args": [ "--learning_rate", "1.5", "--max_steps", @@ -936,9 +930,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--per_device_train_batch_size", "1", "--gradient_accumulation_steps", - "1", - "--save_strategy", - "no"], + "1"], "float_ppl": 33238.8984375, "quant_ppl": 33219.08984375, "exp_layer_types_count": { @@ -949,16 +941,16 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): def rotation_optimization_args_layer_count_and_ppl(default_run_args, request): args = default_run_args run_dict = request.param - unknown_args = run_dict["unknown_args"] + extra_args = run_dict["extra_args"] float_ppl = run_dict["float_ppl"] quant_ppl = run_dict["quant_ppl"] exp_layer_types_count = run_dict["exp_layer_types_count"] del run_dict["float_ppl"] del run_dict["quant_ppl"] - del run_dict["unknown_args"] + del run_dict["extra_args"] del run_dict["exp_layer_types_count"] args.update(**run_dict) - yield args, unknown_args, float_ppl, quant_ppl, exp_layer_types_count + yield args, extra_args, float_ppl, quant_ppl, exp_layer_types_count @requires_pt_ge('2.4') @@ -970,8 +962,8 @@ def test_small_models_rotation_optimization_ppl( # with non-optimized quantized perplexities RTOL_ROT, ATOL_ROT = 1e-05, 2. caplog.set_level(logging.INFO) - args, unknown_args, exp_float_ppl, exp_quant_ppl, _ = rotation_optimization_args_layer_count_and_ppl - float_ppl, quant_ppl, _ = validate_args_and_run_main(args, unknown_args) + args, extra_args, exp_float_ppl, exp_quant_ppl, _ = rotation_optimization_args_layer_count_and_ppl + float_ppl, quant_ppl, _ = validate_args_and_run_main(args, extra_args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" @@ -986,7 +978,7 @@ def test_small_models_rotation_optimization_layer_count( # Tolerances are stricter for this test, to ensure that it does not pass # with non-optimized quantized perplexities caplog.set_level(logging.INFO) - args, unknown_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl + args, extra_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model): - _, _, model = validate_args_and_run_main(args, unknown_args) + _, _, model = validate_args_and_run_main(args, extra_args) assert_layer_types_count(model, exp_layer_types_count) From d808cb285b164df3ca7b85c7d5a1f41632daa194 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 17 Jan 2025 15:54:11 +0000 Subject: [PATCH 13/26] Fix parametrization fusing --- src/brevitas/graph/equalize.py | 29 +++++++++++++++++++++++------ src/brevitas_examples/llm/main.py | 4 ++-- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index d7e28dee5..8188e64cc 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -17,6 +17,7 @@ import torch.nn as nn import torch.nn.utils.parametrize as parametrize +from brevitas import config from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node @@ -1444,17 +1445,33 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module): return model +def _retrieve_quant_state_dict(module: nn.Module) -> Dict[str, torch.Tensor]: + # Retrieve state dict components related to Brevitas quantizers + config._FULL_STATE_DICT = True + quant_state_dict = {k: v for k, v in module.state_dict().items() if "_quant" in k} + config._FULL_STATE_DICT = False + return quant_state_dict + + def fuse_parametrized_rotations(model: nn.Module) -> nn.Module: # First of all, parameters that have parametrizations need to be untied model = _untie_parameters_with_parametrizations(model) # Then, parametrizations can be safely removed for module in model.modules(): - # Names of the tensors that can potentially be parametrized - tensor_names = ["weight", "bias"] - # Remove parametrizations from each tensor - for tensor_name in tensor_names: - if parametrize.is_parametrized(module) and tensor_name in module.parametrizations: - parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True) + if parametrize.is_parametrized(module): + # Names of the tensors that can potentially be parametrized + tensor_names = ["weight", "bias"] + # Get the quantization-related entries of the module state_dict + quant_state_dict = _retrieve_quant_state_dict(module) + # Remove parametrizations from each tensor + for tensor_name in tensor_names: + if parametrize.is_parametrized(module) and tensor_name in module.parametrizations: + parametrize.remove_parametrizations( + module, tensor_name, leave_parametrized=True) + # Restore the state of quantization-related tensors, strict needs to be set to False + # as there will be missing keys + if len(quant_state_dict) > 0: + module.load_state_dict(quant_state_dict, strict=False) return model diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 0d1e2e288..e294892be 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -162,7 +162,7 @@ def model_export(model, ref_input, args): def validate(args, extra_args: Optional[List[str]] = None): if args.rotation != "fused_no_fx_optimize": - assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith("--")]}" + assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith('--')]}" if args.functional_sdpa_quant: assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization" if args.rotation == 'fx': @@ -230,7 +230,7 @@ def validate(args, extra_args: Optional[List[str]] = None): def quantize_llm(args, extra_args=None): - validate(args) + validate(args, extra_args) set_seed(args.seed) if args.export_prefix is None: args.export_prefix = f"{args.model.replace('/', '--')}" From cae27192163f87fdfa67497ac7e196ed1d2227cf Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 17 Jan 2025 16:18:48 +0000 Subject: [PATCH 14/26] Fix LLM tests --- tests/brevitas_examples/test_llm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index a3aab2b24..02fbfc349 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from argparse import Namespace +import copy from dataclasses import dataclass import logging import os @@ -851,7 +852,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--gradient_accumulation_steps", "1"], "float_ppl": 33238.8984375, - "quant_ppl": 33278.98828125, + "quant_ppl": 33239.33984375, "exp_layer_types_count": { "": 4, "": 1, @@ -878,7 +879,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--gradient_accumulation_steps", "1"], "float_ppl": 33238.8984375, - "quant_ppl": 33424.73046875, + "quant_ppl": 33423.0390625, "exp_layer_types_count": { "": 0, "": 1, @@ -905,7 +906,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--gradient_accumulation_steps", "1"], "float_ppl": 33238.8984375, - "quant_ppl": 33339.21875, + "quant_ppl": 33286.98828125, "exp_layer_types_count": { "": 4, "": 1, @@ -932,7 +933,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--gradient_accumulation_steps", "1"], "float_ppl": 33238.8984375, - "quant_ppl": 33219.08984375, + "quant_ppl": 33175.3046875, "exp_layer_types_count": { "": 0, "": 1, @@ -940,7 +941,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "": 14,}},]) def rotation_optimization_args_layer_count_and_ppl(default_run_args, request): args = default_run_args - run_dict = request.param + run_dict = copy.deepcopy(request.param) extra_args = run_dict["extra_args"] float_ppl = run_dict["float_ppl"] quant_ppl = run_dict["quant_ppl"] From f7cf5df593504f23e6bcc310923263c87d6b768e Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 17 Jan 2025 17:42:14 +0000 Subject: [PATCH 15/26] Add test for parametrization fusing in quantized module --- src/brevitas/graph/equalize.py | 28 ++++++------ src/brevitas_examples/llm/main.py | 30 ++++++------ tests/brevitas/graph/test_equalization.py | 56 +++++++++++++++++++++++ 3 files changed, 87 insertions(+), 27 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 8188e64cc..097441279 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -41,6 +41,8 @@ from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule from brevitas.nn.quant_scale_bias import ScaleBias +from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.utils.python_utils import recurse_getattr from brevitas.utils.rotation_utils import RotationWeightParametrization from brevitas.utils.torch_utils import KwargsForwardHook @@ -1445,14 +1447,6 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module): return model -def _retrieve_quant_state_dict(module: nn.Module) -> Dict[str, torch.Tensor]: - # Retrieve state dict components related to Brevitas quantizers - config._FULL_STATE_DICT = True - quant_state_dict = {k: v for k, v in module.state_dict().items() if "_quant" in k} - config._FULL_STATE_DICT = False - return quant_state_dict - - def fuse_parametrized_rotations(model: nn.Module) -> nn.Module: # First of all, parameters that have parametrizations need to be untied model = _untie_parameters_with_parametrizations(model) @@ -1461,17 +1455,23 @@ def fuse_parametrized_rotations(model: nn.Module) -> nn.Module: if parametrize.is_parametrized(module): # Names of the tensors that can potentially be parametrized tensor_names = ["weight", "bias"] - # Get the quantization-related entries of the module state_dict - quant_state_dict = _retrieve_quant_state_dict(module) # Remove parametrizations from each tensor for tensor_name in tensor_names: if parametrize.is_parametrized(module) and tensor_name in module.parametrizations: + # Check if the module has any quantization-related children + state_dict = None + for submodule in module.modules(): + if isinstance(submodule, + (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)): + state_dict = submodule.state_dict() + break + # The rotated tensor is saved by setting leave_parametrized=True parametrize.remove_parametrizations( module, tensor_name, leave_parametrized=True) - # Restore the state of quantization-related tensors, strict needs to be set to False - # as there will be missing keys - if len(quant_state_dict) > 0: - module.load_state_dict(quant_state_dict, strict=False) + # Restore the state of the quantization modules, as these might have been reset + # when registering the parametrized parameter + if state_dict is not None: + submodule.load_state_dict(state_dict) return model diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e294892be..0c201e38a 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -105,19 +105,23 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = full_rotation_method=args.rotation_mode, return_rewriters=True, sdpa_regions=args.rotation_sdpa_regions) - # NOTE: When fuse_rotations=False, parametrized rotations are applied, i.e. the weights of - # selected modules stop being attributes but, instead, properties, and their value is - # computed by passing the original value of the tensor through the forward passes of the - # parametrization modules. Parametrizations are registered using - # torch.nn.utils.parametrize.register_parametrization, which modifies the __class__ - # attribute of the parametrized module, e.g. "" - # changes to "". Therefore, - # algorithms that do type checking might need to use type_before_parametrizations(module), - # instead of only type(module) (see layerwise_layer_handler). Moreover, if, for instance, - # the "weight" attribute is parametrized, it will be removed from the attributes - # of the class. Consequently, quantization algorithms that rely on in-place modifications - # of the weights should not operate on parametrized modules. In this situation, parametrizations - # need to be removed beforehand by invoking fuse_parametrized_rotations + if not fuse_rotations: + # NOTE: When fuse_rotations=False, parametrized rotations are applied, i.e. the weights of + # selected modules stop being attributes but, instead, properties, and their value is + # computed by passing the original value of the tensor through the forward passes of the + # parametrization modules. Parametrizations are registered using + # torch.nn.utils.parametrize.register_parametrization, which modifies the __class__ + # attribute of the parametrized module, e.g. "" + # changes to "". Therefore, + # algorithms that do type checking might need to use type_before_parametrizations(module), + # instead of only type(module) (see layerwise_layer_handler). Moreover, if, for instance, + # the "weight" attribute is parametrized, it will be removed from the attributes + # of the class. Consequently, quantization algorithms that rely on in-place modifications + # of the weights should not operate on parametrized modules. In this situation, parametrizations + # need to be removed beforehand by invoking fuse_parametrized_rotations + warn( + "Using parametrized results might break type-checking, which could lead to unexpected behaviour." + ) new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index b491609d1..90ee91891 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -32,10 +32,13 @@ from brevitas.graph.equalize import random_orthogonal_matrix from brevitas.graph.equalize import Region from brevitas.graph.hadamard import get_hadK +from brevitas.graph.quantize import LAYERWISE_COMPUTE_LAYER_MAP +from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module from brevitas.nn.equalized_layer import RotatedModule +from brevitas.utils.python_utils import recurse_getattr from brevitas.utils.rotation_utils import RotationWeightParametrization from tests.marker import requires_pt_ge @@ -528,3 +531,56 @@ def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_r # Verify that the weights have changed with respect to the unrotated module for the modules that have received parametrizations # Verify that weights match between the fused and unfused model compare_model_weights(rotated_model_fused, rotated_model_unfused) + + +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'sample_input': torch.tensor([[0.8, -0.6]]), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected': ""},]) +def test_fuse_parametrized_modules(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + model = kwargs["model"] + sample_input = kwargs["sample_input"] + module = recurse_getattr(model, key) + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + compute_layer_map = copy.deepcopy(LAYERWISE_COMPUTE_LAYER_MAP) + module = recurse_getattr(model, key) + type_quant_module = parametrize.type_before_parametrizations(module) + compute_layer_map[type_quant_module][1]["weight_quant"] = compute_layer_map[type_quant_module][ + 1]["weight_quant"].let(scaling_impl_type='parameter_from_stats') + qmodel = layerwise_quantize(model, compute_layer_map=compute_layer_map) + # Calibration pass to initialize scales + with torch.no_grad(): + output = qmodel(sample_input) + # Fuse parametrizations + qmodel = fuse_parametrized_rotations(qmodel) + # Verify that scales were not lost + module = recurse_getattr(model, key) + assert module.weight_quant.tensor_quant.scaling_impl.init_done + assert not torch.allclose( + module.weight_quant.tensor_quant.scaling_impl.value, + torch.ones_like(module.weight_quant.tensor_quant.scaling_impl.value)) + # Compute output after fusing and check that it matches + with torch.no_grad(): + output_fused = qmodel(sample_input) + assert torch.allclose(output, output_fused, rtol=0.0, atol=0.0) From 15a584d292b70d0ef8dc8602d9d3958e157ca6df Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Sat, 18 Jan 2025 00:17:39 +0000 Subject: [PATCH 16/26] Skip test --- tests/brevitas/graph/test_equalization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 90ee91891..6989a35b4 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -533,6 +533,7 @@ def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_r compare_model_weights(rotated_model_fused, rotated_model_unfused) +@requires_pt_ge('2.3.1') @pytest_cases.parametrize( 'kwargs', [ From c602ddabb20a7acc9372744abf1bc13f3f2daab1 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 20 Jan 2025 11:45:06 +0000 Subject: [PATCH 17/26] Enable passing HF arguments through YAML --- src/brevitas_examples/llm/main.py | 10 +++ tests/brevitas_examples/llm_test_template.yml | 76 +++++++++++++++++++ tests/brevitas_examples/test_llm.py | 32 ++++++++ 3 files changed, 118 insertions(+) create mode 100644 tests/brevitas_examples/llm_test_template.yml diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 0c201e38a..9bd4765e1 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -971,6 +971,16 @@ def parse_args(args, override_defaults={}): type=str, nargs='*', help='A list of tasks for zero_shot evaluation. Default: %(default)s') + if override_defaults: + # Retrieve keys that are known to the parser + parser_keys = set(map(lambda action: action.dest, parser._actions)) + # Extract the entries in override_defaults that correspond to keys not known to the parser + extra_args_keys = [key for key in override_defaults.keys() if key not in parser_keys] + # Remove those entries from override_defaults, to prevent new keys being added to the argument + # parser and add them to args, to mimic as if they were passed by command line + for key in extra_args_keys: + args += [f"--{key}", str(override_defaults[key])] + del override_defaults[key] parser.set_defaults(**override_defaults) return parser.parse_known_args(args) diff --git a/tests/brevitas_examples/llm_test_template.yml b/tests/brevitas_examples/llm_test_template.yml new file mode 100644 index 000000000..a28ae6a09 --- /dev/null +++ b/tests/brevitas_examples/llm_test_template.yml @@ -0,0 +1,76 @@ +act_calibration: false +act_equalization: null +act_equalization_alpha: 0.5 +bias_corr: false +checkpoint_name: null +convert_layernorm_to_rmsnorm: false +dataset: wikitext2 +eval: false +export_prefix: null +export_target: null +few_shot_compile: false +few_shot_eval: false +few_shot_limit: null +few_shot_tasks: +- arc_challenge +- arc_easy +- winogrande +- piqa +few_shot_zeroshot: false +functional_sdpa_quant: false +fuse_sequences: false +gpfq: false +gptq: false +gpxq_act_order: false +gpxq_block_name: null +gpxq_create_weight_orig: false +gpxq_max_accumulator_bit_width: null +gpxq_max_accumulator_tile_size: null +gpxq_use_quant_activations: false +input_bit_width: null +input_group_size: 64 +input_param_method: stats +input_quant_format: int +input_quant_granularity: per_tensor +input_quant_type: asym +input_scale_precision: float_scale +input_scale_type: static +learned_round: null +learned_round_fast_update: false +learned_round_iters: 200 +learned_round_lr: 0.005 +learned_round_scale: false +learned_round_scale_lr: 0.01 +learned_round_scale_momentum: 0.9 +ln_affine_merge: false +load_awq: null +load_checkpoint: false +model: facebook/opt-125m +no_float16: false +no_quantize: false +nsamples: 128 +quant_sdpa: false +quantize_input_zero_point: false +quantize_last_layer: false +quantize_weight_zero_point: false +replace_mha: false +replace_rmsnorm: false +rotation: null +rotation_mode: had +rotation_orphan_sink: false +scale_rounding_func_type: null +scaling_min_val: 0.0001 +seed: 0 +seqlen: 2048 +weight_bit_width: 8 +weight_equalization: false +weight_group_dim: null +weight_group_size: 128 +weight_param_method: stats +weight_quant_format: int +weight_quant_granularity: per_group +weight_quant_type: sym +weight_scale_precision: float_scale +learning_rate: 1.5 +lr_scheduler_type: cosine +save_safetensors: false diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 02fbfc349..5e489eac1 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -20,6 +20,7 @@ from brevitas import config from brevitas import torch_version +from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args from brevitas_examples.llm.main import quantize_llm from tests.marker import jit_disabled_for_export @@ -983,3 +984,34 @@ def test_small_models_rotation_optimization_layer_count( with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model): _, _, model = validate_args_and_run_main(args, extra_args) assert_layer_types_count(model, exp_layer_types_count) + + +@pytest_cases.parametrize( + "kwargs", + [ + { + "yaml_file_path": + "./llm_test_template.yml", + "expected_extra_args": [ + "--learning_rate", + "1.5", + "--lr_scheduler_type", + "cosine", + "--save_safetensors", + "False"],},], + ids=lambda kwargs: kwargs["yaml_file_path"]) +def test_parse_yaml_trainer_arguments(caplog, kwargs): + caplog.set_level(logging.INFO) + yaml_file_path = kwargs["yaml_file_path"] + expected_extra_args = kwargs["expected_extra_args"] + extra_args_keys = [expected_extra_args[i][2:] for i in range(0, len(expected_extra_args), 2)] + + def quantize_llm_assert_args(args, extra_args=None): + for key in extra_args_keys: + assert key not in args, f"Key {key} should not be known by the parser" + assert extra_args == expected_extra_args, f"Expected extra arguments {expected_extra_args} but got {extra_args}" + + # Run the argument parsing logic of the LLM entrypoint + with patch("brevitas_examples.llm.main.quantize_llm", quantize_llm_assert_args): + with patch("brevitas_examples.llm.main.sys.argv", ["main.py", "--config", yaml_file_path]): + main() From 851d424f130246814f2b132ef7756d9a191a750c Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 20 Jan 2025 11:56:00 +0000 Subject: [PATCH 18/26] Update README --- src/brevitas_examples/llm/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index c1c9d9919..7e2cc335f 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -12,6 +12,8 @@ Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. +When using `--rotation fused_no_fx_optimize`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`. + ```bash usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] From f178e6694dcb135c1a81de1a3f9b1ad7ca7424cb Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 20 Jan 2025 11:59:06 +0000 Subject: [PATCH 19/26] Update file path --- tests/brevitas_examples/test_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 5e489eac1..7b17436a6 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -991,7 +991,7 @@ def test_small_models_rotation_optimization_layer_count( [ { "yaml_file_path": - "./llm_test_template.yml", + "./tests/brevitas_examples/llm_test_template.yml", "expected_extra_args": [ "--learning_rate", "1.5", From 97e2ab38d759902372fa3dc97e06190e1145ecfc Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 20 Jan 2025 15:39:01 +0000 Subject: [PATCH 20/26] Add extra comments --- src/brevitas_examples/llm/main.py | 9 ++++++--- tests/brevitas_examples/llm_test_template.yml | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 9bd4765e1..b4235e283 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -971,13 +971,16 @@ def parse_args(args, override_defaults={}): type=str, nargs='*', help='A list of tasks for zero_shot evaluation. Default: %(default)s') - if override_defaults: + if len(override_defaults) > 0: # Retrieve keys that are known to the parser parser_keys = set(map(lambda action: action.dest, parser._actions)) # Extract the entries in override_defaults that correspond to keys not known to the parser extra_args_keys = [key for key in override_defaults.keys() if key not in parser_keys] - # Remove those entries from override_defaults, to prevent new keys being added to the argument - # parser and add them to args, to mimic as if they were passed by command line + # Remove all the keys in override_defaults that are unknown to the parser and, instead, + # include them in args, as if they were passed as arguments to the command line. + # This prevents the keys of HF TrainingArguments from being added as arguments to the parser. + # Consequently, they will be part of the second value returned by parse_known_args (thus being + # used as extra_args in quantize_llm) for key in extra_args_keys: args += [f"--{key}", str(override_defaults[key])] del override_defaults[key] diff --git a/tests/brevitas_examples/llm_test_template.yml b/tests/brevitas_examples/llm_test_template.yml index a28ae6a09..0efe97998 100644 --- a/tests/brevitas_examples/llm_test_template.yml +++ b/tests/brevitas_examples/llm_test_template.yml @@ -71,6 +71,7 @@ weight_quant_format: int weight_quant_granularity: per_group weight_quant_type: sym weight_scale_precision: float_scale +# TrainingArguments for HF Trainer (see https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) learning_rate: 1.5 lr_scheduler_type: cosine save_safetensors: false From e04763b757c53fc3686a22cc6097c15960c9bc3b Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 20 Jan 2025 18:59:20 +0000 Subject: [PATCH 21/26] Minor change to comment --- src/brevitas/graph/equalize.py | 20 +++++++++++++++----- src/brevitas_examples/llm/main.py | 22 +++------------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 097441279..1519d7ba5 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1531,6 +1531,7 @@ def __init__( orphan_sink: bool = False, sdpa_regions: bool = False, rotate_matmul: bool = False, + fuse_rotations: bool = True, full_rotation_method: str = 'had', return_rewriters: bool = False) -> None: super(GraphRotationEqualization, self).__init__() @@ -1548,6 +1549,17 @@ def __init__( self.full_rotation_method = full_rotation_method self.return_rewriters = return_rewriters self.sdpa_regions = sdpa_regions + if not fuse_rotations: + # NOTE: When fuse_rotations=False, parametrized rotations are applied. This changes the attribute __class__ + # of the parametrized module, e.g. to"". + # Therefore, algorithms that do type checking might need to use type_before_parametrizations(module), + # instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications + # of the weights should not operate on parametrized modules. In this situation, parametrizations + # need to be removed beforehand by invoking fuse_parametrized_rotations + warnings.warn( + "Using parametrized results might break type-checking, which could lead to unexpected behaviour." + ) + self.fuse_rotations = fuse_rotations def rotate_matmuls(self, graph_module): matmul_nodes = list(graph_module.graph.nodes) @@ -1608,10 +1620,8 @@ def find_sink(node): m.pre_process_k = functional_rotate_input return regions - def apply( - self, - graph_model: GraphModule, - fuse_rotations: bool = True) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply(self, + graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1640,7 +1650,7 @@ def apply( self.rotate_matmuls(graph_model) if len(regions) > 0: rewriters = _apply_rotate( - graph_model, regions, self.full_rotation_method, fuse_rotations=fuse_rotations) + graph_model, regions, self.full_rotation_method, fuse_rotations=self.fuse_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b4235e283..d7f49e1a3 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -104,25 +104,9 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True, - sdpa_regions=args.rotation_sdpa_regions) - if not fuse_rotations: - # NOTE: When fuse_rotations=False, parametrized rotations are applied, i.e. the weights of - # selected modules stop being attributes but, instead, properties, and their value is - # computed by passing the original value of the tensor through the forward passes of the - # parametrization modules. Parametrizations are registered using - # torch.nn.utils.parametrize.register_parametrization, which modifies the __class__ - # attribute of the parametrized module, e.g. "" - # changes to "". Therefore, - # algorithms that do type checking might need to use type_before_parametrizations(module), - # instead of only type(module) (see layerwise_layer_handler). Moreover, if, for instance, - # the "weight" attribute is parametrized, it will be removed from the attributes - # of the class. Consequently, quantization algorithms that rely on in-place modifications - # of the weights should not operate on parametrized modules. In this situation, parametrizations - # need to be removed beforehand by invoking fuse_parametrized_rotations - warn( - "Using parametrized results might break type-checking, which could lead to unexpected behaviour." - ) - new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) + sdpa_regions=args.rotation_sdpa_regions, + fuse_rotations=fuse_rotations) + new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: # The weights between model and new_model are tied, so this check prevents From d769a7d78437ba4c9ca5776198c07daf11d67b74 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 21 Jan 2025 10:00:01 +0000 Subject: [PATCH 22/26] Add new flag to entrypoint --- src/brevitas/graph/equalize.py | 13 ++++++++----- src/brevitas_examples/llm/README.md | 2 +- src/brevitas_examples/llm/main.py | 25 ++++++++++++++++--------- tests/brevitas_examples/test_llm.py | 12 ++++++++---- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 1519d7ba5..8747e05d0 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1531,7 +1531,7 @@ def __init__( orphan_sink: bool = False, sdpa_regions: bool = False, rotate_matmul: bool = False, - fuse_rotations: bool = True, + use_parametrized_rotations: bool = False, full_rotation_method: str = 'had', return_rewriters: bool = False) -> None: super(GraphRotationEqualization, self).__init__() @@ -1549,8 +1549,8 @@ def __init__( self.full_rotation_method = full_rotation_method self.return_rewriters = return_rewriters self.sdpa_regions = sdpa_regions - if not fuse_rotations: - # NOTE: When fuse_rotations=False, parametrized rotations are applied. This changes the attribute __class__ + if use_parametrized_rotations: + # NOTE: When use_parametrized_rotations=False, parametrized rotations are applied. This changes the attribute __class__ # of the parametrized module, e.g. to"". # Therefore, algorithms that do type checking might need to use type_before_parametrizations(module), # instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications @@ -1559,7 +1559,7 @@ def __init__( warnings.warn( "Using parametrized results might break type-checking, which could lead to unexpected behaviour." ) - self.fuse_rotations = fuse_rotations + self.use_parametrized_rotations = use_parametrized_rotations def rotate_matmuls(self, graph_module): matmul_nodes = list(graph_module.graph.nodes) @@ -1650,7 +1650,10 @@ def apply(self, self.rotate_matmuls(graph_model) if len(regions) > 0: rewriters = _apply_rotate( - graph_model, regions, self.full_rotation_method, fuse_rotations=self.fuse_rotations) + graph_model, + regions, + self.full_rotation_method, + fuse_rotations=not self.use_parametrized_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 7e2cc335f..e38218b55 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -12,7 +12,7 @@ Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. -When using `--rotation fused_no_fx_optimize`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`. +When using `--optimize-rotations`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`. ```bash usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index d7f49e1a3..f96227930 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -85,7 +85,7 @@ def set_seed(seed): torch.random.manual_seed(seed) -def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = True): +def fused_rotation_no_fx(model, calibration_loader, args): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)): @@ -105,7 +105,7 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = full_rotation_method=args.rotation_mode, return_rewriters=True, sdpa_regions=args.rotation_sdpa_regions, - fuse_rotations=fuse_rotations) + use_parametrized_rotations=args.optimize_rotations) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: @@ -149,7 +149,9 @@ def model_export(model, ref_input, args): def validate(args, extra_args: Optional[List[str]] = None): - if args.rotation != "fused_no_fx_optimize": + if args.optimize_rotations: + assert args.rotation in ['fx', 'fused_no_fx'], f"Rotations can only be optimized if --rotation=fx or --rotation=fused_no_fx" + else: assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith('--')]}" if args.functional_sdpa_quant: assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization" @@ -279,7 +281,7 @@ def quantize_llm(args, extra_args=None): device=None, fuse_sequences=args.fuse_sequences) - if args.rotation in ["fused_no_fx_optimize"]: + if args.optimize_rotations: # Extra arguments should be used as training arguments for rotation optimization rot_optimization_args = parse_rotation_optimization_args(extra_args=extra_args) # Load the data for rotation optimization @@ -353,7 +355,8 @@ def quantize_llm(args, extra_args=None): eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, - sdpa_regions=args.rotation_sdpa_regions) + sdpa_regions=args.rotation_sdpa_regions, + use_parametrized_rotations=args.optimize_rotations) model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise': @@ -361,8 +364,6 @@ def quantize_llm(args, extra_args=None): model = eq.apply(model) elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) - elif args.rotation == 'fused_no_fx_optimize': - fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations=False) if args.weight_equalization: print("Apply weight equalization...") @@ -488,7 +489,7 @@ def quantize_llm(args, extra_args=None): for k, v in dict_hooks.items(): k._hf_hook.post_forward = v - if args.rotation in ['fused_no_fx_optimize']: + if args.optimize_rotations: apply_rotation_optimization( model=model, tokenizer=tokenizer, @@ -859,8 +860,14 @@ def parse_args(args, override_defaults={}): '--rotation', type=str, default=None, - choices=['fx', 'layerwise', 'fused_no_fx', 'fused_no_fx_optimize'], + choices=['fx', 'layerwise', 'fused_no_fx'], help='Apply graph rotation equalization') + parser.add_argument( + "--optimize-rotations", + action="store_true", + default=False, + help="Whether to optimize the rotations (default: %(default)s).", + ) parser.add_argument( '--rotation-mode', default='had', diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 7b17436a6..5c2b60d02 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -838,7 +838,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "weight_bit_width": 4, "input_bit_width": None, "replace_rmsnorm": True, - "rotation": "fused_no_fx_optimize", + "rotation": "fused_no_fx", + "optimize_rotations": True, "rotation_orphan_sink": True, "rotation_mode": "ort", "nsamples_rot_calibration": 2, @@ -865,7 +866,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "weight_bit_width": 4, "input_bit_width": None, "replace_rmsnorm": True, - "rotation": "fused_no_fx_optimize", + "rotation": "fused_no_fx", + "optimize_rotations": True, "rotation_orphan_sink": False, "rotation_mode": "ort", "nsamples_rot_calibration": 2, @@ -892,7 +894,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "weight_bit_width": 4, "input_bit_width": None, "replace_rmsnorm": True, - "rotation": "fused_no_fx_optimize", + "rotation": "fused_no_fx", + "optimize_rotations": True, "rotation_orphan_sink": True, "rotation_mode": "had", "nsamples_rot_calibration": 2, @@ -919,7 +922,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "weight_bit_width": 4, "input_bit_width": None, "replace_rmsnorm": True, - "rotation": "fused_no_fx_optimize", + "rotation": "fused_no_fx", + "optimize_rotations": True, "rotation_orphan_sink": False, "rotation_mode": "had", "nsamples_rot_calibration": 2, From aefde7ddb3e1f7231aa595e8800383c2c9290277 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 21 Jan 2025 12:16:20 +0000 Subject: [PATCH 23/26] Address final comments --- src/brevitas/graph/equalize.py | 1 - src/brevitas_examples/llm/README.md | 4 ++-- src/brevitas_examples/llm/config/default_template.yml | 1 + tests/brevitas_examples/llm_test_template.yml | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 8747e05d0..5197c571c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -17,7 +17,6 @@ import torch.nn as nn import torch.nn.utils.parametrize as parametrize -from brevitas import config from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index e38218b55..4807a9e50 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -51,8 +51,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa] [--functional-sdpa-quant] [--replace-mha] [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] - [--rotation-mode {had,ort}] [--rotation-orphan-sink] - [--rotation-sdpa-regions] + [--rotation-mode {had,ort}] [--optimize-rotations] + [--rotation-orphan-sink] [--rotation-sdpa-regions] [--act-equalization {None,layerwise,fx}] [--act-equalization-alpha ACT_EQUALIZATION_ALPHA] [--load-awq LOAD_AWQ] diff --git a/src/brevitas_examples/llm/config/default_template.yml b/src/brevitas_examples/llm/config/default_template.yml index b7d1ab864..1dc2df871 100644 --- a/src/brevitas_examples/llm/config/default_template.yml +++ b/src/brevitas_examples/llm/config/default_template.yml @@ -49,6 +49,7 @@ model: facebook/opt-125m no_float16: false no_quantize: false nsamples: 128 +optimize_rotations: false quant_sdpa: false quantize_input_zero_point: false quantize_last_layer: false diff --git a/tests/brevitas_examples/llm_test_template.yml b/tests/brevitas_examples/llm_test_template.yml index 0efe97998..f37c99152 100644 --- a/tests/brevitas_examples/llm_test_template.yml +++ b/tests/brevitas_examples/llm_test_template.yml @@ -49,6 +49,7 @@ model: facebook/opt-125m no_float16: false no_quantize: false nsamples: 128 +optimize_rotations: false quant_sdpa: false quantize_input_zero_point: false quantize_last_layer: false From aaf74bbfb76ac9eb90a9bfd8faa25d780555cbde Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 27 Jan 2025 10:36:25 +0000 Subject: [PATCH 24/26] Fix last tests --- tests/brevitas/graph/test_quantize.py | 2 ++ tests/brevitas_examples/test_llm.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py index 62ce405c9..a6c39942e 100644 --- a/tests/brevitas/graph/test_quantize.py +++ b/tests/brevitas/graph/test_quantize.py @@ -11,6 +11,7 @@ from brevitas.utils.python_utils import recurse_getattr from brevitas.utils.rotation_utils import RotationWeightParametrization from tests.marker import requires_pt_ge +from tests.marker import requires_pt_lt @pytest_cases.parametrize( @@ -64,6 +65,7 @@ def test_layerwise_quantize_blacklist(kwargs): K: torch.matmul(tensor, rot_mat), 'key': '0', 'expected': ""},]) +@requires_pt_lt('2.2.2', 'Windows') def test_layerwise_quantize_parametrized_modules(kwargs): key = kwargs['key'] exp = kwargs['expected'] diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 5c2b60d02..e9ec0b330 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -962,8 +962,8 @@ def rotation_optimization_args_layer_count_and_ppl(default_run_args, request): @requires_pt_ge('2.4') def test_small_models_rotation_optimization_ppl( caplog, rotation_optimization_args_layer_count_and_ppl): - if platform.system() == "Windows": - pytest.skip("Skipping dynamo + windows") + if platform.system() != "Linux": + pytest.skip("Skipping dynamo + windows/macos") # Tolerances are stricter for this test, to ensure that it does not pass # with non-optimized quantized perplexities RTOL_ROT, ATOL_ROT = 1e-05, 2. From efd6dbd954539a5537b2f1eb41b35a38a6fee334 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 27 Jan 2025 13:40:32 +0000 Subject: [PATCH 25/26] Correct test skipping --- tests/brevitas/graph/test_quantize.py | 3 ++- tests/brevitas_examples/test_llm.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py index a6c39942e..8f8ae3c93 100644 --- a/tests/brevitas/graph/test_quantize.py +++ b/tests/brevitas/graph/test_quantize.py @@ -65,7 +65,6 @@ def test_layerwise_quantize_blacklist(kwargs): K: torch.matmul(tensor, rot_mat), 'key': '0', 'expected': ""},]) -@requires_pt_lt('2.2.2', 'Windows') def test_layerwise_quantize_parametrized_modules(kwargs): key = kwargs['key'] exp = kwargs['expected'] @@ -162,6 +161,8 @@ def test_remove_parametrization_entries_state_dict(kwargs): 'key': '0', 'expected': ""},]) def test_quantize_parametrized_modules(kwargs): + if platform.system() == "Windows": + pytest.skip("Skipping dynamo + windows") key = kwargs['key'] exp = kwargs['expected'] rot_mat = kwargs['rot_mat'] diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index e9ec0b330..5d9bb10fa 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -979,8 +979,8 @@ def test_small_models_rotation_optimization_ppl( @requires_pt_ge('2.4') def test_small_models_rotation_optimization_layer_count( caplog, rotation_optimization_args_layer_count_and_ppl): - if platform.system() == "Windows": - pytest.skip("Skipping dynamo + windows") + if platform.system() != "Linux": + pytest.skip("Skipping dynamo + windows/macos") # Tolerances are stricter for this test, to ensure that it does not pass # with non-optimized quantized perplexities caplog.set_level(logging.INFO) From 0039bf4a3ebdc4999bf36cdc852829e3b187b412 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 27 Jan 2025 15:14:16 +0100 Subject: [PATCH 26/26] Update test_quantize.py --- tests/brevitas/graph/test_quantize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py index 8f8ae3c93..3a92a3e28 100644 --- a/tests/brevitas/graph/test_quantize.py +++ b/tests/brevitas/graph/test_quantize.py @@ -1,4 +1,5 @@ import copy +import platform import pytest_cases import torch