diff --git a/main.py b/main.py index 655a945..72f9569 100644 --- a/main.py +++ b/main.py @@ -16,12 +16,11 @@ def train(conf, data): saver = tf.train.Saver(tf.trainable_variables()) with tf.Session() as sess: + sess.run(tf.initialize_all_variables()) if os.path.exists(conf.ckpt_file): saver.restore(sess, conf.ckpt_file) print "Model Restored" - else: - sess.run(tf.initialize_all_variables()) - + pointer = 0 for i in range(conf.epochs): for j in range(conf.num_batches): @@ -35,8 +34,8 @@ def train(conf, data): print "Epoch: %d, Cost: %f"%(i, cost) - generate_and_save(sess, model.X, model.pred, conf) saver.save(sess, conf.ckpt_file) + generate_and_save(sess, model.X, model.pred, conf) if __name__ == "__main__": diff --git a/models.py b/models.py index 922b190..fa6af63 100644 --- a/models.py +++ b/models.py @@ -4,8 +4,7 @@ class PixelCNN(): def __init__(self, conf): - data_shape = [conf.batch_size, conf.img_height, conf.img_width, conf.channel] - self.X = tf.placeholder(tf.float32, shape=data_shape) + self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0) v_stack_in, h_stack_in = self.X_norm, self.X_norm @@ -56,5 +55,5 @@ def __init__(self, conf): #self.loss = tf.reduce_mean(-tf.reduce_sum(X_ohehot * tf.log(self.fc2), reduction_indices=[1])) # NOTE or check without argmax - self.preds = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), data_shape) + self.pred = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X)) diff --git a/utils.py b/utils.py index 7183e02..c9e9de1 100644 --- a/utils.py +++ b/utils.py @@ -9,11 +9,13 @@ def binarize(images): def generate_and_save(sess, X, pred, conf): n_row, n_col = 5, 5 - samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, 1), dtype=np.float32) + samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) for i in xrange(conf.img_height): for j in xrange(conf.img_width): - for k in xrange(1): - next_sample = binarize(sess.run(pred, {X:samples})) + for k in xrange(conf.channel): + next_sample = sess.run(pred, {X:samples}) + if data == "mnist": + next_sample = binarize(next_sample) samples[:, i, j, k] = next_sample[:, i, j, k] images = samples images = images.reshape((n_row, n_col, conf.img_height, conf.img_width))