Skip to content

Commit

Permalink
ae arch change
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent e4119ea commit 3f7a3d1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 44 deletions.
98 changes: 55 additions & 43 deletions autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,65 @@
from utils import *
import matplotlib.pyplot as plt
from models import PixelCNN
from layers import conv_op

def get_weights(shape, name):
return tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1), name=name)

def get_biases(shape, name):
return tf.Variable(tf.constant(shape=shape, value=0.1, dtype=tf.float32), name=name)

def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')


class AE(object):
def __init__(self, X):
self.num_layers = 2
self.num_layers = 6
self.fmap_out = [8, 32]
self.fmap_in = conf.channel
self.fan_in = X
self.filter_size = 4
self.W = []
self.strides = [1, 1, 1, 1]

for i in range(self.num_layers):
self.W.append(tf.Variable(tf.truncated_normal(shape=[self.filter_size, self.filter_size, self.fmap_in, self.fmap_out[i]], stddev=0.1), name="W_%d"%i))
b = tf.Variable(tf.ones(shape=[self.fmap_out[i]], dtype=tf.float32), name="encoder_b_%d"%i)
en_conv = tf.nn.conv2d(self.fan_in, self.W[i], self.strides, padding='SAME', name="encoder_conv_%d"%i)
en_pool = tf.nn.max_pool(en_conv, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME', name="encoder_pool_%d"%i)

self.fan_in = tf.tanh(tf.add(en_pool, b))
self.fmap_in = self.fmap_out[i]

op_shape = self.fan_in.get_shape()
self.fan_in = tf.reshape(self.fan_in, (-1, int(op_shape[1])*int(op_shape[2])*int(op_shape[3])))

def decoder(self):
self.W.reverse()
for i in range(self.num_layers):
if i == self.num_layers-1:
self.fmap_out = conf.channel
c = tf.Variable(tf.ones(shape=[self.fmap_out], dtype=tf.float32), name="decoder_b_%d"%i)
de_conv = tf.nn.conv2d_transpose(self.fan_in, self.W[i], [tf.shape(X)[0], conf.img_height, conf.img_width, self.fmap_out], self.strides, padding='SAME', name="decoder_conv_%d"%i)
self.fan_in = tf.tanh(tf.add(de_conv, c))
self.y = self.fan_in

def generate(self, conf):
n_examples = 10
if conf.data == 'mnist':
test_X_pure = data.train.next_batch(n_examples)[0].reshape(n_examples, conf.img_height, conf.img_width, conf.channel)
test_X = binarize(test_X_pure)
W_conv1 = get_weights([5, 5, conf.channel, 100], "W_conv1")
b_conv1 = get_biases([100], "b_conv1")
conv1 = tf.nn.relu(conv_op(X, W_conv1) + b_conv1)
pool1 = max_pool_2x2(conv1)

condition = sess.run(fan_in, feed_dict={X:test_X})
samples = sess.run(y, feed_dict={X: test_X, decoder.h:condition})
fig, axs = plt.subplots(2, n_examples, figsize=(n_examples, 2))
for i in range(n_examples):
axs[0][i].imshow(np.reshape(o_test_X[i], (conf.img_height, conf.img_width)))
axs[1][i].imshow(np.reshape(samples[i], (conf.img_height, conf.img_width)))
fig.show()
plt.draw()
plt.waitforbuttonpress()

W_conv2 = get_weights([5, 5, 100, 150], "W_conv2")
b_conv2 = get_biases([150], "b_conv2")
conv2 = tf.nn.relu(conv_op(pool1, W_conv2) + b_conv2)
pool2 = max_pool_2x2(conv2)

W_conv3 = get_weights([3, 3, 150, 200], "W_conv3")
b_conv3 = get_biases([200], "b_conv3")
conv3 = tf.nn.relu(conv_op(pool2, W_conv3) + b_conv3)
conv3_reshape = tf.reshape(conv3, (-1, 7*7*200))

W_fc = get_weights([7*7*200, 10], "W_fc")
b_fc = get_biases([10], "b_fc")
self.pred = tf.nn.softmax(tf.add(tf.matmul(conv3_reshape, W_fc), b_fc))

def trainPixelCNNAE(conf, data):
encoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])

encoder = AE(encoder_X)
conf.num_classes = int(encoder.fan_in.get_shape()[1])
decoder = PixelCNN(decoder_X, conf, encoder.fan_in)
conf.num_classes = int(encoder.pred.get_shape()[1])
decoder = PixelCNN(decoder_X, conf, encoder.pred)
y = decoder.pred

loss = tf.reduce_mean(tf.square(encoder_X - y))
trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)
#loss = tf.reduce_mean(tf.square(encoder_X - y))
#trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)
trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(decoder.loss)

clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients]
optimizer = trainer.apply_gradients(clipped_gradients)


saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
Expand All @@ -81,11 +78,26 @@ def trainPixelCNNAE(conf, data):
else:
batch_X, pointer = get_batch(data, pointer, conf.batch_size)

_, l = sess.run([trainer, loss], feed_dict={encoder_X: batch_X, decoder_X: batch_X})
_, l = sess.run([optimizer, decoder.loss], feed_dict={encoder_X: batch_X, decoder_X: batch_X})
print "Epoch: %d, Cost: %f"%(i, l)

saver.save(sess, conf.ckpt_file)
generate_ae(sess, encoder_X, decoder_X, y, data, conf)
#generate_ae(sess, encoder_X, decoder_X, y, data, conf)
data = input_data.read_data_sets("data/")
n_examples = 10
if conf.data == 'mnist':
test_X_pure = data.train.next_batch(n_examples)[0].reshape(n_examples, conf.img_height, conf.img_width, conf.channel)
test_X = binarize(test_X_pure)

samples = sess.run(y, feed_dict={encoder_X:test_X, decoder_X:test_X})
fig, axs = plt.subplots(2, n_examples, figsize=(n_examples, 2))
for i in range(n_examples):
axs[0][i].imshow(np.reshape(test_X_pure[i], (conf.img_height, conf.img_width)))
axs[1][i].imshow(np.reshape(samples[i], (conf.img_height, conf.img_width)))
fig.show()
plt.draw()
plt.waitforbuttonpress()



if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def generate_samples(sess, X, h, pred, conf, suff):


def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''):
n_row, n_col = 1,1
n_row, n_col = 5,5
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
if conf.data == 'mnist':
labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel))
Expand Down

0 comments on commit 3f7a3d1

Please sign in to comment.