From 1b9023428efc8704fbe485af56f9e11d3dea3f9c Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 21 Jan 2025 10:00:01 +0000 Subject: [PATCH] Add new flag to entrypoint --- src/brevitas/graph/equalize.py | 13 ++++++++----- src/brevitas_examples/llm/README.md | 2 +- src/brevitas_examples/llm/main.py | 25 ++++++++++++++++--------- tests/brevitas_examples/test_llm.py | 12 ++++++++---- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 1519d7ba5..8747e05d0 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1531,7 +1531,7 @@ def __init__( orphan_sink: bool = False, sdpa_regions: bool = False, rotate_matmul: bool = False, - fuse_rotations: bool = True, + use_parametrized_rotations: bool = False, full_rotation_method: str = 'had', return_rewriters: bool = False) -> None: super(GraphRotationEqualization, self).__init__() @@ -1549,8 +1549,8 @@ def __init__( self.full_rotation_method = full_rotation_method self.return_rewriters = return_rewriters self.sdpa_regions = sdpa_regions - if not fuse_rotations: - # NOTE: When fuse_rotations=False, parametrized rotations are applied. This changes the attribute __class__ + if use_parametrized_rotations: + # NOTE: When use_parametrized_rotations=False, parametrized rotations are applied. This changes the attribute __class__ # of the parametrized module, e.g. to"". # Therefore, algorithms that do type checking might need to use type_before_parametrizations(module), # instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications @@ -1559,7 +1559,7 @@ def __init__( warnings.warn( "Using parametrized results might break type-checking, which could lead to unexpected behaviour." ) - self.fuse_rotations = fuse_rotations + self.use_parametrized_rotations = use_parametrized_rotations def rotate_matmuls(self, graph_module): matmul_nodes = list(graph_module.graph.nodes) @@ -1650,7 +1650,10 @@ def apply(self, self.rotate_matmuls(graph_model) if len(regions) > 0: rewriters = _apply_rotate( - graph_model, regions, self.full_rotation_method, fuse_rotations=self.fuse_rotations) + graph_model, + regions, + self.full_rotation_method, + fuse_rotations=not self.use_parametrized_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 7e2cc335f..e38218b55 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -12,7 +12,7 @@ Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. -When using `--rotation fused_no_fx_optimize`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`. +When using `--optimize-rotations`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`. ```bash usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bdeb74180..bd912ea2f 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -85,7 +85,7 @@ def set_seed(seed): torch.random.manual_seed(seed) -def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = True): +def fused_rotation_no_fx(model, calibration_loader, args): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)): @@ -105,7 +105,7 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = full_rotation_method=args.rotation_mode, return_rewriters=True, sdpa_regions=args.rotation_sdpa_regions, - fuse_rotations=fuse_rotations) + use_parametrized_rotations=args.optimize_rotations) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: @@ -149,7 +149,9 @@ def model_export(model, ref_input, args): def validate(args, extra_args: Optional[List[str]] = None): - if args.rotation != "fused_no_fx_optimize": + if args.optimize_rotations: + assert args.rotation in ['fx', 'fused_no_fx'], f"Rotations can only be optimized if --rotation=fx or --rotation=fused_no_fx" + else: assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith('--')]}" if args.functional_sdpa_quant: assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization" @@ -279,7 +281,7 @@ def quantize_llm(args, extra_args=None): device=None, fuse_sequences=args.fuse_sequences) - if args.rotation in ["fused_no_fx_optimize"]: + if args.optimize_rotations: # Extra arguments should be used as training arguments for rotation optimization rot_optimization_args = parse_rotation_optimization_args(extra_args=extra_args) # Load the data for rotation optimization @@ -353,7 +355,8 @@ def quantize_llm(args, extra_args=None): eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, - sdpa_regions=args.rotation_sdpa_regions) + sdpa_regions=args.rotation_sdpa_regions, + use_parametrized_rotations=args.optimize_rotations) model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise': @@ -361,8 +364,6 @@ def quantize_llm(args, extra_args=None): model = eq.apply(model) elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) - elif args.rotation == 'fused_no_fx_optimize': - fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations=False) if args.weight_equalization: print("Apply weight equalization...") @@ -484,7 +485,7 @@ def quantize_llm(args, extra_args=None): for k, v in dict_hooks.items(): k._hf_hook.post_forward = v - if args.rotation in ['fused_no_fx_optimize']: + if args.optimize_rotations: apply_rotation_optimization( model=model, tokenizer=tokenizer, @@ -855,8 +856,14 @@ def parse_args(args, override_defaults={}): '--rotation', type=str, default=None, - choices=['fx', 'layerwise', 'fused_no_fx', 'fused_no_fx_optimize'], + choices=['fx', 'layerwise', 'fused_no_fx'], help='Apply graph rotation equalization') + parser.add_argument( + "--optimize-rotations", + action="store_true", + default=False, + help="Whether to optimize the rotations (default: %(default)s).", + ) parser.add_argument( '--rotation-mode', default='had', diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 7b17436a6..5c2b60d02 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -838,7 +838,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "weight_bit_width": 4, "input_bit_width": None, "replace_rmsnorm": True, - "rotation": "fused_no_fx_optimize", + "rotation": "fused_no_fx", + "optimize_rotations": True, "rotation_orphan_sink": True, "rotation_mode": "ort", "nsamples_rot_calibration": 2, @@ -865,7 +866,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "weight_bit_width": 4, "input_bit_width": None, "replace_rmsnorm": True, - "rotation": "fused_no_fx_optimize", + "rotation": "fused_no_fx", + "optimize_rotations": True, "rotation_orphan_sink": False, "rotation_mode": "ort", "nsamples_rot_calibration": 2, @@ -892,7 +894,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "weight_bit_width": 4, "input_bit_width": None, "replace_rmsnorm": True, - "rotation": "fused_no_fx_optimize", + "rotation": "fused_no_fx", + "optimize_rotations": True, "rotation_orphan_sink": True, "rotation_mode": "had", "nsamples_rot_calibration": 2, @@ -919,7 +922,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "weight_bit_width": 4, "input_bit_width": None, "replace_rmsnorm": True, - "rotation": "fused_no_fx_optimize", + "rotation": "fused_no_fx", + "optimize_rotations": True, "rotation_orphan_sink": False, "rotation_mode": "had", "nsamples_rot_calibration": 2,