From 04a7de8bec7e4c4d2dd30c4aa03d64e27aa177bf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 11 Sep 2024 19:03:35 +0100 Subject: [PATCH] Fix local loss tests + JIT --- noxfile.py | 2 +- tests/brevitas_examples/test_llm.py | 2 ++ tests/brevitas_examples/test_quantize_model.py | 3 +++ tests/marker.py | 10 ++++++++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 17a38789d..8aad90528 100644 --- a/noxfile.py +++ b/noxfile.py @@ -120,7 +120,7 @@ def tests_brevitas_cpu(session, pytorch, jit_status): @nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS) @nox.parametrize("jit_status", JIT_STATUSES, ids=JIT_IDS) def tests_brevitas_examples_cpu(session, pytorch, jit_status): - session.env['PYTORCH_JIT'] = '{}'.format(int(jit_status == 'jit_enabled')) + session.env['BREVITAS_JIT'] = '{}'.format(int(jit_status == 'jit_enabled')) install_pytorch(pytorch, session) install_torchvision(pytorch, session) # For CV eval scripts session.install('--upgrade', '.[test, tts, stt, vision]') diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index ee5db176f..90981df29 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -161,6 +161,8 @@ def run_test_models_run_args(args, model_with_ppl): def toggle_run_args(default_run_args, request): args = default_run_args args.update(**request.param) + if args.weight_param_method == 'hqo' and config.JIT_ENABLED: + pytest.skip("Local loss mode requires JIT to be disabled") yield args diff --git a/tests/brevitas_examples/test_quantize_model.py b/tests/brevitas_examples/test_quantize_model.py index 6a7184131..8ab34d5db 100644 --- a/tests/brevitas_examples/test_quantize_model.py +++ b/tests/brevitas_examples/test_quantize_model.py @@ -14,6 +14,8 @@ from brevitas.nn import QuantReLU from brevitas.quant_tensor import QuantTensor from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model +from tests.marker import jit_disabled_for_local_loss +from tests.marker import jit_disabled_for_mock # CONSTANTS IMAGE_DIM = 16 @@ -568,6 +570,7 @@ def test_layerwise_percentile_for_calibration(simple_model, act_quant_percentile @pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"]) +@jit_disabled_for_local_loss() def test_layerwise_param_method_mse(simple_model, quant_granularity): """ We test layerwise quantization, with the weight and activation quantization `mse` parameter diff --git a/tests/marker.py b/tests/marker.py index f11dc7a4a..d4ae6d325 100644 --- a/tests/marker.py +++ b/tests/marker.py @@ -50,6 +50,16 @@ def skip_wrapper(f): return skip_wrapper +def jit_disabled_for_local_loss(): + skip = config.JIT_ENABLED + + def skip_wrapper(f): + return pytest.mark.skipif( + skip, reason=f'Local loss functions (e.g., MSE) require JIT to be disabled')(f) + + return skip_wrapper + + def jit_disabled_for_dynamic_quant_act(): skip = config.JIT_ENABLED