From da7dde7c36e6b39308197688ba04132cc83ac876 Mon Sep 17 00:00:00 2001 From: Anant Gupta Date: Mon, 7 Nov 2016 14:16:29 +0530 Subject: [PATCH] 3 channel --- main.py | 36 +++++++++++++++++++++++++++--------- models.py | 38 ++++++++++++++++++++++++++++++-------- utils.py | 6 ++++++ 3 files changed, 63 insertions(+), 17 deletions(-) diff --git a/main.py b/main.py index e190226..655a945 100644 --- a/main.py +++ b/main.py @@ -6,29 +6,34 @@ def train(conf, data): model = PixelCNN(conf) - loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model.fc2, model.X)) trainer = tf.train.RMSPropOptimizer(1e-3) - gradients = trainer.compute_gradients(loss) + gradients = trainer.compute_gradients(model.loss) clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients] optimizer = trainer.apply_gradients(clipped_gradients) saver = tf.train.Saver(tf.trainable_variables()) + with tf.Session() as sess: if os.path.exists(conf.ckpt_file): saver.restore(sess, conf.ckpt_file) print "Model Restored" else: sess.run(tf.initialize_all_variables()) - + + pointer = 0 for i in range(conf.epochs): for j in range(conf.num_batches): - batch_X = binarize(data.train.next_batch(conf.batch_size)[0] \ + if conf.data == "mnist": + batch_X = binarize(data.train.next_batch(conf.batch_size)[0] \ .reshape([conf.batch_size, conf.img_height, conf.img_width, conf.channel])) - _, cost = sess.run([optimizer, loss], feed_dict={model.X:batch_X}) + else: + batch_X, pointer = get_batch(data, pointer, conf.batch_size) + + _, cost = sess.run([optimizer, model.loss], feed_dict={model.X:batch_X}) - print "Epoch: %d, Cost: %f"%(i, cost) + print "Epoch: %d, Cost: %f"%(i, cost) generate_and_save(sess, model.X, model.pred, conf) saver.save(sess, conf.ckpt_file) @@ -38,7 +43,7 @@ def train(conf, data): parser = argparse.ArgumentParser() parser.add_argument('--type', type=str, default='train') parser.add_argument('--data', type=str, default='mnist') - parser.add_argument('--layers', type=int, default=12) + parser.add_argument('--layers', type=int, default=5) parser.add_argument('--f_map', type=int, default=32) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=100) @@ -56,8 +61,21 @@ def train(conf, data): conf.img_height = 28 conf.img_width = 28 conf.channel = 1 - conf.num_batches = 10#mnist.train.num_examples // conf.batch_size - conf.filter_size = 7 + conf.num_batches = mnist.train.num_examples // conf.batch_size + else: + import cPickle + data = cPickle.load(open('cifar-100-python/train', 'r'))['data'] + conf.img_height = 32 + conf.img_width = 32 + conf.channel = 3 + data = np.reshape(data, (data.shape[0], conf.channel, \ + conf.img_height, conf.img_width)) + data = np.transpose(data, (0, 2, 3, 1)) + + # Implementing tf.image.per_image_whitening for normalization + # data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0 + + conf.num_batches = data.shape[0] // conf.batch_size ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map)) if not os.path.exists(ckpt_full_path): diff --git a/models.py b/models.py index 9705161..922b190 100644 --- a/models.py +++ b/models.py @@ -3,12 +3,14 @@ class PixelCNN(): def __init__(self, conf): - - self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) - v_stack_in, h_stack_in = self.X, self.X + + data_shape = [conf.batch_size, conf.img_height, conf.img_width, conf.channel] + self.X = tf.placeholder(tf.float32, shape=data_shape) + self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0) + v_stack_in, h_stack_in = self.X_norm, self.X_norm for i in range(conf.layers): - filter_size = 3 if i > 0 else conf.filter_size + filter_size = 3 if i > 0 else 7 in_dim = conf.f_map if i > 0 else conf.channel mask = 'b' if i > 0 else 'a' residual = True if i > 0 else False @@ -32,7 +34,27 @@ def __init__(self, conf): with tf.variable_scope("fc_1"): fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, gated=False, mask='b').output() - # handle Imagenet differently - with tf.variable_scope("fc_2"): - self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output() - self.pred = tf.nn.sigmoid(self.fc2) + if conf.data == "mnist": + with tf.variable_scope("fc_2"): + self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output() + self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model.fc2, model.X)) + self.pred = tf.nn.sigmoid(self.fc2) + else: + color_dim = 256 + with tf.variable_scope("fc_2"): + self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, gated=False, mask='b', activation=False).output() + #fc2_shape = self.fc2.get_shape() + #self.fc2 = tf.reshape(self.fc2, (int(fc2_shape[0]), int(fc2_shape[1]), int(fc2_shape[2]), conf.channel, -1)) + #fc2_shape = self.fc2.get_shape() + #self.fc2 = tf.nn.softmax(tf.reshape(self.fc2, (-1, int(fc2_shape[-1])))) + self.fc2 = tf.reshape(self.fc2, (-1, color_dim)) + + #self.loss = self.categorical_crossentropy(self.fc2, self.X) + #self.X_flat = tf.reshape(self.X, [-1]) + #self.fc2_flat = tf.cast(tf.argmax(self.fc2, dimension=tf.rank(self.fc2) - 1), dtype=tf.float32) + self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32))) + #self.loss = tf.reduce_mean(-tf.reduce_sum(X_ohehot * tf.log(self.fc2), reduction_indices=[1])) + + # NOTE or check without argmax + self.preds = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), data_shape) + diff --git a/utils.py b/utils.py index c47de82..7183e02 100644 --- a/utils.py +++ b/utils.py @@ -24,3 +24,9 @@ def generate_and_save(sess, X, pred, conf): scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(conf.samples_path, filename)) +def get_batch(data, pointer, batch_size): + if (batch_size + 1) * pointer >= data.shape[0]: + pointer = 0 + batch = data[batch_size * pointer : batch_size * (pointer + 1)] + pointer += 1 + return [batch, pointer]