From 80d247aa9a8ef81c210e6c6022398f094ece3cee Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 16 Jan 2025 19:20:55 +0000 Subject: [PATCH] Address comments and new tests --- src/brevitas/graph/quantize_impl.py | 8 +- src/brevitas_examples/llm/main.py | 13 ++++ tests/brevitas/graph/test_quantize.py | 49 ++++++++++++ tests/brevitas_examples/test_llm.py | 104 +++++++++++++++++--------- 4 files changed, 136 insertions(+), 38 deletions(-) diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index c826481bf..d0a0f4be8 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -408,8 +408,8 @@ def act_handler(model, layer_map): if node.op == 'call_module': module = get_module(model, node.target) if isinstance(module, tuple(layer_map.keys())): - if layer_map[type(module)] is not None: - quant_module_class, quant_module_kwargs = layer_map[type(module)] + if layer_map[type_before_parametrizations(module)] is not None: + quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)] quant_module = quant_module_class(**quant_module_kwargs) # Check for activation equalization mul nodes if len(node.users) == 1: @@ -470,8 +470,8 @@ def layer_handler( quant_identity_map=quant_identity_map, quant_act_map=quant_act_map, unsigned_act_tuple=unsigned_act_tuple) - if layer_map[type(module)] is not None: - quant_module_class, quant_module_kwargs = layer_map[type(module)] + if layer_map[type_before_parametrizations(module)] is not None: + quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)] # Quantize the input if is not quantized, input_quant is not specified, # and the quant_identity_map is provided. if not are_inputs_quantized_and_aligned( diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 8a0aec5c9..3ada642c3 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -103,6 +103,19 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = full_rotation_method=args.rotation_mode, return_rewriters=True, sdpa_regions=args.rotation_sdpa_regions) + # NOTE: When fuse_rotations=False, parametrized rotations are applied, i.e. the weights of + # selected modules stop being attributes but, instead, properties, and their value is + # computed by passing the original value of the tensor through the forward passes of the + # parametrization modules. Parametrizations are registered using + # torch.nn.utils.parametrize.register_parametrization, which modifies the __class__ + # attribute of the parametrized module, e.g. "" + # changes to "". Therefore, + # algorithms that do type checking might need to use type_before_parametrizations(module), + # instead of only type(module) (see layerwise_layer_handler). Moreover, if, for instance, + # the "weight" attribute is parametrized, it will be removed from the attributes + # of the class. Consequently, quantization algorithms that rely on in-place modifications + # of the weights should not operate on parametrized modules. In this situation, parametrizations + # need to be removed beforehand by invoking fuse_parametrized_rotations new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py index ec15da6a2..62ce405c9 100644 --- a/tests/brevitas/graph/test_quantize.py +++ b/tests/brevitas/graph/test_quantize.py @@ -7,8 +7,10 @@ from brevitas.graph.base import _remove_parametrization_entries_state_dict from brevitas.graph.quantize import layerwise_quantize +from brevitas.graph.quantize import quantize from brevitas.utils.python_utils import recurse_getattr from brevitas.utils.rotation_utils import RotationWeightParametrization +from tests.marker import requires_pt_ge @pytest_cases.parametrize( @@ -142,3 +144,50 @@ def test_remove_parametrization_entries_state_dict(kwargs): assert key in expected_state_dict_keys, f"Unexpected key {key} in state_dict" # Compare tensor values assert torch.allclose(value, old_state_dict[key], rtol=0.0, atol=0.0), f"Value of tensor {value} does not match with that in the original state_dict" + + +@requires_pt_ge('2.3.1') +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'sample_input': torch.tensor([[0.8, -0.6]]), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected': ""},]) +def test_quantize_parametrized_modules(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + sample_input = kwargs['sample_input'] + model = kwargs["model"] + + graph_model, _ = torch._dynamo.export(model)(sample_input) + orig_module = recurse_getattr(model, key) + # Use tied weights to identify equivalent model + key, module = [(key, module) for key, module in graph_model.named_modules() if hasattr(module, "weight") and module.weight is orig_module.weight][0] + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + qmodel = quantize(graph_model) + checked = False + found_names = [] + for n, m in qmodel.named_modules(): + found_names.append(n) + if n == key: + mt = str(type(m)) + assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" + checked = True + assert checked, f"Layer named {key} not found. Layer names are: {found_names}" diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index d4af0861e..33d46a179 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -7,6 +7,7 @@ import os import platform import shutil +from unittest.mock import patch import numpy as np import onnx @@ -23,21 +24,16 @@ from tests.marker import jit_disabled_for_export from tests.marker import requires_pt_ge +ATOL_PPL = 2e+02 +RTOL_PPL = 1e-04 + def ptid2pathname(string): return string.replace("/", "-").replace(":", "-") -def allclose(x, y): - return np.allclose(x, y, rtol=1e-03, atol=1e+01, equal_nan=False) - - -def allveryclose(x, y): - return np.allclose(x, y, rtol=1e-04, atol=2e+02, equal_nan=False) - - -def allexact(x, y): - return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False) +def allclose(x, y, rtol=RTOL_PPL, atol=ATOL_PPL): + return np.allclose(x, y, rtol=rtol, atol=atol, equal_nan=False) def transformers_version_ge(required_version: str): @@ -252,8 +248,8 @@ def test_small_models_acc(caplog, acc_args_and_acc): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -294,8 +290,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -738,8 +734,8 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -760,7 +756,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "ort", "float_ppl": 33238.8984375, - "quant_ppl": 33232.65234375}, + "quant_ppl": 33232.65234375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -771,7 +767,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "ort", "float_ppl": 33238.8984375, - "quant_ppl": 33420.65234375}, + "quant_ppl": 33420.65234375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -782,7 +778,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "had", "float_ppl": 33238.8984375, - "quant_ppl": 33290.48046875}, + "quant_ppl": 33290.48046875,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -793,7 +789,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "had", "float_ppl": 33238.8984375, - "quant_ppl": 33204.80859375}, + "quant_ppl": 33204.80859375,}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -802,7 +798,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): "replace_rmsnorm": True, "rotation": "layerwise", "float_ppl": 33238.8984375, - "quant_ppl": 33446.734375},]) + "quant_ppl": 33446.734375,},]) def rotation_ppl_args_and_ppl(default_run_args, request): args = default_run_args run_dict = request.param @@ -823,8 +819,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" @pytest_cases.fixture( @@ -857,7 +853,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--save_strategy", "no"], "float_ppl": 33238.8984375, - "quant_ppl": 33232.65234375}, + "quant_ppl": 33278.98828125, + "exp_layer_types_count": { + "": 4, + "": 1, + "": 1, + "": 14,}}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -881,7 +882,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--save_strategy", "no"], "float_ppl": 33238.8984375, - "quant_ppl": 33420.65234375}, + "quant_ppl": 33424.73046875, + "exp_layer_types_count": { + "": 0, + "": 1, + "": 1, + "": 14,}}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -905,7 +911,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--save_strategy", "no"], "float_ppl": 33238.8984375, - "quant_ppl": 33290.48046875}, + "quant_ppl": 33339.21875, + "exp_layer_types_count": { + "": 4, + "": 1, + "": 1, + "": 14,}}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_calibration": False, @@ -929,28 +940,53 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "--save_strategy", "no"], "float_ppl": 33238.8984375, - "quant_ppl": 33204.80859375},]) -def rotation_optimization_args_and_ppl(default_run_args, request): + "quant_ppl": 33219.08984375, + "exp_layer_types_count": { + "": 0, + "": 1, + "": 1, + "": 14,}},]) +def rotation_optimization_args_layer_count_and_ppl(default_run_args, request): args = default_run_args run_dict = request.param unknown_args = run_dict["unknown_args"] float_ppl = run_dict["float_ppl"] quant_ppl = run_dict["quant_ppl"] + exp_layer_types_count = run_dict["exp_layer_types_count"] del run_dict["float_ppl"] del run_dict["quant_ppl"] del run_dict["unknown_args"] + del run_dict["exp_layer_types_count"] args.update(**run_dict) - yield args, unknown_args, float_ppl, quant_ppl + yield args, unknown_args, float_ppl, quant_ppl, exp_layer_types_count @requires_pt_ge('2.4') -def test_small_models_rotation_optimization_ppl(caplog, rotation_optimization_args_and_ppl): +def test_small_models_rotation_optimization_ppl( + caplog, rotation_optimization_args_layer_count_and_ppl): if platform.system() == "Windows": pytest.skip("Skipping dynamo + windows") + # Tolerances are stricter for this test, to ensure that it does not pass + # with non-optimized quantized perplexities + RTOL_ROT, ATOL_ROT = 1e-05, 2. caplog.set_level(logging.INFO) - args, unknown_args, exp_float_ppl, exp_quant_ppl = rotation_optimization_args_and_ppl - float_ppl, quant_ppl, model = validate_args_and_run_main(args, unknown_args) + args, unknown_args, exp_float_ppl, exp_quant_ppl, _ = rotation_optimization_args_layer_count_and_ppl + float_ppl, quant_ppl, _ = validate_args_and_run_main(args, unknown_args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allclose(exp_quant_ppl, quant_ppl, rtol=RTOL_ROT, atol=ATOL_ROT), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@requires_pt_ge('2.4') +def test_small_models_rotation_optimization_layer_count( + caplog, rotation_optimization_args_layer_count_and_ppl): + if platform.system() == "Windows": + pytest.skip("Skipping dynamo + windows") + # Tolerances are stricter for this test, to ensure that it does not pass + # with non-optimized quantized perplexities + caplog.set_level(logging.INFO) + args, unknown_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl + with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model): + _, _, model = validate_args_and_run_main(args, unknown_args) + assert_layer_types_count(model, exp_layer_types_count)