Skip to content

Commit

Permalink
Merge pull request #102 from Clarifai/refract-openai-gpt4
Browse files Browse the repository at this point in the history
[EAGLE-5211] Refract Openai GPT4 model
  • Loading branch information
luv-bansal authored Dec 19, 2024
2 parents 2a22bf1 + ebcf3ed commit ee93c76
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 170 deletions.
258 changes: 90 additions & 168 deletions models/model_upload/llms/openai-gpt4/1/model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,49 @@
import itertools
from typing import Iterator

from clarifai.runners.models.model_runner import ModelRunner
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
from google.protobuf import json_format
from openai import OpenAI

from clarifai.runners.models.model_runner import ModelRunner
# Set your OpenAI API key here
API_KEY = 'YOUR_OPENAI_API_KEY'


def get_inference_params(request) -> dict:
"""Get the inference params from the request."""
inference_params = {}
if request.model.model_version.id != "":
output_info = request.model.model_version.output_info
output_info = json_format.MessageToDict(output_info, preserving_proto_field_name=True)

if "params" in output_info:
inference_params = output_info["params"]
return inference_params

# model name
MODEL = "gpt-4-1106-preview"

API_KEY = 'OPENAI_API_KEY'
def stream_completion(model, client, input_data, inference_params):
"""Stream iteratively generates completions for the input data."""

temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 512)
top_p = inference_params.get("top_p", 1.0)
system_prompt = "You'r a helpful assistant"
system_prompt = inference_params.get("system_prompt", system_prompt)

prompt = input_data.text.raw
messages = [{"role": "user", "content": prompt}]
kwargs = dict(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True,
)
stream = client.chat.completions.create(**kwargs)
return stream


class MyRunner(ModelRunner):
Expand All @@ -21,182 +53,72 @@ class MyRunner(ModelRunner):
def load_model(self):
"""Load the model here."""
self.client = OpenAI(api_key=API_KEY)
self.model = "gpt-4-1106-preview"

def predict(
self, request: service_pb2.PostModelOutputsRequest
) -> service_pb2.MultiOutputResponse:
def predict(self,
request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""

# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = None
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True
)

outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
inference_params = get_inference_params(request)
streams = []
for input in request.inputs:
output = resources_pb2.Output()

data = inp.data

# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]

system_prompt = "You are a helpful assistant"

messages = [{"role": "system", "content": system_prompt}]
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
top_p = inference_params.get("top_p", 1.0)

kwargs = dict(
model=MODEL, messages=messages, temperature=temperature, max_tokens=max_tokens, top_p=top_p
)

if data.text.raw != "":
prompt = data.text.raw
messages.append({"role": "user", "content": prompt})

res = self.client.chat.completions.create(**kwargs)
res = res.choices[0].message.content

output.data.text.raw = res
# it contains the input data for the model
input_data = input.data
stream = stream_completion(self.model, self.client, input_data, inference_params)
streams.append(stream)

outputs = [resources_pb2.Output() for _ in request.inputs]
for output in outputs:
output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(
outputs=outputs,
)

def generate(
self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""

# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = None
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True
)

# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = None
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True
)
# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]

# TODO: parallelize this over inputs in a single request.
try:
for chunk_batch in itertools.zip_longest(*streams, fillvalue=None):
for idx, chunk in enumerate(chunk_batch):
outputs[idx].data.text.raw += chunk.choices[0].delta.content if (
chunk and chunk.choices[0].delta.content) is not None else ''
response = service_pb2.MultiOutputResponse(
outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS))
except Exception as e:
for output in outputs:
output.status.code = status_code_pb2.MODEL_PREDICTION_FAILED
output.status.description = str(e)
response = service_pb2.MultiOutputResponse(
outputs=outputs, status=status_pb2.Status(code=status_code_pb2.MODEL_PREDICTION_FAILED))

return response

def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This method generates stream of outputs for the given inputs in the request."""

inference_params = get_inference_params(request)
streams = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data

system_prompt = "You are a helpful assistant"

messages = [{"role": "system", "content": system_prompt}]
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
top_p = inference_params.get("top_p", 1.0)

if data.text.raw != "":
prompt = data.text.raw
messages.append({"role": "user", "content": prompt})
kwargs = dict(
model=MODEL,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True,
)
stream = self.client.chat.completions.create(**kwargs)

streams.append(stream)
for chunk_batch in itertools.zip_longest(*streams, fillvalue=None):
resp = service_pb2.MultiOutputResponse()

for chunk in chunk_batch:
output = resp.outputs.add()
output.data.text.raw = (
chunk.choices[0].delta.content
if (chunk and chunk.choices[0].delta.content) is not None
else ''
)
output.status.code = status_code_pb2.SUCCESS
yield resp

def stream(
self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""

for ri, request in enumerate(request_iterator):
output_info = None
if ri == 0: # only first request has model information.
model = request.model
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True
)
# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]

streams = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data

system_prompt = "You are a helpful assistant"

messages = [{"role": "system", "content": system_prompt}]
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
top_p = inference_params.get("top_p", 1.0)

if data.text.raw != "":
prompt = data.text.raw
messages.append({"role": "user", "content": prompt})
kwargs = dict(
model=MODEL,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True,
)
stream = self.client.chat.completions.create(**kwargs)

streams.append(stream)
for input in request.inputs:
# it contains the input data for the model
input_data = input.data
stream = stream_completion(self.model, self.client, input_data, inference_params)
streams.append(stream)
try:
for chunk_batch in itertools.zip_longest(*streams, fillvalue=None):
resp = service_pb2.MultiOutputResponse()

resp.status.code = status_code_pb2.SUCCESS
for chunk in chunk_batch:
output = resp.outputs.add()
output.data.text.raw = (
chunk.choices[0].delta.content
if (chunk and chunk.choices[0].delta.content) is not None
else ''
)
output.data.text.raw = (chunk.choices[0].delta.content if
(chunk and chunk.choices[0].delta.content) is not None else '')
output.status.code = status_code_pb2.SUCCESS
yield resp
except Exception as e:
outputs = [resources_pb2.Output() for _ in request.inputs]
for output in outputs:
output.status.code = status_code_pb2.MODEL_PREDICTION_FAILED
output.status.description = str(e)
yield service_pb2.MultiOutputResponse(
outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS))

def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
NotImplementedError("Stream is not implemented for this model.")
4 changes: 2 additions & 2 deletions models/model_upload/llms/openai-gpt4/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
openai==1.2.2
tenacity==8.1.0
openai==1.55.3
tenacity==8.1.0

0 comments on commit ee93c76

Please sign in to comment.