diff --git a/autoencoder.py b/autoencoder.py index 62f6902..b3a78e3 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -3,10 +3,21 @@ from utils import * import matplotlib.pyplot as plt from models import PixelCNN +from layers import conv_op + +def get_weights(shape, name): + return tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1), name=name) + +def get_biases(shape, name): + return tf.Variable(tf.constant(shape=shape, value=0.1, dtype=tf.float32), name=name) + +def max_pool_2x2(x): + return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') + class AE(object): def __init__(self, X): - self.num_layers = 2 + self.num_layers = 6 self.fmap_out = [8, 32] self.fmap_in = conf.channel self.fan_in = X @@ -14,57 +25,43 @@ def __init__(self, X): self.W = [] self.strides = [1, 1, 1, 1] - for i in range(self.num_layers): - self.W.append(tf.Variable(tf.truncated_normal(shape=[self.filter_size, self.filter_size, self.fmap_in, self.fmap_out[i]], stddev=0.1), name="W_%d"%i)) - b = tf.Variable(tf.ones(shape=[self.fmap_out[i]], dtype=tf.float32), name="encoder_b_%d"%i) - en_conv = tf.nn.conv2d(self.fan_in, self.W[i], self.strides, padding='SAME', name="encoder_conv_%d"%i) - en_pool = tf.nn.max_pool(en_conv, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME', name="encoder_pool_%d"%i) - - self.fan_in = tf.tanh(tf.add(en_pool, b)) - self.fmap_in = self.fmap_out[i] - - op_shape = self.fan_in.get_shape() - self.fan_in = tf.reshape(self.fan_in, (-1, int(op_shape[1])*int(op_shape[2])*int(op_shape[3]))) - - def decoder(self): - self.W.reverse() - for i in range(self.num_layers): - if i == self.num_layers-1: - self.fmap_out = conf.channel - c = tf.Variable(tf.ones(shape=[self.fmap_out], dtype=tf.float32), name="decoder_b_%d"%i) - de_conv = tf.nn.conv2d_transpose(self.fan_in, self.W[i], [tf.shape(X)[0], conf.img_height, conf.img_width, self.fmap_out], self.strides, padding='SAME', name="decoder_conv_%d"%i) - self.fan_in = tf.tanh(tf.add(de_conv, c)) - self.y = self.fan_in - - def generate(self, conf): - n_examples = 10 - if conf.data == 'mnist': - test_X_pure = data.train.next_batch(n_examples)[0].reshape(n_examples, conf.img_height, conf.img_width, conf.channel) - test_X = binarize(test_X_pure) + W_conv1 = get_weights([5, 5, conf.channel, 100], "W_conv1") + b_conv1 = get_biases([100], "b_conv1") + conv1 = tf.nn.relu(conv_op(X, W_conv1) + b_conv1) + pool1 = max_pool_2x2(conv1) - condition = sess.run(fan_in, feed_dict={X:test_X}) - samples = sess.run(y, feed_dict={X: test_X, decoder.h:condition}) - fig, axs = plt.subplots(2, n_examples, figsize=(n_examples, 2)) - for i in range(n_examples): - axs[0][i].imshow(np.reshape(o_test_X[i], (conf.img_height, conf.img_width))) - axs[1][i].imshow(np.reshape(samples[i], (conf.img_height, conf.img_width))) - fig.show() - plt.draw() - plt.waitforbuttonpress() + W_conv2 = get_weights([5, 5, 100, 150], "W_conv2") + b_conv2 = get_biases([150], "b_conv2") + conv2 = tf.nn.relu(conv_op(pool1, W_conv2) + b_conv2) + pool2 = max_pool_2x2(conv2) + + W_conv3 = get_weights([3, 3, 150, 200], "W_conv3") + b_conv3 = get_biases([200], "b_conv3") + conv3 = tf.nn.relu(conv_op(pool2, W_conv3) + b_conv3) + conv3_reshape = tf.reshape(conv3, (-1, 7*7*200)) + W_fc = get_weights([7*7*200, 10], "W_fc") + b_fc = get_biases([10], "b_fc") + self.pred = tf.nn.softmax(tf.add(tf.matmul(conv3_reshape, W_fc), b_fc)) def trainPixelCNNAE(conf, data): encoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) encoder = AE(encoder_X) - conf.num_classes = int(encoder.fan_in.get_shape()[1]) - decoder = PixelCNN(decoder_X, conf, encoder.fan_in) + conf.num_classes = int(encoder.pred.get_shape()[1]) + decoder = PixelCNN(decoder_X, conf, encoder.pred) y = decoder.pred - loss = tf.reduce_mean(tf.square(encoder_X - y)) - trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss) + #loss = tf.reduce_mean(tf.square(encoder_X - y)) + #trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss) + trainer = tf.train.RMSPropOptimizer(1e-3) + gradients = trainer.compute_gradients(decoder.loss) + + clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients] + optimizer = trainer.apply_gradients(clipped_gradients) + saver = tf.train.Saver(tf.trainable_variables()) with tf.Session() as sess: @@ -81,11 +78,26 @@ def trainPixelCNNAE(conf, data): else: batch_X, pointer = get_batch(data, pointer, conf.batch_size) - _, l = sess.run([trainer, loss], feed_dict={encoder_X: batch_X, decoder_X: batch_X}) + _, l = sess.run([optimizer, decoder.loss], feed_dict={encoder_X: batch_X, decoder_X: batch_X}) print "Epoch: %d, Cost: %f"%(i, l) saver.save(sess, conf.ckpt_file) - generate_ae(sess, encoder_X, decoder_X, y, data, conf) + #generate_ae(sess, encoder_X, decoder_X, y, data, conf) + data = input_data.read_data_sets("data/") + n_examples = 10 + if conf.data == 'mnist': + test_X_pure = data.train.next_batch(n_examples)[0].reshape(n_examples, conf.img_height, conf.img_width, conf.channel) + test_X = binarize(test_X_pure) + + samples = sess.run(y, feed_dict={encoder_X:test_X, decoder_X:test_X}) + fig, axs = plt.subplots(2, n_examples, figsize=(n_examples, 2)) + for i in range(n_examples): + axs[0][i].imshow(np.reshape(test_X_pure[i], (conf.img_height, conf.img_width))) + axs[1][i].imshow(np.reshape(samples[i], (conf.img_height, conf.img_width))) + fig.show() + plt.draw() + plt.waitforbuttonpress() + if __name__ == "__main__": diff --git a/utils.py b/utils.py index 8481299..cb051b1 100644 --- a/utils.py +++ b/utils.py @@ -29,7 +29,7 @@ def generate_samples(sess, X, h, pred, conf, suff): def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''): - n_row, n_col = 1,1 + n_row, n_col = 5,5 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) if conf.data == 'mnist': labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel))