Skip to content

Commit

Permalink
Fix tests for transformers 4.47.0
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Dec 6, 2024
1 parent 7a5f77d commit 87bf948
Showing 1 changed file with 97 additions and 14 deletions.
111 changes: 97 additions & 14 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from argparse import Namespace
from collections import defaultdict
from dataclasses import dataclass
import logging
import os
Expand Down Expand Up @@ -72,6 +73,19 @@ def assert_layer_types(model, exp_layer_types):
assert matched, f"Layer key: {key} not found in {layer_names}"


def assert_layer_types_count(model, exp_layer_types_count):
layer_types_count = {}
for name, layer in model.named_modules():
ltype = str(type(layer))
if ltype not in layer_types_count:
layer_types_count[ltype] = 0
layer_types_count[ltype] += 1

for name, count in exp_layer_types_count.items():
curr_count = 0 if name not in layer_types_count else layer_types_count[name]
assert count == curr_count, f"Expect {count} instances of layer type: {name}, found {curr_count}."


class UpdatableNamespace(Namespace):

def update(self, **kwargs):
Expand Down Expand Up @@ -277,6 +291,15 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"


def llama_module_type_count(num_hidden_layers: int):
return {
"<class 'torch.nn.modules.linear.Linear'>": 1 + num_hidden_layers *
(4 + 3), # lm_head + (k/q/v/o_proj + MLP)
"<class 'torch.nn.modules.normalization.RMSNorm'>":
1 + num_hidden_layers * 2, # input + post_attention
}


@pytest_cases.fixture(
ids=[
"mistral-int8",
Expand All @@ -299,7 +322,13 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 28,
}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
Expand All @@ -310,7 +339,13 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant":
"<class 'brevitas.proxy.runtime_quant.ActQuantProxyFromInjector'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 14,
}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
Expand All @@ -323,7 +358,12 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
"<class 'brevitas.core.quant.float.FloatQuant'>",},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
Expand All @@ -336,7 +376,12 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
"<class 'brevitas.core.quant.float.FloatQuant'>",},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
Expand All @@ -363,20 +408,40 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl":
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",}},
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",},
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28, # input_quant/weight_quant
"<class 'brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView'>":
14, # input_quant..input_view_impl/input_quant..scaling_impl.input_view_impl
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>":
28, # weight_quant..input_view_impl/weight_quant..scaling_impl.input_view_impl
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>": 5,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>",
"model.layers.0.self_attn.q_proj.layer":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",}},
"<class 'brevitas.nn.quant_linear.QuantLinear'>",},
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>":
15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
"<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>": 5,}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types": {
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"}},
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"},
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>": 15,
}}, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
Expand All @@ -390,7 +455,14 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'brevitas.nn.equalized_layer.RotatedModule'>"}},
"<class 'brevitas.nn.equalized_layer.RotatedModule'>",},
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>":
4, # Sinks: O proj + Down proj
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5,
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
Expand All @@ -404,28 +476,39 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'torch.nn.modules.linear.Linear'>"}},])
"<class 'torch.nn.modules.linear.Linear'>"},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5, # Input + Post attention
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},])
def layer_args(default_run_args, request):
args = default_run_args
layer_dict = request.param
exp_layer_types = layer_dict["exp_layer_types"]
exp_layer_types_count = layer_dict["exp_layer_types_count"]
del layer_dict["exp_layer_types"]
del layer_dict["exp_layer_types_count"]
args.update(**layer_dict)
yield args, exp_layer_types
yield args, exp_layer_types, exp_layer_types_count


@pytest.mark.llm
@requires_pt_ge('2.2')
def test_small_models_quant_layer(caplog, layer_args):
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args
args, exp_layer_types, exp_layer_types_count = layer_args
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)
# Naming of modules in the GraphModule generated by FX changes across transformers versions, e.g.
# (4.45.0)"L__self___model_layers_2_self_attn_k_proj" ->
# (4.46.0) 'L__self___model_layers_slice_None__2__None___0_self_attn_q_proj'
# Therefore, this check is skipped when rotation="fx".
if args.rotation != "fx":
assert_layer_types(model, exp_layer_types)
assert_layer_types_count(model, exp_layer_types_count)


@pytest_cases.fixture(
Expand Down

0 comments on commit 87bf948

Please sign in to comment.