diff --git a/models/model_upload/llms/openai-gpt4/1/model.py b/models/model_upload/llms/openai-gpt4/1/model.py index e5484db..6a3e347 100644 --- a/models/model_upload/llms/openai-gpt4/1/model.py +++ b/models/model_upload/llms/openai-gpt4/1/model.py @@ -3,12 +3,12 @@ from clarifai.runners.models.model_runner import ModelRunner from clarifai_grpc.grpc.api import resources_pb2, service_pb2 -from clarifai_grpc.grpc.api.status import status_code_pb2 +from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2 from google.protobuf import json_format from openai import OpenAI # Set your OpenAI API key here -API_KEY = 'OPENAI_API_KEY' +API_KEY = 'YOUR_OPENAI_API_KEY' def get_inference_params(request) -> dict: @@ -63,7 +63,7 @@ def predict(self, inference_params = get_inference_params(request) streams = [] - for inp in request.inputs: + for input in request.inputs: output = resources_pb2.Output() # it contains the input data for the model @@ -74,12 +74,21 @@ def predict(self, outputs = [resources_pb2.Output() for _ in request.inputs] for output in outputs: output.status.code = status_code_pb2.SUCCESS - for chunk_batch in itertools.zip_longest(*streams, fillvalue=None): - for idx, chunk in enumerate(chunk_batch): - outputs[idx].data.text.raw += chunk.choices[0].delta.content if ( - chunk and chunk.choices[0].delta.content) is not None else '' - - return service_pb2.MultiOutputResponse(outputs=outputs,) + try: + for chunk_batch in itertools.zip_longest(*streams, fillvalue=None): + for idx, chunk in enumerate(chunk_batch): + outputs[idx].data.text.raw += chunk.choices[0].delta.content if ( + chunk and chunk.choices[0].delta.content) is not None else '' + response = service_pb2.MultiOutputResponse( + outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS)) + except Exception as e: + for output in outputs: + output.status.code = status_code_pb2.MODEL_PREDICTION_FAILED + output.status.description = str(e) + response = service_pb2.MultiOutputResponse( + outputs=outputs, status=status_pb2.Status(code=status_code_pb2.MODEL_PREDICTION_FAILED)) + + return response def generate(self, request: service_pb2.PostModelOutputsRequest ) -> Iterator[service_pb2.MultiOutputResponse]: @@ -92,16 +101,23 @@ def generate(self, request: service_pb2.PostModelOutputsRequest input_data = input.data stream = stream_completion(self.model, self.client, input_data, inference_params) streams.append(stream) - - for chunk_batch in itertools.zip_longest(*streams, fillvalue=None): - resp = service_pb2.MultiOutputResponse() - - for chunk in chunk_batch: - output = resp.outputs.add() - output.data.text.raw = (chunk.choices[0].delta.content - if (chunk and chunk.choices[0].delta.content) is not None else '') - output.status.code = status_code_pb2.SUCCESS - yield resp + try: + for chunk_batch in itertools.zip_longest(*streams, fillvalue=None): + resp = service_pb2.MultiOutputResponse() + resp.status.code = status_code_pb2.SUCCESS + for chunk in chunk_batch: + output = resp.outputs.add() + output.data.text.raw = (chunk.choices[0].delta.content if + (chunk and chunk.choices[0].delta.content) is not None else '') + output.status.code = status_code_pb2.SUCCESS + yield resp + except Exception as e: + outputs = [resources_pb2.Output() for _ in request.inputs] + for output in outputs: + output.status.code = status_code_pb2.MODEL_PREDICTION_FAILED + output.status.description = str(e) + yield service_pb2.MultiOutputResponse( + outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS)) def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest] ) -> Iterator[service_pb2.MultiOutputResponse]: