Skip to content

Commit

Permalink
Support opensource models, e.g., deepseek
Browse files Browse the repository at this point in the history
  • Loading branch information
libowen2121 committed Mar 25, 2024
1 parent aa45fef commit 5005cfb
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 8 deletions.
2 changes: 1 addition & 1 deletion inference/make_datasets/create_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def main(
assert tokenizer_name is not None
if push_to_hub_user is None and not Path(output_dir).exists():
Path(output_dir).mkdir(parents=True)
output_file = f"SWE-bench__{prompt_style}__fs-{file_source}"
output_file = f"{dataset_name_or_path}__{prompt_style}__fs-{file_source}__tok-{tokenizer_name}"
if k is not None:
assert file_source not in {
"all",
Expand Down
9 changes: 5 additions & 4 deletions inference/make_datasets/tokenize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tiktoken
from datasets import disable_caching, load_from_disk, load_dataset
from tqdm.auto import tqdm
from transformers import LlamaTokenizer
from transformers import LlamaTokenizer, AutoTokenizer

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
Expand All @@ -23,15 +23,16 @@ def cl100k(text, tokenizer):
return tokenizer.encode(text, disallowed_special=())


def llama(text, tokenizer):
def hf_tokenize(text, tokenizer):
return tokenizer(text, add_special_tokens=False, return_attention_mask=False)[
"input_ids"
]


TOKENIZER_FUNCS = {
"cl100k": (tiktoken.get_encoding("cl100k_base"), cl100k),
"llama": (LlamaTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K"), llama),
"llama": (LlamaTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K"), hf_tokenize),
"deepseek-coder-33b-instruct": (AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-33b-instruct"), hf_tokenize),
"deepseek-coder-6.7b-instruct": (AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-6.7b-instruct"), hf_tokenize),
}


Expand Down
126 changes: 123 additions & 3 deletions inference/run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tqdm.auto import tqdm
import numpy as np
import openai
from openai import OpenAI
import tiktoken
from anthropic import HUMAN_PROMPT, AI_PROMPT, Anthropic
from tenacity import (
Expand All @@ -22,8 +23,10 @@
)
from datasets import load_dataset, load_from_disk
from make_datasets.utils import extract_diff
from make_datasets.tokenize_dataset import TOKENIZER_FUNCS
from argparse import ArgumentParser
import logging
from icecream import ic

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
Expand All @@ -42,6 +45,8 @@
"gpt-4-0613": 8_192,
"gpt-4-1106-preview": 128_000,
"gpt-4-0125-preview": 128_000,
"deepseek-coder-6.7b-instruct": 32_000,
"deepseek-coder-33b-instruct": 32_000,
}

# The cost per token for each model input.
Expand Down Expand Up @@ -138,7 +143,7 @@ def call_chat(model_name_or_path, inputs, use_azure, temperature, top_p, **model
**model_args,
)
else:
response = openai.ChatCompletion.create(
response = openai.chat.completions.create(
model=model_name_or_path,
messages=[
{"role": "system", "content": system_messages},
Expand All @@ -152,7 +157,45 @@ def call_chat(model_name_or_path, inputs, use_azure, temperature, top_p, **model
output_tokens = response.usage.completion_tokens
cost = calc_cost(response.model, input_tokens, output_tokens)
return response, cost
except openai.error.InvalidRequestError as e:
except openai.exceptions.InvalidRequestError as e: # TODO
if e.code == "context_length_exceeded":
print("Context length exceeded")
return None
raise e


@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(3))
def call_opensource_chat(client, inputs, temperature, top_p, **model_args):
"""
Calls the openai API to generate completions for the given inputs.
Args:
model_name_or_path (str): The name or path of the model to use.
inputs (str): The inputs to generate completions for.
use_azure (bool): Whether to use the azure API.
temperature (float): The temperature to use.
top_p (float): The top_p to use.
**model_args (dict): A dictionary of model arguments.
"""
model_name = client.models.list().data[0].id # Get model name from the deployment service
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]

try:
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_messages},
{"role": "user", "content": user_message},
],
temperature=temperature,
top_p=top_p,
**model_args,
)
from icecream import ic
ic(response)
return response
except openai.exceptions.InvalidRequestError as e: # TODO
if e.code == "context_length_exceeded":
print("Context length exceeded")
return None
Expand All @@ -165,6 +208,16 @@ def gpt_tokenize(string: str, encoding) -> int:
return num_tokens


def get_model_id(model_name: str) -> str:
"""Returns normalized model id from abitrary serving model name.
E.g., abitrary_path/deepseek-coder-33b -> deepseek.
"""
if "deepseek" in model_name:
return "deepseek"
if "codellama" in model_name:
return "codelllma"


def claude_tokenize(string: str, api) -> int:
"""Returns the number of tokens in a text string."""
num_tokens = api.count_tokens(string)
Expand Down Expand Up @@ -231,7 +284,7 @@ def openai_inference(
temperature,
top_p,
)
completion = response.choices[0]["message"]["content"]
completion = response.choices[0].message.content
total_cost += cost
print(f"Total Cost: {total_cost:.2f}")
output_dict["full_output"] = completion
Expand All @@ -242,6 +295,71 @@ def openai_inference(
break


def opensource_inference(
test_dataset,
model_name_or_path,
output_file,
model_args,
existing_ids,
max_cost,
):
"""
Runs inference on a dataset using the openai API.
Args:
test_dataset (datasets.Dataset): The dataset to run inference on.
output_file (str): The path to the output file.
model_args (dict): A dictionary of model arguments.
existing_ids (set): A set of ids that have already been processed.
max_cost (float): The maximum cost to spend on inference.
"""
openai_key = os.environ.get("DEPLOYMENT_API_KEY", None)
url = os.environ.get("DEPLOYMENT_URL", None)
if openai_key is None:
raise ValueError(
"Must provide an api key. Expected in DEPLOYMENT_API_KEY environment variable."
)
client = OpenAI(
api_key=openai_key,
base_url=url
)
# model_name = client.models.list().data[0].id
# ic(model_name)

encoding = TOKENIZER_FUNCS[model_name_or_path][0]
test_dataset = test_dataset.filter(
lambda x: len(encoding(x["text"])) <= MODEL_LIMITS[model_name_or_path],
desc="Filtering",
load_from_cache_file=False,
)
print(f"Using api key {openai_key}")
temperature = model_args.pop("temperature", 0.2)
top_p = model_args.pop("top_p", 0.95 if temperature > 0 else 1)
print(f"Using temperature={temperature}, top_p={top_p}")
basic_args = {
"model_name_or_path": model_name_or_path,
}
print(f"Filtered to {len(test_dataset)} instances")
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
instance_id = datum["instance_id"]
if instance_id in existing_ids:
continue
output_dict = {"instance_id": instance_id}
output_dict.update(basic_args)
output_dict["text"] = f"{datum['text']}\n\n"
response = call_opensource_chat(
client,
output_dict["text"],
temperature,
top_p,
)
completion = response.choices[0].message.content
output_dict["full_output"] = completion
output_dict["model_patch"] = extract_diff(completion)
print(json.dumps(output_dict), file=f, flush=True)


@retry(wait=wait_random_exponential(min=60, max=600), stop=stop_after_attempt(6))
def call_anthropic(
inputs, anthropic, model_name_or_path, temperature, top_p, **model_args
Expand Down Expand Up @@ -503,6 +621,8 @@ def main(
anthropic_inference(**inference_args)
elif model_name_or_path.startswith("gpt"):
openai_inference(**inference_args)
elif get_model_id(model_name_or_path) in ("deepseek", "codellama",):
opensource_inference(**inference_args)
else:
raise ValueError(f"Invalid model name or path {model_name_or_path}")
logger.info(f"Done!")
Expand Down

0 comments on commit 5005cfb

Please sign in to comment.