Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (examples/llm): Specify experiments via YAML files #1116

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points.

```bash
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}]
[--gpxq-block-name GPXQ_BLOCK_NAME]
usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval]
[--dataset {wikitext2,c4}] [--gpxq-block-name GPXQ_BLOCK_NAME]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--weight-param-method {stats,mse,hqo}]
[--weight-scale-precision {float_scale,po2_scale}]
Expand Down Expand Up @@ -61,6 +61,8 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]

options:
-h, --help show this help message and exit
--config CONFIG Specify alternative default commandline args (e.g.,
config/default_template.yml). Default: None.
--model MODEL HF model name. Default: facebook/opt-125m.
--seed SEED Seed for sampling the calibration data. Default: 0.
--nsamples NSAMPLES Number of calibration data samples. Default: 128.
Expand Down
62 changes: 62 additions & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
act_calibration: false
act_equalization: null
act_equalization_alpha: 0.5
bias_corr: false
checkpoint_name: null
convert_layernorm_to_rmsnorm: false
dataset: wikitext2
eval: false
export_prefix: null
export_target: null
fuse_sequences: false
gpfq: false
gptq: false
gpxq_act_order: false
gpxq_block_name: null
gpxq_create_weight_orig: false
gpxq_max_accumulator_bit_width: null
gpxq_max_accumulator_tile_size: null
gpxq_use_quant_activations: false
input_bit_width: null
input_group_size: 64
input_param_method: stats
input_quant_format: int
input_quant_granularity: per_tensor
input_quant_type: asym
input_scale_precision: float_scale
input_scale_type: static
learned_round: null
learned_round_fast_update: false
learned_round_iters: 200
learned_round_lr: 0.005
learned_round_scale: false
learned_round_scale_lr: 0.01
learned_round_scale_momentum: 0.9
ln_affine_merge: false
load_awq: null
model: facebook/opt-125m
no_float16: false
no_quantize: false
nsamples: 128
quant_sdpa: false
quantize_input_zero_point: false
quantize_last_layer: false
quantize_weight_zero_point: false
replace_mha: false
replace_rmsnorm: false
rotation: null
rotation_mode: had
rotation_orphan_sink: false
scale_rounding_func_type: null
scaling_min_val: 0.0001
seed: 0
seqlen: 2048
weight_bit_width: 8
weight_equalization: false
weight_group_dim: null
weight_group_size: 128
weight_param_method: stats
weight_quant_format: int
weight_quant_granularity: per_group
weight_quant_type: sym
weight_scale_precision: float_scale
10 changes: 10 additions & 0 deletions src/brevitas_examples/llm/config/gen_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import yaml

from brevitas_examples.llm.main import parse_args

if __name__ == "__main__":
default_args = parse_args([])
args_dict = default_args.__dict__
del args_dict["config"] # Config file cannot be specified via YAML
with open('default_template.yml', 'w') as f:
yaml.dump(args_dict, f)
32 changes: 30 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.utils.fx import _SUPPORTED_MODELS
import yaml

from brevitas.export import export_torch_qcdq
from brevitas.export.inference.manager import quant_inference_mode
Expand Down Expand Up @@ -477,8 +478,33 @@ def main(args):
return float_ppl, quant_ppl, model


def parse_args(args):
def override_defaults(args):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
'--config',
type=str,
default=None,
help=
'Specify alternative default commandline args (e.g., config/default_template.yml). Default: %(default)s.'
)
known_args = parser.parse_known_args()[0] # Returns a tuple
if known_args.config is not None:
with open(known_args.config, 'r') as f:
defaults = yaml.safe_load(f)
else:
defaults = {}
return defaults


def parse_args(args, override_defaults={}):
parser = argparse.ArgumentParser()
parser.add_argument(
'--config',
type=str,
default=None,
help=
'Specify alternative default commandline args (e.g., config/default_template.yml). Default: %(default)s.'
)
parser.add_argument(
'--model',
type=str,
Expand Down Expand Up @@ -757,9 +783,11 @@ def parse_args(args):
default=False,
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.set_defaults(**override_defaults)
return parser.parse_args(args)


if __name__ == '__main__':
args = parse_args(sys.argv[1:])
overrides = override_defaults(sys.argv[1:])
args = parse_args(sys.argv[1:], override_defaults=overrides)
main(args)
Loading