Skip to content

Commit

Permalink
fix predictions methods (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
luv-bansal authored Jan 10, 2025
1 parent 2ab31a8 commit 41b0723
Showing 1 changed file with 72 additions and 13 deletions.
85 changes: 72 additions & 13 deletions clarifai/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def predict_by_filepath(self,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Predicts the model based on the given filepath.
Expand Down Expand Up @@ -534,14 +535,15 @@ def predict_by_filepath(self,
file_bytes = f.read()

return self.predict_by_bytes(file_bytes, input_type, compute_cluster_id, nodepool_id,
deployment_id, inference_params, output_config)
deployment_id, user_id, inference_params, output_config)

def predict_by_bytes(self,
input_bytes: bytes,
input_type: str = None,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Predicts the model based on the given bytes.
Expand Down Expand Up @@ -581,11 +583,19 @@ def predict_by_bytes(self,

runner_selector = None
if deployment_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with deployment ID, please provide user_id in the method call."
)
runner_selector = Deployment.get_runner_selector(
user_id=self.user_id, deployment_id=deployment_id)
user_id=user_id, deployment_id=deployment_id)
elif compute_cluster_id and nodepool_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
)
runner_selector = Nodepool.get_runner_selector(
user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)

return self.predict(
inputs=[input_proto],
Expand All @@ -599,6 +609,7 @@ def predict_by_url(self,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Predicts the model based on the given URL.
Expand Down Expand Up @@ -639,11 +650,19 @@ def predict_by_url(self,

runner_selector = None
if deployment_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with deployment ID, please provide user_id in the method call."
)
runner_selector = Deployment.get_runner_selector(
user_id=self.user_id, deployment_id=deployment_id)
user_id=user_id, deployment_id=deployment_id)
elif compute_cluster_id and nodepool_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
)
runner_selector = Nodepool.get_runner_selector(
user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)

return self.predict(
inputs=[input_proto],
Expand Down Expand Up @@ -712,6 +731,7 @@ def generate_by_filepath(self,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Generate the stream output on model based on the given filepath.
Expand Down Expand Up @@ -748,6 +768,7 @@ def generate_by_filepath(self,
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id,
deployment_id=deployment_id,
user_id=user_id,
inference_params=inference_params,
output_config=output_config)

Expand All @@ -757,6 +778,7 @@ def generate_by_bytes(self,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Generate the stream output on model based on the given bytes.
Expand Down Expand Up @@ -798,11 +820,19 @@ def generate_by_bytes(self,

runner_selector = None
if deployment_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with deployment ID, please provide user_id in the method call."
)
runner_selector = Deployment.get_runner_selector(
user_id=self.user_id, deployment_id=deployment_id)
user_id=user_id, deployment_id=deployment_id)
elif compute_cluster_id and nodepool_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
)
runner_selector = Nodepool.get_runner_selector(
user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)

return self.generate(
inputs=[input_proto],
Expand All @@ -816,6 +846,7 @@ def generate_by_url(self,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Generate the stream output on model based on the given URL.
Expand Down Expand Up @@ -857,11 +888,19 @@ def generate_by_url(self,

runner_selector = None
if deployment_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with deployment ID, please provide user_id in the method call."
)
runner_selector = Deployment.get_runner_selector(
user_id=self.user_id, deployment_id=deployment_id)
user_id=user_id, deployment_id=deployment_id)
elif compute_cluster_id and nodepool_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
)
runner_selector = Nodepool.get_runner_selector(
user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)

return self.generate(
inputs=[input_proto],
Expand Down Expand Up @@ -930,6 +969,7 @@ def stream_by_filepath(self,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Stream the model output based on the given filepath.
Expand Down Expand Up @@ -964,6 +1004,7 @@ def stream_by_filepath(self,
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id,
deployment_id=deployment_id,
user_id=user_id,
inference_params=inference_params,
output_config=output_config)

Expand All @@ -973,6 +1014,7 @@ def stream_by_bytes(self,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Stream the model output based on the given bytes.
Expand Down Expand Up @@ -1016,11 +1058,19 @@ def input_generator():

runner_selector = None
if deployment_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with deployment ID, please provide user_id in the method call."
)
runner_selector = Deployment.get_runner_selector(
user_id=self.user_id, deployment_id=deployment_id)
user_id=user_id, deployment_id=deployment_id)
elif compute_cluster_id and nodepool_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
)
runner_selector = Nodepool.get_runner_selector(
user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)

return self.stream(
inputs=input_generator(),
Expand All @@ -1034,6 +1084,7 @@ def stream_by_url(self,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""Stream the model output based on the given URL.
Expand Down Expand Up @@ -1075,11 +1126,19 @@ def input_generator():

runner_selector = None
if deployment_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with deployment ID, please provide user_id in the method call."
)
runner_selector = Deployment.get_runner_selector(
user_id=self.user_id, deployment_id=deployment_id)
user_id=user_id, deployment_id=deployment_id)
elif compute_cluster_id and nodepool_id:
if not user_id:
raise UserError(
"User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
)
runner_selector = Nodepool.get_runner_selector(
user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)

return self.stream(
inputs=input_generator(),
Expand Down

0 comments on commit 41b0723

Please sign in to comment.