From d41a8c6ca1de605b35dd6012a2073c18f6554378 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 4 Nov 2024 17:05:11 +0000 Subject: [PATCH 1/8] Fix LLM testd --- src/brevitas/nn/quant_mha.py | 3 ++- src/brevitas_examples/llm/llm_quant/mha_layers.py | 11 ++++++++--- .../llm/llm_quant/prepare_for_quantize.py | 7 ++++++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/brevitas/nn/quant_mha.py b/src/brevitas/nn/quant_mha.py index 6720fe280..0effdf68e 100644 --- a/src/brevitas/nn/quant_mha.py +++ b/src/brevitas/nn/quant_mha.py @@ -602,7 +602,8 @@ def forward( key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, - average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: + average_attn_weights: bool = True, + position_ids: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index cf694d4eb..eb8078fff 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -152,6 +152,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + position_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if key_value_states is None: key_value_states = hidden_states @@ -164,14 +165,18 @@ def forward( query_seq_length, batch_size = hidden_states.shape[:2] key_value_seq_length = key_value_states.shape[0] num_heads = self.num_heads - attention_mask = attention_mask_handler( - attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) + attention_mask = ( + attention_mask_handler( + attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) + if attention_mask else None) attn_output, attn_output_weights = self.mha( hidden_states, key_value_states, key_value_states, attn_mask=attention_mask, need_weights=output_attentions, - average_attn_weights=False) + average_attn_weights=False, + position_ids=position_ids, + ) past_key_value = None return attn_output, attn_output_weights, past_key_value diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index d22b2eff1..2a71546e4 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -2,11 +2,16 @@ import torch from transformers.models.opt.modeling_opt import OPTAttention +from transformers.models.opt.modeling_opt import OPTSdpaAttention from brevitas.graph import ModuleToModuleByClass from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention -QUANTIZABLE_MHA_MAP = {OPTAttention: (QuantizableOPTAttention, {'batch_first': True})} +QUANTIZABLE_MHA_MAP = { + OPTAttention: (QuantizableOPTAttention, { + 'batch_first': True}), + OPTSdpaAttention: (QuantizableOPTAttention, { + 'batch_first': True}),} def replace_mha_with_quantizable_layers(model, dtype): From 63a42462f9e57ce1ead105dd949ddfe7a0d37bbe Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 5 Nov 2024 15:42:13 +0000 Subject: [PATCH 2/8] Replace legacy values --- .../llm/llm_quant/mha_layers.py | 2 +- tests/brevitas_examples/test_llm.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index eb8078fff..fb4abb3e4 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -168,7 +168,7 @@ def forward( attention_mask = ( attention_mask_handler( attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) - if attention_mask else None) + if attention_mask is not None else None) attn_output, attn_output_weights = self.mha( hidden_states, key_value_states, diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 576af04b1..b1793d043 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -203,18 +203,18 @@ def test_small_models_toggle_run_args_pt_ge_2_4( "llama", "mistral",], params=[ - { - "model": "hf-internal-testing/tiny-random-MistralForCausalLM", - "act_equalization": "layerwise", - "gptq": True, - "float_ppl": 31274.05078125, - "quant_ppl": 33139.23046875}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_equalization": "fx", "bias_corr": True, - "float_ppl": 33239.5, - "quant_ppl": 33283.75390625},]) + "float_ppl": 33312.0, # 33239.5, + "quant_ppl": 33056.0}, # 33283.75390625}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "act_equalization": "layerwise", + "gptq": True, + "float_ppl": 31056.0, # 31274.05078125 + "quant_ppl": 33056.0},]) # 33139.23046875},]) def acc_args_and_acc(default_run_args, request): args = default_run_args run_dict = request.param From 593aa27b620268483fdb55f05a45817e554068c1 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 5 Nov 2024 15:42:47 +0000 Subject: [PATCH 3/8] Bump version --- requirements/requirements-llm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-llm.txt b/requirements/requirements-llm.txt index 9bc21d251..524f15f18 100644 --- a/requirements/requirements-llm.txt +++ b/requirements/requirements-llm.txt @@ -1,3 +1,3 @@ # optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main tqdm -transformers[sentencepiece]==4.45.2 +transformers[sentencepiece]>=4.46.0 From e48db06838490ff79775803ac2f7efd9103227dd Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 5 Nov 2024 16:52:27 +0000 Subject: [PATCH 4/8] Revert back change to signature --- src/brevitas/nn/quant_mha.py | 3 +-- src/brevitas_examples/llm/llm_quant/mha_layers.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/brevitas/nn/quant_mha.py b/src/brevitas/nn/quant_mha.py index 0effdf68e..6720fe280 100644 --- a/src/brevitas/nn/quant_mha.py +++ b/src/brevitas/nn/quant_mha.py @@ -602,8 +602,7 @@ def forward( key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, - average_attn_weights: bool = True, - position_ids: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index fb4abb3e4..67eeb8738 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -176,7 +176,6 @@ def forward( attn_mask=attention_mask, need_weights=output_attentions, average_attn_weights=False, - position_ids=position_ids, ) past_key_value = None return attn_output, attn_output_weights, past_key_value From 5e4149c5d4d961bff268ab033725c3e30dfa7a66 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 21 Nov 2024 17:24:39 +0000 Subject: [PATCH 5/8] Unpin version and make ground-truth values dependent on version --- requirements/requirements-llm.txt | 2 +- tests/brevitas_examples/test_llm.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/requirements/requirements-llm.txt b/requirements/requirements-llm.txt index 524f15f18..8935cb159 100644 --- a/requirements/requirements-llm.txt +++ b/requirements/requirements-llm.txt @@ -1,3 +1,3 @@ # optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main tqdm -transformers[sentencepiece]>=4.46.0 +transformers diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index b1793d043..29a153011 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -14,6 +14,7 @@ import pytest import pytest_cases import torch +import transformers from brevitas import config from brevitas import torch_version @@ -40,6 +41,10 @@ def allexact(x, y): return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False) +def transformers_version_ge(required_version: str): + return version.parse(required_version) >= version.parse(transformers.__version__) + + # Check that all args in args are used def validate_args(args): a = vars(args) @@ -207,14 +212,14 @@ def test_small_models_toggle_run_args_pt_ge_2_4( "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_equalization": "fx", "bias_corr": True, - "float_ppl": 33312.0, # 33239.5, - "quant_ppl": 33056.0}, # 33283.75390625}, + "float_ppl": 33312.0 if transformers_version_ge('4.46.0') else 33239.5, + "quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33283.75390625}, { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "act_equalization": "layerwise", "gptq": True, - "float_ppl": 31056.0, # 31274.05078125 - "quant_ppl": 33056.0},]) # 33139.23046875},]) + "float_ppl": 31056.0 if transformers_version_ge('4.46.0') else 31274.05078125, + "quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33139.23046875},]) def acc_args_and_acc(default_run_args, request): args = default_run_args run_dict = request.param From eca928929772ea8fa0a627c63af977cb485f062d Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 27 Nov 2024 15:03:30 +0000 Subject: [PATCH 6/8] Add sentenpiece in requirements --- requirements/requirements-llm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-llm.txt b/requirements/requirements-llm.txt index 8935cb159..bb8c823dc 100644 --- a/requirements/requirements-llm.txt +++ b/requirements/requirements-llm.txt @@ -1,3 +1,3 @@ # optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main tqdm -transformers +transformers[sentencepiece] From 35c709669cf59019203c5a7f89aca51f880a35af Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 27 Nov 2024 17:27:13 +0000 Subject: [PATCH 7/8] Tentative fp32 accuracy checks --- tests/brevitas_examples/test_llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 29a153011..af24bbd71 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -131,6 +131,7 @@ def default_run_args(request): args.weight_quant_granularity = "per_channel" # "per_tensor", "per_channel", "per_group". args.input_bit_width = 8 args.act_calibration = True + args.no_float16 = True return args @@ -219,7 +220,7 @@ def test_small_models_toggle_run_args_pt_ge_2_4( "act_equalization": "layerwise", "gptq": True, "float_ppl": 31056.0 if transformers_version_ge('4.46.0') else 31274.05078125, - "quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33139.23046875},]) + "quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 31278.166015625},]) def acc_args_and_acc(default_run_args, request): args = default_run_args run_dict = request.param From cf55c1281ffeef85d917e1dcc69801d8f26d9288 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 27 Nov 2024 17:31:42 +0000 Subject: [PATCH 8/8] restore ppl --- tests/brevitas_examples/test_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index af24bbd71..79b8536d4 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -220,7 +220,7 @@ def test_small_models_toggle_run_args_pt_ge_2_4( "act_equalization": "layerwise", "gptq": True, "float_ppl": 31056.0 if transformers_version_ge('4.46.0') else 31274.05078125, - "quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 31278.166015625},]) + "quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33139.23046875},]) def acc_args_and_acc(default_run_args, request): args = default_run_args run_dict = request.param