diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 8c4fe1968..93e4a97b1 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -4,12 +4,12 @@ import argparse from contextlib import nullcontext from copy import deepcopy +from datetime import timedelta import functools +import pprint import sys from warnings import warn -from lm_eval import evaluator -from lm_eval.models.huggingface import HFLM import numpy as np from optimum.exporters.onnx import onnx_export_from_model import torch @@ -60,14 +60,10 @@ def filter_results(results, tasks): - # filter out what we actually want to track in azureml + # filter out what we actually want to track eval_results = dict() for task_name in tasks: - # first, log n_shots for each task - # for subtask, n_shots in results["n-shot"].items(): - # name = f"{subtask}_n_shot" - # eval_results[name] = float(n_shots) - # then log all result metrics we have for this task + # log all result metrics we have for this task for key, val in results["results"][task_name].items(): if not isinstance(val, str): # for mmlu, we don't log results per subtask, but simply overall results @@ -200,7 +196,9 @@ def validate(args): assert args.export_target != 'onnx_qcdq', "Cannot export ONNX QCDQ with FX + MHA replacing" else: assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX" - + if args.few_shot_eval == 'lighteval': + # expects a list + args.few_shot_tasks = ",".join(args.few_shot_tasks) if not args.fuse_sequences: # 350 is approximately the 99% percentile for the sequence length in WikiText2 (train partition, using AutoTokenizer) if args.seqlen >= 350: @@ -530,7 +528,9 @@ def quantize_llm(args): model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - if args.few_shot_eval: + if args.few_shot_eval == 'lm_eval': + from lm_eval import evaluator + from lm_eval.models.huggingface import HFLM with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0]) if args.few_shot_compile: @@ -551,7 +551,54 @@ def quantize_llm(args): verbosity="ERROR") results = filter_results(results, args.few_shot_tasks) print("Few shot eval results") - print(results) + pprint.pprint(results) + elif args.few_shot_eval == 'lighteval': + from accelerate import Accelerator + from accelerate import InitProcessGroupKwargs + from lighteval.logging.evaluation_tracker import EvaluationTracker + from lighteval.models.transformers.transformers_model import TransformersModelConfig + from lighteval.pipeline import ParallelismManager + from lighteval.pipeline import Pipeline + from lighteval.pipeline import PipelineParameters + from lighteval.utils.utils import EnvConfig + + accelerator = Accelerator( + kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) + evaluation_tracker = EvaluationTracker( + output_dir="./results", + save_details=True, + ) + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.ACCELERATE, + env_config=EnvConfig(cache_dir="/scratch/hf_models/"), + # Remove the 2 parameters below once your configuration is tested + override_batch_size=0, # max_samples=10 + ) + model_config = TransformersModelConfig( + pretrained=args.model, + dtype=dtype, + use_chat_template=True, + model_parallel=True, + accelerator=accelerator, + compile=False) + + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + if args.few_shot_compile: + remove_hooks(model) + model.cuda() + model.forward = torch.compile(model.forward, fullgraph=True) + pipeline = Pipeline( + tasks=args.few_shot_tasks, + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + config=model_config) + + pipeline.evaluate() + results = pipeline.get_results() + results = filter_results(results, list(results["results"].keys())) + pprint.pprint(results) remove_hooks(model) if args.checkpoint_name is not None and not args.load_checkpoint: @@ -888,12 +935,14 @@ def parse_args(args, override_defaults={}): help='Whether to use fast update with learned round. Prototype (default: %(default)s)') parser.add_argument( '--few-shot-eval', - action="store_true", - help='Perform zero_shot evaluation with lm_eval. Default %(default)s)') + type=str, + default=None, + choices=['lm_eval', 'lighteval'], + help='Perform zero_shot evaluation with lm_eval or lighteval. Default %(default)s)') parser.add_argument( '--few-shot-compile', action="store_true", - help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)') + help='Compile during zero_shot evaluation. Default %(default)s)') parser.add_argument( '--few-shot-zeroshot', action="store_true",