From 62b4284e5bebf076af26037b29033c5d35ca3285 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 16 Jan 2025 11:30:11 +0000 Subject: [PATCH 1/3] Feat (brevitas_examples/llm): support for lighteval --- src/brevitas_examples/llm/main.py | 60 +++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 8c4fe1968..170d2d735 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -4,12 +4,11 @@ import argparse from contextlib import nullcontext from copy import deepcopy +from datetime import timedelta import functools import sys from warnings import warn -from lm_eval import evaluator -from lm_eval.models.huggingface import HFLM import numpy as np from optimum.exporters.onnx import onnx_export_from_model import torch @@ -530,7 +529,9 @@ def quantize_llm(args): model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - if args.few_shot_eval: + if args.few_shot_eval == 'lm_eval': + from lm_eval import evaluator + from lm_eval.models.huggingface import HFLM with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0]) if args.few_shot_compile: @@ -552,6 +553,51 @@ def quantize_llm(args): results = filter_results(results, args.few_shot_tasks) print("Few shot eval results") print(results) + elif args.few_shot_eval == 'lighteval': + from accelerate import Accelerator + from accelerate import InitProcessGroupKwargs + from lighteval.logging.evaluation_tracker import EvaluationTracker + from lighteval.models.transformers.transformers_model import TransformersModelConfig + from lighteval.pipeline import ParallelismManager + from lighteval.pipeline import Pipeline + from lighteval.pipeline import PipelineParameters + from lighteval.utils.utils import EnvConfig + + accelerator = Accelerator( + kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) + evaluation_tracker = EvaluationTracker( + output_dir="./results", + save_details=True, + ) + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.ACCELERATE, + env_config=EnvConfig(cache_dir="/scratch/hf_models/"), + # Remove the 2 parameters below once your configuration is tested + override_batch_size=0, # max_samples=10 + ) + model_config = TransformersModelConfig( + pretrained=args.model, + dtype="float16", + use_chat_template=True, + model_parallel=True, + accelerator=accelerator, + compile=False) + + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + if args.few_shot_compile: + remove_hooks(model) + model.cuda() + model.forward = torch.compile(model.forward, fullgraph=True) + pipeline = Pipeline( + tasks=args.few_shot_tasks, + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + config=model_config) + + pipeline.evaluate() + pipeline.show_results() remove_hooks(model) if args.checkpoint_name is not None and not args.load_checkpoint: @@ -888,12 +934,14 @@ def parse_args(args, override_defaults={}): help='Whether to use fast update with learned round. Prototype (default: %(default)s)') parser.add_argument( '--few-shot-eval', - action="store_true", - help='Perform zero_shot evaluation with lm_eval. Default %(default)s)') + type=str, + default=None, + choices=['lm_eval', 'lighteval'], + help='Perform zero_shot evaluation with lm_eval or lighteval. Default %(default)s)') parser.add_argument( '--few-shot-compile', action="store_true", - help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)') + help='Compile during zero_shot evaluation. Default %(default)s)') parser.add_argument( '--few-shot-zeroshot', action="store_true", From 4642dcef4e4e28cb11b51c8f0b37f2cfcaeb7995 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 16 Jan 2025 11:33:09 +0000 Subject: [PATCH 2/3] Dtype --- src/brevitas_examples/llm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 170d2d735..371b0e566 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -577,7 +577,7 @@ def quantize_llm(args): ) model_config = TransformersModelConfig( pretrained=args.model, - dtype="float16", + dtype=args.dtype, use_chat_template=True, model_parallel=True, accelerator=accelerator, From c921db5830c2bfc86d1d118bfdfaca251a1c31a4 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Wed, 22 Jan 2025 19:18:15 +0000 Subject: [PATCH 3/3] Minor fixes for dtype and task specification --- src/brevitas_examples/llm/main.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 371b0e566..93e4a97b1 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -6,6 +6,7 @@ from copy import deepcopy from datetime import timedelta import functools +import pprint import sys from warnings import warn @@ -59,14 +60,10 @@ def filter_results(results, tasks): - # filter out what we actually want to track in azureml + # filter out what we actually want to track eval_results = dict() for task_name in tasks: - # first, log n_shots for each task - # for subtask, n_shots in results["n-shot"].items(): - # name = f"{subtask}_n_shot" - # eval_results[name] = float(n_shots) - # then log all result metrics we have for this task + # log all result metrics we have for this task for key, val in results["results"][task_name].items(): if not isinstance(val, str): # for mmlu, we don't log results per subtask, but simply overall results @@ -199,7 +196,9 @@ def validate(args): assert args.export_target != 'onnx_qcdq', "Cannot export ONNX QCDQ with FX + MHA replacing" else: assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX" - + if args.few_shot_eval == 'lighteval': + # expects a list + args.few_shot_tasks = ",".join(args.few_shot_tasks) if not args.fuse_sequences: # 350 is approximately the 99% percentile for the sequence length in WikiText2 (train partition, using AutoTokenizer) if args.seqlen >= 350: @@ -552,7 +551,7 @@ def quantize_llm(args): verbosity="ERROR") results = filter_results(results, args.few_shot_tasks) print("Few shot eval results") - print(results) + pprint.pprint(results) elif args.few_shot_eval == 'lighteval': from accelerate import Accelerator from accelerate import InitProcessGroupKwargs @@ -577,7 +576,7 @@ def quantize_llm(args): ) model_config = TransformersModelConfig( pretrained=args.model, - dtype=args.dtype, + dtype=dtype, use_chat_template=True, model_parallel=True, accelerator=accelerator, @@ -597,7 +596,9 @@ def quantize_llm(args): config=model_config) pipeline.evaluate() - pipeline.show_results() + results = pipeline.get_results() + results = filter_results(results, list(results["results"].keys())) + pprint.pprint(results) remove_hooks(model) if args.checkpoint_name is not None and not args.load_checkpoint: