From e2be595b274a0381c48e6a228a48d5b34380cf71 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Jan 2025 21:09:46 +0000 Subject: [PATCH] Fix post rebase --- src/brevitas_examples/llm/main.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index da553200a..365cd8e1b 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -515,6 +515,14 @@ def quantize_llm(args): print(f"Saving checkpoint to {args.checkpoint_name}") torch.save(model.state_dict(), args.checkpoint_name) + if args.eval and not args.no_quantize: + print("Model eval...") + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") + if args.few_shot_eval: with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0])