diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0d20b64 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pyc diff --git a/models.py b/models.py index 192a43b..4ff615e 100644 --- a/models.py +++ b/models.py @@ -2,8 +2,8 @@ import numpy as np def get_weights(shape, mask=None): - # TODO set init bounds - W = tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1)) + weights_initializer = tf.contrib.layers.xavier_initializer() + W = tf.get_variable("weights", shape, tf.float32, weights_initializer) if mask: filter_mid_x = shape[0]//2 @@ -20,7 +20,7 @@ def get_weights(shape, mask=None): def get_bias(shape): - return tf.Variable(tf.constant(shape=shape, value=0.1, dtype=tf.float32)) + return tf.get_variable("biases", shape, tf.float32, tf.zeros_initializer) def conv_op(x, W): return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') @@ -59,9 +59,9 @@ def simple_conv(self): b = get_bias(self.b_shape) conv = conv_op(self.fan_in, W) if self.activation: - self.fan_out = tf.nn.relu(conv + b) + self.fan_out = tf.nn.relu(tf.add(conv, b)) else: - self.fan_out = conv + b + self.fan_out = tf.add(conv, b) def output(self): diff --git a/pixelcnn_tf.py b/pixelcnn_tf.py index 2685545..5feb336 100644 --- a/pixelcnn_tf.py +++ b/pixelcnn_tf.py @@ -1,57 +1,66 @@ -# TODO sanity check pred, -# generate and plot, +# TODO +# try changing 0.0 to np.random +# make for imagenet data +# check network arch # upscale-downscale-q_level -# validation +# cost on test set +# autoencoder import tensorflow as tf import numpy as np -from models import PixelCNN from tensorflow.examples.tutorials.mnist import input_data +from datetime import datetime +import scipy.misc +import os +from models import PixelCNN mnist = input_data.read_data_sets("data/") -epochs = 10 +epochs = 50 batch_size = 50 grad_clip = 1 +img_height = 28 +img_width = 28 +channel = 1 + LAYERS = 3 F_MAP = 32 FILTER_SIZE = 7 -CHANNEL = 1 -X = tf.placeholder(tf.float32, shape=[None, 784]) -X_image = tf.reshape(X, [-1, 28, 28, CHANNEL]) +X = tf.placeholder(tf.float32, shape=[None, img_height, img_width, channel]) # TODO mean pixel value if not mnist -v_stack_in, h_stack_in = X_image, X_image +v_stack_in, h_stack_in = X, X for i in range(LAYERS): FILTER_SIZE = 3 if i > 0 else FILTER_SIZE - CHANNEL = F_MAP if i > 0 else CHANNEL + in_dim = F_MAP if i > 0 else channel mask = 'b' if i > 0 else 'a' i = str(i) - with tf.name_scope("v_stack"+i): - v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, CHANNEL, F_MAP], [F_MAP], v_stack_in, mask=mask).output() + with tf.variable_scope("v_stack"+i): + v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, in_dim, F_MAP], [F_MAP], v_stack_in, mask=mask).output() v_stack_in = v_stack - with tf.name_scope("v_stack_1"+i): + with tf.variable_scope("v_stack_1"+i): v_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], v_stack_in, gated=False, mask=mask).output() - with tf.name_scope("h_stack"+i): - h_stack = PixelCNN([1, FILTER_SIZE, CHANNEL, F_MAP], [F_MAP], h_stack_in, gated=True, payload=v_stack_1, mask=mask).output() + with tf.variable_scope("h_stack"+i): + h_stack = PixelCNN([1, FILTER_SIZE, in_dim, F_MAP], [F_MAP], h_stack_in, gated=True, payload=v_stack_1, mask=mask).output() - with tf.name_scope("h_stack_1"+i): + with tf.variable_scope("h_stack_1"+i): h_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], h_stack, gated=False, mask=mask).output() h_stack_1 += h_stack_in h_stack_in = h_stack_1 -with tf.name_scope("fc_1"): +with tf.variable_scope("fc_1"): fc1 = PixelCNN([1, 1, F_MAP, F_MAP],[F_MAP], h_stack_in, gated=False, mask='b').output() + # handle Imagenet differently -with tf.name_scope("fc_2"): +with tf.variable_scope("fc_2"): fc2 = PixelCNN([1, 1, F_MAP, 1],[1], fc1, gated=False, mask='b', activation=False).output() pred = tf.nn.sigmoid(fc2) -loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X_image)) +loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X)) trainer = tf.train.RMSPropOptimizer(1e-3) gradients = trainer.compute_gradients(loss) @@ -59,18 +68,37 @@ clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients] optimizer = trainer.apply_gradients(clipped_gradients) -#correct_preds = tf.equal(tf.argmax(X,1), tf.argmax(pred, 1)) -#accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32)) -#summary = tf.train.SummaryWriter('logs', sess.graph) +def binarize(images): + return (0.0 < images).astype(np.float32) + +def generate_and_save(sess): + n_row, n_col = 5, 5 + samples = np.zeros((n_row*n_col, img_height, img_width, 1), dtype=np.float32) + for i in xrange(img_height): + for j in xrange(img_width): + for k in xrange(1): + next_sample = binarize(sess.run(pred, {X:samples})) + samples[:, i, j, k] = next_sample[:, i, j, k] + images = samples + images = images.reshape((n_row, n_col, img_height, img_width)) + images = images.transpose(1, 2, 0, 3) + images = images.reshape((img_height * n_row, img_width * n_col)) + filename = '%s_%s.jpg' % ("sample", str(datetime.now())) + scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join("samples", filename)) + +num_batches = mnist.train.num_examples // batch_size with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for i in range(epochs): - batch_X = mnist.train.next_batch(batch_size)[0] - _, cost = sess.run([optimizer, loss], feed_dict={X:batch_X}) - print cost - - #if i%1 == 0: - #print accuracy.eval(feed_dict={X:batch_X}) - #print accuracy.eval(feed_dict={X:mnist.test.images}) + for j in range(num_batches): + batch_X = binarize(mnist.train.next_batch(batch_size)[0] \ + .reshape([batch_size, img_height, img_width, 1])) + _, cost = sess.run([optimizer, loss], feed_dict={X:batch_X}) + + print "Epoch: %d, Cost: %.2f"%(i, cost) + generate_and_save(sess) + +generate_and_save(sess) +