Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (brevitas_examples/llm): support for lighteval #1162

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import argparse
from contextlib import nullcontext
from copy import deepcopy
from datetime import timedelta
import functools
import sys
from warnings import warn

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import numpy as np
from optimum.exporters.onnx import onnx_export_from_model
import torch
Expand Down Expand Up @@ -530,7 +529,9 @@ def quantize_llm(args):
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.few_shot_eval:
if args.few_shot_eval == 'lm_eval':
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
Expand All @@ -552,6 +553,51 @@ def quantize_llm(args):
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
elif args.few_shot_eval == 'lighteval':
from accelerate import Accelerator
from accelerate import InitProcessGroupKwargs
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.transformers.transformers_model import TransformersModelConfig
from lighteval.pipeline import ParallelismManager
from lighteval.pipeline import Pipeline
from lighteval.pipeline import PipelineParameters
from lighteval.utils.utils import EnvConfig

accelerator = Accelerator(
kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
evaluation_tracker = EvaluationTracker(
output_dir="./results",
save_details=True,
)
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.ACCELERATE,
env_config=EnvConfig(cache_dir="/scratch/hf_models/"),
# Remove the 2 parameters below once your configuration is tested
override_batch_size=0, # max_samples=10
)
model_config = TransformersModelConfig(
pretrained=args.model,
dtype=args.dtype,
use_chat_template=True,
model_parallel=True,
accelerator=accelerator,
compile=False)

with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model.forward = torch.compile(model.forward, fullgraph=True)
pipeline = Pipeline(
tasks=args.few_shot_tasks,
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model=model,
config=model_config)

pipeline.evaluate()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increase indentation? will likely want to put this under the above context managers

pipeline.show_results()
remove_hooks(model)

if args.checkpoint_name is not None and not args.load_checkpoint:
Expand Down Expand Up @@ -888,12 +934,14 @@ def parse_args(args, override_defaults={}):
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.add_argument(
'--few-shot-eval',
action="store_true",
help='Perform zero_shot evaluation with lm_eval. Default %(default)s)')
type=str,
default=None,
choices=['lm_eval', 'lighteval'],
help='Perform zero_shot evaluation with lm_eval or lighteval. Default %(default)s)')
parser.add_argument(
'--few-shot-compile',
action="store_true",
help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)')
help='Compile during zero_shot evaluation. Default %(default)s)')
parser.add_argument(
'--few-shot-zeroshot',
action="store_true",
Expand Down
Loading