From 024d4a7da4dd562dc290a8315cfecc5035396737 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago <44771380+pablomlago@users.noreply.github.com> Date: Tue, 28 Jan 2025 12:01:49 +0000 Subject: [PATCH] Feat (equalize): enable rotation matrix optimization (#1155) --------- Co-authored-by: Giuseppe Franco --- src/brevitas/graph/base.py | 65 ++++- src/brevitas/graph/equalize.py | 48 +++- src/brevitas/graph/quantize_impl.py | 21 +- src/brevitas/utils/rotation_utils.py | 16 +- src/brevitas/utils/torch_utils.py | 36 +++ src/brevitas_examples/llm/README.md | 6 +- .../llm/config/default_template.yml | 1 + .../llm/llm_quant/rotation_optimization.py | 119 +++++++++ src/brevitas_examples/llm/main.py | 87 +++++- tests/brevitas/graph/test_equalization.py | 61 ++++- tests/brevitas/graph/test_quantize.py | 153 +++++++++++ tests/brevitas_examples/llm_test_template.yml | 78 ++++++ tests/brevitas_examples/test_llm.py | 248 ++++++++++++++++-- 13 files changed, 872 insertions(+), 67 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/rotation_optimization.py create mode 100644 tests/brevitas_examples/llm_test_template.yml diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index d1631f34e..8fd041a5e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -5,21 +5,25 @@ from abc import abstractmethod import inspect from inspect import getcallargs -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Dict, Type import torch -from torch import Tensor from torch.nn import Module -from torch.nn import Parameter -import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides +# TODO: Deprecate PyTorch 1.11 +try: + from torch.nn.utils.parametrize import is_parametrized + from torch.nn.utils.parametrize import register_parametrization +except ImportError: + from brevitas.utils.torch_utils import is_parametrized + register_parametrization = None + from brevitas.fx import GraphModule from brevitas.fx import immutable_dict from brevitas.fx import Node from brevitas.graph.utils import * from brevitas.utils.python_utils import islambda -from brevitas.utils.rotation_utils import RotationWeightParametrization __all__ = [ 'Transform', @@ -109,6 +113,34 @@ def apply(self, model: Module, *model_args, **model_kwargs): return model +def _remove_parametrization_entries_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + # Keys for values related to parametrizations + keys_to_remove = [] + # Keys/values corresponding to the original tensors before parametrizations + keys_values_to_add = [] + # Iterate over state_dict identifying the keys related to parametrizations + for key, value in state_dict.items(): + split_key = key.split(".") + # Keys for values before parametrizations have the format "prefix.parametrizations.tensor_name.original" + if len(split_key + ) >= 3 and split_key[-3] == "parametrizations" and split_key[-1] == "original": + tensor_name = split_key[-2] + # Name of dictionary entry is "prefix.tensro_name" + keys_values_to_add.append((".".join(split_key[:-3] + [tensor_name]), value)) + # Keys corresponding to the parametrizations attached to the model need to be removed + # to make sure the dictionary can be loaded with no missing/unused keys + # NOTE: For safety, an additional check could be added as this logic would not work if a model + # without parametrizations has any key containing "parametrizations" + if "parametrizations" in split_key: + keys_to_remove.append(key) + # Apply changes in-place to the state_dict + for key in keys_to_remove: + del state_dict[key] + for key, value in keys_values_to_add: + state_dict[key] = value + return state_dict + + class ModuleToModule(GraphTransform, ABC): def __init__(self, new_module_class, **kwargs): @@ -159,7 +191,25 @@ def _init_new_module(self, old_module: Module, name=None): def _replace_old_module(self, model, old_module, new_module, load_state_dict=True): replace_module(model, old_module, new_module) if load_state_dict: - new_module.load_state_dict(old_module.state_dict()) + if not is_parametrized(old_module): + new_module.load_state_dict(old_module.state_dict()) + else: + old_module_state_dict = old_module.state_dict() + # If parametrizations are present in old_module, the state_dict needs + # to be processed beforehand + old_module_state_dict = _remove_parametrization_entries_state_dict( + old_module_state_dict) + # Strict can be set to True, since potential parametrizations were + # accounted for + new_module.load_state_dict(old_module_state_dict) + # If the old module is parametrized, these need to be transferred to the new module. + # The method transfer_parametrizations_and_params as it can result in parameter ties + # being broken + # NOTE: unsafe is set to True for efficiency, as the checks should have been done + # when first registering the parametrization to old_module + for tensor_name in old_module.parametrizations: + for param_func in old_module.parametrizations[tensor_name]: + register_parametrization(new_module, tensor_name, param_func, unsafe=True) class InsertModuleCallAfter(GraphTransform): @@ -211,8 +261,7 @@ def apply(self, model: GraphModule) -> GraphModule: for module in model.modules(): if module is self.module: # register the parametrization to module - parametrize.register_parametrization( - module, self.tensor_name, self.transform_module) + register_parametrization(module, self.tensor_name, self.transform_module) break return model diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 31b2d4f72..5197c571c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -40,6 +40,8 @@ from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule from brevitas.nn.quant_scale_bias import ScaleBias +from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.utils.python_utils import recurse_getattr from brevitas.utils.rotation_utils import RotationWeightParametrization from brevitas.utils.torch_utils import KwargsForwardHook @@ -1444,17 +1446,31 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module): return model -def _fuse_rotations(model: nn.Module) -> nn.Module: +def fuse_parametrized_rotations(model: nn.Module) -> nn.Module: # First of all, parameters that have parametrizations need to be untied model = _untie_parameters_with_parametrizations(model) # Then, parametrizations can be safely removed for module in model.modules(): - # Names of the tensors that can potentially be parametrized - tensor_names = ["weight", "bias"] - # Remove parametrizations from each tensor - for tensor_name in tensor_names: - if parametrize.is_parametrized(module) and tensor_name in module.parametrizations: - parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True) + if parametrize.is_parametrized(module): + # Names of the tensors that can potentially be parametrized + tensor_names = ["weight", "bias"] + # Remove parametrizations from each tensor + for tensor_name in tensor_names: + if parametrize.is_parametrized(module) and tensor_name in module.parametrizations: + # Check if the module has any quantization-related children + state_dict = None + for submodule in module.modules(): + if isinstance(submodule, + (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)): + state_dict = submodule.state_dict() + break + # The rotated tensor is saved by setting leave_parametrized=True + parametrize.remove_parametrizations( + module, tensor_name, leave_parametrized=True) + # Restore the state of the quantization modules, as these might have been reset + # when registering the parametrized parameter + if state_dict is not None: + submodule.load_state_dict(state_dict) return model @@ -1514,6 +1530,7 @@ def __init__( orphan_sink: bool = False, sdpa_regions: bool = False, rotate_matmul: bool = False, + use_parametrized_rotations: bool = False, full_rotation_method: str = 'had', return_rewriters: bool = False) -> None: super(GraphRotationEqualization, self).__init__() @@ -1531,6 +1548,17 @@ def __init__( self.full_rotation_method = full_rotation_method self.return_rewriters = return_rewriters self.sdpa_regions = sdpa_regions + if use_parametrized_rotations: + # NOTE: When use_parametrized_rotations=False, parametrized rotations are applied. This changes the attribute __class__ + # of the parametrized module, e.g. to"". + # Therefore, algorithms that do type checking might need to use type_before_parametrizations(module), + # instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications + # of the weights should not operate on parametrized modules. In this situation, parametrizations + # need to be removed beforehand by invoking fuse_parametrized_rotations + warnings.warn( + "Using parametrized results might break type-checking, which could lead to unexpected behaviour." + ) + self.use_parametrized_rotations = use_parametrized_rotations def rotate_matmuls(self, graph_module): matmul_nodes = list(graph_module.graph.nodes) @@ -1620,7 +1648,11 @@ def apply(self, if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + rewriters = _apply_rotate( + graph_model, + regions, + self.full_rotation_method, + fuse_rotations=not self.use_parametrized_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 535f9a8f9..d0a0f4be8 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -7,7 +7,12 @@ import torch import torch.nn as nn -import brevitas +# TODO: Deprecate PyTorch 1.11 +try: + from torch.nn.utils.parametrize import type_before_parametrizations +except ImportError: + from brevitas.utils.torch_utils import type_before_parametrizations + from brevitas.graph.base import InsertModuleCallAfter from brevitas.graph.base import ModuleInstanceToModuleInstance from brevitas.graph.base import ModuleToModuleByInstance @@ -403,8 +408,8 @@ def act_handler(model, layer_map): if node.op == 'call_module': module = get_module(model, node.target) if isinstance(module, tuple(layer_map.keys())): - if layer_map[type(module)] is not None: - quant_module_class, quant_module_kwargs = layer_map[type(module)] + if layer_map[type_before_parametrizations(module)] is not None: + quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)] quant_module = quant_module_class(**quant_module_kwargs) # Check for activation equalization mul nodes if len(node.users) == 1: @@ -465,8 +470,8 @@ def layer_handler( quant_identity_map=quant_identity_map, quant_act_map=quant_act_map, unsigned_act_tuple=unsigned_act_tuple) - if layer_map[type(module)] is not None: - quant_module_class, quant_module_kwargs = layer_map[type(module)] + if layer_map[type_before_parametrizations(module)] is not None: + quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)] # Quantize the input if is not quantized, input_quant is not specified, # and the quant_identity_map is provided. if not are_inputs_quantized_and_aligned( @@ -511,7 +516,7 @@ def find_module( Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its Linear submodules. """ - if _module_class_name(type(model)) in layer_map.keys(): + if _module_class_name(type_before_parametrizations(model)) in layer_map.keys(): module_to_replace.append(model) else: for name, module in model.named_children(): @@ -532,8 +537,8 @@ def layerwise_layer_handler( find_module(model, layer_map, module_to_replace, name_blacklist) rewriters = [] for module in module_to_replace: - if layer_map[_module_class_name(type(module))] is not None: - quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))] + if layer_map[_module_class_name(type_before_parametrizations(module))] is not None: + quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) for rewriter in rewriters: diff --git a/src/brevitas/utils/rotation_utils.py b/src/brevitas/utils/rotation_utils.py index 6a79d1cc3..e5f92c50b 100644 --- a/src/brevitas/utils/rotation_utils.py +++ b/src/brevitas/utils/rotation_utils.py @@ -1,9 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Callable, Optional +from typing import Callable, List, Optional import torch +from torch import nn from torch import Tensor @@ -47,3 +48,16 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: raise RuntimeError("Not supported yet") return tensor + + +def extract_trainable_rotation_matrices(model: nn.Module) -> List[nn.Parameter]: + trainable_rotations = [] + # IDs of the rotation matrices are tracked, as several modules can share + # the same parametrized rotation + ids_rot = set() + for module in model.modules(): + if isinstance(module, RotationWeightParametrization): + if id(module.rot_mat) not in ids_rot: + ids_rot.add(id(module.rot_mat)) + trainable_rotations.append(module.rot_mat) + return trainable_rotations diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 93e4435de..684ce439d 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -146,3 +146,39 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn return decorator + + +# TODO: Remove after deprecating PyTorch 1.11 +def is_parametrized(module: torch.nn.Module, tensor_name: Optional[str] = None) -> bool: + r"""Determine if a module has a parametrization. + + Args: + module (nn.Module): module to query + tensor_name (str, optional): name of the parameter in the module + Default: ``None`` + Returns: + ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`, + or if it has any parametrization when :attr:`tensor_name` is ``None``; + otherwise ``False`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, torch.nn.ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 + else: + return tensor_name in parametrizations + + +# TODO: Remove after deprecating PyTorch 1.11 +def type_before_parametrizations(module: torch.nn.Module) -> type: + r"""Return the module type before parametrizations were applied and if not, then it returns the module type. + + Args: + module (nn.Module): module to get type of + """ + if is_parametrized(module): + return module.__class__.__bases__[0] + else: + return type(module) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index c1c9d9919..4807a9e50 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -12,6 +12,8 @@ Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. +When using `--optimize-rotations`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`. + ```bash usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] @@ -49,8 +51,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa] [--functional-sdpa-quant] [--replace-mha] [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] - [--rotation-mode {had,ort}] [--rotation-orphan-sink] - [--rotation-sdpa-regions] + [--rotation-mode {had,ort}] [--optimize-rotations] + [--rotation-orphan-sink] [--rotation-sdpa-regions] [--act-equalization {None,layerwise,fx}] [--act-equalization-alpha ACT_EQUALIZATION_ALPHA] [--load-awq LOAD_AWQ] diff --git a/src/brevitas_examples/llm/config/default_template.yml b/src/brevitas_examples/llm/config/default_template.yml index b7d1ab864..1dc2df871 100644 --- a/src/brevitas_examples/llm/config/default_template.yml +++ b/src/brevitas_examples/llm/config/default_template.yml @@ -49,6 +49,7 @@ model: facebook/opt-125m no_float16: false no_quantize: false nsamples: 128 +optimize_rotations: false quant_sdpa: false quantize_input_zero_point: false quantize_last_layer: false diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py new file mode 100644 index 000000000..84446ad80 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from dataclasses import field +import os +from typing import List, Optional + +from accelerate.utils import DistributedType +from datasets import Dataset +import torch +import transformers +from transformers import Trainer +from transformers.tokenization_utils import PreTrainedTokenizerBase + +from brevitas.optim.cailey_sgd import CaileySGD +from brevitas.utils.rotation_utils import extract_trainable_rotation_matrices +from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.llm.llm_quant.data_utils import DatasetToDevice + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + # By default, arguments are saved in the current working directory + output_dir: Optional[str] = field(default=os.getcwd()) + # NOTE: Currently, there is no infrastructure to resume training + # from a checkpoint, so related files are not save by default + save_strategy: Optional[str] = field(default="no") + + +def parse_rotation_optimization_args(extra_args: Optional[List[str]] = None) -> TrainingArguments: + parser = transformers.HfArgumentParser(TrainingArguments) + training_args = parser.parse_args_into_dataclasses(args=extra_args) + # If a single-process is running, only one GPU should be available + # for Trainer, to prevent using DataParallel, which was causing an + # error due to tensors in different devices being operated. + # Therefore, DistributedDataParallel should be used to run in + # multiple GPUs + if training_args[0].distributed_state.distributed_type == DistributedType.NO and training_args[ + 0]._n_gpu > 1: + training_args[0]._n_gpu = 1 + return training_args[0] + + +# Function to create a batch +def collate_fn(kwargs_list, return_tensors="pt"): + kwargs = {} + for curr_dict in kwargs_list: + for key, value in curr_dict.items(): + if isinstance(value, torch.Tensor): + if key not in kwargs: + kwargs[key] = [] + kwargs[key].append(value) + else: + if key not in kwargs: + kwargs[key] = value + for key, value in kwargs.items(): + if isinstance(value, list) and len(value) > 0: + kwargs[key] = torch.cat(kwargs[key], dim=0) + return kwargs + + +def _prepare_train_dataset(train_dataset: DatasetToDevice) -> Dataset: + return DatasetToDevice( + data=[{ + "input_ids": train_datapoint["input_ids"], "labels": train_datapoint["input_ids"]} + for train_datapoint in train_dataset.data], + device=None) + + +def _prepare_model(model: torch.nn.Module) -> torch.nn.Module: + # For a PretrainedModel, the Trainer in accelerate calls save_pretrained after + # finishing the optimization. However, this method no longer works after + # registering parametrizations/quantizing, so this method is mocked to prevent + # a crash. + def mock_save_pretrained_fn(*args, **kwargs): + pass + + model.save_pretrained = mock_save_pretrained_fn + # Cache needs to be disabled for training + model.config.use_cache = False + # Loss for training + model.config.loss_type = "ForCausalLM" + + return model + + +def apply_rotation_optimization( + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + train_dataset: DatasetToDevice, + training_args: TrainingArguments, +) -> None: + + # Prepare dataset and model for training + train_dataset = _prepare_train_dataset(train_dataset) + model = _prepare_model(model) + # Enable skipping optimization + if training_args.max_steps <= 0: + return + # Remove hooks and empty cache before starting optimization + remove_hooks(model) + torch.cuda.empty_cache() + # Set to False the model parameters + for param in model.parameters(): + param.requires_grad = False + # Collect trainable matrices + trainable_rotations = extract_trainable_rotation_matrices(model) + for rot_mat in trainable_rotations: + rot_mat.requires_grad = True + optimizer = CaileySGD(trainable_rotations, lr=training_args.learning_rate, stiefel=True) + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + eval_dataset=None, + data_collator=collate_fn, + optimizers=(optimizer, None)) + trainer.train() + # After finishing training, set eval mode again + model.eval() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 7567f1727..f96227930 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -6,6 +6,7 @@ from copy import deepcopy import functools import sys +from typing import List, Optional from warnings import warn from lm_eval import evaluator @@ -22,7 +23,8 @@ from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph import load_quant_model_mode -from brevitas.graph.base import ModuleInstanceWrapModule +from brevitas.graph.base import ModuleInstanceTransformTensor +from brevitas.graph.equalize import fuse_parametrized_rotations from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import functional_quantization_mode @@ -54,6 +56,8 @@ from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.prepare_for_quantize import \ replace_sdpa_with_quantizable_layers +from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization +from brevitas_examples.llm.llm_quant.rotation_optimization import parse_rotation_optimization_args from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -89,6 +93,7 @@ def fused_rotation_no_fx(model, calibration_loader, args): new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add) apply_layernorm_affine_merge(new_model) + # NOTE: This call breaks ties between the the lm_head and the embedding layer new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -99,14 +104,15 @@ def fused_rotation_no_fx(model, calibration_loader, args): orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True, - sdpa_regions=args.rotation_sdpa_regions) + sdpa_regions=args.rotation_sdpa_regions, + use_parametrized_rotations=args.optimize_rotations) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: # The weights between model and new_model are tied, so this check prevents # rotating the weights twice - if isinstance(r, ModuleInstanceWrapModule): - r.apply(model) + if not isinstance(r, ModuleInstanceTransformTensor): + model = r.apply(model) remove_hooks(new_model) @@ -142,7 +148,11 @@ def model_export(model, ref_input, args): export_torch_qcdq(model, ref_input['input_ids'], export_path=f"{args.export_prefix}.pt") -def validate(args): +def validate(args, extra_args: Optional[List[str]] = None): + if args.optimize_rotations: + assert args.rotation in ['fx', 'fused_no_fx'], f"Rotations can only be optimized if --rotation=fx or --rotation=fused_no_fx" + else: + assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith('--')]}" if args.functional_sdpa_quant: assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization" if args.rotation == 'fx': @@ -209,8 +219,8 @@ def validate(args): "or decreasing the sequence length (seqlen)") -def quantize_llm(args): - validate(args) +def quantize_llm(args, extra_args=None): + validate(args, extra_args) set_seed(args.seed) if args.export_prefix is None: args.export_prefix = f"{args.model.replace('/', '--')}" @@ -271,6 +281,22 @@ def quantize_llm(args): device=None, fuse_sequences=args.fuse_sequences) + if args.optimize_rotations: + # Extra arguments should be used as training arguments for rotation optimization + rot_optimization_args = parse_rotation_optimization_args(extra_args=extra_args) + # Load the data for rotation optimization + rot_calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples_rot_calibration, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=require_fx and args.export_target is not None, + device=None, + fuse_sequences=args.fuse_sequences) + device = next(iter(model.parameters())).device print("Data loaded.") @@ -329,7 +355,8 @@ def quantize_llm(args): eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, - sdpa_regions=args.rotation_sdpa_regions) + sdpa_regions=args.rotation_sdpa_regions, + use_parametrized_rotations=args.optimize_rotations) model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise': @@ -462,6 +489,20 @@ def quantize_llm(args): for k, v in dict_hooks.items(): k._hf_hook.post_forward = v + if args.optimize_rotations: + apply_rotation_optimization( + model=model, + tokenizer=tokenizer, + train_dataset=rot_calibration_loader, + training_args=rot_optimization_args, + ) + # Remove hooks from optimization + remove_hooks(model) + # Offload model before fusing the rotations + model = offload_model(model) + # Fuse rotations with weights + model = fuse_parametrized_rotations(model) + if args.act_calibration and not args.load_checkpoint: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -610,6 +651,11 @@ def parse_args(args, override_defaults={}): type=int, default=128, help='Number of calibration data samples. Default: 128.') + parser.add_argument( + '--nsamples-rot-calibration', + type=int, + default=800, + help='Number of calibration data samples for rotation. Default: %(default)d.') parser.add_argument('--seqlen', type=int, default=2048, help='Sequence length. Default: 2048.') parser.add_argument('--eval', action='store_true', help='Eval model PPL on the chosen Dataset.') parser.add_argument( @@ -816,6 +862,12 @@ def parse_args(args, override_defaults={}): default=None, choices=['fx', 'layerwise', 'fused_no_fx'], help='Apply graph rotation equalization') + parser.add_argument( + "--optimize-rotations", + action="store_true", + default=False, + help="Whether to optimize the rotations (default: %(default)s).", + ) parser.add_argument( '--rotation-mode', default='had', @@ -910,15 +962,28 @@ def parse_args(args, override_defaults={}): type=str, nargs='*', help='A list of tasks for zero_shot evaluation. Default: %(default)s') + if len(override_defaults) > 0: + # Retrieve keys that are known to the parser + parser_keys = set(map(lambda action: action.dest, parser._actions)) + # Extract the entries in override_defaults that correspond to keys not known to the parser + extra_args_keys = [key for key in override_defaults.keys() if key not in parser_keys] + # Remove all the keys in override_defaults that are unknown to the parser and, instead, + # include them in args, as if they were passed as arguments to the command line. + # This prevents the keys of HF TrainingArguments from being added as arguments to the parser. + # Consequently, they will be part of the second value returned by parse_known_args (thus being + # used as extra_args in quantize_llm) + for key in extra_args_keys: + args += [f"--{key}", str(override_defaults[key])] + del override_defaults[key] parser.set_defaults(**override_defaults) - return parser.parse_args(args) + return parser.parse_known_args(args) def main(): overrides = override_defaults(sys.argv[1:]) - args = parse_args(sys.argv[1:], override_defaults=overrides) - quantize_llm(args) + args, extra_args = parse_args(sys.argv[1:], override_defaults=overrides) + quantize_llm(args, extra_args) if __name__ == '__main__': diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 41edf0752..6989a35b4 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -20,22 +20,25 @@ from brevitas.graph.equalize import _apply_rotate from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions -from brevitas.graph.equalize import _fuse_rotations from brevitas.graph.equalize import _get_input_axis from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module from brevitas.graph.equalize import _supported_layers from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizationIndexes +from brevitas.graph.equalize import fuse_parametrized_rotations from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.equalize import random_orthogonal_matrix from brevitas.graph.equalize import Region from brevitas.graph.hadamard import get_hadK +from brevitas.graph.quantize import LAYERWISE_COMPUTE_LAYER_MAP +from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module from brevitas.nn.equalized_layer import RotatedModule +from brevitas.utils.python_utils import recurse_getattr from brevitas.utils.rotation_utils import RotationWeightParametrization from tests.marker import requires_pt_ge @@ -517,7 +520,7 @@ def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_r isinstance(module, RotatedModule) for module in rotated_model.modules()]) # Optionally fuse the rotations if fuse_rotations: - rotated_model_unfused = _fuse_rotations(rotated_model_unfused) + rotated_model_unfused = fuse_parametrized_rotations(rotated_model_unfused) # Verify that no parametrizations remain after fusing for module in rotated_model_unfused.modules(): assert not parametrize.is_parametrized(module) @@ -528,3 +531,57 @@ def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_r # Verify that the weights have changed with respect to the unrotated module for the modules that have received parametrizations # Verify that weights match between the fused and unfused model compare_model_weights(rotated_model_fused, rotated_model_unfused) + + +@requires_pt_ge('2.3.1') +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'sample_input': torch.tensor([[0.8, -0.6]]), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected': ""},]) +def test_fuse_parametrized_modules(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + model = kwargs["model"] + sample_input = kwargs["sample_input"] + module = recurse_getattr(model, key) + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + compute_layer_map = copy.deepcopy(LAYERWISE_COMPUTE_LAYER_MAP) + module = recurse_getattr(model, key) + type_quant_module = parametrize.type_before_parametrizations(module) + compute_layer_map[type_quant_module][1]["weight_quant"] = compute_layer_map[type_quant_module][ + 1]["weight_quant"].let(scaling_impl_type='parameter_from_stats') + qmodel = layerwise_quantize(model, compute_layer_map=compute_layer_map) + # Calibration pass to initialize scales + with torch.no_grad(): + output = qmodel(sample_input) + # Fuse parametrizations + qmodel = fuse_parametrized_rotations(qmodel) + # Verify that scales were not lost + module = recurse_getattr(model, key) + assert module.weight_quant.tensor_quant.scaling_impl.init_done + assert not torch.allclose( + module.weight_quant.tensor_quant.scaling_impl.value, + torch.ones_like(module.weight_quant.tensor_quant.scaling_impl.value)) + # Compute output after fusing and check that it matches + with torch.no_grad(): + output_fused = qmodel(sample_input) + assert torch.allclose(output, output_fused, rtol=0.0, atol=0.0) diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py index 45167b673..3a92a3e28 100644 --- a/tests/brevitas/graph/test_quantize.py +++ b/tests/brevitas/graph/test_quantize.py @@ -1,7 +1,18 @@ +import copy +import platform + import pytest_cases +import torch import torch.nn as nn +import torch.nn.utils.parametrize as parametrize +from brevitas.graph.base import _remove_parametrization_entries_state_dict from brevitas.graph.quantize import layerwise_quantize +from brevitas.graph.quantize import quantize +from brevitas.utils.python_utils import recurse_getattr +from brevitas.utils.rotation_utils import RotationWeightParametrization +from tests.marker import requires_pt_ge +from tests.marker import requires_pt_lt @pytest_cases.parametrize( @@ -42,3 +53,145 @@ def test_layerwise_quantize_blacklist(kwargs): assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" checked = True assert checked, f"Layer named {key} not found. Layer names are: {found_names}" + + +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected': ""},]) +def test_layerwise_quantize_parametrized_modules(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + del kwargs['key'] + del kwargs['expected'] + del kwargs['rot_mat'] + del kwargs['rot_func'] + + model = kwargs["model"] + module = recurse_getattr(model, key) + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + qmodel = layerwise_quantize(**kwargs) + checked = False + found_names = [] + for n, m in qmodel.named_modules(): + found_names.append(n) + if n == key: + mt = str(type(m)) + assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" + checked = True + assert checked, f"Layer named {key} not found. Layer names are: {found_names}" + + +@pytest_cases.parametrize( + 'kwargs', + [{ + 'model': nn.Sequential(nn.Linear(2, 3)), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected_state_dict_keys': ['0.weight', '0.bias'],}]) +def test_remove_parametrization_entries_state_dict(kwargs): + key = kwargs['key'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + expected_state_dict_keys = kwargs['expected_state_dict_keys'] + del kwargs['key'] + del kwargs['rot_mat'] + del kwargs['rot_func'] + del kwargs['expected_state_dict_keys'] + + model = kwargs["model"] + module = recurse_getattr(model, key) + old_state_dict = copy.deepcopy(model.state_dict()) + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + # Retrieve state dict after parametrization + state_dict = model.state_dict() + # Remove parametrization entries from state dict + state_dict = _remove_parametrization_entries_state_dict(state_dict) + # Verify that all the expected keys in expected_state_dict_keys + # are present in state_dict + assert len(set(expected_state_dict_keys) - set(state_dict.keys())) == 0 + # Verify that keys match + for key, value in state_dict.items(): + # Verify that key is in the expected keys + assert key in expected_state_dict_keys, f"Unexpected key {key} in state_dict" + # Compare tensor values + assert torch.allclose(value, old_state_dict[key], rtol=0.0, atol=0.0), f"Value of tensor {value} does not match with that in the original state_dict" + + +@requires_pt_ge('2.3.1') +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'sample_input': torch.tensor([[0.8, -0.6]]), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected': ""},]) +def test_quantize_parametrized_modules(kwargs): + if platform.system() == "Windows": + pytest.skip("Skipping dynamo + windows") + key = kwargs['key'] + exp = kwargs['expected'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + sample_input = kwargs['sample_input'] + model = kwargs["model"] + + graph_model, _ = torch._dynamo.export(model)(sample_input) + orig_module = recurse_getattr(model, key) + # Use tied weights to identify equivalent model + key, module = [(key, module) for key, module in graph_model.named_modules() if hasattr(module, "weight") and module.weight is orig_module.weight][0] + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + qmodel = quantize(graph_model) + checked = False + found_names = [] + for n, m in qmodel.named_modules(): + found_names.append(n) + if n == key: + mt = str(type(m)) + assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" + checked = True + assert checked, f"Layer named {key} not found. Layer names are: {found_names}" diff --git a/tests/brevitas_examples/llm_test_template.yml b/tests/brevitas_examples/llm_test_template.yml new file mode 100644 index 000000000..f37c99152 --- /dev/null +++ b/tests/brevitas_examples/llm_test_template.yml @@ -0,0 +1,78 @@ +act_calibration: false +act_equalization: null +act_equalization_alpha: 0.5 +bias_corr: false +checkpoint_name: null +convert_layernorm_to_rmsnorm: false +dataset: wikitext2 +eval: false +export_prefix: null +export_target: null +few_shot_compile: false +few_shot_eval: false +few_shot_limit: null +few_shot_tasks: +- arc_challenge +- arc_easy +- winogrande +- piqa +few_shot_zeroshot: false +functional_sdpa_quant: false +fuse_sequences: false +gpfq: false +gptq: false +gpxq_act_order: false +gpxq_block_name: null +gpxq_create_weight_orig: false +gpxq_max_accumulator_bit_width: null +gpxq_max_accumulator_tile_size: null +gpxq_use_quant_activations: false +input_bit_width: null +input_group_size: 64 +input_param_method: stats +input_quant_format: int +input_quant_granularity: per_tensor +input_quant_type: asym +input_scale_precision: float_scale +input_scale_type: static +learned_round: null +learned_round_fast_update: false +learned_round_iters: 200 +learned_round_lr: 0.005 +learned_round_scale: false +learned_round_scale_lr: 0.01 +learned_round_scale_momentum: 0.9 +ln_affine_merge: false +load_awq: null +load_checkpoint: false +model: facebook/opt-125m +no_float16: false +no_quantize: false +nsamples: 128 +optimize_rotations: false +quant_sdpa: false +quantize_input_zero_point: false +quantize_last_layer: false +quantize_weight_zero_point: false +replace_mha: false +replace_rmsnorm: false +rotation: null +rotation_mode: had +rotation_orphan_sink: false +scale_rounding_func_type: null +scaling_min_val: 0.0001 +seed: 0 +seqlen: 2048 +weight_bit_width: 8 +weight_equalization: false +weight_group_dim: null +weight_group_size: 128 +weight_param_method: stats +weight_quant_format: int +weight_quant_granularity: per_group +weight_quant_type: sym +weight_scale_precision: float_scale +# TrainingArguments for HF Trainer (see https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) +learning_rate: 1.5 +lr_scheduler_type: cosine +save_safetensors: false diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index f6db73924..5d9bb10fa 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -2,11 +2,13 @@ # SPDX-License-Identifier: BSD-3-Clause from argparse import Namespace +import copy from dataclasses import dataclass import logging import os import platform import shutil +from unittest.mock import patch import numpy as np import onnx @@ -18,26 +20,22 @@ from brevitas import config from brevitas import torch_version +from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args from brevitas_examples.llm.main import quantize_llm from tests.marker import jit_disabled_for_export from tests.marker import requires_pt_ge +ATOL_PPL = 2e+02 +RTOL_PPL = 1e-04 + def ptid2pathname(string): return string.replace("/", "-").replace(":", "-") -def allclose(x, y): - return np.allclose(x, y, rtol=1e-03, atol=1e+01, equal_nan=False) - - -def allveryclose(x, y): - return np.allclose(x, y, rtol=1e-04, atol=2e+02, equal_nan=False) - - -def allexact(x, y): - return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False) +def allclose(x, y, rtol=RTOL_PPL, atol=ATOL_PPL): + return np.allclose(x, y, rtol=rtol, atol=atol, equal_nan=False) def transformers_version_ge(required_version: str): @@ -47,14 +45,14 @@ def transformers_version_ge(required_version: str): # Check that all args in args are used def validate_args(args): a = vars(args) - da = vars(parse_args([])) + da = vars(parse_args([])[0]) for k in a.keys(): assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `quantize_llm`" -def validate_args_and_run_main(args): +def validate_args_and_run_main(args, extra_args=None): validate_args(args) - float_ppl, quant_ppl, model = quantize_llm(args) + float_ppl, quant_ppl, model = quantize_llm(args, extra_args=extra_args) return float_ppl, quant_ppl, model @@ -131,7 +129,7 @@ def small_models_with_ppl(request): @pytest_cases.fixture() def default_run_args(request): - args = UpdatableNamespace(**vars(parse_args([]))) + args = UpdatableNamespace(**vars(parse_args([])[0])) args.nsamples = 2 args.seqlen = 2 args.model = "hf-internal-testing/tiny-random-MistralForCausalLM" @@ -252,8 +250,8 @@ def test_small_models_acc(caplog, acc_args_and_acc): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -294,8 +292,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -738,8 +736,8 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -760,7 +758,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "ort", "float_ppl": 33238.8984375, - "quant_ppl": 33232.65234375}, + "quant_ppl": 33232.65234375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -771,7 +769,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "ort", "float_ppl": 33238.8984375, - "quant_ppl": 33420.65234375}, + "quant_ppl": 33420.65234375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -782,7 +780,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "had", "float_ppl": 33238.8984375, - "quant_ppl": 33290.48046875}, + "quant_ppl": 33290.48046875,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -793,7 +791,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "had", "float_ppl": 33238.8984375, - "quant_ppl": 33204.80859375}, + "quant_ppl": 33204.80859375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -802,7 +800,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "replace_rmsnorm": True, "rotation": "layerwise", "float_ppl": 33238.8984375, - "quant_ppl": 33446.734375},]) + "quant_ppl": 33446.734375,},]) def rotation_ppl_args_and_ppl(default_run_args, request): args = default_run_args run_dict = request.param @@ -823,5 +821,201 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@pytest_cases.fixture( + ids=[ + "llama_rotation_optimization_ort", + "llama_rotation_optimization_ort_no_orphan", + "llama_rotation_optimization_had", + "llama_rotation_optimization_had_no_orphan",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "optimize_rotations": True, + "rotation_orphan_sink": True, + "rotation_mode": "ort", + "nsamples_rot_calibration": 2, + "no_float16": True, + "extra_args": [ + "--learning_rate", + "1.5", + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33239.33984375, + "exp_layer_types_count": { + "": 4, + "": 1, + "": 1, + "": 14,}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "optimize_rotations": True, + "rotation_orphan_sink": False, + "rotation_mode": "ort", + "nsamples_rot_calibration": 2, + "no_float16": True, + "extra_args": [ + "--learning_rate", + "1.5", + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33423.0390625, + "exp_layer_types_count": { + "": 0, + "": 1, + "": 1, + "": 14,}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "optimize_rotations": True, + "rotation_orphan_sink": True, + "rotation_mode": "had", + "nsamples_rot_calibration": 2, + "no_float16": True, + "extra_args": [ + "--learning_rate", + "1.5", + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33286.98828125, + "exp_layer_types_count": { + "": 4, + "": 1, + "": 1, + "": 14,}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "optimize_rotations": True, + "rotation_orphan_sink": False, + "rotation_mode": "had", + "nsamples_rot_calibration": 2, + "no_float16": True, + "extra_args": [ + "--learning_rate", + "1.5", + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33175.3046875, + "exp_layer_types_count": { + "": 0, + "": 1, + "": 1, + "": 14,}},]) +def rotation_optimization_args_layer_count_and_ppl(default_run_args, request): + args = default_run_args + run_dict = copy.deepcopy(request.param) + extra_args = run_dict["extra_args"] + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + exp_layer_types_count = run_dict["exp_layer_types_count"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + del run_dict["extra_args"] + del run_dict["exp_layer_types_count"] + args.update(**run_dict) + yield args, extra_args, float_ppl, quant_ppl, exp_layer_types_count + + +@requires_pt_ge('2.4') +def test_small_models_rotation_optimization_ppl( + caplog, rotation_optimization_args_layer_count_and_ppl): + if platform.system() != "Linux": + pytest.skip("Skipping dynamo + windows/macos") + # Tolerances are stricter for this test, to ensure that it does not pass + # with non-optimized quantized perplexities + RTOL_ROT, ATOL_ROT = 1e-05, 2. + caplog.set_level(logging.INFO) + args, extra_args, exp_float_ppl, exp_quant_ppl, _ = rotation_optimization_args_layer_count_and_ppl + float_ppl, quant_ppl, _ = validate_args_and_run_main(args, extra_args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl, rtol=RTOL_ROT, atol=ATOL_ROT), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@requires_pt_ge('2.4') +def test_small_models_rotation_optimization_layer_count( + caplog, rotation_optimization_args_layer_count_and_ppl): + if platform.system() != "Linux": + pytest.skip("Skipping dynamo + windows/macos") + # Tolerances are stricter for this test, to ensure that it does not pass + # with non-optimized quantized perplexities + caplog.set_level(logging.INFO) + args, extra_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl + with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model): + _, _, model = validate_args_and_run_main(args, extra_args) + assert_layer_types_count(model, exp_layer_types_count) + + +@pytest_cases.parametrize( + "kwargs", + [ + { + "yaml_file_path": + "./tests/brevitas_examples/llm_test_template.yml", + "expected_extra_args": [ + "--learning_rate", + "1.5", + "--lr_scheduler_type", + "cosine", + "--save_safetensors", + "False"],},], + ids=lambda kwargs: kwargs["yaml_file_path"]) +def test_parse_yaml_trainer_arguments(caplog, kwargs): + caplog.set_level(logging.INFO) + yaml_file_path = kwargs["yaml_file_path"] + expected_extra_args = kwargs["expected_extra_args"] + extra_args_keys = [expected_extra_args[i][2:] for i in range(0, len(expected_extra_args), 2)] + + def quantize_llm_assert_args(args, extra_args=None): + for key in extra_args_keys: + assert key not in args, f"Key {key} should not be known by the parser" + assert extra_args == expected_extra_args, f"Expected extra arguments {expected_extra_args} but got {extra_args}" + + # Run the argument parsing logic of the LLM entrypoint + with patch("brevitas_examples.llm.main.quantize_llm", quantize_llm_assert_args): + with patch("brevitas_examples.llm.main.sys.argv", ["main.py", "--config", yaml_file_path]): + main()