From 40ac4a4a3f80c54de4cd336c5043374203234fc3 Mon Sep 17 00:00:00 2001 From: Anant Gupta Date: Tue, 8 Nov 2016 19:39:19 +0530 Subject: [PATCH] gatedcnn-1 --- layers.py | 18 +++++++++++++----- main.py | 15 +++++++++++---- models.py | 9 +++++---- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/layers.py b/layers.py index 569aac5..8903e27 100644 --- a/layers.py +++ b/layers.py @@ -25,7 +25,7 @@ def conv_op(x, W): return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') class GatedCNN(): - def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True): + def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True, conditional=None): self.fan_in = fan_in in_dim = self.fan_in.get_shape()[-1] self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]] @@ -34,8 +34,9 @@ def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activat self.payload = payload self.mask = mask self.activation = activation + # TODO need to map (batch_size,num_classes) to (f_map,) + self.conditional = None#conditional - if gated: self.gated_conv() else: @@ -43,10 +44,17 @@ def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activat def gated_conv(self): W_f = get_weights(self.W_shape, "v_W", mask=self.mask) - b_f = get_bias(self.b_shape, "v_b") W_g = get_weights(self.W_shape, "h_W", mask=self.mask) - b_g = get_bias(self.b_shape, "h_b") - + if self.conditional is not None: + h_shape = int(self.conditional.get_shape()[1]) + V_f = get_weights([h_shape, self.W_shape[3]], "v_V") + b_f = tf.matmul(self.conditional, V_f) + V_g = get_weights([h_shape, self.W_shape[3]], "h_V") + b_g = tf.matmul(self.conditional, V_g) + else: + b_f = get_bias(self.b_shape, "v_b") + b_g = get_bias(self.b_shape, "h_b") + conv_f = conv_op(self.fan_in, W_f) conv_g = conv_op(self.fan_in, W_g) diff --git a/main.py b/main.py index 72f9569..c8f9028 100644 --- a/main.py +++ b/main.py @@ -25,12 +25,16 @@ def train(conf, data): for i in range(conf.epochs): for j in range(conf.num_batches): if conf.data == "mnist": - batch_X = binarize(data.train.next_batch(conf.batch_size)[0] \ - .reshape([conf.batch_size, conf.img_height, conf.img_width, conf.channel])) + batch_X, batch_y = data.train.next_batch(conf.batch_size) + batch_X = binarize(batch_X.reshape([conf.batch_size, \ + conf.img_height, conf.img_width, conf.channel])) + y_ = np.zeros((batch_y.shape[0], conf.num_classes)) + y_[np.arange(batch_y.shape[0]), batch_y] = 1 + batch_y = y_ else: batch_X, pointer = get_batch(data, pointer, conf.batch_size) - _, cost = sess.run([optimizer, model.loss], feed_dict={model.X:batch_X}) + _, cost = sess.run([optimizer, model.loss], feed_dict={model.X:batch_X, model.h:batch_y}) print "Epoch: %d, Cost: %f"%(i, cost) @@ -57,10 +61,11 @@ def train(conf, data): if not os.path.exists(conf.data_path): os.makedirs(conf.data_path) data = input_data.read_data_sets(conf.data_path) + conf.num_classes = 10 conf.img_height = 28 conf.img_width = 28 conf.channel = 1 - conf.num_batches = mnist.train.num_examples // conf.batch_size + conf.num_batches = 10#mnist.train.num_examples // conf.batch_size else: import cPickle data = cPickle.load(open('cifar-100-python/train', 'r'))['data'] @@ -70,6 +75,8 @@ def train(conf, data): data = np.reshape(data, (data.shape[0], conf.channel, \ conf.img_height, conf.img_width)) data = np.transpose(data, (0, 2, 3, 1)) + raise ValueError("Specify num_classes") + conf.num_classes = 10 # Implementing tf.image.per_image_whitening for normalization # data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0 diff --git a/models.py b/models.py index fa6af63..64d942a 100644 --- a/models.py +++ b/models.py @@ -7,22 +7,23 @@ def __init__(self, conf): self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0) v_stack_in, h_stack_in = self.X_norm, self.X_norm + # TODO norm for multichannel: dubtract mean and divide by std feature-wise + self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes]) for i in range(conf.layers): filter_size = 3 if i > 0 else 7 - in_dim = conf.f_map if i > 0 else conf.channel mask = 'b' if i > 0 else 'a' residual = True if i > 0 else False i = str(i) with tf.variable_scope("v_stack"+i): - v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask).output() + v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask, conditional=self.h).output() v_stack_in = v_stack with tf.variable_scope("v_stack_1"+i): v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, gated=False, mask=mask).output() with tf.variable_scope("h_stack"+i): - h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask).output() + h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask, conditional=self.h).output() with tf.variable_scope("h_stack_1"+i): h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, gated=False, mask=mask).output() @@ -36,7 +37,7 @@ def __init__(self, conf): if conf.data == "mnist": with tf.variable_scope("fc_2"): self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output() - self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model.fc2, model.X)) + self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.fc2, self.X)) self.pred = tf.nn.sigmoid(self.fc2) else: color_dim = 256