diff --git a/autoencoder.py b/autoencoder.py index fa88969..d909fd2 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -1,143 +1,126 @@ - -# coding: utf-8 - -# In[1]: - import tensorflow as tf import numpy as np -from tensorflow.examples.tutorials.mnist import input_data from utils import * - - -# In[2]: - import matplotlib.pyplot as plt -#get_ipython().magic(u'matplotlib inline') - -mnist = input_data.read_data_sets("data/") - -# In[3]: - -img_height = 28 -img_width = 28 -channel = 1 - -num_layers = 3 -filter_size = 3 -fmap_in = channel -fmap_out = 32 -strides = [1, 1, 1, 1] - -batch_size = 50 - from models import PixelCNN -class Conf(object): - pass - -conf = Conf() -conf.ckpt_path='ckpts' -conf.conditional=True -conf.data='mnist' -conf.data_path='data' -conf.epochs=50 -conf.f_map=32 -conf.grad_clip=1 -conf.layers=5 -conf.samples_path='samples' -conf.num_classes = 10 -conf.img_height = 28 -conf.img_width = 28 -conf.channel = 1 -conf.num_batches = mnist.train.num_examples // batch_size -conf.type='train' - - - -X = tf.placeholder(shape=[None, img_height, img_width, channel], dtype=tf.float32) - -fan_in = X -W = [] -for i in range(num_layers): - if i == num_layers -1 : - fmap_out = 10 - W.append(tf.Variable(tf.truncated_normal(shape=[filter_size, filter_size, fmap_in, fmap_out], stddev=0.1), name="W_%d"%i)) - b = tf.Variable(tf.ones(shape=[fmap_out], dtype=tf.float32), name="encoder_b_%d"%i) - en_conv = tf.nn.conv2d(fan_in, W[i], strides, padding='SAME', name="encoder_conv_%d"%i) - - fan_in = tf.tanh(tf.add(en_conv, b)) - fmap_in = fmap_out - -fan_in = tf.reshape(fan_in, (-1, conf.img_width*conf.img_height*fmap_out)) -conf.num_classes = int(fan_in.get_shape()[1]) - -# TODO -# Make X enter from model input -model = PixelCNN(conf) -# output is model.pre -# define loss function here after getting prediction -y = model.pred - -''' -W.reverse() -for i in range(num_layers): - if i == num_layers-1: - fmap_out = channel - c = tf.Variable(tf.ones(shape=[fmap_out], dtype=tf.float32), name="decoder_b_%d"%i) - de_conv = tf.nn.conv2d_transpose(fan_in, W[i], [tf.shape(X)[0], img_height, img_width, fmap_out], strides, padding='SAME', name="decoder_conv_%d"%i) - fan_in = tf.tanh(tf.add(de_conv, c)) -y = fan_in -''' - - -# In[10]: - -loss = tf.reduce_mean(tf.square(X - y)) -trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss) - - -# In[5]: -''' -import cPickle -data = cPickle.load(open('cifar-100-python/test', 'r'))['data'] -data = np.reshape(data, (data.shape[0], 3, 32, 32)) -data = np.transpose(data, (0, 2, 3, 1)) -#data = (data - np.mean(data))/np.std(data) -''' - - -# In[ ]: -epochs = 5 -num_batches = 1#mnist.train.num_examples // batch_size - -with tf.Session() as sess: - sess.run(tf.initialize_all_variables()) - - for i in range(epochs): - for j in range(num_batches): - batch_X = binarize(mnist.train.next_batch(batch_size)[0].reshape(batch_size, img_height, img_width, channel)) - condition = sess.run([fan_in], feed_dict={X:batch_X}) - # TODO shape of condition does not match: (1, 10, 28, 28, 32) for (?, 10) - _, l = sess.run([trainer, loss], feed_dict={X:batch_X, model.X:batch_X, model.h: condition[0]}) - #batch_X = data[:10]/255.0 - #_, l = sess.run([trainer, loss], feed_dict={X:batch_X}) - print l - - n_examples = 10 - #test_X = mnist.train.next_batch(n_examples)[0].reshape(n_examples, img_height, img_width, channel) - - o_test_X = mnist.test.next_batch(10)[0].reshape(10, img_height, img_width, channel) - test_X = binarize(o_test_X) - condition = sess.run(fan_in, feed_dict={X:test_X}) - samples = sess.run(y, feed_dict={X: test_X, model.X:test_X, model.h:condition}) - print samples.shape - #test_X = data[:10] - #samples = sess.run(y, feed_dict={X:test_X/255.0}) - fig, axs = plt.subplots(2, n_examples, figsize=(10,2)) - for i in range(n_examples): - axs[0][i].imshow(np.reshape(o_test_X[i], (img_height, img_width)), cmap='binary') - axs[1][i].imshow(np.reshape(samples[i], (img_height, img_width)), cmap='binary') - fig.show() - plt.draw() - plt.waitforbuttonpress() +class AE(object): + def __init__(self, X): + self.fmap_out = 32 + self.fmap_in = conf.channel + self.fan_in = X + self.num_layers = 3 + self.filter_size = 3 + self.W = [] + self.strides = [1, 1, 1, 1] + + for i in range(self.num_layers): + if i == self.num_layers -1 : + self.fmap_out = 10 + self.W.append(tf.Variable(tf.truncated_normal(shape=[self.filter_size, self.filter_size, self.fmap_in, self.fmap_out], stddev=0.1), name="W_%d"%i)) + b = tf.Variable(tf.ones(shape=[self.fmap_out], dtype=tf.float32), name="encoder_b_%d"%i) + en_conv = tf.nn.conv2d(self.fan_in, self.W[i], self.strides, padding='SAME', name="encoder_conv_%d"%i) + + self.fan_in = tf.tanh(tf.add(en_conv, b)) + self.fmap_in = self.fmap_out + + self.fan_in = tf.reshape(self.fan_in, (-1, conf.img_width*conf.img_height*self.fmap_out)) + + def decoder(self): + self.W.reverse() + for i in range(self.num_layers): + if i == self.num_layers-1: + self.fmap_out = conf.channel + c = tf.Variable(tf.ones(shape=[self.fmap_out], dtype=tf.float32), name="decoder_b_%d"%i) + de_conv = tf.nn.conv2d_transpose(self.fan_in, self.W[i], [tf.shape(X)[0], conf.img_height, conf.img_width, self.fmap_out], self.strides, padding='SAME', name="decoder_conv_%d"%i) + self.fan_in = tf.tanh(tf.add(de_conv, c)) + self.y = self.fan_in + + def generate(self, conf): + n_examples = 10 + if conf.data == 'mnist': + test_X_pure = data.train.next_batch(n_examples)[0].reshape(n_examples, conf.img_height, conf.img_width, conf.channel) + test_X = binarize(test_X_pure) + + condition = sess.run(fan_in, feed_dict={X:test_X}) + samples = sess.run(y, feed_dict={X: test_X, decoder.h:condition}) + fig, axs = plt.subplots(2, n_examples, figsize=(n_examples, 2)) + for i in range(n_examples): + axs[0][i].imshow(np.reshape(o_test_X[i], (conf.img_height, conf.img_width))) + axs[1][i].imshow(np.reshape(samples[i], (conf.img_height, conf.img_width))) + fig.show() + plt.draw() + plt.waitforbuttonpress() + + + +def trainPixelCNNAE(conf, data): + encoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) + decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) + + encoder = AE(encoder_X) + conf.num_classes = int(encoder.fan_in.get_shape()[1]) + # TODO keep X out for main.py also + decoder = PixelCNN(decoder_X, conf, encoder.fan_in) + y = decoder.pred + + loss = tf.reduce_mean(tf.square(encoder_X - y)) + trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss) + + saver = tf.train.Saver(tf.trainable_variables()) + with tf.Session() as sess: + sess.run(tf.initialize_all_variables()) + if os.path.exists(conf.ckpt_file): + saver.restore(sess, conf.ckpt_file) + print "Model Restored" + + for i in range(conf.epochs): + for j in range(conf.num_batches): + if conf.data == 'mnist': + batch_X = binarize(data.train.next_batch(conf.batch_size)[0].reshape(conf.batch_size, conf.img_height, conf.img_width, conf.channel)) + else: + # TODO move batch + batch_X = data[0][0][:conf.batch_size] + #condition = sess.run(encoder.fan_in, feed_dict={encoder_X:batch_X}) + _, l = sess.run([trainer, loss], feed_dict={encoder_X: batch_X, decoder_X:batch_X}) + print l + + saver.save(sess, conf.ckpt_file) + generate_ae(sess, encoder_X, decoder_X, y, data, conf) + + +if __name__ == "__main__": + class Conf(object): + pass + + conf = Conf() + conf.conditional=True + conf.data='mnist' + conf.data_path='data' + conf.f_map=32 + conf.grad_clip=1 + conf.layers=5 + conf.samples_path='samples/ae' + conf.ckpt_path='ckpts/ae' + conf.epochs=10 + conf.batch_size = 100 + + if conf.data == 'mnist': + from tensorflow.examples.tutorials.mnist import input_data + data = input_data.read_data_sets("data/") + conf.img_height = 28 + conf.img_width = 28 + conf.channel = 1 + train_size = data.train.num_examples + else: + from keras.datasets import cifar10 + data = cifar10.load_data() + # TODO normalize pixel values + data[0][0] = np.transpose(data[0][0], (0, 2, 3, 1)) + data[1][0] = np.transpose(data[1][0], (0, 2, 3, 1)) + train_size = data[0][0].shape[0] + + conf.num_batches = train_size // conf.batch_size + conf = makepaths(conf) + trainPixelCNNAE(conf, data) diff --git a/main.py b/main.py index 29d92d0..47b9f12 100644 --- a/main.py +++ b/main.py @@ -37,7 +37,7 @@ def train(conf, data): print "Epoch: %d, Cost: %f"%(i, cost) saver.save(sess, conf.ckpt_file) - generate_and_save(sess, model.X, model.h, model.pred, conf) + generate_samples(sess, model.X, model.h, model.pred, conf) if __name__ == "__main__": @@ -76,19 +76,10 @@ def train(conf, data): data = np.transpose(data, (0, 2, 3, 1)) raise ValueError("Specify num_classes") conf.num_classes = 10 + conf.num_batches = 10#data.shape[0] // conf.batch_size # Implementing tf.image.per_image_whitening for normalization # data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0 - conf.num_batches = 10#data.shape[0] // conf.batch_size - - ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map)) - if not os.path.exists(ckpt_full_path): - os.makedirs(ckpt_full_path) - conf.ckpt_file = os.path.join(ckpt_full_path, "model.ckpt") - - conf.samples_path = os.path.join(conf.samples_path, "epoch=%d_bs=%d_layers=%d_fmap=%d"%(conf.epochs, conf.batch_size, conf.layers, conf.f_map)) - if not os.path.exists(conf.samples_path): - os.makedirs(conf.samples_path) - + conf = makepaths(conf) train(conf, data) diff --git a/models.py b/models.py index e213a61..354ec68 100644 --- a/models.py +++ b/models.py @@ -2,14 +2,17 @@ from layers import GatedCNN class PixelCNN(): - def __init__(self, conf): - - self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) + def __init__(self, X, conf, h=None): + self.X = X self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0) v_stack_in, h_stack_in = self.X_norm, self.X_norm # TODO norm for multichannel: dubtract mean and divide by std feature-wise - self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes]) - if conf.conditional is False: + if conf.conditional is True: + if h is not None: + self.h = h + else: + self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes]) + else: self.h = None for i in range(conf.layers): diff --git a/utils.py b/utils.py index 32b8f24..a8275c3 100644 --- a/utils.py +++ b/utils.py @@ -7,11 +7,11 @@ def binarize(images): return (np.random.uniform(size=images.shape) < images).astype(np.float32) -def generate_and_save(sess, X, h, pred, conf): +def generate_samples(sess, X, h, pred, conf): n_row, n_col = 5, 5 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) # TODO make it generic - labels = one_hot([1,2,3,4,5]*5) + labels = one_hot([1,2,3,4,5]*5, conf.num_classes) for i in xrange(conf.img_height): for j in xrange(conf.img_width): @@ -20,6 +20,26 @@ def generate_and_save(sess, X, h, pred, conf): if conf.data == "mnist": next_sample = binarize(next_sample) samples[:, i, j, k] = next_sample[:, i, j, k] + save_images(samples, n_row, n_col, conf) + + +def generate_ae(sess, encoder_X, decoder_X, y, data, conf): + n_row, n_col = 3, 3 + samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) + if conf.data == 'mnist': + labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel)) + + for i in xrange(conf.img_height): + for j in xrange(conf.img_width): + for k in xrange(conf.channel): + next_sample = sess.run(y, {encoder_X: labels, decoder_X: samples}) + if conf.data == 'mnist': + next_sample = binarize(next_sample) + samples[:, i, j, k] = next_sample[:, i, j, k] + + save_images(samples, n_row, n_col, conf) + +def save_images(samples, n_row, n_col, conf): images = samples images = images.reshape((n_row, n_col, conf.img_height, conf.img_width)) images = images.transpose(1, 2, 0, 3) @@ -42,3 +62,14 @@ def one_hot(batch_y, num_classes): return y_ +def makepaths(conf): + ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map)) + if not os.path.exists(ckpt_full_path): + os.makedirs(ckpt_full_path) + conf.ckpt_file = os.path.join(ckpt_full_path, "model.ckpt") + + conf.samples_path = os.path.join(conf.samples_path, "epoch=%d_bs=%d_layers=%d_fmap=%d"%(conf.epochs, conf.batch_size, conf.layers, conf.f_map)) + if not os.path.exists(conf.samples_path): + os.makedirs(conf.samples_path) + + return conf