Skip to content

Commit

Permalink
Fix LLM tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 17, 2025
1 parent 0399f6d commit eca8bfd
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from argparse import Namespace
import copy
from dataclasses import dataclass
import logging
import os
Expand Down Expand Up @@ -851,7 +852,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"--gradient_accumulation_steps",
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33278.98828125,
"quant_ppl": 33239.33984375,
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 4,
"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>": 1,
Expand All @@ -878,7 +879,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"--gradient_accumulation_steps",
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33424.73046875,
"quant_ppl": 33423.0390625,
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 0,
"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>": 1,
Expand All @@ -905,7 +906,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"--gradient_accumulation_steps",
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33339.21875,
"quant_ppl": 33286.98828125,
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 4,
"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>": 1,
Expand All @@ -932,15 +933,15 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"--gradient_accumulation_steps",
"1"],
"float_ppl": 33238.8984375,
"quant_ppl": 33219.08984375,
"quant_ppl": 33175.3046875,
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 0,
"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedEmbedding'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedQuantLinear'>": 14,}},])
def rotation_optimization_args_layer_count_and_ppl(default_run_args, request):
args = default_run_args
run_dict = request.param
run_dict = copy.deepcopy(request.param)
extra_args = run_dict["extra_args"]
float_ppl = run_dict["float_ppl"]
quant_ppl = run_dict["quant_ppl"]
Expand Down

0 comments on commit eca8bfd

Please sign in to comment.