Skip to content

Commit

Permalink
pending changes
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent b1ebfda commit a4632ea
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
20 changes: 11 additions & 9 deletions autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def trainPixelCNNAE(conf, data):

encoder = AE(encoder_X)
conf.num_classes = int(encoder.fan_in.get_shape()[1])
# TODO keep X out for main.py also
decoder = PixelCNN(decoder_X, conf, encoder.fan_in)
y = decoder.pred

Expand All @@ -74,16 +73,16 @@ def trainPixelCNNAE(conf, data):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"

pointer = 0
for i in range(conf.epochs):
for j in range(conf.num_batches):
if conf.data == 'mnist':
batch_X = binarize(data.train.next_batch(conf.batch_size)[0].reshape(conf.batch_size, conf.img_height, conf.img_width, conf.channel))
else:
# TODO move batch
batch_X = data[0][0][:conf.batch_size]
#condition = sess.run(encoder.fan_in, feed_dict={encoder_X:batch_X})
_, l = sess.run([trainer, loss], feed_dict={encoder_X: batch_X, decoder_X:batch_X})
print l
batch_X, pointer = get_batch(data, pointer, conf.batch_size)

_, l = sess.run([trainer, loss], feed_dict={encoder_X: batch_X, decoder_X: batch_X})
print "Epoch: %d, Cost: %f"%(i, l)

saver.save(sess, conf.ckpt_file)
generate_ae(sess, encoder_X, decoder_X, y, data, conf)
Expand Down Expand Up @@ -116,9 +115,12 @@ class Conf(object):
from keras.datasets import cifar10
data = cifar10.load_data()
# TODO normalize pixel values
data[0][0] = np.transpose(data[0][0], (0, 2, 3, 1))
data[1][0] = np.transpose(data[1][0], (0, 2, 3, 1))
train_size = data[0][0].shape[0]
data = data[0][0]
data = np.transpose(data, (0, 2, 3, 1))
conf.img_height = 32
conf.img_width = 32
conf.channel = 3
train_size = data.shape[0]

conf.num_batches = train_size // conf.batch_size
conf = makepaths(conf)
Expand Down
20 changes: 9 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from utils import *

def train(conf, data):
model = PixelCNN(conf)
X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
model = PixelCNN(conf, X)

trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(model.loss)
Expand All @@ -32,12 +33,12 @@ def train(conf, data):
else:
batch_X, pointer = get_batch(data, pointer, conf.batch_size)

_, cost = sess.run([optimizer, model.loss], feed_dict={model.X:batch_X, model.h:batch_y})
_, cost = sess.run([optimizer, model.loss], feed_dict={X:batch_X, model.h:batch_y})

print "Epoch: %d, Cost: %f"%(i, cost)

saver.save(sess, conf.ckpt_file)
generate_samples(sess, model.X, model.h, model.pred, conf)
generate_samples(sess, X, model.h, model.pred, conf)


if __name__ == "__main__":
Expand Down Expand Up @@ -66,20 +67,17 @@ def train(conf, data):
conf.channel = 1
conf.num_batches = data.train.num_examples // conf.batch_size
else:
import cPickle
data = cPickle.load(open('cifar-100-python/train', 'r'))['data']
from keras.datasets import cifar10
data = cifar10.load_data()
data = data[0][0]
data = np.transpose(data, (0, 2, 3, 1))
conf.img_height = 32
conf.img_width = 32
conf.channel = 3
data = np.reshape(data, (data.shape[0], conf.channel, \
conf.img_height, conf.img_width))
data = np.transpose(data, (0, 2, 3, 1))
raise ValueError("Specify num_classes")
conf.num_classes = 10
conf.num_batches = 10#data.shape[0] // conf.batch_size
conf.num_batches = data.shape[0] // conf.batch_size

# Implementing tf.image.per_image_whitening for normalization
# data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0

conf = makepaths(conf)
train(conf, data)
4 changes: 3 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def generate_samples(sess, X, h, pred, conf):


def generate_ae(sess, encoder_X, decoder_X, y, data, conf):
n_row, n_col = 3, 3
n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
if conf.data == 'mnist':
labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel))
else:
labels = get_batch(data, 0, n_row*n_col)

for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
Expand Down

0 comments on commit a4632ea

Please sign in to comment.