Skip to content

Commit

Permalink
conditional: test 1
Browse files Browse the repository at this point in the history
fixes
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent 40ac4a4 commit 715985e
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
9 changes: 7 additions & 2 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activat
self.payload = payload
self.mask = mask
self.activation = activation
# TODO need to map (batch_size,num_classes) to (f_map,)
self.conditional = None#conditional
self.conditional = conditional

if gated:
self.gated_conv()
Expand All @@ -51,6 +50,11 @@ def gated_conv(self):
b_f = tf.matmul(self.conditional, V_f)
V_g = get_weights([h_shape, self.W_shape[3]], "h_V")
b_g = tf.matmul(self.conditional, V_g)

b_f_shape = tf.shape(b_f)
b_f = tf.reshape(b_f, (b_f_shape[0], 1, 1, b_f_shape[1]))
b_g_shape = tf.shape(b_g)
b_g = tf.reshape(b_g, (b_g_shape[0], 1, 1, b_g_shape[1]))
else:
b_f = get_bias(self.b_shape, "v_b")
b_g = get_bias(self.b_shape, "h_b")
Expand All @@ -62,6 +66,7 @@ def gated_conv(self):
conv_f += self.payload
conv_g += self.payload


self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))

def simple_conv(self):
Expand Down
11 changes: 5 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def train(conf, data):
batch_X, batch_y = data.train.next_batch(conf.batch_size)
batch_X = binarize(batch_X.reshape([conf.batch_size, \
conf.img_height, conf.img_width, conf.channel]))
y_ = np.zeros((batch_y.shape[0], conf.num_classes))
y_[np.arange(batch_y.shape[0]), batch_y] = 1
batch_y = y_
batch_y = one_hot(batch_y, conf.num_classes)
else:
batch_X, pointer = get_batch(data, pointer, conf.batch_size)

Expand All @@ -39,7 +37,7 @@ def train(conf, data):
print "Epoch: %d, Cost: %f"%(i, cost)

saver.save(sess, conf.ckpt_file)
generate_and_save(sess, model.X, model.pred, conf)
generate_and_save(sess, model.X, model.h, model.pred, conf)


if __name__ == "__main__":
Expand All @@ -51,6 +49,7 @@ def train(conf, data):
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--grad_clip', type=int, default=1)
parser.add_argument('--conditional', type=bool, default=False)
parser.add_argument('--data_path', type=str, default='data')
parser.add_argument('--ckpt_path', type=str, default='ckpts')
parser.add_argument('--samples_path', type=str, default='samples')
Expand All @@ -65,7 +64,7 @@ def train(conf, data):
conf.img_height = 28
conf.img_width = 28
conf.channel = 1
conf.num_batches = 10#mnist.train.num_examples // conf.batch_size
conf.num_batches = data.train.num_examples // conf.batch_size
else:
import cPickle
data = cPickle.load(open('cifar-100-python/train', 'r'))['data']
Expand All @@ -81,7 +80,7 @@ def train(conf, data):
# Implementing tf.image.per_image_whitening for normalization
# data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0

conf.num_batches = data.shape[0] // conf.batch_size
conf.num_batches = 10#data.shape[0] // conf.batch_size

ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map))
if not os.path.exists(ckpt_full_path):
Expand Down
4 changes: 3 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def __init__(self, conf):
self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0)
v_stack_in, h_stack_in = self.X_norm, self.X_norm
# TODO norm for multichannel: dubtract mean and divide by std feature-wise
self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes])
self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes])
if conf.conditional is False:
self.h = None

for i in range(conf.layers):
filter_size = 3 if i > 0 else 7
Expand Down
16 changes: 13 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype(np.float32)

def generate_and_save(sess, X, pred, conf):
def generate_and_save(sess, X, h, pred, conf):
n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
# TODO make it generic
labels = one_hot([1,2,3,4,5]*5)

for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
for k in xrange(conf.channel):
next_sample = sess.run(pred, {X:samples})
if data == "mnist":
next_sample = sess.run(pred, {X:samples, h: labels})
if conf.data == "mnist":
next_sample = binarize(next_sample)
samples[:, i, j, k] = next_sample[:, i, j, k]
images = samples
Expand All @@ -32,3 +35,10 @@ def get_batch(data, pointer, batch_size):
batch = data[batch_size * pointer : batch_size * (pointer + 1)]
pointer += 1
return [batch, pointer]

def one_hot(batch_y, num_classes):
y_ = np.zeros((batch_y.shape[0], num_classes))
y_[np.arange(batch_y.shape[0]), batch_y] = 1
return y_


0 comments on commit 715985e

Please sign in to comment.