Skip to content

Commit

Permalink
file rename, model save and restore
Browse files Browse the repository at this point in the history
fix

residual conn
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent d2c5abc commit c132983
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 44 deletions.
66 changes: 22 additions & 44 deletions pixelcnn_tf.py → main.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,36 @@
# TODO
# kundan: concat payload instead of add
# : arch: without 1X1 in 1st layer and last 2 layers
# : replaces masking with n/2 filter
# check network arch
# autoencoder
# cost on test set
# make for imagenet data: upscale-downscale-q_level, mean pixel value if not mnist
# stats
# logger

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime
import scipy.misc
import os
from models import PixelCNN
from utils import *

mnist = input_data.read_data_sets("data/")
epochs = 1
batch_size = 50
epochs = 50
batch_size = 100
grad_clip = 1
num_batches = mnist.train.num_examples // batch_size
ckpt_dir = "ckpts"
samples_dir = "samples"
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
ckpt_file = os.path.join(ckpt_dir, "model.ckpt")

img_height = 28
img_width = 28
channel = 1

LAYERS = 3
LAYERS = 12
F_MAP = 32
FILTER_SIZE = 7

X = tf.placeholder(tf.float32, shape=[None, img_height, img_width, channel])
v_stack_in, h_stack_in = X, X

# TODO encapsulate
for i in range(LAYERS):
FILTER_SIZE = 3 if i > 0 else FILTER_SIZE
in_dim = F_MAP if i > 0 else channel
mask = 'b' if i > 0 else 'a'
residual = True if i > 0 else False
i = str(i)
with tf.variable_scope("v_stack"+i):
v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, F_MAP], v_stack_in, mask=mask).output()
Expand All @@ -51,7 +44,8 @@

with tf.variable_scope("h_stack_1"+i):
h_stack_1 = PixelCNN([1, 1, F_MAP], h_stack, gated=False, mask=mask).output()
#h_stack_1 += h_stack_in # Residual connection
if residual:
h_stack_1 += h_stack_in # Residual connection
h_stack_in = h_stack_1

with tf.variable_scope("fc_1"):
Expand All @@ -70,37 +64,21 @@
clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients]
optimizer = trainer.apply_gradients(clipped_gradients)


def binarize(images):
return (0.0 < images).astype(np.float32)

def generate_and_save(sess):
n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, img_height, img_width, 1), dtype=np.float32)
for i in xrange(img_height):
for j in xrange(img_width):
for k in xrange(1):
next_sample = binarize(sess.run(pred, {X:samples}))
samples[:, i, j, k] = next_sample[:, i, j, k]
images = samples
images = images.reshape((n_row, n_col, img_height, img_width))
images = images.transpose(1, 2, 0, 3)
images = images.reshape((img_height * n_row, img_width * n_col))

filename = '%s_%s.jpg' % ("sample", str(datetime.now()))
scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join("samples", filename))

num_batches = mnist.train.num_examples // batch_size
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if os.path.exists(ckpt_file):
saver.restore(sess, ckpt_file)
print "Model Restored"
else:
sess.run(tf.initialize_all_variables())

for i in range(epochs):
for j in range(num_batches):
batch_X = binarize(mnist.train.next_batch(batch_size)[0] \
.reshape([batch_size, img_height, img_width, 1]))
_, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})

print "Epoch: %d, Cost: %f"%(i, cost)
generate_and_save(sess)

generate_and_save(sess)

generate_and_save(sess, X, pred, img_height, img_width, epochs, samples_dir)
saver.save(sess, ckpt_file)
30 changes: 30 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import os
import scipy.misc
from datetime import datetime


def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype(np.float32)

def generate_and_save(sess, X, pred, img_height, img_width, epoch, samples_dir):
sample_save_dir = samples_dir.rstrip("/")+"_"+str(epoch)
if not os.path.isdir(sample_save_dir):
os.makedirs(sample_save_dir)

n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, img_height, img_width, 1), dtype=np.float32)
for i in xrange(img_height):
for j in xrange(img_width):
for k in xrange(1):
next_sample = binarize(sess.run(pred, {X:samples}))
samples[:, i, j, k] = next_sample[:, i, j, k]
images = samples
images = images.reshape((n_row, n_col, img_height, img_width))
images = images.transpose(1, 2, 0, 3)
images = images.reshape((img_height * n_row, img_width * n_col))

filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+".jpg"
scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(sample_save_dir, filename))


0 comments on commit c132983

Please sign in to comment.