Skip to content

Commit

Permalink
fixes for multi-channel
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent da7dde7 commit d0069b1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
7 changes: 3 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ def train(conf, data):
saver = tf.train.Saver(tf.trainable_variables())

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if os.path.exists(conf.ckpt_file):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"
else:
sess.run(tf.initialize_all_variables())


pointer = 0
for i in range(conf.epochs):
for j in range(conf.num_batches):
Expand All @@ -35,8 +34,8 @@ def train(conf, data):

print "Epoch: %d, Cost: %f"%(i, cost)

generate_and_save(sess, model.X, model.pred, conf)
saver.save(sess, conf.ckpt_file)
generate_and_save(sess, model.X, model.pred, conf)


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
class PixelCNN():
def __init__(self, conf):

data_shape = [conf.batch_size, conf.img_height, conf.img_width, conf.channel]
self.X = tf.placeholder(tf.float32, shape=data_shape)
self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0)
v_stack_in, h_stack_in = self.X_norm, self.X_norm

Expand Down Expand Up @@ -56,5 +55,5 @@ def __init__(self, conf):
#self.loss = tf.reduce_mean(-tf.reduce_sum(X_ohehot * tf.log(self.fc2), reduction_indices=[1]))

# NOTE or check without argmax
self.preds = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), data_shape)
self.pred = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X))

8 changes: 5 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ def binarize(images):

def generate_and_save(sess, X, pred, conf):
n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, 1), dtype=np.float32)
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
for k in xrange(1):
next_sample = binarize(sess.run(pred, {X:samples}))
for k in xrange(conf.channel):
next_sample = sess.run(pred, {X:samples})
if data == "mnist":
next_sample = binarize(next_sample)
samples[:, i, j, k] = next_sample[:, i, j, k]
images = samples
images = images.reshape((n_row, n_col, conf.img_height, conf.img_width))
Expand Down

0 comments on commit d0069b1

Please sign in to comment.