diff --git a/autoencoder.py b/autoencoder.py index 16a37c1..1f3b21a 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -27,11 +27,11 @@ def trainAE(conf, data): if os.path.exists(conf.ckpt_file): saver.restore(sess, conf.ckpt_file) - print "Model Restored" + print("Model Restored") # TODO The training part below and in main.py could be generalized if conf.epochs > 0: - print "Started Model Training..." + print("Started Model Training...") pointer = 0 step = 0 for i in range(conf.epochs): @@ -45,7 +45,7 @@ def trainAE(conf, data): writer.add_summary(summary, step) step += 1 - print "Epoch: %d, Cost: %f"%(i, l) + print("Epoch: %d, Cost: %f"%(i, l)) if (i+1)%10 == 0: saver.save(sess, conf.ckpt_file) generate_ae(sess, encoder_X, decoder_X, y, data, conf, str(i)) diff --git a/layers.py b/layers.py index 209f45b..7c270d3 100644 --- a/layers.py +++ b/layers.py @@ -1,7 +1,7 @@ import tensorflow as tf import numpy as np -def get_weights(shape, name, mask=None): +def get_weights(shape, name, horizontal, mask_mode='noblind', mask=None): weights_initializer = tf.contrib.layers.xavier_initializer() W = tf.get_variable(name, shape, tf.float32, weights_initializer) @@ -9,15 +9,35 @@ def get_weights(shape, name, mask=None): Use of masking to hide subsequent pixel values ''' if mask: - filter_mid_x = shape[0]//2 - filter_mid_y = shape[1]//2 + filter_mid_y = shape[0]//2 + filter_mid_x = shape[1]//2 mask_filter = np.ones(shape, dtype=np.float32) - mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0. - mask_filter[filter_mid_x+1:, :, :, :] = 0. + if mask_mode == 'noblind': + if horizontal: + # All rows after center must be zero + mask_filter[filter_mid_y+1:, :, :, :] = 0.0 + # All columns after center in center row must be zero + mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0.0 + else: + if mask == 'a': + # In the first layer, can ONLY access pixels above it + mask_filter[filter_mid_y:, :, :, :] = 0.0 + else: + # In the second layer, can access pixels above or even with it. + # Reason being that the pixels to the right or left of the current pixel + # only have a receptive field of the layer above the current layer and up. + mask_filter[filter_mid_y+1:, :, :, :] = 0.0 + + if mask == 'a': + # Center must be zero in first layer + mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0 + else: + mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0. + mask_filter[filter_mid_y+1:, :, :, :] = 0. - if mask == 'a': - mask_filter[filter_mid_x, filter_mid_y, :, :] = 0. - + if mask == 'a': + mask_filter[filter_mid_y, filter_mid_x, :, :] = 0. + W *= mask_filter return W @@ -31,16 +51,19 @@ def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') class GatedCNN(): - def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True, conditional=None): + def __init__(self, W_shape, fan_in, horizontal, gated=True, payload=None, mask=None, activation=True, conditional=None, conditional_image=None): self.fan_in = fan_in in_dim = self.fan_in.get_shape()[-1] self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]] self.b_shape = W_shape[2] + self.in_dim = in_dim self.payload = payload self.mask = mask self.activation = activation self.conditional = conditional + self.conditional_image = conditional_image + self.horizontal = horizontal if gated: self.gated_conv() @@ -48,22 +71,28 @@ def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activat self.simple_conv() def gated_conv(self): - W_f = get_weights(self.W_shape, "v_W", mask=self.mask) - W_g = get_weights(self.W_shape, "h_W", mask=self.mask) + W_f = get_weights(self.W_shape, "v_W", self.horizontal, mask=self.mask) + W_g = get_weights(self.W_shape, "h_W", self.horizontal, mask=self.mask) + + b_f_total = get_bias(self.b_shape, "v_b") + b_g_total = get_bias(self.b_shape, "h_b") if self.conditional is not None: h_shape = int(self.conditional.get_shape()[1]) - V_f = get_weights([h_shape, self.W_shape[3]], "v_V") + V_f = get_weights([h_shape, self.W_shape[3]], "v_V", self.horizontal) b_f = tf.matmul(self.conditional, V_f) - V_g = get_weights([h_shape, self.W_shape[3]], "h_V") + V_g = get_weights([h_shape, self.W_shape[3]], "h_V", self.horizontal) b_g = tf.matmul(self.conditional, V_g) b_f_shape = tf.shape(b_f) b_f = tf.reshape(b_f, (b_f_shape[0], 1, 1, b_f_shape[1])) b_g_shape = tf.shape(b_g) b_g = tf.reshape(b_g, (b_g_shape[0], 1, 1, b_g_shape[1])) - else: - b_f = get_bias(self.b_shape, "v_b") - b_g = get_bias(self.b_shape, "h_b") + + b_f_total = b_f_total + b_f + b_g_total = b_g_total + b_g + if self.conditional_image is not None: + b_f_total = b_f_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_f") + b_g_total = b_g_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_g") conv_f = conv_op(self.fan_in, W_f) conv_g = conv_op(self.fan_in, W_g) @@ -72,10 +101,10 @@ def gated_conv(self): conv_f += self.payload conv_g += self.payload - self.fan_out = tf.multiply(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g)) + self.fan_out = tf.multiply(tf.tanh(conv_f + b_f_total), tf.sigmoid(conv_g + b_g_total)) def simple_conv(self): - W = get_weights(self.W_shape, "W", mask=self.mask) + W = get_weights(self.W_shape, "W", self.horizontal, mask_mode="standard", mask=self.mask) b = get_bias(self.b_shape, "b") conv = conv_op(self.fan_in, W) if self.activation: diff --git a/main.py b/main.py index 905e60f..116d882 100644 --- a/main.py +++ b/main.py @@ -21,10 +21,10 @@ def train(conf, data): sess.run(tf.initialize_all_variables()) if os.path.exists(conf.ckpt_file): saver.restore(sess, conf.ckpt_file) - print "Model Restored" + print("Model Restored") if conf.epochs > 0: - print "Started Model Training..." + print("Started Model Training...") pointer = 0 for i in range(conf.epochs): for j in range(conf.num_batches): @@ -39,7 +39,7 @@ def train(conf, data): if conf.conditional is True: data_dict[model.h] = batch_y _, cost = sess.run([optimizer, model.loss], feed_dict=data_dict) - print "Epoch: %d, Cost: %f"%(i, cost) + print("Epoch: %d, Cost: %f"%(i, cost)) if (i+1)%10 == 0: saver.save(sess, conf.ckpt_file) generate_samples(sess, X, model.h, model.pred, conf, "") diff --git a/models.py b/models.py index 6b0e36a..6f7df4a 100644 --- a/models.py +++ b/models.py @@ -2,7 +2,7 @@ from layers import * class PixelCNN(object): - def __init__(self, X, conf, h=None): + def __init__(self, X, conf, full_horizontal=True, h=None): self.X = X if conf.data == "mnist": self.X_norm = X @@ -27,33 +27,33 @@ def __init__(self, X, conf, h=None): residual = True if i > 0 else False i = str(i) with tf.variable_scope("v_stack"+i): - v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask, conditional=self.h).output() + v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, False, mask=mask, conditional=self.h).output() v_stack_in = v_stack with tf.variable_scope("v_stack_1"+i): - v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, gated=False, mask=mask).output() + v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, False, gated=False, mask=mask).output() with tf.variable_scope("h_stack"+i): - h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask, conditional=self.h).output() + h_stack = GatedCNN([filter_size if full_horizontal else 1, filter_size, conf.f_map], h_stack_in, True, payload=v_stack_1, mask=mask, conditional=self.h).output() with tf.variable_scope("h_stack_1"+i): - h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, gated=False, mask=mask).output() + h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, True, gated=False, mask=mask).output() if residual: h_stack_1 += h_stack_in # Residual connection h_stack_in = h_stack_1 with tf.variable_scope("fc_1"): - fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, gated=False, mask='b').output() + fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, True, gated=False, mask='b').output() if conf.data == "mnist": with tf.variable_scope("fc_2"): - self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output() + self.fc2 = GatedCNN([1, 1, 1], fc1, True, gated=False, mask='b', activation=False).output() self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fc2, labels=self.X)) self.pred = tf.nn.sigmoid(self.fc2) else: color_dim = 256 with tf.variable_scope("fc_2"): - self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, gated=False, mask='b', activation=False).output() + self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, True, gated=False, mask='b', activation=False).output() self.fc2 = tf.reshape(self.fc2, (-1, color_dim)) self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32))) diff --git a/utils.py b/utils.py index ad95997..d25e839 100644 --- a/utils.py +++ b/utils.py @@ -8,15 +8,15 @@ def binarize(images): return (np.random.uniform(size=images.shape) < images).astype(np.float32) def generate_samples(sess, X, h, pred, conf, suff): - print "Generating Sample Images..." + print("Generating Sample Images...") n_row, n_col = 10,10 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) # TODO make it generic labels = one_hot(np.array([0,1,2,3,4,5,6,7,8,9]*10), conf.num_classes) - for i in xrange(conf.img_height): - for j in xrange(conf.img_width): - for k in xrange(conf.channel): + for i in range(conf.img_height): + for j in range(conf.img_width): + for k in range(conf.channel): data_dict = {X:samples} if conf.conditional is True: data_dict[h] = labels @@ -29,7 +29,7 @@ def generate_samples(sess, X, h, pred, conf, suff): def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''): - print "Generating Sample Images..." + print("Generating Sample Images...") n_row, n_col = 10,10 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) if conf.data == 'mnist': @@ -37,9 +37,9 @@ def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''): else: labels = get_batch(data, 0, n_row*n_col) - for i in xrange(conf.img_height): - for j in xrange(conf.img_width): - for k in xrange(conf.channel): + for i in range(conf.img_height): + for j in range(conf.img_width): + for k in range(conf.channel): next_sample = sess.run(y, {encoder_X: labels, decoder_X: samples}) if conf.data == 'mnist': next_sample = binarize(next_sample)