-
Notifications
You must be signed in to change notification settings - Fork 166
/
Copy pathutils.py
56 lines (43 loc) · 1.62 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
from tensorflow.contrib import layers
def encoder(input_tensor, output_size):
'''Create encoder network.
Args:
input_tensor: a batch of flattened images [batch_size, 28*28]
Returns:
A tensor that expresses the encoder network
'''
net = tf.reshape(input_tensor, [-1, 28, 28, 1])
net = layers.conv2d(net, 32, 5, stride=2)
net = layers.conv2d(net, 64, 5, stride=2)
net = layers.conv2d(net, 128, 5, stride=2, padding='VALID')
net = layers.dropout(net, keep_prob=0.9)
net = layers.flatten(net)
return layers.fully_connected(net, output_size, activation_fn=None)
def discriminator(input_tensor):
'''Create a network that discriminates between images from a dataset and
generated ones.
Args:
input: a batch of real images [batch, height, width, channels]
Returns:
A tensor that represents the network
'''
return encoder(input_tensor, 1)
def decoder(input_tensor):
'''Create decoder network.
If input tensor is provided then decodes it, otherwise samples from
a sampled vector.
Args:
input_tensor: a batch of vectors to decode
Returns:
A tensor that expresses the decoder network
'''
net = tf.expand_dims(input_tensor, 1)
net = tf.expand_dims(net, 1)
net = layers.conv2d_transpose(net, 128, 3, padding='VALID')
net = layers.conv2d_transpose(net, 64, 5, padding='VALID')
net = layers.conv2d_transpose(net, 32, 5, stride=2)
net = layers.conv2d_transpose(
net, 1, 5, stride=2, activation_fn=tf.nn.sigmoid)
net = layers.flatten(net)
return net