Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prefetch examples to showcase performance gain #635

Open
dmpetrov opened this issue Nov 26, 2024 · 17 comments
Open

prefetch examples to showcase performance gain #635

dmpetrov opened this issue Nov 26, 2024 · 17 comments
Assignees
Labels
performance question Further information is requested

Comments

@dmpetrov
Copy link
Member

Description

We need examples when it's clear how prefetch helps. I tried in several examples and I don't see any difference.

An examples is below. Note, the library utilizes CPU pretty well (can utilize it 600% in my laptop), no parallelization is needed if prefetch is good.

Results:

  • prefetch=0: 35.144s
  • prefetch=2: 35.126s
  • prefetch=8: 34.418s

Code:

from io import BytesIO
from typing import Optional

from PIL import Image
from ultralytics import YOLO
from datachain import DataChain, C, File, model

from datachain.lib.data_model import DataModel

def process1(yolo: YOLO, file: File) -> model.ultralytics.YoloBBoxes:
    img = Image.open(BytesIO(file.read()))
    return model.ultralytics.YoloBBoxes.from_results(yolo(img, verbose=False))

(
    DataChain
    .from_storage("gs://mpii-human-pose")
    .limit(100)
    .settings(prefetch=0)
    .setup(
        yolo=lambda: YOLO("yolo11n.pt"),
    )
    .map(my_yolo=process1)
    .save("mpii-human-pose-segment")
)

Version Info

0.7.1
Python 3.9.16
@dmpetrov dmpetrov added performance question Further information is requested labels Nov 26, 2024
@dmpetrov
Copy link
Member Author

@skshetry I'd appreciate it if you could take a look

@skshetry skshetry self-assigned this Nov 26, 2024
@skshetry
Copy link
Member

skshetry commented Nov 28, 2024

Turns out, prefetch is only enabled when caching is enabled, which makes sense.

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)

For me, cache=True, prefetch=8 reduces runtime to 15s from cache=True, prefetch=0 (which takes 50s).

@skshetry
Copy link
Member

Can we enable caching by default? What is the reason for not doing so already?

@shcheklein
Copy link
Member

it might be taking a lot of space by default? But, yes seems reasonable to enable it by default. @dmpetrov wdyt?

@dmpetrov
Copy link
Member Author

dmpetrov commented Dec 1, 2024

pre-fetch is not related to caching. We need to decouple these options. Created #647

Caching by default is a separate question - but I don't see any strong reasons of doing this.

@skshetry
Copy link
Member

skshetry commented Dec 1, 2024

Caching by default is a separate question - but I don't see any strong reasons of doing this.

We have to persist them somewhere for the lifetime of the script's run. IIUC, due to the worker processes, we cannot do it in memory.

@shcheklein
Copy link
Member

I think the point here is that for pre-fetch all the things related to cache should be an implementation detail. E.g. cache them as needed and probably only for the scrip lifetime (or even UDF call?) and drop then. People should not care about cache flag when they use pre-fetch (even if internally we have to use some caching mechanism).

@dmpetrov
Copy link
Member Author

dmpetrov commented Dec 2, 2024

Right. It's ok for prefetch to store data locally (in the cache is ok) but data has to be removed right after the data was used (only for udf-call).

@skshetry
Copy link
Member

skshetry commented Dec 4, 2024

I tested on this script that has 1800 rows with following DataLoader.

train_loader = DataLoader(
    ds.to_pytorch(transform=transform),
    batch_size=36,
    num_workers=4,
)
prefetch warm cache wall time
disabled (=0) no 6m 24s
36 no 6m 20s
disabled no 5m 57s
disabled yes 41s
36 no 5m 57s
36 yes 43s

As you can see, if the cache is warm, it takes just 41s for the script to run. If the cache is empty, it takes ~6 mins. The 20s difference is not very meaningful - if any, it could just be my unstable hotel Wi-Fi. Maybe prefetching is helping.

The gap here is too large, potentially that could be completed in 41s takes 5 more minutes to download files.
And at least in the above script, DataLoader seems to consume 50 rows * 4 workers in <1s, so we should be able to prefetch at least 200 files/sec to make it as fast as when the cache is warm.

Benchmark script
#! /bin/bash

trap "exit" INT

rm_cache () {
  rm -rf .datachain/cache
}

run () {
  echo $@
  gtime -v python ./examples/get_started/torch-loader.py $@
}

rm_cache
run --prefetch 0
run --prefetch 36

rm_cache
run --prefetch 0 --cache
run --prefetch 0 --cache

rm_cache
run --prefetch 36 --cache
run --prefetch 36 --cache

@dmpetrov
Copy link
Member Author

dmpetrov commented Dec 4, 2024

So, based on this data pre-fetch does not give any perf improvements.

prefetch=36 and batch_size=36 suppose to prefetch 36*36 (1K+) items. Is this what happened?

Any ideas how make a cleaner case? Like singlethreaded, no batch, prefetch=4?

@skshetry
Copy link
Member

skshetry commented Dec 5, 2024

prefetch=36 and batch_size=36 suppose to prefetch 36*36 (1K+) items. Is this what happened?

