From deab1b8452144d67cb2a9edd59bc6c37137bad28 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 5 Feb 2025 18:29:51 +0000 Subject: [PATCH] Tests with unfused parametrizations --- src/brevitas/graph/equalize.py | 6 +- src/brevitas_examples/llm/main.py | 4 +- tests/brevitas/graph/equalization_fixtures.py | 7 ++- tests/brevitas/graph/test_equalization.py | 57 ++++++++++++++++++- tests/brevitas_examples/test_llm.py | 2 +- 5 files changed, 65 insertions(+), 11 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 8a0a9c057..d51f0e89e 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1526,14 +1526,14 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module): return model -def fuse_parametrized_rotations(model: nn.Module) -> nn.Module: +def fuse_parametrizations(model: nn.Module) -> nn.Module: # First of all, parameters that have parametrizations need to be untied model = _untie_parameters_with_parametrizations(model) # Then, parametrizations can be safely removed for module in model.modules(): if parametrize.is_parametrized(module): # Names of the tensors that can potentially be parametrized - tensor_names = ["weight", "bias"] + tensor_names = ["weight", "in_proj_weight", "bias"] # Remove parametrizations from each tensor for tensor_name in tensor_names: if parametrize.is_parametrized(module) and tensor_name in module.parametrizations: @@ -1634,7 +1634,7 @@ def __init__( # Therefore, algorithms that do type checking might need to use type_before_parametrizations(module), # instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications # of the weights should not operate on parametrized modules. In this situation, parametrizations - # need to be removed beforehand by invoking fuse_parametrized_rotations + # need to be removed beforehand by invoking fuse_parametrizations warnings.warn( "Using parametrized results might break type-checking, which could lead to unexpected behaviour." ) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index f96227930..ee51235c7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -24,7 +24,7 @@ from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph import load_quant_model_mode from brevitas.graph.base import ModuleInstanceTransformTensor -from brevitas.graph.equalize import fuse_parametrized_rotations +from brevitas.graph.equalize import fuse_parametrizations from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import functional_quantization_mode @@ -501,7 +501,7 @@ def quantize_llm(args, extra_args=None): # Offload model before fusing the rotations model = offload_model(model) # Fuse rotations with weights - model = fuse_parametrized_rotations(model) + model = fuse_parametrizations(model) if args.act_calibration and not args.load_checkpoint: print("Apply act calibration...") diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 371e65603..5a1d51935 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -38,15 +38,18 @@ IN_SIZE_CONV_SMALL = (1, 3, 32, 32) -def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type): +def equalize_test( + model, regions, merge_bias, bias_shrinkage, scale_computation_type, fuse_scaling=True): scale_factors_regions = [] for i in range(3): for region in regions: scale_factors_region, _ = _cross_layer_equalization( + model, region, merge_bias=merge_bias, bias_shrinkage=bias_shrinkage, - scale_computation_type=scale_computation_type) + scale_computation_type=scale_computation_type, + fuse_scaling=fuse_scaling) if i == 0: scale_factors_regions.append(scale_factors_region) return scale_factors_regions diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index f59bf48ce..d22e4f963 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -26,7 +26,7 @@ from brevitas.graph.equalize import _supported_layers from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizationIndexes -from brevitas.graph.equalize import fuse_parametrized_rotations +from brevitas.graph.equalize import fuse_parametrizations from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.equalize import random_orthogonal_matrix @@ -145,6 +145,57 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool assert all([shape != () for shape in shape_scale_regions]) +@requires_pt_ge('2.3.1') +@pytest_cases.parametrize("merge_bias", [True, False]) +def test_equalization_torchvision_models_unfused(model_coverage: tuple, merge_bias: bool): + model, _ = model_coverage + + torch.manual_seed(SEED) + model.eval() + model = symbolic_trace(model) + model = TorchFunctionalToModule().apply(model) + + supported_sinks = list(_supported_layers) + supported_sinks = tuple([ + x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)]) + regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks}) + # Instantiate model with unfused scales + model_unfused = copy.deepcopy(model) + # Copy regions and ensure that they point to the modules in the copied model + regions_unfused = copy.deepcopy(regions) + for r in regions_unfused: + for m_name in r.name_to_module: + r.name_to_module[m_name] = recurse_getattr(model_unfused, m_name) + # Equalize original model + scale_factor_regions = equalize_test( + model, + regions, + merge_bias=merge_bias, + bias_shrinkage='vaiq', + scale_computation_type='maxabs') + # Equalized copied model + scale_factor_regions_unfused = equalize_test( + model_unfused, + regions_unfused, + merge_bias=merge_bias, + bias_shrinkage='vaiq', + scale_computation_type='maxabs', + fuse_scaling=False) + # Ensure that scale factors match + for scale_factor, scale_factor_unfused in zip(scale_factor_regions, scale_factor_regions_unfused): + assert torch.allclose(scale_factor, scale_factor_unfused, atol=0.0, rtol=0.0) + # Ensure that parameters match + for name, param in model.named_parameters(): + param_unfused = recurse_getattr(model_unfused, name) + assert torch.allclose(param, param_unfused, atol=0.0, rtol=0.0) + # Fuse parametrizations and make sure that weights keep matching + model_unfused = fuse_parametrizations(model_unfused) + assert all([not parametrize.is_parametrized(m) for m in model_unfused.modules()]) + for name, param in model.named_parameters(): + param_unfused = recurse_getattr(model_unfused, name) + assert torch.allclose(param, param_unfused, atol=0.0, rtol=0.0) + + @pytest_cases.parametrize("merge_bias", [True, False]) def test_models(toy_model, merge_bias, request): test_id = request.node.callspec.id @@ -528,7 +579,7 @@ def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_r isinstance(module, RotatedModule) for module in rotated_model.modules()]) # Optionally fuse the rotations if fuse_rotations: - rotated_model_unfused = fuse_parametrized_rotations(rotated_model_unfused) + rotated_model_unfused = fuse_parametrizations(rotated_model_unfused) # Verify that no parametrizations remain after fusing for module in rotated_model_unfused.modules(): assert not parametrize.is_parametrized(module) @@ -582,7 +633,7 @@ def test_fuse_parametrized_modules(kwargs): with torch.no_grad(): output = qmodel(sample_input) # Fuse parametrizations - qmodel = fuse_parametrized_rotations(qmodel) + qmodel = fuse_parametrizations(qmodel) # Verify that scales were not lost module = recurse_getattr(model, key) assert module.weight_quant.tensor_quant.scaling_impl.init_done diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 5d9bb10fa..c9e3f36c1 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -985,7 +985,7 @@ def test_small_models_rotation_optimization_layer_count( # with non-optimized quantized perplexities caplog.set_level(logging.INFO) args, extra_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl - with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model): + with patch('brevitas_examples.llm.main.fuse_parametrizations', lambda model: model): _, _, model = validate_args_and_run_main(args, extra_args) assert_layer_types_count(model, exp_layer_types_count)