Skip to content

Commit

Permalink
torch-loader(example): use persistent workers to reduce test time (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Dec 12, 2024
1 parent 250ce97 commit eaf0412
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
7 changes: 4 additions & 3 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datachain.torch import label_to_int

STORAGE = "gs://datachain-demo/dogs-and-cats/"
NUM_EPOCHS = os.getenv("NUM_EPOCHS", "3")
NUM_EPOCHS = int(os.getenv("NUM_EPOCHS", "3"))

# Define transformation for data preprocessing
transform = v2.Compose(
Expand Down Expand Up @@ -68,7 +68,8 @@ def forward(self, x):
train_loader = DataLoader(
ds.to_pytorch(transform=transform),
batch_size=25,
num_workers=4,
num_workers=max(4, os.cpu_count() or 2),
persistent_workers=True,
multiprocessing_context=multiprocessing.get_context("spawn"),
)

Expand All @@ -77,7 +78,7 @@ def forward(self, x):
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
for epoch in range(int(NUM_EPOCHS)):
for epoch in range(NUM_EPOCHS):
with tqdm(
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
) as loader:
Expand Down
2 changes: 2 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def examples(session: nox.Session) -> None:
session.install(".[examples]")
session.run(
"pytest",
"--durations=0",
"tests/examples",
"-m",
"examples",
*session.posargs,
Expand Down
2 changes: 1 addition & 1 deletion tests/examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def smoke_test(example: str, env: Optional[dict] = None):
@pytest.mark.get_started
@pytest.mark.parametrize("example", get_started_examples)
def test_get_started_examples(example):
smoke_test(example, {"NUM_EPOCHS": "1"})
smoke_test(example)


@pytest.mark.examples
Expand Down

0 comments on commit eaf0412

Please sign in to comment.