Skip to content

Commit

Permalink
gatedcnn-1
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent d0069b1 commit 40ac4a4
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
18 changes: 13 additions & 5 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def conv_op(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

class GatedCNN():
def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True):
def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True, conditional=None):
self.fan_in = fan_in
in_dim = self.fan_in.get_shape()[-1]
self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]]
Expand All @@ -34,19 +34,27 @@ def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activat
self.payload = payload
self.mask = mask
self.activation = activation
# TODO need to map (batch_size,num_classes) to (f_map,)
self.conditional = None#conditional


if gated:
self.gated_conv()
else:
self.simple_conv()

def gated_conv(self):
W_f = get_weights(self.W_shape, "v_W", mask=self.mask)
b_f = get_bias(self.b_shape, "v_b")
W_g = get_weights(self.W_shape, "h_W", mask=self.mask)
b_g = get_bias(self.b_shape, "h_b")

if self.conditional is not None:
h_shape = int(self.conditional.get_shape()[1])
V_f = get_weights([h_shape, self.W_shape[3]], "v_V")
b_f = tf.matmul(self.conditional, V_f)
V_g = get_weights([h_shape, self.W_shape[3]], "h_V")
b_g = tf.matmul(self.conditional, V_g)
else:
b_f = get_bias(self.b_shape, "v_b")
b_g = get_bias(self.b_shape, "h_b")

conv_f = conv_op(self.fan_in, W_f)
conv_g = conv_op(self.fan_in, W_g)

Expand Down
15 changes: 11 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ def train(conf, data):
for i in range(conf.epochs):
for j in range(conf.num_batches):
if conf.data == "mnist":
batch_X = binarize(data.train.next_batch(conf.batch_size)[0] \
.reshape([conf.batch_size, conf.img_height, conf.img_width, conf.channel]))
batch_X, batch_y = data.train.next_batch(conf.batch_size)
batch_X = binarize(batch_X.reshape([conf.batch_size, \
conf.img_height, conf.img_width, conf.channel]))
y_ = np.zeros((batch_y.shape[0], conf.num_classes))
y_[np.arange(batch_y.shape[0]), batch_y] = 1
batch_y = y_
else:
batch_X, pointer = get_batch(data, pointer, conf.batch_size)

_, cost = sess.run([optimizer, model.loss], feed_dict={model.X:batch_X})
_, cost = sess.run([optimizer, model.loss], feed_dict={model.X:batch_X, model.h:batch_y})

print "Epoch: %d, Cost: %f"%(i, cost)

Expand All @@ -57,10 +61,11 @@ def train(conf, data):
if not os.path.exists(conf.data_path):
os.makedirs(conf.data_path)
data = input_data.read_data_sets(conf.data_path)
conf.num_classes = 10
conf.img_height = 28
conf.img_width = 28
conf.channel = 1
conf.num_batches = mnist.train.num_examples // conf.batch_size
conf.num_batches = 10#mnist.train.num_examples // conf.batch_size
else:
import cPickle
data = cPickle.load(open('cifar-100-python/train', 'r'))['data']
Expand All @@ -70,6 +75,8 @@ def train(conf, data):
data = np.reshape(data, (data.shape[0], conf.channel, \
conf.img_height, conf.img_width))
data = np.transpose(data, (0, 2, 3, 1))
raise ValueError("Specify num_classes")
conf.num_classes = 10

# Implementing tf.image.per_image_whitening for normalization
# data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0
Expand Down
9 changes: 5 additions & 4 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@ def __init__(self, conf):
self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0)
v_stack_in, h_stack_in = self.X_norm, self.X_norm
# TODO norm for multichannel: dubtract mean and divide by std feature-wise
self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes])

for i in range(conf.layers):
filter_size = 3 if i > 0 else 7
in_dim = conf.f_map if i > 0 else conf.channel
mask = 'b' if i > 0 else 'a'
residual = True if i > 0 else False
i = str(i)
with tf.variable_scope("v_stack"+i):
v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask).output()
v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask, conditional=self.h).output()
v_stack_in = v_stack

with tf.variable_scope("v_stack_1"+i):
v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, gated=False, mask=mask).output()

with tf.variable_scope("h_stack"+i):
h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask).output()
h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask, conditional=self.h).output()

with tf.variable_scope("h_stack_1"+i):
h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, gated=False, mask=mask).output()
Expand All @@ -36,7 +37,7 @@ def __init__(self, conf):
if conf.data == "mnist":
with tf.variable_scope("fc_2"):
self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output()
self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model.fc2, model.X))
self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.fc2, self.X))
self.pred = tf.nn.sigmoid(self.fc2)
else:
color_dim = 256
Expand Down

0 comments on commit 40ac4a4

Please sign in to comment.