diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index f6db73924..199a3d29c 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -47,14 +47,14 @@ def transformers_version_ge(required_version: str): # Check that all args in args are used def validate_args(args): a = vars(args) - da = vars(parse_args([])) + da = vars(parse_args([])[0]) for k in a.keys(): assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `quantize_llm`" -def validate_args_and_run_main(args): +def validate_args_and_run_main(args, unknown_args=None): validate_args(args) - float_ppl, quant_ppl, model = quantize_llm(args) + float_ppl, quant_ppl, model = quantize_llm(args, unknown_args=unknown_args) return float_ppl, quant_ppl, model @@ -131,7 +131,7 @@ def small_models_with_ppl(request): @pytest_cases.fixture() def default_run_args(request): - args = UpdatableNamespace(**vars(parse_args([]))) + args = UpdatableNamespace(**vars(parse_args([])[0])) args.nsamples = 2 args.seqlen = 2 args.model = "hf-internal-testing/tiny-random-MistralForCausalLM" @@ -156,6 +156,11 @@ def run_test_models_run_args(args, model_with_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) +@pytest.fixture(scope="session", autouse=True) +def set_env(): + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + # yapf: disable @pytest_cases.fixture( ids=[ @@ -825,3 +830,125 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): quant_ppl = quant_ppl.detach().cpu().numpy() assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@pytest_cases.fixture( + ids=[ + "llama_rotation_optimization_ort", + "llama_rotation_optimization_ort_no_orphan", + "llama_rotation_optimization_had", + "llama_rotation_optimization_had_no_orphan",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx_optimize", + "rotation_orphan_sink": True, + "rotation_mode": "ort", + "nsamples_rot_calibration": 2, + "no_float16": True, + "unknown_args": [ + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33232.65234375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx_optimize", + "rotation_orphan_sink": False, + "rotation_mode": "ort", + "nsamples_rot_calibration": 2, + "max_steps": 2, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "no_float16": True, + "unknown_args": [ + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33420.65234375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx_optimize", + "rotation_orphan_sink": True, + "rotation_mode": "had", + "nsamples_rot_calibration": 2, + "max_steps": 2, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "no_float16": True, + "unknown_args": [ + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33290.48046875}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx_optimize", + "rotation_orphan_sink": False, + "rotation_mode": "had", + "nsamples_rot_calibration": 2, + "max_steps": 2, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "no_float16": True, + "unknown_args": [ + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--gradient_accumulation_steps", + "1"], + "float_ppl": 33238.8984375, + "quant_ppl": 33204.80859375},]) +def rotation_optimization_args_and_ppl(default_run_args, request): + args = default_run_args + run_dict = request.param + unknown_args = run_dict["unknown_args"] + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + del run_dict["unknown_args"] + args.update(**run_dict) + yield args, unknown_args, float_ppl, quant_ppl + + +@requires_pt_ge('2.4') +def test_small_models_rotation_optimization_ppl(caplog, rotation_optimization_args_and_ppl): + if platform.system() == "Windows": + pytest.skip("Skipping dynamo + windows") + caplog.set_level(logging.INFO) + args, unknown_args, exp_float_ppl, exp_quant_ppl = rotation_optimization_args_and_ppl + float_ppl, quant_ppl, model = validate_args_and_run_main(args, unknown_args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"