From 41b07234c7f0e013e003556b30d11357a2afdedf Mon Sep 17 00:00:00 2001 From: Luv Bansal <70321430+luv-bansal@users.noreply.github.com> Date: Fri, 10 Jan 2025 16:36:13 +0530 Subject: [PATCH] fix predictions methods (#475) --- clarifai/client/model.py | 85 ++++++++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/clarifai/client/model.py b/clarifai/client/model.py index 020842d6..713b6e98 100644 --- a/clarifai/client/model.py +++ b/clarifai/client/model.py @@ -503,6 +503,7 @@ def predict_by_filepath(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Predicts the model based on the given filepath. @@ -534,7 +535,7 @@ def predict_by_filepath(self, file_bytes = f.read() return self.predict_by_bytes(file_bytes, input_type, compute_cluster_id, nodepool_id, - deployment_id, inference_params, output_config) + deployment_id, user_id, inference_params, output_config) def predict_by_bytes(self, input_bytes: bytes, @@ -542,6 +543,7 @@ def predict_by_bytes(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Predicts the model based on the given bytes. @@ -581,11 +583,19 @@ def predict_by_bytes(self, runner_selector = None if deployment_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with deployment ID, please provide user_id in the method call." + ) runner_selector = Deployment.get_runner_selector( - user_id=self.user_id, deployment_id=deployment_id) + user_id=user_id, deployment_id=deployment_id) elif compute_cluster_id and nodepool_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call." + ) runner_selector = Nodepool.get_runner_selector( - user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) + user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) return self.predict( inputs=[input_proto], @@ -599,6 +609,7 @@ def predict_by_url(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Predicts the model based on the given URL. @@ -639,11 +650,19 @@ def predict_by_url(self, runner_selector = None if deployment_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with deployment ID, please provide user_id in the method call." + ) runner_selector = Deployment.get_runner_selector( - user_id=self.user_id, deployment_id=deployment_id) + user_id=user_id, deployment_id=deployment_id) elif compute_cluster_id and nodepool_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call." + ) runner_selector = Nodepool.get_runner_selector( - user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) + user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) return self.predict( inputs=[input_proto], @@ -712,6 +731,7 @@ def generate_by_filepath(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Generate the stream output on model based on the given filepath. @@ -748,6 +768,7 @@ def generate_by_filepath(self, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id, deployment_id=deployment_id, + user_id=user_id, inference_params=inference_params, output_config=output_config) @@ -757,6 +778,7 @@ def generate_by_bytes(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Generate the stream output on model based on the given bytes. @@ -798,11 +820,19 @@ def generate_by_bytes(self, runner_selector = None if deployment_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with deployment ID, please provide user_id in the method call." + ) runner_selector = Deployment.get_runner_selector( - user_id=self.user_id, deployment_id=deployment_id) + user_id=user_id, deployment_id=deployment_id) elif compute_cluster_id and nodepool_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call." + ) runner_selector = Nodepool.get_runner_selector( - user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) + user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) return self.generate( inputs=[input_proto], @@ -816,6 +846,7 @@ def generate_by_url(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Generate the stream output on model based on the given URL. @@ -857,11 +888,19 @@ def generate_by_url(self, runner_selector = None if deployment_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with deployment ID, please provide user_id in the method call." + ) runner_selector = Deployment.get_runner_selector( - user_id=self.user_id, deployment_id=deployment_id) + user_id=user_id, deployment_id=deployment_id) elif compute_cluster_id and nodepool_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call." + ) runner_selector = Nodepool.get_runner_selector( - user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) + user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) return self.generate( inputs=[input_proto], @@ -930,6 +969,7 @@ def stream_by_filepath(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Stream the model output based on the given filepath. @@ -964,6 +1004,7 @@ def stream_by_filepath(self, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id, deployment_id=deployment_id, + user_id=user_id, inference_params=inference_params, output_config=output_config) @@ -973,6 +1014,7 @@ def stream_by_bytes(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Stream the model output based on the given bytes. @@ -1016,11 +1058,19 @@ def input_generator(): runner_selector = None if deployment_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with deployment ID, please provide user_id in the method call." + ) runner_selector = Deployment.get_runner_selector( - user_id=self.user_id, deployment_id=deployment_id) + user_id=user_id, deployment_id=deployment_id) elif compute_cluster_id and nodepool_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call." + ) runner_selector = Nodepool.get_runner_selector( - user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) + user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) return self.stream( inputs=input_generator(), @@ -1034,6 +1084,7 @@ def stream_by_url(self, compute_cluster_id: str = None, nodepool_id: str = None, deployment_id: str = None, + user_id: str = None, inference_params: Dict = {}, output_config: Dict = {}): """Stream the model output based on the given URL. @@ -1075,11 +1126,19 @@ def input_generator(): runner_selector = None if deployment_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with deployment ID, please provide user_id in the method call." + ) runner_selector = Deployment.get_runner_selector( - user_id=self.user_id, deployment_id=deployment_id) + user_id=user_id, deployment_id=deployment_id) elif compute_cluster_id and nodepool_id: + if not user_id: + raise UserError( + "User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call." + ) runner_selector = Nodepool.get_runner_selector( - user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) + user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id) return self.stream( inputs=input_generator(),