Skip to content

Commit

Permalink
Merge pull request #200 from roboflow/add-video-inference
Browse files Browse the repository at this point in the history
Feature: Video Inference
  • Loading branch information
capjamesg authored Nov 3, 2023
2 parents ed84065 + aa250cc commit 080d588
Show file tree
Hide file tree
Showing 8 changed files with 542 additions and 6 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ supervision
urllib3>=1.26.6
tqdm>=4.41.0
PyYAML>=5.3.1
requests_toolbelt
requests_toolbelt
python-magic
1 change: 1 addition & 0 deletions roboflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from roboflow.config import API_URL, APP_URL, DEMO_KEYS, load_roboflow_api_key
from roboflow.core.project import Project
from roboflow.core.workspace import Workspace
from roboflow.models import CLIPModel, GazeModel
from roboflow.util.general import write_line

__version__ = "1.1.7"
Expand Down
2 changes: 2 additions & 0 deletions roboflow/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .clip import CLIPModel
from .gaze import GazeModel
16 changes: 16 additions & 0 deletions roboflow/models/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .inference import InferenceModel


class CLIPModel(InferenceModel):
"""
Run inference on CLIP, hosted on Roboflow.
"""

def __init__(self, api_key: str):
"""
Initialize a CLIP model.
Args:
api_key: Your Roboflow API key.
"""
super().__init__(api_key=api_key, version_id="BASE_MODEL")
16 changes: 16 additions & 0 deletions roboflow/models/gaze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .inference import InferenceModel


class GazeModel(InferenceModel):
"""
Run inference on a gaze detection model, hosted on Roboflow.
"""

def __init__(self, api_key: str):
"""
Initialize a CLIP model.
Args:
api_key: Your Roboflow API key.
"""
super().__init__(api_key=api_key)
266 changes: 262 additions & 4 deletions roboflow/models/inference.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
import io
import json
import os
import time
import urllib
from typing import List
from urllib.parse import urljoin

import requests
from PIL import Image
from requests_toolbelt.multipart.encoder import MultipartEncoder

from roboflow.config import API_URL
from roboflow.util.image_utils import validate_image_path
from roboflow.util.prediction import PredictionGroup

SUPPORTED_ROBOFLOW_MODELS = ["batch-video"]

SUPPORTED_ADDITIONAL_MODELS = {
"clip": {
"model_id": "clip",
"model_version": "1",
"inference_type": "clip-embed-image",
},
"gaze": {
"model_id": "gaze",
"model_version": "1",
"inference_type": "gaze-detection",
},
}


class InferenceModel:
def __init__(
Expand All @@ -25,13 +46,15 @@ def __init__(
api_key (str): private roboflow api key
version_id (str): the ID of the dataset version to use for inference
"""

self.__api_key = api_key
self.id = version_id

version_info = self.id.rsplit("/")
self.dataset_id = version_info[1]
self.version = version_info[2]
self.colors = {} if colors is None else colors
if version_id != "BASE_MODEL":
version_info = self.id.rsplit("/")
self.dataset_id = version_info[1]
self.version = version_info[2]
self.colors = {} if colors is None else colors

def __get_image_params(self, image_path):
"""
Expand Down Expand Up @@ -111,3 +134,238 @@ def predict(self, image_path, prediction_type=None, **kwargs):
image_dims=image_dims,
colors=self.colors,
)

def predict_video(
self,
video_path: str,
fps: int = 5,
additional_models: list = [],
prediction_type: str = "batch-video",
) -> List[str]:
"""
Infers detections based on image from specified model and image path.
Args:
video_path (str): path to the video you'd like to perform prediction on
prediction_type (str): type of the model to run
fps (int): frames per second to run inference
Returns:
A list of the signed url and job id
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> job_id, signed_url, signed_url_expires = model.predict_video("video.mp4", fps=5, inference_type="object-detection")
"""

signed_url_expires = None

url = urljoin(API_URL, "/video_upload_signed_url?api_key=" + self.__api_key)

if fps > 5:
raise Exception("FPS must be less than or equal to 5.")

for model in additional_models:
if model not in SUPPORTED_ADDITIONAL_MODELS:
raise Exception(f"Model {model} is not supported for video inference.")

if prediction_type not in SUPPORTED_ROBOFLOW_MODELS:
raise Exception(f"{prediction_type} is not supported for video inference.")

model_class = self.__class__.__name__

if model_class == "ObjectDetectionModel":
self.type = "object-detection"
elif model_class == "ClassificationModel":
self.type = "classification"
elif model_class == "InstanceSegmentationModel":
self.type = "instance-segmentation"
elif model_class == "GazeModel":
self.type = "gaze-detection"
elif model_class == "CLIPModel":
self.type = "clip-embed-image"
else:
raise Exception("Model type not supported for video inference.")

payload = json.dumps(
{
"file_name": os.path.basename(video_path),
}
)

if not video_path.startswith(("http://", "https://")):
headers = {"Content-Type": "application/json"}

try:
response = requests.request("POST", url, headers=headers, data=payload)
except Exception as e:
raise Exception(f"Error uploading video: {e}")

if not response.ok:
raise Exception(f"Error uploading video: {response.text}")

signed_url = response.json()["signed_url"]

signed_url_expires = (
signed_url.split("&X-Goog-Expires")[1].split("&")[0].strip("=")
)

# make a POST request to the signed URL
headers = {"Content-Type": "application/octet-stream"}

try:
with open(video_path, "rb") as f:
video_data = f.read()
except Exception as e:
raise Exception(f"Error reading video: {e}")

try:
result = requests.put(signed_url, data=video_data, headers=headers)
except Exception as e:
raise Exception(f"There was an error uploading the video: {e}")

if not result.ok:
raise Exception(
f"There was an error uploading the video: {result.text}"
)
else:
signed_url = video_path

url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key)

