Skip to content

Commit

Permalink
remove unused newline
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyangci committed Feb 7, 2025
1 parent 6d3dcdc commit da93ab3
Showing 1 changed file with 0 additions and 9 deletions.
9 changes: 0 additions & 9 deletions examples/resnet/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from logging import getLogger

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.
from pathlib import Path

import numpy as np
import torchvision.transforms as transforms
from torch import from_numpy
from torch.utils.data import Dataset

from olive.data.registry import Registry

logger = getLogger(__name__)


class ImagenetDataset(Dataset):
def __init__(self, data):
self.images = from_numpy(data["images"])
Expand All @@ -22,12 +19,10 @@ def __len__(self):
def __getitem__(self, idx):
return {"input": self.images[idx]}, self.labels[idx]


@Registry.register_post_process()
def imagenet_post_fun(output):
return output.argmax(axis=1)


preprocess = transforms.Compose(
[
transforms.Resize(256),
Expand All @@ -37,7 +32,6 @@ def imagenet_post_fun(output):
]
)


@Registry.register_pre_process()
def dataset_pre_process(output_data, **kwargs):
cache_key = kwargs.get("cache_key")
Expand All @@ -54,13 +48,10 @@ def dataset_pre_process(output_data, **kwargs):
for i, sample in enumerate(output_data):
if i >= size:
break

image = sample["image"]
label = sample["label"]

image = image.convert("RGB")
image = preprocess(image)

images.append(image)
labels.append(label)

Expand Down

0 comments on commit da93ab3

Please sign in to comment.