Skip to content

Commit

Permalink
code cleanup & readme
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent dc68105 commit 2a7e1bf
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 140 deletions.
65 changes: 65 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Image Generation with Gated PixelCNN Decoders

This is a Tensorflow implementation of [Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328) which introduces the Gated PixelCNN model based on PixelCNN architecture originally mentioned in [Pixel Recurrent Neural Networks](https://arxiv.org/abs/1601.06759). The model can be conditioned on latent representation of labels or images to generate images accordingly. Images can also be modelled unconditionally. It can also act as a powerful decoder and can replace deconvolution (transposed convolution) in Autoencoders and GANs. A detailed summary of the paper can be found [here](https://gist.github.com/anantzoid/b2dca657003998027c2861f3121c43b7).

These are some conditioned samples generated by the authors of the paper:

![Paper Sample](images/conditioned_samples.png)

## Architecture

This is the architecture for Gated PixelCNN used in the model:

![Gated PCNN](images/gated_cnn.png)

The gating accounts for remembering the context and model more complex interactions, like in LSTM. The network stack on the left is the Vertical stack that takes care of blind spots that occure while convolution due to the masking layer (Refer the Pixel RNN paper to know more about masking). Use of residual connection significantly improves the model performance.

## Usage

This implementation consists of the following models based on the Gated PixelCNN architecture:

- **Unconditional image generation**:
```
python main.py
```
Sample generated by training MNIST dataset after 70 epochs with a cross-entropy loss of 0.104610:
![Unconditional image](images/sample.jpg)
- **Conditional image generation based on class labels**:
```
python main.py --model=conditional
```
As mentioned in the paper, conditionally generated images are more visually appealing though the loss difference is almost same. It has a loss of 0.102719 after 40 epochs:
![Conditional image](images/conditional.gif)
- **Autoencoder with PixelCNN decoder**:
```
python main.py --model=autoencoder
```
The encoder part of the autoencoder has the original architecture as mentioned in [Stacked Convolutional Auto-Encoders for Hierarchical Feature Extraction](https://pdfs.semanticscholar.org/1c6d/990c80e60aa0b0059415444cdf94b3574f0f.pdf). The representation is encoded into 10d tensor. The image generated after 10 epochs with a loss of 0.115306:
![AE image](images/ae_sample.jpg)
To only generate images append the `--epochs=0` flag after the command.
To train the any model on CIFAR-10 dataset, add the `--data=cifar` flag.
Refer `main.py` for other available flags for hyperparameter tuning.
## Training Details
The system was trained on a single AWS p2.xlarge spot instance. The implementation was only done on MNIST dataset. Generation of samples based on CIFAR-10 images took the authors 32 GPUs trained for 60 hours.
To visualize the graph and loss during training, run:
```
tensorboard --logdir=logs
```
Loss minimization for the autoencoder model:
![Loss](images/loss.png)
121 changes: 10 additions & 111 deletions autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,37 @@
import tensorflow as tf
import numpy as np
from utils import *
import matplotlib.pyplot as plt
from models import PixelCNN
from layers import conv_op
from models import *

def get_weights(shape, name):
return tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1), name=name)

def get_biases(shape, name):
return tf.Variable(tf.constant(shape=shape, value=0.1, dtype=tf.float32), name=name)

def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')


class AE(object):
def __init__(self, X):
self.num_layers = 6
self.fmap_out = [8, 32]
self.fmap_in = conf.channel
self.fan_in = X
self.filter_size = 4
self.W = []
self.strides = [1, 1, 1, 1]

W_conv1 = get_weights([5, 5, conf.channel, 100], "W_conv1")
b_conv1 = get_biases([100], "b_conv1")
conv1 = tf.nn.relu(conv_op(X, W_conv1) + b_conv1)
pool1 = max_pool_2x2(conv1)


W_conv2 = get_weights([5, 5, 100, 150], "W_conv2")
b_conv2 = get_biases([150], "b_conv2")
conv2 = tf.nn.relu(conv_op(pool1, W_conv2) + b_conv2)
pool2 = max_pool_2x2(conv2)

W_conv3 = get_weights([3, 3, 150, 200], "W_conv3")
b_conv3 = get_biases([200], "b_conv3")
conv3 = tf.nn.relu(conv_op(pool2, W_conv3) + b_conv3)
conv3_reshape = tf.reshape(conv3, (-1, 7*7*200))

W_fc = get_weights([7*7*200, 10], "W_fc")
b_fc = get_biases([10], "b_fc")
self.pred = tf.nn.softmax(tf.add(tf.matmul(conv3_reshape, W_fc), b_fc))

def trainPixelCNNAE(conf, data):
def trainAE(conf, data):
encoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])

encoder = AE(encoder_X)
encoder = ConvolutionalEncoder(encoder_X, conf)
decoder = PixelCNN(decoder_X, conf, encoder.pred)
y = decoder.pred
tf.scalar_summary('loss', decoder.loss)

#loss = tf.reduce_mean(tf.square(encoder_X - y))
#trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)
trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(decoder.loss)

clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients]
optimizer = trainer.apply_gradients(clipped_gradients)


saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter('/tmp/mnist_ae', sess.graph)
writer = tf.train.SummaryWriter(conf.summary_path, sess.graph)

sess.run(tf.initialize_all_variables())

if os.path.exists(conf.ckpt_file):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"


# TODO The training part below and in main.py could be generalized
if conf.epochs > 0:
print "Started Model Training..."
pointer = 0
step = 0
for i in range(conf.epochs):
Expand All @@ -86,71 +44,12 @@ def trainPixelCNNAE(conf, data):
_, l, summary = sess.run([optimizer, decoder.loss, merged], feed_dict={encoder_X: batch_X, decoder_X: batch_X})
writer.add_summary(summary, step)
step += 1

print "Epoch: %d, Cost: %f"%(i, l)
if (i+1)%10 == 0:
saver.save(sess, conf.ckpt_file)
generate_ae(sess, encoder_X, decoder_X, y, data, conf, str(i))

writer.close()
'''
data = input_data.read_data_sets("data/")
n_examples = 10
if conf.data == 'mnist':
test_X_pure = data.train.next_batch(n_examples)[0].reshape(n_examples, conf.img_height, conf.img_width, conf.channel)
test_X = binarize(test_X_pure)
samples = sess.run(y, feed_dict={encoder_X:test_X, decoder_X:test_X})
fig, axs = plt.subplots(2, n_examples, figsize=(n_examples, 2))
for i in range(n_examples):
axs[0][i].imshow(np.reshape(test_X_pure[i], (conf.img_height, conf.img_width)))
axs[1][i].imshow(np.reshape(samples[i], (conf.img_height, conf.img_width)))
fig.show()
plt.draw()
plt.waitforbuttonpress()
'''


if __name__ == "__main__":
class Conf(object):
pass

conf = Conf()
conf.conditional=True
conf.data='mnist'
conf.data_path='data'
conf.f_map=32
conf.grad_clip=1
conf.layers=5
conf.samples_path='samples/ae'
conf.ckpt_path='ckpts/ae'
conf.summary_path='/tmp/mnist_ae'
conf.epochs=50
conf.batch_size = 64
conf.num_classes = 10

if conf.data == 'mnist':
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets("data/")
conf.img_height = 28
conf.img_width = 28
conf.channel = 1
train_size = data.train.num_examples
else:
from keras.datasets import cifar10
data = cifar10.load_data()
data = data[0][0]
data /= 255.0
data = np.transpose(data, (0, 2, 3, 1))
conf.img_height = 32
conf.img_width = 32
conf.channel = 3
train_size = data.shape[0]

conf.num_batches = train_size // conf.batch_size
conf = makepaths(conf)
if tf.gfile.Exists(conf.summary_path):
tf.gfile.DeleteRecursively(conf.summary_path)
tf.gfile.MakeDirs(conf.summary_path)

trainPixelCNNAE(conf, data)
generate_ae(sess, encoder_X, decoder_X, y, data, conf, '')

Binary file added images/ae_sample.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/conditional.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/conditioned_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/sample.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ def get_weights(shape, name, mask=None):
weights_initializer = tf.contrib.layers.xavier_initializer()
W = tf.get_variable(name, shape, tf.float32, weights_initializer)

'''
Use of masking to hide subsequent pixel values
'''
if mask:
filter_mid_x = shape[0]//2
filter_mid_y = shape[1]//2
Expand All @@ -24,6 +27,9 @@ def get_bias(shape, name):
def conv_op(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

class GatedCNN():
def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True, conditional=None):
self.fan_in = fan_in
Expand Down Expand Up @@ -66,7 +72,6 @@ def gated_conv(self):
conv_f += self.payload
conv_g += self.payload


self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))

def simple_conv(self):
Expand Down
37 changes: 20 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import numpy as np
import argparse
from models import PixelCNN
from autoencoder import *
from utils import *

tf.set_random_seed(100)
def train(conf, data):
X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
model = PixelCNN(X, conf)
Expand All @@ -23,6 +23,8 @@ def train(conf, data):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"

if conf.epochs > 0:
print "Started Model Training..."
pointer = 0
for i in range(conf.epochs):
for j in range(conf.num_batches):
Expand All @@ -33,31 +35,30 @@ def train(conf, data):
batch_y = one_hot(batch_y, conf.num_classes)
else:
batch_X, pointer = get_batch(data, pointer, conf.batch_size)
#batch_X, batch_y = next(data)
data_dict = {X:batch_X}
if conf.conditional is True:
#TODO extract one-hot classes
data_dict[model.h] = batch_y
_, cost,_f = sess.run([optimizer, model.loss, model.fc2], feed_dict=data_dict)
_, cost = sess.run([optimizer, model.loss], feed_dict=data_dict)
print "Epoch: %d, Cost: %f"%(i, cost)
if (i+1)%10 == 0:
saver.save(sess, conf.ckpt_file)
generate_samples(sess, X, model.h, model.pred, conf, "argmax")
generate_samples(sess, X, model.h, model.pred, conf, "")

generate_samples(sess, X, model.h, model.pred, conf, "")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, default='train')
parser.add_argument('--data', type=str, default='mnist')
parser.add_argument('--layers', type=int, default=12)
parser.add_argument('--f_map', type=int, default=32)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--grad_clip', type=int, default=1)
parser.add_argument('--conditional', type=bool, default=False)
parser.add_argument('--model', type=str, default='')
parser.add_argument('--data_path', type=str, default='data')
parser.add_argument('--ckpt_path', type=str, default='ckpts')
parser.add_argument('--samples_path', type=str, default='samples')
parser.add_argument('--summary_path', type=str, default='logs')
conf = parser.parse_args()

if conf.data == 'mnist':
Expand All @@ -69,7 +70,7 @@ def train(conf, data):
conf.img_height = 28
conf.img_width = 28
conf.channel = 1
conf.num_batches = data.train.num_examples // conf.batch_size
conf.num_batches = 10#data.train.num_examples // conf.batch_size
else:
from keras.datasets import cifar10
data = cifar10.load_data()
Expand All @@ -81,14 +82,16 @@ def train(conf, data):
conf.channel = 3
conf.num_classes = 10
conf.num_batches = data.shape[0] // conf.batch_size
'''
# TODO debug shape
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(featurewise_center=True,
featurewise_std_normalization=True)
datagen.fit(data)
data = datagen.flow(data, labels, batch_size=conf.batch_size)
'''

conf = makepaths(conf)
train(conf, data)
if conf.model == '':
conf.conditional = False
train(conf, data)
elif conf.model.lower() == 'conditional':
conf.conditional = True
train(conf, data)
elif conf.model.lower() == 'autoencoder':
conf.conditional = True
trainAE(conf, data)


Loading

0 comments on commit 2a7e1bf

Please sign in to comment.