Skip to content

Commit

Permalink
prettify code
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent c132983 commit fe30471
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 158 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
*.pyc
ckpts
samples
data
.DS_Store
71 changes: 71 additions & 0 deletions layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import tensorflow as tf
import numpy as np

def get_weights(shape, name, mask=None):
weights_initializer = tf.contrib.layers.xavier_initializer()
W = tf.get_variable(name, shape, tf.float32, weights_initializer)

if mask:
filter_mid_x = shape[0]//2
filter_mid_y = shape[1]//2
mask_filter = np.ones(shape, dtype=np.float32)
mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0.
mask_filter[filter_mid_x+1:, :, :, :] = 0.

if mask == 'a':
mask_filter[filter_mid_x, filter_mid_y, :, :] = 0.

W *= mask_filter
return W

def get_bias(shape, name):
return tf.get_variable(name, shape, tf.float32, tf.zeros_initializer)

def conv_op(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

class GatedCNN():
def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True):
self.fan_in = fan_in
in_dim = self.fan_in.get_shape()[-1]
self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]]
self.b_shape = W_shape[2]

self.payload = payload
self.mask = mask
self.activation = activation


if gated:
self.gated_conv()
else:
self.simple_conv()

def gated_conv(self):
W_f = get_weights(self.W_shape, "v_W", mask=self.mask)
b_f = get_bias(self.b_shape, "v_b")
W_g = get_weights(self.W_shape, "h_W", mask=self.mask)
b_g = get_bias(self.b_shape, "h_b")

conv_f = conv_op(self.fan_in, W_f)
conv_g = conv_op(self.fan_in, W_g)

if self.payload is not None:
conv_f += self.payload
conv_g += self.payload

self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))

def simple_conv(self):
W = get_weights(self.W_shape, "W", mask=self.mask)
b = get_bias(self.b_shape, "b")
conv = conv_op(self.fan_in, W)
if self.activation:
self.fan_out = tf.nn.relu(tf.add(conv, b))
else:
self.fan_out = tf.add(conv, b)

def output(self):
return self.fan_out


145 changes: 66 additions & 79 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,71 @@
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import argparse
from models import PixelCNN
from utils import *

mnist = input_data.read_data_sets("data/")
epochs = 50
batch_size = 100
grad_clip = 1
num_batches = mnist.train.num_examples // batch_size
ckpt_dir = "ckpts"
samples_dir = "samples"
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
ckpt_file = os.path.join(ckpt_dir, "model.ckpt")

img_height = 28
img_width = 28
channel = 1

LAYERS = 12
F_MAP = 32
FILTER_SIZE = 7

X = tf.placeholder(tf.float32, shape=[None, img_height, img_width, channel])
v_stack_in, h_stack_in = X, X

for i in range(LAYERS):
FILTER_SIZE = 3 if i > 0 else FILTER_SIZE
in_dim = F_MAP if i > 0 else channel
mask = 'b' if i > 0 else 'a'
residual = True if i > 0 else False
i = str(i)
with tf.variable_scope("v_stack"+i):
v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, F_MAP], v_stack_in, mask=mask).output()
v_stack_in = v_stack

with tf.variable_scope("v_stack_1"+i):
v_stack_1 = PixelCNN([1, 1, F_MAP], v_stack_in, gated=False, mask=mask).output()

with tf.variable_scope("h_stack"+i):
h_stack = PixelCNN([1, FILTER_SIZE, F_MAP], h_stack_in, payload=v_stack_1, mask=mask).output()

with tf.variable_scope("h_stack_1"+i):
h_stack_1 = PixelCNN([1, 1, F_MAP], h_stack, gated=False, mask=mask).output()
if residual:
h_stack_1 += h_stack_in # Residual connection
h_stack_in = h_stack_1

with tf.variable_scope("fc_1"):
fc1 = PixelCNN([1, 1, F_MAP], h_stack_in, gated=False, mask='b').output()

# handle Imagenet differently
with tf.variable_scope("fc_2"):
fc2 = PixelCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output()
pred = tf.nn.sigmoid(fc2)

loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X))

trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(loss)

clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients]
optimizer = trainer.apply_gradients(clipped_gradients)

saver = tf.train.Saver()
with tf.Session() as sess:
if os.path.exists(ckpt_file):
saver.restore(sess, ckpt_file)
print "Model Restored"
else:
sess.run(tf.initialize_all_variables())

for i in range(epochs):
for j in range(num_batches):
batch_X = binarize(mnist.train.next_batch(batch_size)[0] \
.reshape([batch_size, img_height, img_width, 1]))
_, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})

print "Epoch: %d, Cost: %f"%(i, cost)

generate_and_save(sess, X, pred, img_height, img_width, epochs, samples_dir)
saver.save(sess, ckpt_file)
def train(conf, data):
model = PixelCNN(conf)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model.fc2, model.X))

trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(loss)

clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients]
optimizer = trainer.apply_gradients(clipped_gradients)

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
if os.path.exists(conf.ckpt_file):
saver.restore(sess, conf.ckpt_file)
print "Model Restored"
else:
sess.run(tf.initialize_all_variables())

for i in range(conf.epochs):
for j in range(conf.num_batches):
batch_X = binarize(data.train.next_batch(conf.batch_size)[0] \
.reshape([conf.batch_size, conf.img_height, conf.img_width, conf.channel]))
_, cost = sess.run([optimizer, loss], feed_dict={model.X:batch_X})

print "Epoch: %d, Cost: %f"%(i, cost)

