diff --git a/models.py b/models.py index 9fe331d..e3e8e4a 100644 --- a/models.py +++ b/models.py @@ -25,13 +25,16 @@ def conv_op(x, W): return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') class PixelCNN(): - def __init__(self, W_shape, b_shape, fan_in, gated=True, payload=None, mask=None, activation=True): - self.W_shape = W_shape - self.b_shape = b_shape + def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True): self.fan_in = fan_in + in_dim = self.fan_in.get_shape()[-1] + self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]] + self.b_shape = W_shape[2] + self.payload = payload self.mask = mask self.activation = activation + if gated: self.gated_conv() @@ -62,7 +65,6 @@ def simple_conv(self): else: self.fan_out = tf.add(conv, b) - def output(self): return self.fan_out diff --git a/pixelcnn_tf.py b/pixelcnn_tf.py index b94c5b5..59ce0a4 100644 --- a/pixelcnn_tf.py +++ b/pixelcnn_tf.py @@ -1,10 +1,13 @@ # TODO -# try changing 0.0 to np.random +# kundan: concat payload instead of add +# : arch: without 1X1 in 1st layer and last 2 layers +# : replaces masking with n/2 filter # check network arch -# upscale-downscale-q_level # autoencoder # cost on test set -# make for imagenet data +# make for imagenet data: upscale-downscale-q_level, mean pixel value if not mnist +# stats +# logger import tensorflow as tf import numpy as np @@ -15,7 +18,7 @@ from models import PixelCNN mnist = input_data.read_data_sets("data/") -epochs = 50 +epochs = 1 batch_size = 50 grad_clip = 1 @@ -28,36 +31,35 @@ FILTER_SIZE = 7 X = tf.placeholder(tf.float32, shape=[None, img_height, img_width, channel]) -# TODO mean pixel value if not mnist v_stack_in, h_stack_in = X, X +# TODO encapsulate for i in range(LAYERS): FILTER_SIZE = 3 if i > 0 else FILTER_SIZE in_dim = F_MAP if i > 0 else channel mask = 'b' if i > 0 else 'a' i = str(i) - with tf.variable_scope("v_stack"+i): - v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, in_dim, F_MAP], [F_MAP], v_stack_in, mask='a').output() + v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, F_MAP], v_stack_in, mask=mask).output() v_stack_in = v_stack with tf.variable_scope("v_stack_1"+i): - v_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], v_stack_in, gated=False, mask=mask).output() - + v_stack_1 = PixelCNN([1, 1, F_MAP], v_stack_in, gated=False, mask=mask).output() + with tf.variable_scope("h_stack"+i): - h_stack = PixelCNN([1, FILTER_SIZE, in_dim, F_MAP], [F_MAP], h_stack_in, gated=True, payload=v_stack_1, mask=mask).output() + h_stack = PixelCNN([1, FILTER_SIZE, F_MAP], h_stack_in, payload=v_stack_1, mask=mask).output() with tf.variable_scope("h_stack_1"+i): - h_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], h_stack, gated=False, mask=mask).output() - h_stack_1 += h_stack_in + h_stack_1 = PixelCNN([1, 1, F_MAP], h_stack, gated=False, mask=mask).output() + #h_stack_1 += h_stack_in # Residual connection h_stack_in = h_stack_1 with tf.variable_scope("fc_1"): - fc1 = PixelCNN([1, 1, F_MAP, F_MAP],[F_MAP], h_stack_in, gated=False, mask='b').output() + fc1 = PixelCNN([1, 1, F_MAP], h_stack_in, gated=False, mask='b').output() # handle Imagenet differently with tf.variable_scope("fc_2"): - fc2 = PixelCNN([1, 1, F_MAP, 1],[1], fc1, gated=False, mask='b', activation=False).output() + fc2 = PixelCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output() pred = tf.nn.sigmoid(fc2) loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X)) @@ -97,7 +99,7 @@ def generate_and_save(sess): .reshape([batch_size, img_height, img_width, 1])) _, cost = sess.run([optimizer, loss], feed_dict={X:batch_X}) - print "Epoch: %d, Cost: %f"%(i, cost) + print "Epoch: %d, Cost: %f"%(i, cost) generate_and_save(sess) generate_and_save(sess)