if model_class in ("CLIPModel", "GazeModel"):
if model_class == "CLIPModel":
model = "clip"
else:
model = "gaze"

models = [
{
"model_id": SUPPORTED_ADDITIONAL_MODELS[model]["model_id"],
"model_version": SUPPORTED_ADDITIONAL_MODELS[model][
"model_version"
],
"inference_type": SUPPORTED_ADDITIONAL_MODELS[model][
"inference_type"
],
}
]

for model in additional_models:
models.append(SUPPORTED_ADDITIONAL_MODELS[model])

payload = json.dumps(
{"input_url": signed_url, "infer_fps": 5, "models": models}
)

headers = {"Content-Type": "application/json"}

try:
response = requests.request("POST", url, headers=headers, data=payload)
except Exception as e:
raise Exception(f"Error starting video inference: {e}")

if not response.ok:
raise Exception(f"Error starting video inference: {response.text}")

job_id = response.json()["job_id"]

self.job_id = job_id

return job_id, signed_url, signed_url_expires

def poll_for_video_results(self, job_id: str = None) -> dict:
"""
Polls the Roboflow API to check if video inference is complete.
Returns:
Inference results as a dict
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("video.mp4")
>>> results = model.poll_for_video_results()
"""

if job_id is None:
job_id = self.job_id

url = urljoin(
API_URL, "/videoinfer/?api_key=" + self.__api_key + "&job_id=" + self.job_id
)

try:
response = requests.get(url, headers={"Content-Type": "application/json"})
except Exception as e:
raise Exception(f"Error getting video inference results: {e}")

if not response.ok:
raise Exception(f"Error getting video inference results: {response.text}")

data = response.json()

if data.get("status") != 0:
return {}

output_signed_url = data["output_signed_url"]

inference_data = requests.get(
output_signed_url, headers={"Content-Type": "application/json"}
)

# frame_offset and model name are top-level keys
return inference_data.json()

def poll_until_video_results(self, job_id) -> dict:
"""
Polls the Roboflow API to check if video inference is complete.
When inference is complete, the results are returned.
Returns:
Inference results as a dict
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("video.mp4")
>>> results = model.poll_until_results()
"""
if job_id is None:
job_id = self.job_id

attempts = 0

while True:
print(f"({attempts * 60}s): Checking for inference results")

response = self.poll_for_video_results()

time.sleep(60)

attempts += 1

if response != {}:
return response
4 changes: 3 additions & 1 deletion roboflow/models/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from PIL import Image

from roboflow.config import API_URL, OBJECT_DETECTION_MODEL, OBJECT_DETECTION_URL
from roboflow.models.inference import InferenceModel
from roboflow.util.image_utils import check_image_url
from roboflow.util.prediction import PredictionGroup
from roboflow.util.versions import print_warn_for_wrong_dependencies_versions


class ObjectDetectionModel:
class ObjectDetectionModel(InferenceModel):
"""
Run inference on an object detection model hosted on Roboflow or served through Roboflow Inference.
"""
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
"""
# Instantiate different API URL parameters
# To be moved to predict
super(ObjectDetectionModel, self).__init__(api_key, id)
self.__api_key = api_key
self.id = id
self.name = name
Expand Down
Loading

0 comments on commit 080d588

Please sign in to comment.