Skip to content

Commit

Permalink
fix clip, allow access to signed url expires
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Nov 1, 2023
1 parent 5187587 commit 7af4637
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion roboflow/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ def __init__(self, api_key: str):
Args:
api_key: Your Roboflow API key.
"""
super().__init__(api_key=api_key)
super().__init__(api_key=api_key, version_id="BASE_MODEL")
21 changes: 14 additions & 7 deletions roboflow/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ def __init__(
self.__api_key = api_key
self.id = version_id

version_info = self.id.rsplit("/")
self.dataset_id = version_info[1]
self.version = version_info[2]
self.colors = {} if colors is None else colors
if version_id != "BASE_MODEL":
version_info = self.id.rsplit("/")
self.dataset_id = version_info[1]
self.version = version_info[2]
self.colors = {} if colors is None else colors

def __get_image_params(self, image_path):
"""
Expand Down Expand Up @@ -162,9 +163,11 @@ def predict_video(
>>> model = project.version("1").model
>>> prediction = model.predict("video.mp4", fps=5, inference_type="object-detection")
>>> job_id, signed_url, signed_url_expires = model.predict_video("video.mp4", fps=5, inference_type="object-detection")
"""

signed_url_expires = None

url = urljoin(API_URL, "/video_upload_signed_url?api_key=" + self.__api_key)

if fps > 5:
Expand All @@ -187,7 +190,7 @@ def predict_video(
self.type = "instance-segmentation"
elif model_class == "GazeModel":
self.type = "gaze-detection"
elif model_class == "CLIP":
elif model_class == "CLIPModel":
self.type = "clip-embed-image"
else:
raise Exception("Model type not supported for video inference.")
Expand All @@ -211,6 +214,10 @@ def predict_video(

signed_url = response.json()["signed_url"]

signed_url_expires = (
signed_url.split("&X-Goog-Expires")[1].split("&")[0].strip("=")
)

# make a POST request to the signed URL
headers = {"Content-Type": "application/octet-stream"}

Expand Down Expand Up @@ -273,7 +280,7 @@ def predict_video(

self.job_id = job_id

return job_id, signed_url
return job_id, signed_url, signed_url_expires

def poll_for_video_results(self, job_id: str = None) -> dict:
"""
Expand Down

0 comments on commit 7af4637

Please sign in to comment.