Skip to content

Commit

Permalink
Merge pull request #198 from roboflow/feature/bump_ultralitics_pin
Browse files Browse the repository at this point in the history
Suggest users new version of ultralytics package
  • Loading branch information
PawelPeczek-Roboflow authored Oct 23, 2023
2 parents 0d1527a + 4feaf95 commit 79c7013
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
21 changes: 10 additions & 11 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ def __init__(
else:
self.__api_key = api_key
self.name = name

# FIXME: the version argument is inconsistently passed into this object.
# Sometimes it is passed as: test-workspace/test-project/2
# Other times, it is passed as: 2
self.version = version
self.version = unwrap_version_id(version_id=version)
self.type = type
self.augmentation = version_dict["augmentation"]
self.created = version_dict["created"]
Expand Down Expand Up @@ -139,7 +135,6 @@ def __check_if_generating(self):
url = f"{API_URL}/{self.workspace}/{self.project}/{self.version}?nocache=true"
response = requests.get(url, params={"api_key": self.__api_key})
response.raise_for_status()

if response.json()["version"]["progress"] == None:
progress = 0.0
else:
Expand Down Expand Up @@ -197,11 +192,11 @@ def download(self, model_format=None, location=None, overwrite: bool = True):
try:
import_module("ultralytics")
print_warn_for_wrong_dependencies_versions(
[("ultralytics", "==", "8.0.134")]
[("ultralytics", "==", "8.0.196")]
)
except ImportError as e:
print(
"[WARNING] we noticed you are downloading a `yolov8` datasets but you don't have `ultralytics` installed. Roboflow `.deploy` supports only models trained with `ultralytics==8.0.134`, to intall it `pip install ultralytics==8.0.134`."
"[WARNING] we noticed you are downloading a `yolov8` datasets but you don't have `ultralytics` installed. Roboflow `.deploy` supports only models trained with `ultralytics==8.0.196`, to intall it `pip install ultralytics==8.0.196`."
)
# silently fail
pass
Expand Down Expand Up @@ -460,7 +455,7 @@ def live_plot(epochs, mAP, loss, title=""):
# return the model object
return self.model

# @warn_for_wrong_dependencies_versions([("ultralytics", "==", "8.0.134")])
# @warn_for_wrong_dependencies_versions([("ultralytics", "==", "8.0.196")])
def deploy(self, model_type: str, model_path: str) -> None:
"""Uploads provided weights file to Roboflow
Expand Down Expand Up @@ -490,7 +485,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
)

print_warn_for_wrong_dependencies_versions(
[("ultralytics", "==", "8.0.134")], ask_to_continue=True
[("ultralytics", "==", "8.0.196")], ask_to_continue=True
)

elif "yolov5" in model_type or "yolov7" in model_type:
Expand Down Expand Up @@ -775,7 +770,7 @@ def data_yaml_callback(content: dict) -> dict:
try:
# get_wrong_dependencies_versions raises exception if ultralytics is not installed at all
if format == "yolov8" and not get_wrong_dependencies_versions(
dependencies_versions=[("ultralytics", "==", "8.0.134")]
dependencies_versions=[("ultralytics", "==", "8.0.196")]
):
content["train"] = "train/images"
content["val"] = "valid/images"
Expand All @@ -802,3 +797,7 @@ def __str__(self):
"workspace": self.workspace,
}
return json.dumps(json_value, indent=2)


def unwrap_version_id(version_id: str) -> str:
return version_id if "/" not in str(version_id) else version_id.split("/")[-1]
18 changes: 17 additions & 1 deletion tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from unittest.mock import patch

from .helpers import get_version
from roboflow.core.version import Version
from roboflow.core.version import Version, unwrap_version_id


class TestDownload(unittest.TestCase):
Expand Down Expand Up @@ -168,3 +168,19 @@ def test_raises_runtime_error_if_model_format_is_none(self):
self.version.model_format = None
with self.assertRaises(RuntimeError):
self.get_format_identifier(None)


def test_unwrap_version_id_when_full_identifier_is_given() -> None:
# when
result = unwrap_version_id(version_id="some-workspace/some-project/3")

# then
assert result == "3"


def test_unwrap_version_id_when_only_version_id_is_given() -> None:
# when
result = unwrap_version_id(version_id="3")

# then
assert result == "3"

0 comments on commit 79c7013

Please sign in to comment.