Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (brevitas_examples/llm): support for lighteval #1162

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import argparse
from contextlib import nullcontext
from copy import deepcopy
from datetime import timedelta
import functools
import pprint
import sys
from warnings import warn

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import numpy as np
from optimum.exporters.onnx import onnx_export_from_model
import torch
Expand Down Expand Up @@ -60,14 +60,10 @@


def filter_results(results, tasks):
# filter out what we actually want to track in azureml
# filter out what we actually want to track
eval_results = dict()
for task_name in tasks:
# first, log n_shots for each task
# for subtask, n_shots in results["n-shot"].items():
# name = f"{subtask}_n_shot"
# eval_results[name] = float(n_shots)
# then log all result metrics we have for this task
# log all result metrics we have for this task
for key, val in results["results"][task_name].items():
if not isinstance(val, str):
# for mmlu, we don't log results per subtask, but simply overall results
Expand Down Expand Up @@ -200,7 +196,9 @@ def validate(args):
assert args.export_target != 'onnx_qcdq', "Cannot export ONNX QCDQ with FX + MHA replacing"
else:
assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX"

if args.few_shot_eval == 'lighteval':
# expects a list
args.few_shot_tasks = ",".join(args.few_shot_tasks)
if not args.fuse_sequences:
# 350 is approximately the 99% percentile for the sequence length in WikiText2 (train partition, using AutoTokenizer)
if args.seqlen >= 350:
Expand Down Expand Up @@ -530,7 +528,9 @@ def quantize_llm(args):
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.few_shot_eval:
if args.few_shot_eval == 'lm_eval':
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
Expand All @@ -551,7 +551,54 @@ def quantize_llm(args):
verbosity="ERROR")
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
pprint.pprint(results)
elif args.few_shot_eval == 'lighteval':
from accelerate import Accelerator
from accelerate import InitProcessGroupKwargs
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.transformers.transformers_model import TransformersModelConfig
from lighteval.pipeline import ParallelismManager
from lighteval.pipeline import Pipeline
from lighteval.pipeline import PipelineParameters
from lighteval.utils.utils import EnvConfig

accelerator = Accelerator(
kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
evaluation_tracker = EvaluationTracker(
output_dir="./results",
save_details=True,
)
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.ACCELERATE,
env_config=EnvConfig(cache_dir="/scratch/hf_models/"),
# Remove the 2 parameters below once your configuration is tested
override_batch_size=0, # max_samples=10
)
model_config = TransformersModelConfig(
pretrained=args.model,
dtype=dtype,
use_chat_template=True,
model_parallel=True,
accelerator=accelerator,
compile=False)

with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model.forward = torch.compile(model.forward, fullgraph=True)
pipeline = Pipeline(
tasks=args.few_shot_tasks,
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model=model,
config=model_config)

pipeline.evaluate()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increase indentation? will likely want to put this under the above context managers

results = pipeline.get_results()
results = filter_results(results, list(results["results"].keys()))
pprint.pprint(results)
remove_hooks(model)

if args.checkpoint_name is not None and not args.load_checkpoint:
Expand Down Expand Up @@ -888,12 +935,14 @@ def parse_args(args, override_defaults={}):
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.add_argument(
'--few-shot-eval',
action="store_true",
help='Perform zero_shot evaluation with lm_eval. Default %(default)s)')
type=str,
default=None,
choices=['lm_eval', 'lighteval'],
help='Perform zero_shot evaluation with lm_eval or lighteval. Default %(default)s)')
parser.add_argument(
'--few-shot-compile',
action="store_true",
help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)')
help='Compile during zero_shot evaluation. Default %(default)s)')
parser.add_argument(
'--few-shot-zeroshot',
action="store_true",
Expand Down
Loading