prefetch and batch_size are independent: the former is for our async downloader, and the latter is for Pytorch, which determines how many samples to load on a single batch.

With both set to the same value, what we have prefetched will be loaded as a single batch to PyTorch.
Note that both of these settings are per worker. So, we are prefetching 36 items in each worker, which is 36 * 4 workers = 144 items.

@skshetry
Copy link
Member

skshetry commented Dec 5, 2024

Apologies for the misleading benchmarks. Turns out, prefetching was not working in the to_pytorch case.

Now with #664, the above script finishes in <1m15s (so, ~30s overhead), down from almost 6mins.

@shcheklein
Copy link
Member

@skshetry could you please share the table with the results?

@dmpetrov
Copy link
Member Author

dmpetrov commented Dec 5, 2024

great news!

1m15s instead of 6m20s - that what's neede!

@skshetry
Copy link
Member

skshetry commented Dec 6, 2024

could you please share the table with the results?

prefetch warm cache wall time
disabled (=0) no 6m 24s
disabled yes 41s
36 no 1m 15s
36 yes 43s

@shcheklein
Copy link
Member

can we consider this to be done?

@skshetry
Copy link
Member

I tested on CIFAR10 dataset that has about 60,000 images with datachain and streaming from mosaicml.

This was on a no-op code like follows:

for _ in data_loader:
    pass
Tool Config Time
DataChain prefetch=25, num_workers=6 1m 50s
DataChain warm cache, num_workers=6 15s
streaming remote files 53s
streaming local shard files 36s

FYI, mosaicml-streaming requires converting the dataset to their own format before using for training. As an example, the Cifar10 dataset gets sharded into 4 files totaling ~178M, which looks something like follows:

   - local
   - ├── test
 417 │   ├── index.json
 31M │   └── shard.00000.mds
   - └── train
2.0k     ├── index.json
 34M     ├── shard.00000.mds
 34M     ├── shard.00001.mds
 34M     ├── shard.00002.mds
 34M     ├── shard.00003.mds
 21M     └── shard.00004.mds

You can find an example guide here in this repository.

This shard is pushed to remote and streaming caches them locally as needed.

Example DataChain Code

import multiprocessing

from torch.utils.data import DataLoader
from tqdm import tqdm

from datachain import DataChain

source = "gs://datachain-cifar10/"
name = "cifar10"


if __name__ == "__main__":
    try:
        ds = DataChain.from_dataset(name)
    except:  # noqa: E722
        ds = DataChain.from_storage(source).save(name)
        print("created dataset", name)
    else:
        print("using existing dataset")

    ds = ds.settings(prefetch=50)
    train_loader = DataLoader(
        ds.to_pytorch(),
        batch_size=25,
        num_workers=6,
        persistent_workers=True,
        multiprocessing_context=multiprocessing.get_context("spawn"),
    )

    with tqdm(
        train_loader, disable=True, desc="Loading dataset", leave=True, position=100
    ) as loader:
        for _ in loader:
            pass

Example Mosaicml code

import os
from typing import Any, Callable

import torch
from streaming import StreamingDataset
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

# the location of our dataset
in_root = "./dataset"

# the location of the "remote" streaming dataset (`sds`).
# Upload `out_root` to your cloud storage provider of choice.
out_root = "./sds"
out_train = "./sds/train"
out_test = "./sds/test"

# the location to download the streaming dataset during training
local = "./local"
local_train = "./local/train"
local_test = "./local/test"

# toggle shuffling in dataloader
shuffle_train = True
shuffle_test = False

# shard size limit, in bytes
size_limit = 1 << 25

# training batch size
batch_size = 32

# training hardware parameters
device = "cuda" if torch.cuda.is_available() else "cpu"

# number of training epochs
train_epochs = 2  # increase the number of epochs for greater accuracy

# Hashing algorithm to use for dataset
hashes = ["sha1", "xxh64"]

# upload location for the dataset splits (change this if you want to upload to a different location, for example, AWS S3 bucket location)
upload_location = "gs://datachain-imagenet/sds"
upload_train_location = os.path.join(upload_location, "train")
upload_test_location = os.path.join(upload_location, "test")
remote_train = upload_train_location
remote_test = upload_test_location


class CIFAR10Dataset(StreamingDataset):
    def __init__(
        self,
        remote: str,
        local: str,
        shuffle: bool,
        batch_size: int,
        transforms: Callable,
    ) -> None:
        super().__init__(
            local=local, remote=remote, shuffle=shuffle, batch_size=batch_size
        )
        self.transforms = transforms

    def __getitem__(self, idx: int) -> Any:
        obj = super().__getitem__(idx)
        x = obj["x"]
        y = obj["y"]
        return self.transforms(x), y


transformation = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_dataset = CIFAR10Dataset(
    remote_train,
    local_train,
    shuffle_train,
    batch_size=batch_size,
    transforms=transformation,
)
test_dataset = CIFAR10Dataset(
    remote_test,
    local_test,
    shuffle_test,
    batch_size=batch_size,
    transforms=transformation,
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

for _ in tqdm(train_dataloader, desc="train"):
    pass

for _ in tqdm(test_dataloader, desc="test"):
    pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants