From a4632ea0f0aded7f00a366d494ec7f91c4535ed7 Mon Sep 17 00:00:00 2001 From: Anant Gupta Date: Mon, 14 Nov 2016 12:12:32 +0530 Subject: [PATCH] pending changes --- autoencoder.py | 20 +++++++++++--------- main.py | 20 +++++++++----------- utils.py | 4 +++- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/autoencoder.py b/autoencoder.py index d909fd2..ff02194 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -60,7 +60,6 @@ def trainPixelCNNAE(conf, data): encoder = AE(encoder_X) conf.num_classes = int(encoder.fan_in.get_shape()[1]) - # TODO keep X out for main.py also decoder = PixelCNN(decoder_X, conf, encoder.fan_in) y = decoder.pred @@ -74,16 +73,16 @@ def trainPixelCNNAE(conf, data): saver.restore(sess, conf.ckpt_file) print "Model Restored" + pointer = 0 for i in range(conf.epochs): for j in range(conf.num_batches): if conf.data == 'mnist': batch_X = binarize(data.train.next_batch(conf.batch_size)[0].reshape(conf.batch_size, conf.img_height, conf.img_width, conf.channel)) else: - # TODO move batch - batch_X = data[0][0][:conf.batch_size] - #condition = sess.run(encoder.fan_in, feed_dict={encoder_X:batch_X}) - _, l = sess.run([trainer, loss], feed_dict={encoder_X: batch_X, decoder_X:batch_X}) - print l + batch_X, pointer = get_batch(data, pointer, conf.batch_size) + + _, l = sess.run([trainer, loss], feed_dict={encoder_X: batch_X, decoder_X: batch_X}) + print "Epoch: %d, Cost: %f"%(i, l) saver.save(sess, conf.ckpt_file) generate_ae(sess, encoder_X, decoder_X, y, data, conf) @@ -116,9 +115,12 @@ class Conf(object): from keras.datasets import cifar10 data = cifar10.load_data() # TODO normalize pixel values - data[0][0] = np.transpose(data[0][0], (0, 2, 3, 1)) - data[1][0] = np.transpose(data[1][0], (0, 2, 3, 1)) - train_size = data[0][0].shape[0] + data = data[0][0] + data = np.transpose(data, (0, 2, 3, 1)) + conf.img_height = 32 + conf.img_width = 32 + conf.channel = 3 + train_size = data.shape[0] conf.num_batches = train_size // conf.batch_size conf = makepaths(conf) diff --git a/main.py b/main.py index 47b9f12..e95b4ce 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,8 @@ from utils import * def train(conf, data): - model = PixelCNN(conf) + X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) + model = PixelCNN(conf, X) trainer = tf.train.RMSPropOptimizer(1e-3) gradients = trainer.compute_gradients(model.loss) @@ -32,12 +33,12 @@ def train(conf, data): else: batch_X, pointer = get_batch(data, pointer, conf.batch_size) - _, cost = sess.run([optimizer, model.loss], feed_dict={model.X:batch_X, model.h:batch_y}) + _, cost = sess.run([optimizer, model.loss], feed_dict={X:batch_X, model.h:batch_y}) print "Epoch: %d, Cost: %f"%(i, cost) saver.save(sess, conf.ckpt_file) - generate_samples(sess, model.X, model.h, model.pred, conf) + generate_samples(sess, X, model.h, model.pred, conf) if __name__ == "__main__": @@ -66,20 +67,17 @@ def train(conf, data): conf.channel = 1 conf.num_batches = data.train.num_examples // conf.batch_size else: - import cPickle - data = cPickle.load(open('cifar-100-python/train', 'r'))['data'] + from keras.datasets import cifar10 + data = cifar10.load_data() + data = data[0][0] + data = np.transpose(data, (0, 2, 3, 1)) conf.img_height = 32 conf.img_width = 32 conf.channel = 3 - data = np.reshape(data, (data.shape[0], conf.channel, \ - conf.img_height, conf.img_width)) - data = np.transpose(data, (0, 2, 3, 1)) raise ValueError("Specify num_classes") conf.num_classes = 10 - conf.num_batches = 10#data.shape[0] // conf.batch_size + conf.num_batches = data.shape[0] // conf.batch_size - # Implementing tf.image.per_image_whitening for normalization - # data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0 conf = makepaths(conf) train(conf, data) diff --git a/utils.py b/utils.py index a8275c3..71af225 100644 --- a/utils.py +++ b/utils.py @@ -24,10 +24,12 @@ def generate_samples(sess, X, h, pred, conf): def generate_ae(sess, encoder_X, decoder_X, y, data, conf): - n_row, n_col = 3, 3 + n_row, n_col = 5, 5 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) if conf.data == 'mnist': labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel)) + else: + labels = get_batch(data, 0, n_row*n_col) for i in xrange(conf.img_height): for j in xrange(conf.img_width):