diff --git a/pixelcnn_tf.py b/pixelcnn_tf.py index 4f139d4..cee4c4a 100644 --- a/pixelcnn_tf.py +++ b/pixelcnn_tf.py @@ -1,8 +1,12 @@ -# TODO masking, pred:final layer (X256) - import tensorflow as tf import numpy as np from models import PixelCNN +from tensorflow.examples.tutorials.mnist import input_data + +# TODO get mean if pixel value >1 +mnist = input_data.read_data_sets("data/") +epochs = 10 +batch_size = 50 LAYERS = 3 F_MAP = 32 @@ -20,26 +24,38 @@ i = str(i) with tf.name_scope("v_stack"+i): - v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, CHANNEL, F_MAP], [F_MAP], v_stack_in).output() + v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, CHANNEL, F_MAP], [F_MAP], v_stack_in, mask=mask).output() v_stack_in = v_stack with tf.name_scope("v_stack_1"+i): - v_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], v_stack_in, gated=False).output() + v_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], v_stack_in, gated=False, mask=mask).output() with tf.name_scope("h_stack"+i): - h_stack = PixelCNN([1, FILTER_SIZE, CHANNEL, F_MAP], [F_MAP], h_stack_in, gated=True, payload=v_stack_1).output() + h_stack = PixelCNN([1, FILTER_SIZE, CHANNEL, F_MAP], [F_MAP], h_stack_in, gated=True, payload=v_stack_1, mask=mask).output() with tf.name_scope("h_stack_1"+i): - h_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], h_stack, gated=False).output() + h_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], h_stack, gated=False, mask=mask).output() h_stack_1 += h_stack_in h_stack_in = h_stack_1 -pred = None -softmax = tf.nn.softmax(pred) -cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(softmax), reduction_indices=[1])) +with tf.name_scope("f_layer"): + pred = PixelCNN([1, 1, F_MAP, 1],[1], h_stack_in, gated=False, mask='b', activation='sigmoid').output() + pred = tf.reshape(pred, [batch_size, 784]) + +cross_entropy = tf.reduce_mean(-tf.reduce_sum(X * tf.log(pred), reduction_indices=[1])) #TODO gradient clipping trainer = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy) +correct_preds = tf.equal(tf.argmax(X,1), tf.argmax(pred, 1)) +accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32)) -sess = tf.Session() #summary = tf.train.SummaryWriter('logs', sess.graph) -#Combine and Quantize into 255 + +with tf.Session() as sess: + sess.run(tf.initialize_all_variables()) + for i in range(epochs): + batch_X, batch_y = mnist.train.next_batch(batch_size) + sess.run(trainer, feed_dict={X:batch_X}) + + if i%1 == 0: + print accuracy.eval(feed_dict={X:batch_X}) + print accuracy.eval(feed_dict={X:mnist.test.images})