Skip to content

Commit

Permalink
pixelCNN decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent 66a419e commit b1ebfda
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 155 deletions.
255 changes: 119 additions & 136 deletions autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,126 @@

# coding: utf-8

# In[1]:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from utils import *


# In[2]:

import matplotlib.pyplot as plt
#get_ipython().magic(u'matplotlib inline')

mnist = input_data.read_data_sets("data/")

# In[3]:

img_height = 28
img_width = 28
channel = 1

num_layers = 3
filter_size = 3
fmap_in = channel
fmap_out = 32
strides = [1, 1, 1, 1]

batch_size = 50

from models import PixelCNN
class Conf(object):
pass

conf = Conf()
conf.ckpt_path='ckpts'
conf.conditional=True
conf.data='mnist'
conf.data_path='data'
conf.epochs=50
conf.f_map=32
conf.grad_clip=1
conf.layers=5
conf.samples_path='samples'
conf.num_classes = 10
conf.img_height = 28
conf.img_width = 28
conf.channel = 1
conf.num_batches = mnist.train.num_examples // batch_size
conf.type='train'



X = tf.placeholder(shape=[None, img_height, img_width, channel], dtype=tf.float32)

fan_in = X
W = []
for i in range(num_layers):
if i == num_layers -1 :
fmap_out = 10
W.append(tf.Variable(tf.truncated_normal(shape=[filter_size, filter_size, fmap_in, fmap_out], stddev=0.1), name="W_%d"%i))
b = tf.Variable(tf.ones(shape=[fmap_out], dtype=tf.float32), name="encoder_b_%d"%i)
en_conv = tf.nn.conv2d(fan_in, W[i], strides, padding='SAME', name="encoder_conv_%d"%i)

fan_in = tf.tanh(tf.add(en_conv, b))
fmap_in = fmap_out

fan_in = tf.reshape(fan_in, (-1, conf.img_width*conf.img_height*fmap_out))
conf.num_classes = int(fan_in.get_shape()[1])

# TODO
# Make X enter from model input
model = PixelCNN(conf)
# output is model.pre
# define loss function here after getting prediction
y = model.pred

'''
W.reverse()
for i in range(num_layers):
if i == num_layers-1:
fmap_out = channel
c = tf.Variable(tf.ones(shape=[fmap_out], dtype=tf.float32), name="decoder_b_%d"%i)
de_conv = tf.nn.conv2d_transpose(fan_in, W[i], [tf.shape(X)[0], img_height, img_width, fmap_out], strides, padding='SAME', name="decoder_conv_%d"%i)
fan_in = tf.tanh(tf.add(de_conv, c))
y = fan_in
'''


# In[10]:

loss = tf.reduce_mean(tf.square(X - y))
trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)


# In[5]:
'''
import cPickle
data = cPickle.load(open('cifar-100-python/test', 'r'))['data']
data = np.reshape(data, (data.shape[0], 3, 32, 32))
data = np.transpose(data, (0, 2, 3, 1))
#data = (data - np.mean(data))/np.std(data)
'''


# In[ ]:
epochs = 5
num_batches = 1#mnist.train.num_examples // batch_size

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())

for i in range(epochs):
for j in range(num_batches):
batch_X = binarize(mnist.train.next_batch(batch_size)[0].reshape(batch_size, img_height, img_width, channel))
condition = sess.run([fan_in], feed_dict={X:batch_X})
# TODO shape of condition does not match: (1, 10, 28, 28, 32) for (?, 10)
_, l = sess.run([trainer, loss], feed_dict={X:batch_X, model.X:batch_X, model.h: condition[0]})
#batch_X = data[:10]/255.0
#_, l = sess.run([trainer, loss], feed_dict={X:batch_X})
print l

n_examples = 10
#test_X = mnist.train.next_batch(n_examples)[0].reshape(n_examples, img_height, img_width, channel)

o_test_X = mnist.test.next_batch(10)[0].reshape(10, img_height, img_width, channel)
test_X = binarize(o_test_X)
condition = sess.run(fan_in, feed_dict={X:test_X})
samples = sess.run(y, feed_dict={X: test_X, model.X:test_X, model.h:condition})
print samples.shape
#test_X = data[:10]
#samples = sess.run(y, feed_dict={X:test_X/255.0})
fig, axs = plt.subplots(2, n_examples, figsize=(10,2))
for i in range(n_examples):
axs[0][i].imshow(np.reshape(o_test_X[i], (img_height, img_width)), cmap='binary')
axs[1][i].imshow(np.reshape(samples[i], (img_height, img_width)), cmap='binary')
fig.show()
plt.draw()
plt.waitforbuttonpress()

class AE(object):
def __init__(self, X):
self.fmap_out = 32
self.fmap_in = conf.channel
self.fan_in = X
self.num_layers = 3
self.filter_size = 3
self.W = []
self.strides = [1, 1, 1, 1]

for i in range(self.num_layers):
if i == self.num_layers -1 :
self.fmap_out = 10
self.W.append(tf.Variable(tf.truncated_normal(shape=[self.filter_size, self.filter_size, self.fmap_in, self.fmap_out], stddev=0.1), name="W_%d"%i))
b = tf.Variable(tf.ones(shape=[self.fmap_out], dtype=tf.float32), name="encoder_b_%d"%i)
en_conv = tf.nn.conv2d(self.fan_in, self.W[i], self.strides, padding='SAME', name="encoder_conv_%d"%i)

self.fan_in = tf.tanh(tf.add(en_conv, b))
self.fmap_in = self.fmap_out

self.fan_in = tf.reshape(self.fan_in, (-1, conf.img_width*conf.img_height*self.fmap_out))

def decoder(self):
self.W.reverse()
for i in range(self.num_layers):
if i == self.num_layers-1:
self.fmap_out = conf.channel
c = tf.Variable(tf.ones(shape=[self.fmap_out], dtype=tf.float32), name="decoder_b_%d"%i)
de_conv = tf.nn.conv2d_transpose(self.fan_in, self.W[i], [tf.shape(X)[0], conf.img_height, conf.img_width, self.fmap_out], self.strides, padding='SAME', name="decoder_conv_%d"%i)
self.fan_in = tf.tanh(tf.add(de_conv, c))
self.y = self.fan_in

def generate(self, conf):
n_examples = 10
if conf.data == 'mnist':
test_X_pure = data.train.next_batch(n_examples)[0].reshape(n_examples, conf.img_height, conf.img_width, conf.channel)
test_X = binarize(test_X_pure)

condition = sess.run(fan_in, feed_dict={X:test_X})
samples = sess.run(y, feed_dict={X: test_X, decoder.h:condition})
fig, axs = plt.subplots(2, n_examples, figsize=(n_examples, 2))
for i in range(n_examples):
axs[0][i].imshow(np.reshape(o_test_X[i], (conf.img_height, conf.img_width)))
axs[1][i].imshow(np.reshape(samples[i], (conf.img_height, conf.img_width)))
fig.show()
plt.draw()
plt.waitforbuttonpress()



def trainPixelCNNAE(conf, data):
encoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])

encoder = AE(encoder_X)
conf.num_classes = int(encoder.fan_in.get_shape()[1])
# TODO keep X out for main.py also
decoder = PixelCNN(decoder_X, conf, encoder.fan_in)
y = decoder.pred

loss = tf.reduce_mean(tf.square(encoder_X - y))
trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if os.path.exists(conf.ckpt_file):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"

for i in range(conf.epochs):
for j in range(conf.num_batches):
if conf.data == 'mnist':
batch_X = binarize(data.train.next_batch(conf.batch_size)[0].reshape(conf.batch_size, conf.img_height, conf.img_width, conf.channel))
else:
# TODO move batch
batch_X = data[0][0][:conf.batch_size]
#condition = sess.run(encoder.fan_in, feed_dict={encoder_X:batch_X})
_, l = sess.run([trainer, loss], feed_dict={encoder_X: batch_X, decoder_X:batch_X})
print l

saver.save(sess, conf.ckpt_file)
generate_ae(sess, encoder_X, decoder_X, y, data, conf)


if __name__ == "__main__":
class Conf(object):
pass

conf = Conf()
conf.conditional=True
conf.data='mnist'
conf.data_path='data'
conf.f_map=32
conf.grad_clip=1
conf.layers=5
conf.samples_path='samples/ae'
conf.ckpt_path='ckpts/ae'
conf.epochs=10
conf.batch_size = 100

if conf.data == 'mnist':
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets("data/")
conf.img_height = 28
conf.img_width = 28
conf.channel = 1
train_size = data.train.num_examples
else:
from keras.datasets import cifar10
data = cifar10.load_data()
# TODO normalize pixel values
data[0][0] = np.transpose(data[0][0], (0, 2, 3, 1))
data[1][0] = np.transpose(data[1][0], (0, 2, 3, 1))
train_size = data[0][0].shape[0]

conf.num_batches = train_size // conf.batch_size
conf = makepaths(conf)
trainPixelCNNAE(conf, data)

15 changes: 3 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def train(conf, data):
print "Epoch: %d, Cost: %f"%(i, cost)

saver.save(sess, conf.ckpt_file)
generate_and_save(sess, model.X, model.h, model.pred, conf)
generate_samples(sess, model.X, model.h, model.pred, conf)


if __name__ == "__main__":
Expand Down Expand Up @@ -76,19 +76,10 @@ def train(conf, data):
data = np.transpose(data, (0, 2, 3, 1))
raise ValueError("Specify num_classes")
conf.num_classes = 10
conf.num_batches = 10#data.shape[0] // conf.batch_size

# Implementing tf.image.per_image_whitening for normalization
# data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0

conf.num_batches = 10#data.shape[0] // conf.batch_size

ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map))
if not os.path.exists(ckpt_full_path):
os.makedirs(ckpt_full_path)
conf.ckpt_file = os.path.join(ckpt_full_path, "model.ckpt")

conf.samples_path = os.path.join(conf.samples_path, "epoch=%d_bs=%d_layers=%d_fmap=%d"%(conf.epochs, conf.batch_size, conf.layers, conf.f_map))
if not os.path.exists(conf.samples_path):
os.makedirs(conf.samples_path)

conf = makepaths(conf)
train(conf, data)
13 changes: 8 additions & 5 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from layers import GatedCNN

class PixelCNN():
def __init__(self, conf):

self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
def __init__(self, X, conf, h=None):
self.X = X
self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0)
v_stack_in, h_stack_in = self.X_norm, self.X_norm
# TODO norm for multichannel: dubtract mean and divide by std feature-wise
self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes])
if conf.conditional is False:
if conf.conditional is True:
if h is not None:
self.h = h
else:
self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes])
else:
self.h = None

for i in range(conf.layers):
Expand Down
35 changes: 33 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype(np.float32)

def generate_and_save(sess, X, h, pred, conf):
def generate_samples(sess, X, h, pred, conf):
n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
# TODO make it generic
labels = one_hot([1,2,3,4,5]*5)
labels = one_hot([1,2,3,4,5]*5, conf.num_classes)

for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
Expand All @@ -20,6 +20,26 @@ def generate_and_save(sess, X, h, pred, conf):
if conf.data == "mnist":
next_sample = binarize(next_sample)
samples[:, i, j, k] = next_sample[:, i, j, k]
save_images(samples, n_row, n_col, conf)


def generate_ae(sess, encoder_X, decoder_X, y, data, conf):
n_row, n_col = 3, 3
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
if conf.data == 'mnist':
labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel))

for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
for k in xrange(conf.channel):
next_sample = sess.run(y, {encoder_X: labels, decoder_X: samples})
if conf.data == 'mnist':
next_sample = binarize(next_sample)
samples[:, i, j, k] = next_sample[:, i, j, k]

save_images(samples, n_row, n_col, conf)

def save_images(samples, n_row, n_col, conf):
images = samples
images = images.reshape((n_row, n_col, conf.img_height, conf.img_width))
images = images.transpose(1, 2, 0, 3)
Expand All @@ -42,3 +62,14 @@ def one_hot(batch_y, num_classes):
return y_


def makepaths(conf):
ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map))
if not os.path.exists(ckpt_full_path):
os.makedirs(ckpt_full_path)
conf.ckpt_file = os.path.join(ckpt_full_path, "model.ckpt")

conf.samples_path = os.path.join(conf.samples_path, "epoch=%d_bs=%d_layers=%d_fmap=%d"%(conf.epochs, conf.batch_size, conf.layers, conf.f_map))
if not os.path.exists(conf.samples_path):
os.makedirs(conf.samples_path)

return conf

0 comments on commit b1ebfda

Please sign in to comment.