From e4119ea8827b3b6f8b61e1162dae212647a65561 Mon Sep 17 00:00:00 2001 From: Anant Gupta Date: Wed, 16 Nov 2016 21:18:09 +0530 Subject: [PATCH] pCNN AE sample generation changes --- autoencoder.py | 20 ++++++++++---------- main.py | 9 ++++----- utils.py | 10 ++++++---- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/autoencoder.py b/autoencoder.py index 791c2dd..62f6902 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -6,25 +6,25 @@ class AE(object): def __init__(self, X): - self.fmap_out = 32 + self.num_layers = 2 + self.fmap_out = [8, 32] self.fmap_in = conf.channel self.fan_in = X - self.num_layers = 3 - self.filter_size = 3 + self.filter_size = 4 self.W = [] self.strides = [1, 1, 1, 1] for i in range(self.num_layers): - if i == self.num_layers -1 : - self.fmap_out = 10 - self.W.append(tf.Variable(tf.truncated_normal(shape=[self.filter_size, self.filter_size, self.fmap_in, self.fmap_out], stddev=0.1), name="W_%d"%i)) - b = tf.Variable(tf.ones(shape=[self.fmap_out], dtype=tf.float32), name="encoder_b_%d"%i) + self.W.append(tf.Variable(tf.truncated_normal(shape=[self.filter_size, self.filter_size, self.fmap_in, self.fmap_out[i]], stddev=0.1), name="W_%d"%i)) + b = tf.Variable(tf.ones(shape=[self.fmap_out[i]], dtype=tf.float32), name="encoder_b_%d"%i) en_conv = tf.nn.conv2d(self.fan_in, self.W[i], self.strides, padding='SAME', name="encoder_conv_%d"%i) + en_pool = tf.nn.max_pool(en_conv, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME', name="encoder_pool_%d"%i) - self.fan_in = tf.tanh(tf.add(en_conv, b)) - self.fmap_in = self.fmap_out + self.fan_in = tf.tanh(tf.add(en_pool, b)) + self.fmap_in = self.fmap_out[i] - self.fan_in = tf.reshape(self.fan_in, (-1, conf.img_width*conf.img_height*self.fmap_out)) + op_shape = self.fan_in.get_shape() + self.fan_in = tf.reshape(self.fan_in, (-1, int(op_shape[1])*int(op_shape[2])*int(op_shape[3]))) def decoder(self): self.W.reverse() diff --git a/main.py b/main.py index a49575f..f752ccf 100644 --- a/main.py +++ b/main.py @@ -40,17 +40,16 @@ def train(conf, data): data_dict[model.h] = batch_y _, cost,_f = sess.run([optimizer, model.loss, model.fc2], feed_dict=data_dict) print "Epoch: %d, Cost: %f"%(i, cost) - - saver.save(sess, conf.ckpt_file) - generate_samples(sess, X, model.h, model.pred_sample, conf, "sample") - generate_samples(sess, X, model.h, model.pred_argmax, conf, "argmax") + if (i+1)%10 == 0: + saver.save(sess, conf.ckpt_file) + generate_samples(sess, X, model.h, model.pred, conf, "argmax") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--type', type=str, default='train') parser.add_argument('--data', type=str, default='mnist') - parser.add_argument('--layers', type=int, default=5) + parser.add_argument('--layers', type=int, default=12) parser.add_argument('--f_map', type=int, default=32) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=100) diff --git a/utils.py b/utils.py index 8f558d5..8481299 100644 --- a/utils.py +++ b/utils.py @@ -8,7 +8,7 @@ def binarize(images): return (np.random.uniform(size=images.shape) < images).astype(np.float32) def generate_samples(sess, X, h, pred, conf, suff): - n_row, n_col = 5, 5 + n_row, n_col = 10,10 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) # TODO make it generic labels = one_hot(np.array([1,2,3,4,5]*5), conf.num_classes) @@ -28,8 +28,8 @@ def generate_samples(sess, X, h, pred, conf, suff): save_images(samples, n_row, n_col, conf, suff) -def generate_ae(sess, encoder_X, decoder_X, y, data, conf): - n_row, n_col = 5, 5 +def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''): + n_row, n_col = 1,1 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) if conf.data == 'mnist': labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel)) @@ -44,7 +44,9 @@ def generate_ae(sess, encoder_X, decoder_X, y, data, conf): next_sample = binarize(next_sample) samples[:, i, j, k] = next_sample[:, i, j, k] - save_images(samples, n_row, n_col, conf, '') + np.save('preds_'+suff+'.npy', samples) + print "height:%d"%i + save_images(samples, n_row, n_col, conf, suff) def save_images(samples, n_row, n_col, conf, suff): images = samples