diff --git a/layers.py b/layers.py index 8903e27..b88dc48 100644 --- a/layers.py +++ b/layers.py @@ -34,8 +34,7 @@ def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activat self.payload = payload self.mask = mask self.activation = activation - # TODO need to map (batch_size,num_classes) to (f_map,) - self.conditional = None#conditional + self.conditional = conditional if gated: self.gated_conv() @@ -51,6 +50,11 @@ def gated_conv(self): b_f = tf.matmul(self.conditional, V_f) V_g = get_weights([h_shape, self.W_shape[3]], "h_V") b_g = tf.matmul(self.conditional, V_g) + + b_f_shape = tf.shape(b_f) + b_f = tf.reshape(b_f, (b_f_shape[0], 1, 1, b_f_shape[1])) + b_g_shape = tf.shape(b_g) + b_g = tf.reshape(b_g, (b_g_shape[0], 1, 1, b_g_shape[1])) else: b_f = get_bias(self.b_shape, "v_b") b_g = get_bias(self.b_shape, "h_b") @@ -62,6 +66,7 @@ def gated_conv(self): conv_f += self.payload conv_g += self.payload + self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g)) def simple_conv(self): diff --git a/main.py b/main.py index c8f9028..29d92d0 100644 --- a/main.py +++ b/main.py @@ -28,9 +28,7 @@ def train(conf, data): batch_X, batch_y = data.train.next_batch(conf.batch_size) batch_X = binarize(batch_X.reshape([conf.batch_size, \ conf.img_height, conf.img_width, conf.channel])) - y_ = np.zeros((batch_y.shape[0], conf.num_classes)) - y_[np.arange(batch_y.shape[0]), batch_y] = 1 - batch_y = y_ + batch_y = one_hot(batch_y, conf.num_classes) else: batch_X, pointer = get_batch(data, pointer, conf.batch_size) @@ -39,7 +37,7 @@ def train(conf, data): print "Epoch: %d, Cost: %f"%(i, cost) saver.save(sess, conf.ckpt_file) - generate_and_save(sess, model.X, model.pred, conf) + generate_and_save(sess, model.X, model.h, model.pred, conf) if __name__ == "__main__": @@ -51,6 +49,7 @@ def train(conf, data): parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--grad_clip', type=int, default=1) + parser.add_argument('--conditional', type=bool, default=False) parser.add_argument('--data_path', type=str, default='data') parser.add_argument('--ckpt_path', type=str, default='ckpts') parser.add_argument('--samples_path', type=str, default='samples') @@ -65,7 +64,7 @@ def train(conf, data): conf.img_height = 28 conf.img_width = 28 conf.channel = 1 - conf.num_batches = 10#mnist.train.num_examples // conf.batch_size + conf.num_batches = data.train.num_examples // conf.batch_size else: import cPickle data = cPickle.load(open('cifar-100-python/train', 'r'))['data'] @@ -81,7 +80,7 @@ def train(conf, data): # Implementing tf.image.per_image_whitening for normalization # data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0 - conf.num_batches = data.shape[0] // conf.batch_size + conf.num_batches = 10#data.shape[0] // conf.batch_size ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map)) if not os.path.exists(ckpt_full_path): diff --git a/models.py b/models.py index 64d942a..e213a61 100644 --- a/models.py +++ b/models.py @@ -8,7 +8,9 @@ def __init__(self, conf): self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0) v_stack_in, h_stack_in = self.X_norm, self.X_norm # TODO norm for multichannel: dubtract mean and divide by std feature-wise - self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes]) + self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes]) + if conf.conditional is False: + self.h = None for i in range(conf.layers): filter_size = 3 if i > 0 else 7 diff --git a/utils.py b/utils.py index c9e9de1..32b8f24 100644 --- a/utils.py +++ b/utils.py @@ -7,14 +7,17 @@ def binarize(images): return (np.random.uniform(size=images.shape) < images).astype(np.float32) -def generate_and_save(sess, X, pred, conf): +def generate_and_save(sess, X, h, pred, conf): n_row, n_col = 5, 5 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) + # TODO make it generic + labels = one_hot([1,2,3,4,5]*5) + for i in xrange(conf.img_height): for j in xrange(conf.img_width): for k in xrange(conf.channel): - next_sample = sess.run(pred, {X:samples}) - if data == "mnist": + next_sample = sess.run(pred, {X:samples, h: labels}) + if conf.data == "mnist": next_sample = binarize(next_sample) samples[:, i, j, k] = next_sample[:, i, j, k] images = samples @@ -32,3 +35,10 @@ def get_batch(data, pointer, batch_size): batch = data[batch_size * pointer : batch_size * (pointer + 1)] pointer += 1 return [batch, pointer] + +def one_hot(batch_y, num_classes): + y_ = np.zeros((batch_y.shape[0], num_classes)) + y_[np.arange(batch_y.shape[0]), batch_y] = 1 + return y_ + +