From 8546589a9f2000531b959465da544e5a4ea345c0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Jan 2025 13:40:43 +0100 Subject: [PATCH] Feat (brevitas_examples/llm): load from checkpoint (#1151) --- src/brevitas_examples/llm/README.md | 6 ++- .../llm/config/default_template.yml | 1 + src/brevitas_examples/llm/main.py | 37 +++++++++++++++---- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 0d6fb5f42..59e17084f 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -55,8 +55,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] [--export-prefix EXPORT_PREFIX] - [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] - [--learned-round {None,linear_round}] + [--checkpoint-name CHECKPOINT_NAME] [--load-checkpoint] + [--fuse-sequences] [--learned-round {None,linear_round}] [--learned-round-fast-update] [--few-shot-eval] [--few-shot-compile] [--few-shot-zeroshot] [--few-shot-limit FEW_SHOT_LIMIT] @@ -202,6 +202,8 @@ options: --checkpoint-name CHECKPOINT_NAME Filename to save checkpoint. If `None`, no checkpoint is saved (default: None) + --load-checkpoint Boolean flag to load_checkpoint, uses checkpoint_name. + Default False) --fuse-sequences Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to diff --git a/src/brevitas_examples/llm/config/default_template.yml b/src/brevitas_examples/llm/config/default_template.yml index ef20184ac..f686a7b36 100644 --- a/src/brevitas_examples/llm/config/default_template.yml +++ b/src/brevitas_examples/llm/config/default_template.yml @@ -43,6 +43,7 @@ learned_round_scale_lr: 0.01 learned_round_scale_momentum: 0.9 ln_affine_merge: false load_awq: null +load_checkpoint: false model: facebook/opt-125m no_float16: false no_quantize: false diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index ed2ebc2c8..cca2172ab 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -20,6 +20,7 @@ from brevitas.export import export_torch_qcdq from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph import load_quant_model_mode from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize @@ -327,8 +328,12 @@ def quantize_llm(args): if args.act_equalization is not None: offload_model(model) print(f"Apply act equalization (SmoothQuant) with alpha {args.act_equalization_alpha}") + if args.load_checkpoint: + loader = [calibration_loader[0]] + else: + loader = calibration_loader apply_act_equalization( - model, args.act_equalization, calibration_loader, alpha=args.act_equalization_alpha) + model, args.act_equalization, loader, alpha=args.act_equalization_alpha) print("Act equalization applied.") remove_hooks(model) @@ -423,18 +428,24 @@ def quantize_llm(args): for k, v in dict_hooks.items(): k._hf_hook.post_forward = v - if args.act_calibration: + if args.act_calibration and not args.load_checkpoint: print("Apply act calibration...") apply_calibration(model, calibration_loader) print("Act calibration applied.") if args.learned_round: print("Applying learned round...") + if args.load_checkpoint: + iters = 1 + loader = [calibration_loader[0]] + else: + iters = args.learned_round_iters + loader = calibration_loader remove_hooks(model) apply_learned_round( model, - calibration_loader, - iters=args.learned_round_iters, + loader, + iters=iters, block_name_attribute=args.gpxq_block_name, learn_scale=args.learned_round_scale, scale_optimizer_class='sgd', @@ -446,7 +457,13 @@ def quantize_llm(args): model = offload_model(model) - if args.gptq: + if args.load_checkpoint: + remove_hooks(model) + with load_quant_model_mode(model): + model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu')) + model = offload_model(model) + + if args.gptq and not args.load_checkpoint: print("Applying GPTQ...") apply_gptq( model, @@ -459,7 +476,7 @@ def quantize_llm(args): max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) print("GPTQ applied.") - if args.gpfq: + if args.gpfq and not args.load_checkpoint: print("Applying GPFQ...") apply_gpfq( model, @@ -470,7 +487,7 @@ def quantize_llm(args): max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) print("GPFQ applied.") - if args.bias_corr: + if args.bias_corr and not args.load_checkpoint: print("Applying bias correction...") apply_bias_correction(model, calibration_loader) print("Bias correction applied.") @@ -507,7 +524,7 @@ def quantize_llm(args): print(results) remove_hooks(model) - if args.checkpoint_name is not None: + if args.checkpoint_name is not None and not args.load_checkpoint: print(f"Saving checkpoint to {args.checkpoint_name}") torch.save(model.state_dict(), args.checkpoint_name) @@ -808,6 +825,10 @@ def parse_args(args, override_defaults={}): default=None, help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)" ) + parser.add_argument( + '--load-checkpoint', + action="store_true", + help='Boolean flag to load_checkpoint, uses checkpoint_name. Default %(default)s)') parser.add_argument( "--fuse-sequences", action="store_true",