diff --git a/autoencoder.py b/autoencoder.py index b3a78e3..20e632f 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -50,9 +50,9 @@ def trainPixelCNNAE(conf, data): decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) encoder = AE(encoder_X) - conf.num_classes = int(encoder.pred.get_shape()[1]) decoder = PixelCNN(decoder_X, conf, encoder.pred) y = decoder.pred + tf.scalar_summary('loss', decoder.loss) #loss = tf.reduce_mean(tf.square(encoder_X - y)) #trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss) @@ -65,12 +65,17 @@ def trainPixelCNNAE(conf, data): saver = tf.train.Saver(tf.trainable_variables()) with tf.Session() as sess: + merged = tf.merge_all_summaries() + writer = tf.train.SummaryWriter('/tmp/mnist_ae', sess.graph) + sess.run(tf.initialize_all_variables()) + if os.path.exists(conf.ckpt_file): saver.restore(sess, conf.ckpt_file) print "Model Restored" pointer = 0 + step = 0 for i in range(conf.epochs): for j in range(conf.num_batches): if conf.data == 'mnist': @@ -78,11 +83,16 @@ def trainPixelCNNAE(conf, data): else: batch_X, pointer = get_batch(data, pointer, conf.batch_size) - _, l = sess.run([optimizer, decoder.loss], feed_dict={encoder_X: batch_X, decoder_X: batch_X}) + _, l, summary = sess.run([optimizer, decoder.loss, merged], feed_dict={encoder_X: batch_X, decoder_X: batch_X}) + writer.add_summary(summary, step) + step += 1 print "Epoch: %d, Cost: %f"%(i, l) + if (i+1)%10 == 0: + saver.save(sess, conf.ckpt_file) + generate_ae(sess, encoder_X, decoder_X, y, data, conf, str(i)) - saver.save(sess, conf.ckpt_file) - #generate_ae(sess, encoder_X, decoder_X, y, data, conf) + writer.close() + ''' data = input_data.read_data_sets("data/") n_examples = 10 if conf.data == 'mnist': @@ -97,7 +107,7 @@ def trainPixelCNNAE(conf, data): fig.show() plt.draw() plt.waitforbuttonpress() - + ''' if __name__ == "__main__": @@ -113,8 +123,10 @@ class Conf(object): conf.layers=5 conf.samples_path='samples/ae' conf.ckpt_path='ckpts/ae' - conf.epochs=10 - conf.batch_size = 100 + conf.summary_path='/tmp/mnist_ae' + conf.epochs=50 + conf.batch_size = 64 + conf.num_classes = 10 if conf.data == 'mnist': from tensorflow.examples.tutorials.mnist import input_data @@ -136,5 +148,9 @@ class Conf(object): conf.num_batches = train_size // conf.batch_size conf = makepaths(conf) + if tf.gfile.Exists(conf.summary_path): + tf.gfile.DeleteRecursively(conf.summary_path) + tf.gfile.MakeDirs(conf.summary_path) + trainPixelCNNAE(conf, data)