Skip to content

Commit

Permalink
final commit
Browse files Browse the repository at this point in the history
  • Loading branch information
luv-bansal committed Dec 19, 2024
1 parent 3786d66 commit ebcf3ed
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions models/model_upload/llms/openai-gpt4/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from clarifai.runners.models.model_runner import ModelRunner
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
from google.protobuf import json_format
from openai import OpenAI

# Set your OpenAI API key here
API_KEY = 'OPENAI_API_KEY'
API_KEY = 'YOUR_OPENAI_API_KEY'


def get_inference_params(request) -> dict:
Expand Down Expand Up @@ -63,7 +63,7 @@ def predict(self,

inference_params = get_inference_params(request)
streams = []
for inp in request.inputs:
for input in request.inputs:
output = resources_pb2.Output()

# it contains the input data for the model
Expand All @@ -74,12 +74,21 @@ def predict(self,
outputs = [resources_pb2.Output() for _ in request.inputs]
for output in outputs:
output.status.code = status_code_pb2.SUCCESS
for chunk_batch in itertools.zip_longest(*streams, fillvalue=None):
for idx, chunk in enumerate(chunk_batch):
outputs[idx].data.text.raw += chunk.choices[0].delta.content if (
chunk and chunk.choices[0].delta.content) is not None else ''

return service_pb2.MultiOutputResponse(outputs=outputs,)
try:
for chunk_batch in itertools.zip_longest(*streams, fillvalue=None):
for idx, chunk in enumerate(chunk_batch):
outputs[idx].data.text.raw += chunk.choices[0].delta.content if (
chunk and chunk.choices[0].delta.content) is not None else ''
response = service_pb2.MultiOutputResponse(
outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS))
except Exception as e:
for output in outputs:
output.status.code = status_code_pb2.MODEL_PREDICTION_FAILED
output.status.description = str(e)
response = service_pb2.MultiOutputResponse(
outputs=outputs, status=status_pb2.Status(code=status_code_pb2.MODEL_PREDICTION_FAILED))

return response

def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
Expand All @@ -92,16 +101,23 @@ def generate(self, request: service_pb2.PostModelOutputsRequest
input_data = input.data
stream = stream_completion(self.model, self.client, input_data, inference_params)
streams.append(stream)

for chunk_batch in itertools.zip_longest(*streams, fillvalue=None):
resp = service_pb2.MultiOutputResponse()

for chunk in chunk_batch:
output = resp.outputs.add()
output.data.text.raw = (chunk.choices[0].delta.content
if (chunk and chunk.choices[0].delta.content) is not None else '')
output.status.code = status_code_pb2.SUCCESS
yield resp
try:
for chunk_batch in itertools.zip_longest(*streams, fillvalue=None):
resp = service_pb2.MultiOutputResponse()
resp.status.code = status_code_pb2.SUCCESS
for chunk in chunk_batch:
output = resp.outputs.add()
output.data.text.raw = (chunk.choices[0].delta.content if
(chunk and chunk.choices[0].delta.content) is not None else '')
output.status.code = status_code_pb2.SUCCESS
yield resp
except Exception as e:
outputs = [resources_pb2.Output() for _ in request.inputs]
for output in outputs:
output.status.code = status_code_pb2.MODEL_PREDICTION_FAILED
output.status.description = str(e)
yield service_pb2.MultiOutputResponse(
outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS))

def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
Expand Down

0 comments on commit ebcf3ed

Please sign in to comment.