Skip to content

Commit

Permalink
Merge pull request #7 from kkleidal/noblind
Browse files Browse the repository at this point in the history
Removed blind spot in masked convolutions (see
  • Loading branch information
anantzoid authored Dec 10, 2017
2 parents 9a5c9a3 + dc5e67b commit 055dab6
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 40 deletions.
6 changes: 3 additions & 3 deletions autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def trainAE(conf, data):

if os.path.exists(conf.ckpt_file):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"
print("Model Restored")

# TODO The training part below and in main.py could be generalized
if conf.epochs > 0:
print "Started Model Training..."
print("Started Model Training...")
pointer = 0
step = 0
for i in range(conf.epochs):
Expand All @@ -45,7 +45,7 @@ def trainAE(conf, data):
writer.add_summary(summary, step)
step += 1

print "Epoch: %d, Cost: %f"%(i, l)
print("Epoch: %d, Cost: %f"%(i, l))
if (i+1)%10 == 0:
saver.save(sess, conf.ckpt_file)
generate_ae(sess, encoder_X, decoder_X, y, data, conf, str(i))
Expand Down
65 changes: 47 additions & 18 deletions layers.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
import tensorflow as tf
import numpy as np

def get_weights(shape, name, mask=None):
def get_weights(shape, name, horizontal, mask_mode='noblind', mask=None):
weights_initializer = tf.contrib.layers.xavier_initializer()
W = tf.get_variable(name, shape, tf.float32, weights_initializer)

'''
Use of masking to hide subsequent pixel values
'''
if mask:
filter_mid_x = shape[0]//2
filter_mid_y = shape[1]//2
filter_mid_y = shape[0]//2
filter_mid_x = shape[1]//2
mask_filter = np.ones(shape, dtype=np.float32)
mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0.
mask_filter[filter_mid_x+1:, :, :, :] = 0.
if mask_mode == 'noblind':
if horizontal:
# All rows after center must be zero
mask_filter[filter_mid_y+1:, :, :, :] = 0.0
# All columns after center in center row must be zero
mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0.0
else:
if mask == 'a':
# In the first layer, can ONLY access pixels above it
mask_filter[filter_mid_y:, :, :, :] = 0.0
else:
# In the second layer, can access pixels above or even with it.
# Reason being that the pixels to the right or left of the current pixel
# only have a receptive field of the layer above the current layer and up.
mask_filter[filter_mid_y+1:, :, :, :] = 0.0

if mask == 'a':
# Center must be zero in first layer
mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0
else:
mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0.
mask_filter[filter_mid_y+1:, :, :, :] = 0.

if mask == 'a':
mask_filter[filter_mid_x, filter_mid_y, :, :] = 0.

if mask == 'a':
mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.
W *= mask_filter
return W

Expand All @@ -31,39 +51,48 @@ def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

class GatedCNN():
def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True, conditional=None):
def __init__(self, W_shape, fan_in, horizontal, gated=True, payload=None, mask=None, activation=True, conditional=None, conditional_image=None):
self.fan_in = fan_in
in_dim = self.fan_in.get_shape()[-1]
self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]]
self.b_shape = W_shape[2]

self.in_dim = in_dim
self.payload = payload
self.mask = mask
self.activation = activation
self.conditional = conditional
self.conditional_image = conditional_image
self.horizontal = horizontal

if gated:
self.gated_conv()
else:
self.simple_conv()

def gated_conv(self):
W_f = get_weights(self.W_shape, "v_W", mask=self.mask)
W_g = get_weights(self.W_shape, "h_W", mask=self.mask)
W_f = get_weights(self.W_shape, "v_W", self.horizontal, mask=self.mask)
W_g = get_weights(self.W_shape, "h_W", self.horizontal, mask=self.mask)

b_f_total = get_bias(self.b_shape, "v_b")
b_g_total = get_bias(self.b_shape, "h_b")
if self.conditional is not None:
h_shape = int(self.conditional.get_shape()[1])
V_f = get_weights([h_shape, self.W_shape[3]], "v_V")
V_f = get_weights([h_shape, self.W_shape[3]], "v_V", self.horizontal)
b_f = tf.matmul(self.conditional, V_f)
V_g = get_weights([h_shape, self.W_shape[3]], "h_V")
V_g = get_weights([h_shape, self.W_shape[3]], "h_V", self.horizontal)
b_g = tf.matmul(self.conditional, V_g)

b_f_shape = tf.shape(b_f)
b_f = tf.reshape(b_f, (b_f_shape[0], 1, 1, b_f_shape[1]))
b_g_shape = tf.shape(b_g)
b_g = tf.reshape(b_g, (b_g_shape[0], 1, 1, b_g_shape[1]))
else:
b_f = get_bias(self.b_shape, "v_b")
b_g = get_bias(self.b_shape, "h_b")

