Skip to content

Commit

Permalink
ae summary
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent 3f7a3d1 commit dc68105
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def trainPixelCNNAE(conf, data):
decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])

encoder = AE(encoder_X)
conf.num_classes = int(encoder.pred.get_shape()[1])
decoder = PixelCNN(decoder_X, conf, encoder.pred)
y = decoder.pred
tf.scalar_summary('loss', decoder.loss)

#loss = tf.reduce_mean(tf.square(encoder_X - y))
#trainer = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)
Expand All @@ -65,24 +65,34 @@ def trainPixelCNNAE(conf, data):

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter('/tmp/mnist_ae', sess.graph)

sess.run(tf.initialize_all_variables())

if os.path.exists(conf.ckpt_file):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"

pointer = 0
step = 0
for i in range(conf.epochs):
for j in range(conf.num_batches):
if conf.data == 'mnist':
batch_X = binarize(data.train.next_batch(conf.batch_size)[0].reshape(conf.batch_size, conf.img_height, conf.img_width, conf.channel))
else:
batch_X, pointer = get_batch(data, pointer, conf.batch_size)

_, l = sess.run([optimizer, decoder.loss], feed_dict={encoder_X: batch_X, decoder_X: batch_X})
_, l, summary = sess.run([optimizer, decoder.loss, merged], feed_dict={encoder_X: batch_X, decoder_X: batch_X})
writer.add_summary(summary, step)
step += 1
print "Epoch: %d, Cost: %f"%(i, l)
if (i+1)%10 == 0:
saver.save(sess, conf.ckpt_file)
generate_ae(sess, encoder_X, decoder_X, y, data, conf, str(i))

saver.save(sess, conf.ckpt_file)
#generate_ae(sess, encoder_X, decoder_X, y, data, conf)
writer.close()
'''
data = input_data.read_data_sets("data/")
n_examples = 10
if conf.data == 'mnist':
Expand All @@ -97,7 +107,7 @@ def trainPixelCNNAE(conf, data):
fig.show()
plt.draw()
plt.waitforbuttonpress()

'''


if __name__ == "__main__":
Expand All @@ -113,8 +123,10 @@ class Conf(object):
conf.layers=5
conf.samples_path='samples/ae'
conf.ckpt_path='ckpts/ae'
conf.epochs=10
conf.batch_size = 100
conf.summary_path='/tmp/mnist_ae'
conf.epochs=50
conf.batch_size = 64
conf.num_classes = 10

if conf.data == 'mnist':
from tensorflow.examples.tutorials.mnist import input_data
Expand All @@ -136,5 +148,9 @@ class Conf(object):

conf.num_batches = train_size // conf.batch_size
conf = makepaths(conf)
if tf.gfile.Exists(conf.summary_path):
tf.gfile.DeleteRecursively(conf.summary_path)
tf.gfile.MakeDirs(conf.summary_path)

trainPixelCNNAE(conf, data)

0 comments on commit dc68105

Please sign in to comment.