From c7900701e3531cbcf85d7aee32ba5f3a6ba83f5b Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 20 Jan 2025 11:45:06 +0000 Subject: [PATCH] Enable passing HF arguments through YAML --- src/brevitas_examples/llm/main.py | 10 +++ tests/brevitas_examples/llm_test_template.yml | 76 +++++++++++++++++++ tests/brevitas_examples/test_llm.py | 32 ++++++++ 3 files changed, 118 insertions(+) create mode 100644 tests/brevitas_examples/llm_test_template.yml diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b3e5c91e4..0958fd6cf 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -967,6 +967,16 @@ def parse_args(args, override_defaults={}): type=str, nargs='*', help='A list of tasks for zero_shot evaluation. Default: %(default)s') + if override_defaults: + # Retrieve keys that are known to the parser + parser_keys = set(map(lambda action: action.dest, parser._actions)) + # Extract the entries in override_defaults that correspond to keys not known to the parser + extra_args_keys = [key for key in override_defaults.keys() if key not in parser_keys] + # Remove those entries from override_defaults, to prevent new keys being added to the argument + # parser and add them to args, to mimic as if they were passed by command line + for key in extra_args_keys: + args += [f"--{key}", str(override_defaults[key])] + del override_defaults[key] parser.set_defaults(**override_defaults) return parser.parse_known_args(args) diff --git a/tests/brevitas_examples/llm_test_template.yml b/tests/brevitas_examples/llm_test_template.yml new file mode 100644 index 000000000..a28ae6a09 --- /dev/null +++ b/tests/brevitas_examples/llm_test_template.yml @@ -0,0 +1,76 @@ +act_calibration: false +act_equalization: null +act_equalization_alpha: 0.5 +bias_corr: false +checkpoint_name: null +convert_layernorm_to_rmsnorm: false +dataset: wikitext2 +eval: false +export_prefix: null +export_target: null +few_shot_compile: false +few_shot_eval: false +few_shot_limit: null +few_shot_tasks: +- arc_challenge +- arc_easy +- winogrande +- piqa +few_shot_zeroshot: false +functional_sdpa_quant: false +fuse_sequences: false +gpfq: false +gptq: false +gpxq_act_order: false +gpxq_block_name: null +gpxq_create_weight_orig: false +gpxq_max_accumulator_bit_width: null +gpxq_max_accumulator_tile_size: null +gpxq_use_quant_activations: false +input_bit_width: null +input_group_size: 64 +input_param_method: stats +input_quant_format: int +input_quant_granularity: per_tensor +input_quant_type: asym +input_scale_precision: float_scale +input_scale_type: static +learned_round: null +learned_round_fast_update: false +learned_round_iters: 200 +learned_round_lr: 0.005 +learned_round_scale: false +learned_round_scale_lr: 0.01 +learned_round_scale_momentum: 0.9 +ln_affine_merge: false +load_awq: null +load_checkpoint: false +model: facebook/opt-125m +no_float16: false +no_quantize: false +nsamples: 128 +quant_sdpa: false +quantize_input_zero_point: false +quantize_last_layer: false +quantize_weight_zero_point: false +replace_mha: false +replace_rmsnorm: false +rotation: null +rotation_mode: had +rotation_orphan_sink: false +scale_rounding_func_type: null +scaling_min_val: 0.0001 +seed: 0 +seqlen: 2048 +weight_bit_width: 8 +weight_equalization: false +weight_group_dim: null +weight_group_size: 128 +weight_param_method: stats +weight_quant_format: int +weight_quant_granularity: per_group +weight_quant_type: sym +weight_scale_precision: float_scale +learning_rate: 1.5 +lr_scheduler_type: cosine +save_safetensors: false diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 02fbfc349..5e489eac1 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -20,6 +20,7 @@ from brevitas import config from brevitas import torch_version +from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args from brevitas_examples.llm.main import quantize_llm from tests.marker import jit_disabled_for_export @@ -983,3 +984,34 @@ def test_small_models_rotation_optimization_layer_count( with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model): _, _, model = validate_args_and_run_main(args, extra_args) assert_layer_types_count(model, exp_layer_types_count) + + +@pytest_cases.parametrize( + "kwargs", + [ + { + "yaml_file_path": + "./llm_test_template.yml", + "expected_extra_args": [ + "--learning_rate", + "1.5", + "--lr_scheduler_type", + "cosine", + "--save_safetensors", + "False"],},], + ids=lambda kwargs: kwargs["yaml_file_path"]) +def test_parse_yaml_trainer_arguments(caplog, kwargs): + caplog.set_level(logging.INFO) + yaml_file_path = kwargs["yaml_file_path"] + expected_extra_args = kwargs["expected_extra_args"] + extra_args_keys = [expected_extra_args[i][2:] for i in range(0, len(expected_extra_args), 2)] + + def quantize_llm_assert_args(args, extra_args=None): + for key in extra_args_keys: + assert key not in args, f"Key {key} should not be known by the parser" + assert extra_args == expected_extra_args, f"Expected extra arguments {expected_extra_args} but got {extra_args}" + + # Run the argument parsing logic of the LLM entrypoint + with patch("brevitas_examples.llm.main.quantize_llm", quantize_llm_assert_args): + with patch("brevitas_examples.llm.main.sys.argv", ["main.py", "--config", yaml_file_path]): + main()