-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathdataloader.py
59 lines (45 loc) · 1.77 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import os
from skimage import io, transform
class Brain_data(Dataset):
def __init__(self, path):
self.path = path
self.patients = [
file for file in os.listdir(path) if file not in ["data.csv", "README.md"]
]
self.masks, self.images = [], []
for patient in self.patients:
for file in os.listdir(os.path.join(self.path, patient)):
if "mask" in file.split(".")[0].split("_"):
self.masks.append(os.path.join(self.path, patient, file))
else:
self.images.append(os.path.join(self.path, patient, file))
self.images = sorted(self.images)
self.masks = sorted(self.masks)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
mask = self.masks[idx]
image = io.imread(image)
image = transform.resize(image, (256, 256))
image = image / 255
image = image.transpose((2, 0, 1))
mask = io.imread(mask)
mask = transform.resize(mask, (256, 256))
mask = mask / 255
mask = np.expand_dims(mask, axis=-1).transpose((2, 0, 1))
image = torch.from_numpy(image)
mask = torch.from_numpy(mask)
return (image, mask)
def get_loader(batch_size, num_workers):
data_folder = "/content/kaggle_3m"
data = Brain_data(data_folder)
train_set, val_set = random_split(data, [3600, 329])
train_dl = DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
val_dl = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers)
return train_dl, val_dl