Skip to content

Commit

Permalink
changed name of package, added transform parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
facundoq committed Jul 12, 2022
1 parent f4f68e7 commit 83bc94f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
long_description = f.read()


url="https://github.com/facundoq/torchvision-tinyimagenet"
VERSION="0.2"
url="https://github.com/facundoq/tinyimagenet"
VERSION="0.1"

class UploadCommand(Command):
"""Support setup.py upload."""
Expand Down Expand Up @@ -53,7 +53,7 @@ def run(self):


setup(
name="torchvision-tinyimagenet",
name="tinyimagenet",
version=VERSION,
python_requires='>=3.6',
packages=find_packages(),
Expand All @@ -77,7 +77,7 @@ def run(self):
# metadata to display on PyPI
author="Facundo Manuel Quiroga",
author_email="[email protected]",
description="Dataset class for PyTorch and the TinyImageNet dataset.",
description="Dataset class for PyTorch and the TinyImageNet dataset, with automated download and extraction.",
keywords="TinyImageNet ImageNet Dataset PyTorch torch torchvision",
url=url, # project home page, if any
project_urls={
Expand Down
17 changes: 11 additions & 6 deletions tinyimagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class TinyImageNet(ImageFolder):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def __init__(self, root: Path, split: str = "train") -> None:
def __init__(self, root: Path, split: str = "train",transform=None, target_transform=None) -> None:
if isinstance(root,str):
root = Path(root)
assert split in ["train","val","test"]
Expand All @@ -112,7 +112,7 @@ def __init__(self, root: Path, split: str = "train") -> None:
if not images_root.exists():
download_resources(root,mirrors,resources)
preprocess_val(images_root)
super().__init__(images_root/split)
super().__init__(images_root/split,transform=transform,target_transform=target_transform)
self.idx_to_words,self.idx_to_class = self.load_words_classes(images_root)

def load_words_classes(self,root:Path):
Expand All @@ -131,16 +131,21 @@ def load_words_classes(self,root:Path):


if __name__ == '__main__':

from torchvision import transforms

logging.basicConfig(level=logging.INFO)

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(TinyImageNet.mean,TinyImageNet.std)]
)

split ="val"
dataset = TinyImageNet(Path("~/.torchvision/tinyimagenet/"),split=split)
dataset = TinyImageNet(Path("~/.torchvision/tinyimagenet/"),split=split,transform=transform)
n = len(dataset)
print(f"TinyImageNet, split {split}, has {n} samples.")
print("Showing some samples")
for i in range(0,n,n//5):
image,klass = dataset[i]
print(f"Sample of class {klass:3d}, image {image}, words {dataset.idx_to_words[klass]}")
print(f"Sample of class {klass:3d}, image shape {image.shape}, words {dataset.idx_to_words[klass]}")


0 comments on commit 83bc94f

Please sign in to comment.