diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 60dd33ac2..ac9bd6ec2 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from argparse import Namespace +from collections import defaultdict from dataclasses import dataclass import logging import os @@ -72,6 +73,19 @@ def assert_layer_types(model, exp_layer_types): assert matched, f"Layer key: {key} not found in {layer_names}" +def assert_layer_types_count(model, exp_layer_types_count): + layer_types_count = {} + for name, layer in model.named_modules(): + ltype = str(type(layer)) + if ltype not in layer_types_count: + layer_types_count[ltype] = 0 + layer_types_count[ltype] += 1 + + for name, count in exp_layer_types_count.items(): + curr_count = 0 if name not in layer_types_count else layer_types_count[name] + assert count == curr_count, f"Expect {count} instances of layer type: {name}, found {curr_count}." + + class UpdatableNamespace(Namespace): def update(self, **kwargs): @@ -277,6 +291,15 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" +def llama_module_type_count(num_hidden_layers: int): + return { + "": 1 + num_hidden_layers * + (4 + 3), # lm_head + (k/q/v/o_proj + MLP) + "": + 1 + num_hidden_layers * 2, # input + post_attention + } + + @pytest_cases.fixture( ids=[ "mistral-int8", @@ -299,7 +322,13 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "",}}, + "",}, + "exp_layer_types_count": { + "": 1, # LM Head + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 28, + }}, # input_quant/weight_quant { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "input_bit_width": None, @@ -310,7 +339,13 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "model.layers.0.self_attn.q_proj.input_quant": "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "",}}, + "",}, + "exp_layer_types_count": { + "": 1, # LM Head + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 14, + }}, # input_quant/weight_quant { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "weight_quant_format": "float_ocp_e4m3", @@ -323,7 +358,12 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "",}}, + "",}, + "exp_layer_types_count": { + "": 1, # LM Head + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 28,}}, # input_quant/weight_quant { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "weight_quant_format": "float_fnuz_e4m3", @@ -336,7 +376,12 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "",}}, + "",}, + "exp_layer_types_count": { + "": 1, # LM Head + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 28,}}, # input_quant/weight_quant { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "weight_quant_format": "float_ocp_e4m3", @@ -363,7 +408,17 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl": - "",}}, + "",}, + "exp_layer_types_count": { + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 28, # input_quant/weight_quant + "": + 14, # input_quant..input_view_impl/input_quant..scaling_impl.input_view_impl + "": + 28, # weight_quant..input_view_impl/weight_quant..scaling_impl.input_view_impl + "": 1, # LM Head + "": 5,}}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_equalization": "layerwise", @@ -371,12 +426,22 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "model.layers.0.self_attn.q_proj": "", "model.layers.0.self_attn.q_proj.layer": - "",}}, + "",}, + "exp_layer_types_count": { + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 1, # LM Head + "": + 15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs + "": 5,}}, { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "quantize_last_layer": True, "exp_layer_types": { - "lm_head": ""}}, + "lm_head": ""}, + "exp_layer_types_count": { + "": 15, + }}, # LM Head + Q/K/V/O projs + Up/Gate/Down projs { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "ln_affine_merge": True, @@ -390,7 +455,14 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "L__self___model_layers_0_self_attn_k_proj": "", "L__self___model_layers_0_self_attn_o_proj": - ""}}, + "",}, + "exp_layer_types_count": { + "": + 4, # Sinks: O proj + Down proj + "": + 15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs + "": 5, + "": 0,}}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "ln_affine_merge": True, @@ -404,28 +476,39 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "L__self___model_layers_0_self_attn_k_proj": "", "L__self___model_layers_0_self_attn_o_proj": - ""}},]) + ""}, + "exp_layer_types_count": { + "": + 15, # LM Head + Q/K/V projs + Up/Gate/Down projs + "": 5, # Input + Post attention + "": 0,}},]) def layer_args(default_run_args, request): args = default_run_args layer_dict = request.param exp_layer_types = layer_dict["exp_layer_types"] + exp_layer_types_count = layer_dict["exp_layer_types_count"] del layer_dict["exp_layer_types"] + del layer_dict["exp_layer_types_count"] args.update(**layer_dict) - yield args, exp_layer_types + yield args, exp_layer_types, exp_layer_types_count -@pytest.mark.llm -@requires_pt_ge('2.2') def test_small_models_quant_layer(caplog, layer_args): caplog.set_level(logging.INFO) - args, exp_layer_types = layer_args + args, exp_layer_types, exp_layer_types_count = layer_args if args.replace_rmsnorm: if torch_version < version.parse('2.4'): pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater") if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows': pytest.skip("Skipping dynamo + windows") float_ppl, quant_ppl, model = validate_args_and_run_main(args) - assert_layer_types(model, exp_layer_types) + # Naming of modules in the GraphModule generated by FX changes across transformers versions, e.g. + # (4.45.0)"L__self___model_layers_2_self_attn_k_proj" -> + # (4.46.0) 'L__self___model_layers_slice_None__2__None___0_self_attn_q_proj' + # Therefore, this check is skipped when rotation="fx". + if args.rotation != "fx": + assert_layer_types(model, exp_layer_types) + assert_layer_types_count(model, exp_layer_types_count) @pytest_cases.fixture(