From c1329830854f296f4095dd8d43564306f5576b85 Mon Sep 17 00:00:00 2001 From: Anant Gupta Date: Wed, 2 Nov 2016 10:10:54 +0530 Subject: [PATCH] file rename, model save and restore fix residual conn --- pixelcnn_tf.py => main.py | 66 +++++++++++++-------------------------- utils.py | 30 ++++++++++++++++++ 2 files changed, 52 insertions(+), 44 deletions(-) rename pixelcnn_tf.py => main.py (61%) create mode 100644 utils.py diff --git a/pixelcnn_tf.py b/main.py similarity index 61% rename from pixelcnn_tf.py rename to main.py index 59ce0a4..c88b917 100644 --- a/pixelcnn_tf.py +++ b/main.py @@ -1,43 +1,36 @@ -# TODO -# kundan: concat payload instead of add -# : arch: without 1X1 in 1st layer and last 2 layers -# : replaces masking with n/2 filter -# check network arch -# autoencoder -# cost on test set -# make for imagenet data: upscale-downscale-q_level, mean pixel value if not mnist -# stats -# logger - import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_data -from datetime import datetime -import scipy.misc -import os from models import PixelCNN +from utils import * mnist = input_data.read_data_sets("data/") -epochs = 1 -batch_size = 50 +epochs = 50 +batch_size = 100 grad_clip = 1 +num_batches = mnist.train.num_examples // batch_size +ckpt_dir = "ckpts" +samples_dir = "samples" +if not os.path.isdir(ckpt_dir): + os.makedirs(ckpt_dir) +ckpt_file = os.path.join(ckpt_dir, "model.ckpt") img_height = 28 img_width = 28 channel = 1 -LAYERS = 3 +LAYERS = 12 F_MAP = 32 FILTER_SIZE = 7 X = tf.placeholder(tf.float32, shape=[None, img_height, img_width, channel]) v_stack_in, h_stack_in = X, X -# TODO encapsulate for i in range(LAYERS): FILTER_SIZE = 3 if i > 0 else FILTER_SIZE in_dim = F_MAP if i > 0 else channel mask = 'b' if i > 0 else 'a' + residual = True if i > 0 else False i = str(i) with tf.variable_scope("v_stack"+i): v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, F_MAP], v_stack_in, mask=mask).output() @@ -51,7 +44,8 @@ with tf.variable_scope("h_stack_1"+i): h_stack_1 = PixelCNN([1, 1, F_MAP], h_stack, gated=False, mask=mask).output() - #h_stack_1 += h_stack_in # Residual connection + if residual: + h_stack_1 += h_stack_in # Residual connection h_stack_in = h_stack_1 with tf.variable_scope("fc_1"): @@ -70,29 +64,14 @@ clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients] optimizer = trainer.apply_gradients(clipped_gradients) - -def binarize(images): - return (0.0 < images).astype(np.float32) - -def generate_and_save(sess): - n_row, n_col = 5, 5 - samples = np.zeros((n_row*n_col, img_height, img_width, 1), dtype=np.float32) - for i in xrange(img_height): - for j in xrange(img_width): - for k in xrange(1): - next_sample = binarize(sess.run(pred, {X:samples})) - samples[:, i, j, k] = next_sample[:, i, j, k] - images = samples - images = images.reshape((n_row, n_col, img_height, img_width)) - images = images.transpose(1, 2, 0, 3) - images = images.reshape((img_height * n_row, img_width * n_col)) - - filename = '%s_%s.jpg' % ("sample", str(datetime.now())) - scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join("samples", filename)) - -num_batches = mnist.train.num_examples // batch_size +saver = tf.train.Saver() with tf.Session() as sess: - sess.run(tf.initialize_all_variables()) + if os.path.exists(ckpt_file): + saver.restore(sess, ckpt_file) + print "Model Restored" + else: + sess.run(tf.initialize_all_variables()) + for i in range(epochs): for j in range(num_batches): batch_X = binarize(mnist.train.next_batch(batch_size)[0] \ @@ -100,7 +79,6 @@ def generate_and_save(sess): _, cost = sess.run([optimizer, loss], feed_dict={X:batch_X}) print "Epoch: %d, Cost: %f"%(i, cost) - generate_and_save(sess) - - generate_and_save(sess) + generate_and_save(sess, X, pred, img_height, img_width, epochs, samples_dir) + saver.save(sess, ckpt_file) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..34299f3 --- /dev/null +++ b/utils.py @@ -0,0 +1,30 @@ +import numpy as np +import os +import scipy.misc +from datetime import datetime + + +def binarize(images): + return (np.random.uniform(size=images.shape) < images).astype(np.float32) + +def generate_and_save(sess, X, pred, img_height, img_width, epoch, samples_dir): + sample_save_dir = samples_dir.rstrip("/")+"_"+str(epoch) + if not os.path.isdir(sample_save_dir): + os.makedirs(sample_save_dir) + + n_row, n_col = 5, 5 + samples = np.zeros((n_row*n_col, img_height, img_width, 1), dtype=np.float32) + for i in xrange(img_height): + for j in xrange(img_width): + for k in xrange(1): + next_sample = binarize(sess.run(pred, {X:samples})) + samples[:, i, j, k] = next_sample[:, i, j, k] + images = samples + images = images.reshape((n_row, n_col, img_height, img_width)) + images = images.transpose(1, 2, 0, 3) + images = images.reshape((img_height * n_row, img_width * n_col)) + + filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+".jpg" + scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(sample_save_dir, filename)) + +