b_f_total = b_f_total + b_f
b_g_total = b_g_total + b_g
if self.conditional_image is not None:
b_f_total = b_f_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_f")
b_g_total = b_g_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_g")

conv_f = conv_op(self.fan_in, W_f)
conv_g = conv_op(self.fan_in, W_g)
Expand All @@ -72,10 +101,10 @@ def gated_conv(self):
conv_f += self.payload
conv_g += self.payload

self.fan_out = tf.multiply(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))
self.fan_out = tf.multiply(tf.tanh(conv_f + b_f_total), tf.sigmoid(conv_g + b_g_total))

def simple_conv(self):
W = get_weights(self.W_shape, "W", mask=self.mask)
W = get_weights(self.W_shape, "W", self.horizontal, mask_mode="standard", mask=self.mask)
b = get_bias(self.b_shape, "b")
conv = conv_op(self.fan_in, W)
if self.activation:
Expand Down
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def train(conf, data):
sess.run(tf.initialize_all_variables())
if os.path.exists(conf.ckpt_file):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"
print("Model Restored")

if conf.epochs > 0:
print "Started Model Training..."
print("Started Model Training...")
pointer = 0
for i in range(conf.epochs):
for j in range(conf.num_batches):
Expand All @@ -39,7 +39,7 @@ def train(conf, data):
if conf.conditional is True:
data_dict[model.h] = batch_y
_, cost = sess.run([optimizer, model.loss], feed_dict=data_dict)
print "Epoch: %d, Cost: %f"%(i, cost)
print("Epoch: %d, Cost: %f"%(i, cost))
if (i+1)%10 == 0:
saver.save(sess, conf.ckpt_file)
generate_samples(sess, X, model.h, model.pred, conf, "")
Expand Down
16 changes: 8 additions & 8 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from layers import *

class PixelCNN(object):
def __init__(self, X, conf, h=None):
def __init__(self, X, conf, full_horizontal=True, h=None):
self.X = X
if conf.data == "mnist":
self.X_norm = X
Expand All @@ -27,33 +27,33 @@ def __init__(self, X, conf, h=None):
residual = True if i > 0 else False
i = str(i)
with tf.variable_scope("v_stack"+i):
v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask, conditional=self.h).output()
v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, False, mask=mask, conditional=self.h).output()
v_stack_in = v_stack

with tf.variable_scope("v_stack_1"+i):
v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, gated=False, mask=mask).output()
v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, False, gated=False, mask=mask).output()

with tf.variable_scope("h_stack"+i):
h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask, conditional=self.h).output()
h_stack = GatedCNN([filter_size if full_horizontal else 1, filter_size, conf.f_map], h_stack_in, True, payload=v_stack_1, mask=mask, conditional=self.h).output()

with tf.variable_scope("h_stack_1"+i):
h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, gated=False, mask=mask).output()
h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, True, gated=False, mask=mask).output()
if residual:
h_stack_1 += h_stack_in # Residual connection
h_stack_in = h_stack_1

with tf.variable_scope("fc_1"):
fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, gated=False, mask='b').output()
fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, True, gated=False, mask='b').output()

if conf.data == "mnist":
with tf.variable_scope("fc_2"):
self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output()
self.fc2 = GatedCNN([1, 1, 1], fc1, True, gated=False, mask='b', activation=False).output()
self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fc2, labels=self.X))
self.pred = tf.nn.sigmoid(self.fc2)
else:
color_dim = 256
with tf.variable_scope("fc_2"):
self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, gated=False, mask='b', activation=False).output()
self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, True, gated=False, mask='b', activation=False).output()
self.fc2 = tf.reshape(self.fc2, (-1, color_dim))

self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32)))
Expand Down
16 changes: 8 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype(np.float32)

def generate_samples(sess, X, h, pred, conf, suff):
print "Generating Sample Images..."
print("Generating Sample Images...")
n_row, n_col = 10,10
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
# TODO make it generic
labels = one_hot(np.array([0,1,2,3,4,5,6,7,8,9]*10), conf.num_classes)

for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
for k in xrange(conf.channel):
for i in range(conf.img_height):
for j in range(conf.img_width):
for k in range(conf.channel):
data_dict = {X:samples}
if conf.conditional is True:
data_dict[h] = labels
Expand All @@ -29,17 +29,17 @@ def generate_samples(sess, X, h, pred, conf, suff):


def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''):
print "Generating Sample Images..."
print("Generating Sample Images...")
n_row, n_col = 10,10
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
if conf.data == 'mnist':
labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel))
else:
labels = get_batch(data, 0, n_row*n_col)

for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
for k in xrange(conf.channel):
for i in range(conf.img_height):
for j in range(conf.img_width):
for k in range(conf.channel):
next_sample = sess.run(y, {encoder_X: labels, decoder_X: samples})
if conf.data == 'mnist':
next_sample = binarize(next_sample)
Expand Down

0 comments on commit 055dab6

Please sign in to comment.