Skip to content

Commit

Permalink
get processor
Browse files Browse the repository at this point in the history
  • Loading branch information
lrosemberg committed Dec 16, 2024
1 parent 019b764 commit 1b1f988
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,16 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
model_path (str): File path to the model weights to be uploaded.
filename (str, optional): The name of the weights file. Defaults to "weights/best.pt".
"""
processor = self._get_processor_function(model_type)

zip_file_name = processor(model_type, model_path, filename)

if zip_file_name is None:
raise RuntimeError("Failed to process model")

self._upload_zip(model_type, model_path, zip_file_name)

def _get_processor_function(self, model_type: str) -> callable:
if model_type.startswith("yolo11"):
model_type = model_type.replace("yolo11", "yolov11")

Expand All @@ -497,8 +507,6 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
if not any(supported_model in model_type for supported_model in supported_models):
raise (ValueError(f"Model type {model_type} not supported. Supported models are" f" {supported_models}"))

zip_file_name = None

if model_type.startswith(("paligemma", "paligemma2", "florence-2")):
if any(model in model_type for model in ["paligemma", "paligemma2", "florence-2"]):
supported_hf_types = [
Expand All @@ -516,19 +524,14 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
f"{model_type} not supported for this type of upload."
f"Supported upload types are {supported_hf_types}"
)
zip_file_name = self.process_huggingface(model_type, model_path, filename)
return self._process_huggingface

if "yolonas" in model_type:
zip_file_name = self.process_yolonas(model_type, model_path, filename)

zip_file_name = self.process_yolo(model_type, model_path, filename)

if zip_file_name is None:
raise RuntimeError("Failed to process model")
return self._process_yolonas

self.upload_zip(model_type, model_path, zip_file_name)
return self._process_yolo

def process_yolo(self, model_type: str, model_path: str, filename: str) -> str:
def _process_yolo(self, model_type: str, model_path: str, filename: str) -> str:
if "yolov8" in model_type:
try:
import torch
Expand Down Expand Up @@ -662,7 +665,7 @@ def process_yolo(self, model_type: str, model_path: str, filename: str) -> str:

return zip_file_name

def process_huggingface(
def _process_huggingface(
self, model_type: str, model_path: str, filename: str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
) -> str:
# Check if model_path exists
Expand Down Expand Up @@ -713,7 +716,7 @@ def process_huggingface(

return tar_file_name

def process_yolonas(self, model_type: str, model_path: str, filename: str = "weights/best.pt") -> str:
def _process_yolonas(self, model_type: str, model_path: str, filename: str = "weights/best.pt") -> str:
try:
import torch
except ImportError:
Expand Down Expand Up @@ -780,7 +783,7 @@ def process_yolonas(self, model_type: str, model_path: str, filename: str = "wei

self.upload_zip(model_type, model_path, zip_file_name)

def upload_zip(self, model_type: str, model_path: str, model_file_name: str):
def _upload_zip(self, model_type: str, model_path: str, model_file_name: str):
res = requests.get(
f"{API_URL}/{self.workspace}/{self.project}/{self.version}"
f"/uploadModel?api_key={self.__api_key}&modelType={model_type}&nocache=true"
Expand Down

0 comments on commit 1b1f988

Please sign in to comment.