diff --git a/README.md b/README.md new file mode 100644 index 0000000..93520de --- /dev/null +++ b/README.md @@ -0,0 +1,65 @@ +# Image Generation with Gated PixelCNN Decoders + +This is a Tensorflow implementation of [Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328) which introduces the Gated PixelCNN model based on PixelCNN architecture originally mentioned in [Pixel Recurrent Neural Networks](https://arxiv.org/abs/1601.06759). The model can be conditioned on latent representation of labels or images to generate images accordingly. Images can also be modelled unconditionally. It can also act as a powerful decoder and can replace deconvolution (transposed convolution) in Autoencoders and GANs. A detailed summary of the paper can be found [here](https://gist.github.com/anantzoid/b2dca657003998027c2861f3121c43b7). + +These are some conditioned samples generated by the authors of the paper: + +![Paper Sample](images/conditioned_samples.png) + +## Architecture + +This is the architecture for Gated PixelCNN used in the model: + +![Gated PCNN](images/gated_cnn.png) + +The gating accounts for remembering the context and model more complex interactions, like in LSTM. The network stack on the left is the Vertical stack that takes care of blind spots that occure while convolution due to the masking layer (Refer the Pixel RNN paper to know more about masking). Use of residual connection significantly improves the model performance. + +## Usage + +This implementation consists of the following models based on the Gated PixelCNN architecture: + +- **Unconditional image generation**: + ``` + python main.py + ``` + + Sample generated by training MNIST dataset after 70 epochs with a cross-entropy loss of 0.104610: + + ![Unconditional image](images/sample.jpg) + +- **Conditional image generation based on class labels**: + ``` + python main.py --model=conditional + ``` + + As mentioned in the paper, conditionally generated images are more visually appealing though the loss difference is almost same. It has a loss of 0.102719 after 40 epochs: + + ![Conditional image](images/conditional.gif) + +- **Autoencoder with PixelCNN decoder**: + ``` + python main.py --model=autoencoder + ``` + + The encoder part of the autoencoder has the original architecture as mentioned in [Stacked Convolutional Auto-Encoders for Hierarchical Feature Extraction](https://pdfs.semanticscholar.org/1c6d/990c80e60aa0b0059415444cdf94b3574f0f.pdf). The representation is encoded into 10d tensor. The image generated after 10 epochs with a loss of 0.115306: + + ![AE image](images/ae_sample.jpg) + +To only generate images append the `--epochs=0` flag after the command. + +To train the any model on CIFAR-10 dataset, add the `--data=cifar` flag. + +Refer `main.py` for other available flags for hyperparameter tuning. + +## Training Details + +The system was trained on a single AWS p2.xlarge spot instance. The implementation was only done on MNIST dataset. Generation of samples based on CIFAR-10 images took the authors 32 GPUs trained for 60 hours. + +To visualize the graph and loss during training, run: +``` +tensorboard --logdir=logs +``` + +Loss minimization for the autoencoder model: + +![Loss](images/loss.png) diff --git a/autoencoder.py b/autoencoder.py index 20e632f..16a37c1 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -1,79 +1,37 @@ import tensorflow as tf import numpy as np from utils import * -import matplotlib.pyplot as plt -from models import PixelCNN -from layers import conv_op +from models import * -def get_weights(shape, name): - return tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1), name=name) - -def get_biases(shape, name): - return tf.Variable(tf.constant(shape=shape, value=0.1, dtype=tf.float32), name=name) - -def max_pool_2x2(x): - return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') - - -class AE(object): - def __init__(self, X): - self.num_layers = 6 - self.fmap_out = [8, 32] - self.fmap_in = conf.channel - self.fan_in = X - self.filter_size = 4 - self.W = [] - self.strides = [1, 1, 1, 1] - - W_conv1 = get_weights([5, 5, conf.channel, 100], "W_conv1") - b_conv1 = get_biases([100], "b_conv1") - conv1 = tf.nn.relu(conv_op(X, W_conv1) + b_conv1) - pool1 = max_pool_2x2(conv1) - - - W_conv2 = get_weights([5, 5, 100, 150], "W_conv2") - b_conv2 = get_biases([150], "b_conv2") - conv2 = tf.nn.relu(conv_op(pool1, W_conv2) + b_conv2) - pool2 = max_pool_2x2(conv2) - - W_conv3 = get_weights([3, 3, 150, 200], "W_conv3") - b_conv3 = get_biases([200], "b_conv3") - conv3 = tf.nn.relu(conv_op(pool2, W_conv3) + b_conv3) - conv3_reshape = tf.reshape(conv3, (-1, 7*7*200)) - - W_fc = get_weights([7*7*200, 10], "W_fc") - b_fc = get_biases([10], "b_fc") - self.pred = tf.nn.softmax(tf.add(tf.matmul(conv3_reshape, W_fc), b_fc)) - -def trainPixelCNNAE(conf, data): +def trainAE(conf, data): encoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) - encoder = AE(encoder_X) + encoder = ConvolutionalEncoder(encoder_X, conf) decoder = PixelCNN(decoder_X, conf, encoder.pred) y = decoder.pred tf.scalar_summary('loss', decoder.loss) - #loss = tf.reduce_mean(tf.square(encoder_X - y)) - #trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss) trainer = tf.train.RMSPropOptimizer(1e-3) gradients = trainer.compute_gradients(decoder.loss) clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients] optimizer = trainer.apply_gradients(clipped_gradients) - saver = tf.train.Saver(tf.trainable_variables()) with tf.Session() as sess: merged = tf.merge_all_summaries() - writer = tf.train.SummaryWriter('/tmp/mnist_ae', sess.graph) + writer = tf.train.SummaryWriter(conf.summary_path, sess.graph) sess.run(tf.initialize_all_variables()) if os.path.exists(conf.ckpt_file): saver.restore(sess, conf.ckpt_file) print "Model Restored" - + + # TODO The training part below and in main.py could be generalized + if conf.epochs > 0: + print "Started Model Training..." pointer = 0 step = 0 for i in range(conf.epochs): @@ -86,71 +44,12 @@ def trainPixelCNNAE(conf, data): _, l, summary = sess.run([optimizer, decoder.loss, merged], feed_dict={encoder_X: batch_X, decoder_X: batch_X}) writer.add_summary(summary, step) step += 1 + print "Epoch: %d, Cost: %f"%(i, l) if (i+1)%10 == 0: saver.save(sess, conf.ckpt_file) generate_ae(sess, encoder_X, decoder_X, y, data, conf, str(i)) writer.close() - ''' - data = input_data.read_data_sets("data/") - n_examples = 10 - if conf.data == 'mnist': - test_X_pure = data.train.next_batch(n_examples)[0].reshape(n_examples, conf.img_height, conf.img_width, conf.channel) - test_X = binarize(test_X_pure) - - samples = sess.run(y, feed_dict={encoder_X:test_X, decoder_X:test_X}) - fig, axs = plt.subplots(2, n_examples, figsize=(n_examples, 2)) - for i in range(n_examples): - axs[0][i].imshow(np.reshape(test_X_pure[i], (conf.img_height, conf.img_width))) - axs[1][i].imshow(np.reshape(samples[i], (conf.img_height, conf.img_width))) - fig.show() - plt.draw() - plt.waitforbuttonpress() - ''' - - -if __name__ == "__main__": - class Conf(object): - pass - - conf = Conf() - conf.conditional=True - conf.data='mnist' - conf.data_path='data' - conf.f_map=32 - conf.grad_clip=1 - conf.layers=5 - conf.samples_path='samples/ae' - conf.ckpt_path='ckpts/ae' - conf.summary_path='/tmp/mnist_ae' - conf.epochs=50 - conf.batch_size = 64 - conf.num_classes = 10 - - if conf.data == 'mnist': - from tensorflow.examples.tutorials.mnist import input_data - data = input_data.read_data_sets("data/") - conf.img_height = 28 - conf.img_width = 28 - conf.channel = 1 - train_size = data.train.num_examples - else: - from keras.datasets import cifar10 - data = cifar10.load_data() - data = data[0][0] - data /= 255.0 - data = np.transpose(data, (0, 2, 3, 1)) - conf.img_height = 32 - conf.img_width = 32 - conf.channel = 3 - train_size = data.shape[0] - - conf.num_batches = train_size // conf.batch_size - conf = makepaths(conf) - if tf.gfile.Exists(conf.summary_path): - tf.gfile.DeleteRecursively(conf.summary_path) - tf.gfile.MakeDirs(conf.summary_path) - - trainPixelCNNAE(conf, data) + generate_ae(sess, encoder_X, decoder_X, y, data, conf, '') diff --git a/images/ae_sample.jpg b/images/ae_sample.jpg new file mode 100644 index 0000000..255837c Binary files /dev/null and b/images/ae_sample.jpg differ diff --git a/images/conditional.gif b/images/conditional.gif new file mode 100644 index 0000000..6ff4143 Binary files /dev/null and b/images/conditional.gif differ diff --git a/images/conditioned_samples.png b/images/conditioned_samples.png new file mode 100644 index 0000000..9dac968 Binary files /dev/null and b/images/conditioned_samples.png differ diff --git a/images/loss.png b/images/loss.png new file mode 100644 index 0000000..ca9d894 Binary files /dev/null and b/images/loss.png differ diff --git a/images/sample.jpg b/images/sample.jpg new file mode 100644 index 0000000..a4c3ebc Binary files /dev/null and b/images/sample.jpg differ diff --git a/layers.py b/layers.py index b88dc48..d90485a 100644 --- a/layers.py +++ b/layers.py @@ -5,6 +5,9 @@ def get_weights(shape, name, mask=None): weights_initializer = tf.contrib.layers.xavier_initializer() W = tf.get_variable(name, shape, tf.float32, weights_initializer) + ''' + Use of masking to hide subsequent pixel values + ''' if mask: filter_mid_x = shape[0]//2 filter_mid_y = shape[1]//2 @@ -24,6 +27,9 @@ def get_bias(shape, name): def conv_op(x, W): return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') +def max_pool_2x2(x): + return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') + class GatedCNN(): def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True, conditional=None): self.fan_in = fan_in @@ -66,7 +72,6 @@ def gated_conv(self): conv_f += self.payload conv_g += self.payload - self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g)) def simple_conv(self): diff --git a/main.py b/main.py index f752ccf..1305231 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,9 @@ import numpy as np import argparse from models import PixelCNN +from autoencoder import * from utils import * -tf.set_random_seed(100) def train(conf, data): X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) model = PixelCNN(X, conf) @@ -23,6 +23,8 @@ def train(conf, data): saver.restore(sess, conf.ckpt_file) print "Model Restored" + if conf.epochs > 0: + print "Started Model Training..." pointer = 0 for i in range(conf.epochs): for j in range(conf.num_batches): @@ -33,31 +35,30 @@ def train(conf, data): batch_y = one_hot(batch_y, conf.num_classes) else: batch_X, pointer = get_batch(data, pointer, conf.batch_size) - #batch_X, batch_y = next(data) data_dict = {X:batch_X} if conf.conditional is True: - #TODO extract one-hot classes data_dict[model.h] = batch_y - _, cost,_f = sess.run([optimizer, model.loss, model.fc2], feed_dict=data_dict) + _, cost = sess.run([optimizer, model.loss], feed_dict=data_dict) print "Epoch: %d, Cost: %f"%(i, cost) if (i+1)%10 == 0: saver.save(sess, conf.ckpt_file) - generate_samples(sess, X, model.h, model.pred, conf, "argmax") + generate_samples(sess, X, model.h, model.pred, conf, "") + generate_samples(sess, X, model.h, model.pred, conf, "") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--type', type=str, default='train') parser.add_argument('--data', type=str, default='mnist') parser.add_argument('--layers', type=int, default=12) parser.add_argument('--f_map', type=int, default=32) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--grad_clip', type=int, default=1) - parser.add_argument('--conditional', type=bool, default=False) + parser.add_argument('--model', type=str, default='') parser.add_argument('--data_path', type=str, default='data') parser.add_argument('--ckpt_path', type=str, default='ckpts') parser.add_argument('--samples_path', type=str, default='samples') + parser.add_argument('--summary_path', type=str, default='logs') conf = parser.parse_args() if conf.data == 'mnist': @@ -69,7 +70,7 @@ def train(conf, data): conf.img_height = 28 conf.img_width = 28 conf.channel = 1 - conf.num_batches = data.train.num_examples // conf.batch_size + conf.num_batches = 10#data.train.num_examples // conf.batch_size else: from keras.datasets import cifar10 data = cifar10.load_data() @@ -81,14 +82,16 @@ def train(conf, data): conf.channel = 3 conf.num_classes = 10 conf.num_batches = data.shape[0] // conf.batch_size - ''' - # TODO debug shape - from keras.preprocessing.image import ImageDataGenerator - datagen = ImageDataGenerator(featurewise_center=True, - featurewise_std_normalization=True) - datagen.fit(data) - data = datagen.flow(data, labels, batch_size=conf.batch_size) - ''' conf = makepaths(conf) - train(conf, data) + if conf.model == '': + conf.conditional = False + train(conf, data) + elif conf.model.lower() == 'conditional': + conf.conditional = True + train(conf, data) + elif conf.model.lower() == 'autoencoder': + conf.conditional = True + trainAE(conf, data) + + diff --git a/models.py b/models.py index ef732bc..e4b4eeb 100644 --- a/models.py +++ b/models.py @@ -1,14 +1,16 @@ import tensorflow as tf -from layers import GatedCNN +from layers import * -class PixelCNN(): +class PixelCNN(object): def __init__(self, X, conf, h=None): self.X = X if conf.data == "mnist": self.X_norm = X else: + ''' + Image normalization for CIFAR-10 was supposed to be done here + ''' self.X_norm = X - #self.X_norm = tf.image.per_image_whitening(X) v_stack_in, h_stack_in = self.X_norm, self.X_norm if conf.conditional is True: @@ -56,6 +58,44 @@ def __init__(self, X, conf, h=None): self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32))) + ''' + Since this code was not run on CIFAR-10, I'm not sure which + would be a suitable way to generate 3-channel images. Below are + the 2 methods which may be used, with the first one (self.pred) + being more likely. + ''' + self.pred = tf.reshape(tf.multinomial(tf.nn.softmax(self.fc2), num_samples=1, seed=100), tf.shape(self.X)) self.pred_argmax = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X)) - self.pred_sample = tf.reshape(tf.multinomial(tf.nn.softmax(self.fc2), num_samples=1, seed=100), tf.shape(self.X)) + + +class ConvolutionalEncoder(object): + def __init__(self, X, conf): + ''' + This is the 6-layer architecture for Convolutional Autoencoder + mentioned in the original paper: + Stacked Convolutional Auto-Encoders for Hierarchical Feature Extraction + + Note that only the encoder part is implemented as PixelCNN is taken + as the decoder. + ''' + + W_conv1 = get_weights([5, 5, conf.channel, 100], "W_conv1") + b_conv1 = get_bias([100], "b_conv1") + conv1 = tf.nn.relu(conv_op(X, W_conv1) + b_conv1) + pool1 = max_pool_2x2(conv1) + + W_conv2 = get_weights([5, 5, 100, 150], "W_conv2") + b_conv2 = get_bias([150], "b_conv2") + conv2 = tf.nn.relu(conv_op(pool1, W_conv2) + b_conv2) + pool2 = max_pool_2x2(conv2) + + W_conv3 = get_weights([3, 3, 150, 200], "W_conv3") + b_conv3 = get_bias([200], "b_conv3") + conv3 = tf.nn.relu(conv_op(pool2, W_conv3) + b_conv3) + conv3_reshape = tf.reshape(conv3, (-1, 7*7*200)) + + W_fc = get_weights([7*7*200, 10], "W_fc") + b_fc = get_bias([10], "b_fc") + self.pred = tf.nn.softmax(tf.add(tf.matmul(conv3_reshape, W_fc), b_fc)) + diff --git a/utils.py b/utils.py index cb051b1..ad95997 100644 --- a/utils.py +++ b/utils.py @@ -2,16 +2,17 @@ import os import scipy.misc from datetime import datetime - +import tensorflow as tf def binarize(images): return (np.random.uniform(size=images.shape) < images).astype(np.float32) def generate_samples(sess, X, h, pred, conf, suff): + print "Generating Sample Images..." n_row, n_col = 10,10 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) # TODO make it generic - labels = one_hot(np.array([1,2,3,4,5]*5), conf.num_classes) + labels = one_hot(np.array([0,1,2,3,4,5,6,7,8,9]*10), conf.num_classes) for i in xrange(conf.img_height): for j in xrange(conf.img_width): @@ -23,13 +24,13 @@ def generate_samples(sess, X, h, pred, conf, suff): if conf.data == "mnist": next_sample = binarize(next_sample) samples[:, i, j, k] = next_sample[:, i, j, k] - np.save('preds_'+suff+'.npy', samples) - print "height:%d"%i + save_images(samples, n_row, n_col, conf, suff) def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''): - n_row, n_col = 5,5 + print "Generating Sample Images..." + n_row, n_col = 10,10 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) if conf.data == 'mnist': labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel)) @@ -44,10 +45,9 @@ def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''): next_sample = binarize(next_sample) samples[:, i, j, k] = next_sample[:, i, j, k] - np.save('preds_'+suff+'.npy', samples) - print "height:%d"%i save_images(samples, n_row, n_col, conf, suff) + def save_images(samples, n_row, n_col, conf, suff): images = samples if conf.data == "mnist": @@ -70,6 +70,7 @@ def get_batch(data, pointer, batch_size): pointer += 1 return [batch, pointer] + def one_hot(batch_y, num_classes): y_ = np.zeros((batch_y.shape[0], num_classes)) y_[np.arange(batch_y.shape[0]), batch_y] = 1 @@ -86,4 +87,8 @@ def makepaths(conf): if not os.path.exists(conf.samples_path): os.makedirs(conf.samples_path) + if tf.gfile.Exists(conf.summary_path): + tf.gfile.DeleteRecursively(conf.summary_path) + tf.gfile.MakeDirs(conf.summary_path) + return conf