Skip to content

Commit

Permalink
add missing torch dependencies test (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Jul 16, 2024
1 parent 9ca80fb commit d39d9af
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/unit/test_module_exports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# flake8: noqa: F401

import builtins
import sys

import pytest


Expand Down Expand Up @@ -35,3 +38,56 @@ def test_module_exports():
)
except Exception as e: # noqa: BLE001
pytest.fail(f"Importing raised an exception: {e}")


@pytest.mark.parametrize("dep", ["torch", "torchvision", "transformers"])
def test_no_torch_deps(monkeypatch, dep):
real_import = builtins.__import__

def monkey_import_importerror(
name, globals=None, locals=None, fromlist=(), level=0
):
if name.startswith(dep):
raise ImportError(f"Mocked import error {name}")
return real_import(
name, globals=globals, locals=locals, fromlist=fromlist, level=level
)

for module in list(sys.modules):
if module.startswith((dep, "datachain")):
monkeypatch.delitem(sys.modules, module)
monkeypatch.setattr(builtins, "__import__", monkey_import_importerror)

try:
from datachain import (
AbstractUDF,
Aggregator,
BaseUDF,
C,
Column,
DataChain,
DataChainError,
DataModel,
File,
FileBasic,
FileError,
Generator,
ImageFile,
IndexedFile,
Mapper,
Session,
TarVFile,
TextFile,
)
except Exception as e: # noqa: BLE001
pytest.fail(f"Importing raised an exception: {e}")

with pytest.raises(ImportError):
from datachain.torch import (
PytorchDataset,
clip_similarity_scores,
convert_image,
convert_images,
convert_text,
label_to_int,
)

0 comments on commit d39d9af

Please sign in to comment.