Skip to content

Commit

Permalink
Enable passing HF arguments through YAML
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 20, 2025
1 parent a5d8188 commit c790070
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,16 @@ def parse_args(args, override_defaults={}):
type=str,
nargs='*',
help='A list of tasks for zero_shot evaluation. Default: %(default)s')
if override_defaults:
# Retrieve keys that are known to the parser
parser_keys = set(map(lambda action: action.dest, parser._actions))
# Extract the entries in override_defaults that correspond to keys not known to the parser
extra_args_keys = [key for key in override_defaults.keys() if key not in parser_keys]
# Remove those entries from override_defaults, to prevent new keys being added to the argument
# parser and add them to args, to mimic as if they were passed by command line
for key in extra_args_keys:
args += [f"--{key}", str(override_defaults[key])]
del override_defaults[key]
parser.set_defaults(**override_defaults)

return parser.parse_known_args(args)
Expand Down
76 changes: 76 additions & 0 deletions tests/brevitas_examples/llm_test_template.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
act_calibration: false
act_equalization: null
act_equalization_alpha: 0.5
bias_corr: false
checkpoint_name: null
convert_layernorm_to_rmsnorm: false
dataset: wikitext2
eval: false
export_prefix: null
export_target: null
few_shot_compile: false
few_shot_eval: false
few_shot_limit: null
few_shot_tasks:
- arc_challenge
- arc_easy
- winogrande
- piqa
few_shot_zeroshot: false
functional_sdpa_quant: false
fuse_sequences: false
gpfq: false
gptq: false
gpxq_act_order: false
gpxq_block_name: null
gpxq_create_weight_orig: false
gpxq_max_accumulator_bit_width: null
gpxq_max_accumulator_tile_size: null
gpxq_use_quant_activations: false
input_bit_width: null
input_group_size: 64
input_param_method: stats
input_quant_format: int
input_quant_granularity: per_tensor
input_quant_type: asym
input_scale_precision: float_scale
input_scale_type: static
learned_round: null
learned_round_fast_update: false
learned_round_iters: 200
learned_round_lr: 0.005
learned_round_scale: false
learned_round_scale_lr: 0.01
learned_round_scale_momentum: 0.9
ln_affine_merge: false
load_awq: null
load_checkpoint: false
model: facebook/opt-125m
no_float16: false
no_quantize: false
nsamples: 128
quant_sdpa: false
quantize_input_zero_point: false
quantize_last_layer: false
quantize_weight_zero_point: false
replace_mha: false
replace_rmsnorm: false
rotation: null
rotation_mode: had
rotation_orphan_sink: false
scale_rounding_func_type: null
scaling_min_val: 0.0001
seed: 0
seqlen: 2048
weight_bit_width: 8
weight_equalization: false
weight_group_dim: null
weight_group_size: 128
weight_param_method: stats
weight_quant_format: int
weight_quant_granularity: per_group
weight_quant_type: sym
weight_scale_precision: float_scale
learning_rate: 1.5
lr_scheduler_type: cosine
save_safetensors: false
32 changes: 32 additions & 0 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from brevitas import config
from brevitas import torch_version
from brevitas_examples.llm.main import main
from brevitas_examples.llm.main import parse_args
from brevitas_examples.llm.main import quantize_llm
from tests.marker import jit_disabled_for_export
Expand Down Expand Up @@ -983,3 +984,34 @@ def test_small_models_rotation_optimization_layer_count(
with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model):
_, _, model = validate_args_and_run_main(args, extra_args)
assert_layer_types_count(model, exp_layer_types_count)


@pytest_cases.parametrize(
"kwargs",
[
{
"yaml_file_path":
"./llm_test_template.yml",
"expected_extra_args": [
"--learning_rate",
"1.5",
"--lr_scheduler_type",
"cosine",
"--save_safetensors",
"False"],},],
ids=lambda kwargs: kwargs["yaml_file_path"])
def test_parse_yaml_trainer_arguments(caplog, kwargs):
caplog.set_level(logging.INFO)
yaml_file_path = kwargs["yaml_file_path"]
expected_extra_args = kwargs["expected_extra_args"]
extra_args_keys = [expected_extra_args[i][2:] for i in range(0, len(expected_extra_args), 2)]

def quantize_llm_assert_args(args, extra_args=None):
for key in extra_args_keys:
assert key not in args, f"Key {key} should not be known by the parser"
assert extra_args == expected_extra_args, f"Expected extra arguments {expected_extra_args} but got {extra_args}"

# Run the argument parsing logic of the LLM entrypoint
with patch("brevitas_examples.llm.main.quantize_llm", quantize_llm_assert_args):
with patch("brevitas_examples.llm.main.sys.argv", ["main.py", "--config", yaml_file_path]):
main()

0 comments on commit c790070

Please sign in to comment.