generate_and_save(sess, model.X, model.pred, conf)
saver.save(sess, conf.ckpt_file)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, default='train')
parser.add_argument('--data', type=str, default='mnist')
parser.add_argument('--layers', type=int, default=12)
parser.add_argument('--f_map', type=int, default=32)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--grad_clip', type=int, default=1)
parser.add_argument('--data_path', type=str, default='data')
parser.add_argument('--ckpt_path', type=str, default='ckpts')
parser.add_argument('--samples_path', type=str, default='samples')
conf = parser.parse_args()

if conf.data == 'mnist':
from tensorflow.examples.tutorials.mnist import input_data
if not os.path.exists(conf.data_path):
os.makedirs(conf.data_path)
data = input_data.read_data_sets(conf.data_path)
conf.img_height = 28
conf.img_width = 28
conf.channel = 1
conf.num_batches = 10#mnist.train.num_examples // conf.batch_size
conf.filter_size = 7

ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map))
if not os.path.exists(ckpt_full_path):
os.makedirs(ckpt_full_path)
conf.ckpt_file = os.path.join(ckpt_full_path, "model.ckpt")

conf.samples_path = os.path.join(conf.samples_path, "epoch=%d_bs=%d_layers=%d_fmap=%d"%(conf.epochs, conf.batch_size, conf.layers, conf.f_map))
if not os.path.exists(conf.samples_path):
os.makedirs(conf.samples_path)

train(conf, data)
103 changes: 35 additions & 68 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,38 @@
import tensorflow as tf
import numpy as np

def get_weights(shape, name, mask=None):
weights_initializer = tf.contrib.layers.xavier_initializer()
W = tf.get_variable(name, shape, tf.float32, weights_initializer)

if mask:
filter_mid_x = shape[0]//2
filter_mid_y = shape[1]//2
mask_filter = np.ones(shape, dtype=np.float32)
mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0.
mask_filter[filter_mid_x+1:, :, :, :] = 0.

if mask == 'a':
mask_filter[filter_mid_x, filter_mid_y, :, :] = 0.

W *= mask_filter
return W

def get_bias(shape, name):
return tf.get_variable(name, shape, tf.float32, tf.zeros_initializer)

def conv_op(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
from layers import GatedCNN

class PixelCNN():
def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True):
self.fan_in = fan_in
in_dim = self.fan_in.get_shape()[-1]
self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]]
self.b_shape = W_shape[2]

self.payload = payload
self.mask = mask
self.activation = activation


if gated:
self.gated_conv()
else:
self.simple_conv()

def gated_conv(self):
W_f = get_weights(self.W_shape, "v_W", mask=self.mask)
b_f = get_bias(self.b_shape, "v_b")
W_g = get_weights(self.W_shape, "h_W", mask=self.mask)
b_g = get_bias(self.b_shape, "h_b")

conv_f = conv_op(self.fan_in, W_f)
conv_g = conv_op(self.fan_in, W_g)

if self.payload is not None:
conv_f += self.payload
conv_g += self.payload

self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))

def simple_conv(self):
W = get_weights(self.W_shape, "W", mask=self.mask)
b = get_bias(self.b_shape, "b")
conv = conv_op(self.fan_in, W)
if self.activation:
self.fan_out = tf.nn.relu(tf.add(conv, b))
else:
self.fan_out = tf.add(conv, b)

def output(self):
return self.fan_out


def __init__(self, conf):

self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
v_stack_in, h_stack_in = self.X, self.X

for i in range(conf.layers):
filter_size = 3 if i > 0 else conf.filter_size
in_dim = conf.f_map if i > 0 else conf.channel
mask = 'b' if i > 0 else 'a'
residual = True if i > 0 else False
i = str(i)
with tf.variable_scope("v_stack"+i):
v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask).output()
v_stack_in = v_stack

with tf.variable_scope("v_stack_1"+i):
v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, gated=False, mask=mask).output()

with tf.variable_scope("h_stack"+i):
h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask).output()

with tf.variable_scope("h_stack_1"+i):
h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, gated=False, mask=mask).output()
if residual:
h_stack_1 += h_stack_in # Residual connection
h_stack_in = h_stack_1

with tf.variable_scope("fc_1"):
fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, gated=False, mask='b').output()

# handle Imagenet differently
with tf.variable_scope("fc_2"):
self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output()
self.pred = tf.nn.sigmoid(self.fc2)
18 changes: 7 additions & 11 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,20 @@
def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype(np.float32)

def generate_and_save(sess, X, pred, img_height, img_width, epoch, samples_dir):
sample_save_dir = samples_dir.rstrip("/")+"_"+str(epoch)
if not os.path.isdir(sample_save_dir):
os.makedirs(sample_save_dir)

def generate_and_save(sess, X, pred, conf):
n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, img_height, img_width, 1), dtype=np.float32)
for i in xrange(img_height):
for j in xrange(img_width):
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, 1), dtype=np.float32)
for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
for k in xrange(1):
next_sample = binarize(sess.run(pred, {X:samples}))
samples[:, i, j, k] = next_sample[:, i, j, k]
images = samples
images = images.reshape((n_row, n_col, img_height, img_width))
images = images.reshape((n_row, n_col, conf.img_height, conf.img_width))
images = images.transpose(1, 2, 0, 3)
images = images.reshape((img_height * n_row, img_width * n_col))
images = images.reshape((conf.img_height * n_row, conf.img_width * n_col))

filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+".jpg"
scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(sample_save_dir, filename))
scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(conf.samples_path, filename))


0 comments on commit fe30471

Please sign in to comment.