Skip to content

Commit

Permalink
3 channel
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent fe30471 commit da7dde7
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 17 deletions.
36 changes: 27 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,34 @@

def train(conf, data):
model = PixelCNN(conf)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model.fc2, model.X))

trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(loss)
gradients = trainer.compute_gradients(model.loss)

clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients]
optimizer = trainer.apply_gradients(clipped_gradients)

saver = tf.train.Saver(tf.trainable_variables())

with tf.Session() as sess:
if os.path.exists(conf.ckpt_file):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"
else:
sess.run(tf.initialize_all_variables())


pointer = 0
for i in range(conf.epochs):
for j in range(conf.num_batches):
batch_X = binarize(data.train.next_batch(conf.batch_size)[0] \
if conf.data == "mnist":
batch_X = binarize(data.train.next_batch(conf.batch_size)[0] \
.reshape([conf.batch_size, conf.img_height, conf.img_width, conf.channel]))
_, cost = sess.run([optimizer, loss], feed_dict={model.X:batch_X})
else:
batch_X, pointer = get_batch(data, pointer, conf.batch_size)

_, cost = sess.run([optimizer, model.loss], feed_dict={model.X:batch_X})

print "Epoch: %d, Cost: %f"%(i, cost)
print "Epoch: %d, Cost: %f"%(i, cost)

generate_and_save(sess, model.X, model.pred, conf)
saver.save(sess, conf.ckpt_file)
Expand All @@ -38,7 +43,7 @@ def train(conf, data):
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, default='train')
parser.add_argument('--data', type=str, default='mnist')
parser.add_argument('--layers', type=int, default=12)
parser.add_argument('--layers', type=int, default=5)
parser.add_argument('--f_map', type=int, default=32)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=100)
Expand All @@ -56,8 +61,21 @@ def train(conf, data):
conf.img_height = 28
conf.img_width = 28
conf.channel = 1
conf.num_batches = 10#mnist.train.num_examples // conf.batch_size
conf.filter_size = 7
conf.num_batches = mnist.train.num_examples // conf.batch_size
else:
import cPickle
data = cPickle.load(open('cifar-100-python/train', 'r'))['data']
conf.img_height = 32
conf.img_width = 32
conf.channel = 3
data = np.reshape(data, (data.shape[0], conf.channel, \
conf.img_height, conf.img_width))
data = np.transpose(data, (0, 2, 3, 1))

# Implementing tf.image.per_image_whitening for normalization
# data = (data-np.mean(data)) / max(np.std(data), 1.0/np.sqrt(sum(data.shape))) * 255.0

conf.num_batches = data.shape[0] // conf.batch_size

ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map))
if not os.path.exists(ckpt_full_path):
Expand Down
38 changes: 30 additions & 8 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

class PixelCNN():
def __init__(self, conf):

self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
v_stack_in, h_stack_in = self.X, self.X

data_shape = [conf.batch_size, conf.img_height, conf.img_width, conf.channel]
self.X = tf.placeholder(tf.float32, shape=data_shape)
self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0)
v_stack_in, h_stack_in = self.X_norm, self.X_norm

for i in range(conf.layers):
filter_size = 3 if i > 0 else conf.filter_size
filter_size = 3 if i > 0 else 7
in_dim = conf.f_map if i > 0 else conf.channel
mask = 'b' if i > 0 else 'a'
residual = True if i > 0 else False
Expand All @@ -32,7 +34,27 @@ def __init__(self, conf):
with tf.variable_scope("fc_1"):
fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, gated=False, mask='b').output()

# handle Imagenet differently
with tf.variable_scope("fc_2"):
self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output()
self.pred = tf.nn.sigmoid(self.fc2)
if conf.data == "mnist":
with tf.variable_scope("fc_2"):
self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output()
self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model.fc2, model.X))
self.pred = tf.nn.sigmoid(self.fc2)
else:
color_dim = 256
with tf.variable_scope("fc_2"):
self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, gated=False, mask='b', activation=False).output()
#fc2_shape = self.fc2.get_shape()
#self.fc2 = tf.reshape(self.fc2, (int(fc2_shape[0]), int(fc2_shape[1]), int(fc2_shape[2]), conf.channel, -1))
#fc2_shape = self.fc2.get_shape()
#self.fc2 = tf.nn.softmax(tf.reshape(self.fc2, (-1, int(fc2_shape[-1]))))
self.fc2 = tf.reshape(self.fc2, (-1, color_dim))

#self.loss = self.categorical_crossentropy(self.fc2, self.X)
#self.X_flat = tf.reshape(self.X, [-1])
#self.fc2_flat = tf.cast(tf.argmax(self.fc2, dimension=tf.rank(self.fc2) - 1), dtype=tf.float32)
self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32)))
#self.loss = tf.reduce_mean(-tf.reduce_sum(X_ohehot * tf.log(self.fc2), reduction_indices=[1]))

# NOTE or check without argmax
self.preds = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), data_shape)

6 changes: 6 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,9 @@ def generate_and_save(sess, X, pred, conf):
scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(conf.samples_path, filename))


def get_batch(data, pointer, batch_size):
if (batch_size + 1) * pointer >= data.shape[0]:
pointer = 0
batch = data[batch_size * pointer : batch_size * (pointer + 1)]
pointer += 1
return [batch, pointer]

0 comments on commit da7dde7

Please sign in to comment.