-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataset.py
executable file
·77 lines (60 loc) · 2.33 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import cv2
import numpy as np
import tensorflow as tf
class Dataset(object):
"""Represent input dataset for Derain model."""
def __init__(self, split_name, dataset_dir, batch_size,
shuffle=False, repeat=False):
"""Initializes the dataset.
Args:
split_name: A train/val split name.
dataset_dir: The directory of the dataset sources.
batch_size: Batch size.
shuffle: Boolean, if should shuffle the input data.
repeat: Boolean, if should repeat the input data.
"""
self.split_name = split_name
self.dataset_dir = dataset_dir
self.batch_size = batch_size
self.shuffle = shuffle
self.repeat = repeat
def get_data_iterator(self):
"""Get an iterator that iterates across the dataset once.
Returns:
An iterator of type tf.data.Iterator
"""
dataset = tf.data.TFRecordDataset(self.dataset_dir).map(
self._parse_function)
if self.shuffle:
dataset = dataset.shuffle(buffer_size=100)
if self.repeat:
dataset = dataset.repeat()
else:
dataset = dataset.repeat(1)
dataset = dataset.batch(self.batch_size)
iterator = dataset.make_initializable_iterator()
return iterator
def _parse_function(self, proto):
"""Function to parse the proto.
Args:
proto: Proto in the format of tf.Example.
Returns:
Sample with paired O and B.
"""
features = {
'O': tf.FixedLenFeature((), tf.string, default_value=''),
'B': tf.FixedLenFeature((), tf.string, default_value=''),
'height': tf.FixedLenFeature((), tf.int64, default_value=0),
'width': tf.FixedLenFeature((), tf.int64, default_value=0),
}
parsed_features = tf.parse_single_example(proto, features)
shape = tf.stack((
parsed_features['height'], parsed_features['width'], 3))
O = tf.decode_raw(parsed_features['O'], tf.uint8)
B = tf.decode_raw(parsed_features['B'], tf.uint8)
sample = [
tf.reshape(tf.cast(O, tf.float32), shape) / 255.0,
tf.reshape(tf.cast(B, tf.float32), shape) / 255.0,
]
return sample