diff --git a/models.py b/models.py new file mode 100644 index 0000000..192a43b --- /dev/null +++ b/models.py @@ -0,0 +1,70 @@ +import tensorflow as tf +import numpy as np + +def get_weights(shape, mask=None): + # TODO set init bounds + W = tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1)) + + if mask: + filter_mid_x = shape[0]//2 + filter_mid_y = shape[1]//2 + mask_filter = np.ones(shape, dtype=np.float32) + mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0. + mask_filter[filter_mid_x+1:, :, :, :] = 0. + + if mask == 'a': + mask_filter[filter_mid_x, filter_mid_y, :, :] = 0. + + W *= mask_filter + return W + + +def get_bias(shape): + return tf.Variable(tf.constant(shape=shape, value=0.1, dtype=tf.float32)) + +def conv_op(x, W): + return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') + +class PixelCNN(): + def __init__(self, W_shape, b_shape, fan_in, gated=True, payload=None, mask=None, activation=True): + self.W_shape = W_shape + self.b_shape = b_shape + self.fan_in = fan_in + self.payload = payload + self.mask = mask + self.activation = activation + + if gated: + self.gated_conv() + else: + self.simple_conv() + + def gated_conv(self): + W_f = get_weights(self.W_shape, mask=self.mask) + b_f = get_bias(self.b_shape) + W_g = get_weights(self.W_shape, mask=self.mask) + b_g = get_bias(self.b_shape) + + conv_f = conv_op(self.fan_in, W_f) + conv_g = conv_op(self.fan_in, W_g) + + if self.payload is not None: + conv_f += self.payload + conv_g += self.payload + + self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g)) + + def simple_conv(self): + W = get_weights(self.W_shape, mask=self.mask) + b = get_bias(self.b_shape) + conv = conv_op(self.fan_in, W) + if self.activation: + self.fan_out = tf.nn.relu(conv + b) + else: + self.fan_out = conv + b + + + def output(self): + return self.fan_out + + diff --git a/pixelcnn_tf.py b/pixelcnn_tf.py index cee4c4a..2685545 100644 --- a/pixelcnn_tf.py +++ b/pixelcnn_tf.py @@ -1,12 +1,17 @@ +# TODO sanity check pred, +# generate and plot, +# upscale-downscale-q_level +# validation + import tensorflow as tf import numpy as np from models import PixelCNN from tensorflow.examples.tutorials.mnist import input_data -# TODO get mean if pixel value >1 mnist = input_data.read_data_sets("data/") epochs = 10 batch_size = 50 +grad_clip = 1 LAYERS = 3 F_MAP = 32 @@ -15,6 +20,7 @@ X = tf.placeholder(tf.float32, shape=[None, 784]) X_image = tf.reshape(X, [-1, 28, 28, CHANNEL]) +# TODO mean pixel value if not mnist v_stack_in, h_stack_in = X_image, X_image for i in range(LAYERS): @@ -38,24 +44,33 @@ h_stack_1 += h_stack_in h_stack_in = h_stack_1 -with tf.name_scope("f_layer"): - pred = PixelCNN([1, 1, F_MAP, 1],[1], h_stack_in, gated=False, mask='b', activation='sigmoid').output() - pred = tf.reshape(pred, [batch_size, 784]) +with tf.name_scope("fc_1"): + fc1 = PixelCNN([1, 1, F_MAP, F_MAP],[F_MAP], h_stack_in, gated=False, mask='b').output() +# handle Imagenet differently +with tf.name_scope("fc_2"): + fc2 = PixelCNN([1, 1, F_MAP, 1],[1], fc1, gated=False, mask='b', activation=False).output() +pred = tf.nn.sigmoid(fc2) + +loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X_image)) + +trainer = tf.train.RMSPropOptimizer(1e-3) +gradients = trainer.compute_gradients(loss) + +clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients] +optimizer = trainer.apply_gradients(clipped_gradients) -cross_entropy = tf.reduce_mean(-tf.reduce_sum(X * tf.log(pred), reduction_indices=[1])) -#TODO gradient clipping -trainer = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy) -correct_preds = tf.equal(tf.argmax(X,1), tf.argmax(pred, 1)) -accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32)) +#correct_preds = tf.equal(tf.argmax(X,1), tf.argmax(pred, 1)) +#accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32)) #summary = tf.train.SummaryWriter('logs', sess.graph) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for i in range(epochs): - batch_X, batch_y = mnist.train.next_batch(batch_size) - sess.run(trainer, feed_dict={X:batch_X}) + batch_X = mnist.train.next_batch(batch_size)[0] + _, cost = sess.run([optimizer, loss], feed_dict={X:batch_X}) + print cost - if i%1 == 0: - print accuracy.eval(feed_dict={X:batch_X}) - print accuracy.eval(feed_dict={X:mnist.test.images}) + #if i%1 == 0: + #print accuracy.eval(feed_dict={X:batch_X}) + #print accuracy.eval(feed_dict={X:mnist.test.images})