diff --git a/.gitignore b/.gitignore index 0d20b64..540c9a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,5 @@ *.pyc +ckpts +samples +data +.DS_Store diff --git a/layers.py b/layers.py new file mode 100644 index 0000000..569aac5 --- /dev/null +++ b/layers.py @@ -0,0 +1,71 @@ +import tensorflow as tf +import numpy as np + +def get_weights(shape, name, mask=None): + weights_initializer = tf.contrib.layers.xavier_initializer() + W = tf.get_variable(name, shape, tf.float32, weights_initializer) + + if mask: + filter_mid_x = shape[0]//2 + filter_mid_y = shape[1]//2 + mask_filter = np.ones(shape, dtype=np.float32) + mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0. + mask_filter[filter_mid_x+1:, :, :, :] = 0. + + if mask == 'a': + mask_filter[filter_mid_x, filter_mid_y, :, :] = 0. + + W *= mask_filter + return W + +def get_bias(shape, name): + return tf.get_variable(name, shape, tf.float32, tf.zeros_initializer) + +def conv_op(x, W): + return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') + +class GatedCNN(): + def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True): + self.fan_in = fan_in + in_dim = self.fan_in.get_shape()[-1] + self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]] + self.b_shape = W_shape[2] + + self.payload = payload + self.mask = mask + self.activation = activation + + + if gated: + self.gated_conv() + else: + self.simple_conv() + + def gated_conv(self): + W_f = get_weights(self.W_shape, "v_W", mask=self.mask) + b_f = get_bias(self.b_shape, "v_b") + W_g = get_weights(self.W_shape, "h_W", mask=self.mask) + b_g = get_bias(self.b_shape, "h_b") + + conv_f = conv_op(self.fan_in, W_f) + conv_g = conv_op(self.fan_in, W_g) + + if self.payload is not None: + conv_f += self.payload + conv_g += self.payload + + self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g)) + + def simple_conv(self): + W = get_weights(self.W_shape, "W", mask=self.mask) + b = get_bias(self.b_shape, "b") + conv = conv_op(self.fan_in, W) + if self.activation: + self.fan_out = tf.nn.relu(tf.add(conv, b)) + else: + self.fan_out = tf.add(conv, b) + + def output(self): + return self.fan_out + + diff --git a/main.py b/main.py index c88b917..e190226 100644 --- a/main.py +++ b/main.py @@ -1,84 +1,71 @@ import tensorflow as tf import numpy as np -from tensorflow.examples.tutorials.mnist import input_data +import argparse from models import PixelCNN from utils import * -mnist = input_data.read_data_sets("data/") -epochs = 50 -batch_size = 100 -grad_clip = 1 -num_batches = mnist.train.num_examples // batch_size -ckpt_dir = "ckpts" -samples_dir = "samples" -if not os.path.isdir(ckpt_dir): - os.makedirs(ckpt_dir) -ckpt_file = os.path.join(ckpt_dir, "model.ckpt") - -img_height = 28 -img_width = 28 -channel = 1 - -LAYERS = 12 -F_MAP = 32 -FILTER_SIZE = 7 - -X = tf.placeholder(tf.float32, shape=[None, img_height, img_width, channel]) -v_stack_in, h_stack_in = X, X - -for i in range(LAYERS): - FILTER_SIZE = 3 if i > 0 else FILTER_SIZE - in_dim = F_MAP if i > 0 else channel - mask = 'b' if i > 0 else 'a' - residual = True if i > 0 else False - i = str(i) - with tf.variable_scope("v_stack"+i): - v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, F_MAP], v_stack_in, mask=mask).output() - v_stack_in = v_stack - - with tf.variable_scope("v_stack_1"+i): - v_stack_1 = PixelCNN([1, 1, F_MAP], v_stack_in, gated=False, mask=mask).output() - - with tf.variable_scope("h_stack"+i): - h_stack = PixelCNN([1, FILTER_SIZE, F_MAP], h_stack_in, payload=v_stack_1, mask=mask).output() - - with tf.variable_scope("h_stack_1"+i): - h_stack_1 = PixelCNN([1, 1, F_MAP], h_stack, gated=False, mask=mask).output() - if residual: - h_stack_1 += h_stack_in # Residual connection - h_stack_in = h_stack_1 - -with tf.variable_scope("fc_1"): - fc1 = PixelCNN([1, 1, F_MAP], h_stack_in, gated=False, mask='b').output() - -# handle Imagenet differently -with tf.variable_scope("fc_2"): - fc2 = PixelCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output() -pred = tf.nn.sigmoid(fc2) - -loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X)) - -trainer = tf.train.RMSPropOptimizer(1e-3) -gradients = trainer.compute_gradients(loss) - -clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients] -optimizer = trainer.apply_gradients(clipped_gradients) - -saver = tf.train.Saver() -with tf.Session() as sess: - if os.path.exists(ckpt_file): - saver.restore(sess, ckpt_file) - print "Model Restored" - else: - sess.run(tf.initialize_all_variables()) - - for i in range(epochs): - for j in range(num_batches): - batch_X = binarize(mnist.train.next_batch(batch_size)[0] \ - .reshape([batch_size, img_height, img_width, 1])) - _, cost = sess.run([optimizer, loss], feed_dict={X:batch_X}) - - print "Epoch: %d, Cost: %f"%(i, cost) - - generate_and_save(sess, X, pred, img_height, img_width, epochs, samples_dir) - saver.save(sess, ckpt_file) +def train(conf, data): + model = PixelCNN(conf) + loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model.fc2, model.X)) + + trainer = tf.train.RMSPropOptimizer(1e-3) + gradients = trainer.compute_gradients(loss) + + clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients] + optimizer = trainer.apply_gradients(clipped_gradients) + + saver = tf.train.Saver(tf.trainable_variables()) + with tf.Session() as sess: + if os.path.exists(conf.ckpt_file): + saver.restore(sess, conf.ckpt_file) + print "Model Restored" + else: + sess.run(tf.initialize_all_variables()) + + for i in range(conf.epochs): + for j in range(conf.num_batches): + batch_X = binarize(data.train.next_batch(conf.batch_size)[0] \ + .reshape([conf.batch_size, conf.img_height, conf.img_width, conf.channel])) + _, cost = sess.run([optimizer, loss], feed_dict={model.X:batch_X}) + + print "Epoch: %d, Cost: %f"%(i, cost) + + generate_and_save(sess, model.X, model.pred, conf) + saver.save(sess, conf.ckpt_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--type', type=str, default='train') + parser.add_argument('--data', type=str, default='mnist') + parser.add_argument('--layers', type=int, default=12) + parser.add_argument('--f_map', type=int, default=32) + parser.add_argument('--epochs', type=int, default=50) + parser.add_argument('--batch_size', type=int, default=100) + parser.add_argument('--grad_clip', type=int, default=1) + parser.add_argument('--data_path', type=str, default='data') + parser.add_argument('--ckpt_path', type=str, default='ckpts') + parser.add_argument('--samples_path', type=str, default='samples') + conf = parser.parse_args() + + if conf.data == 'mnist': + from tensorflow.examples.tutorials.mnist import input_data + if not os.path.exists(conf.data_path): + os.makedirs(conf.data_path) + data = input_data.read_data_sets(conf.data_path) + conf.img_height = 28 + conf.img_width = 28 + conf.channel = 1 + conf.num_batches = 10#mnist.train.num_examples // conf.batch_size + conf.filter_size = 7 + + ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map)) + if not os.path.exists(ckpt_full_path): + os.makedirs(ckpt_full_path) + conf.ckpt_file = os.path.join(ckpt_full_path, "model.ckpt") + + conf.samples_path = os.path.join(conf.samples_path, "epoch=%d_bs=%d_layers=%d_fmap=%d"%(conf.epochs, conf.batch_size, conf.layers, conf.f_map)) + if not os.path.exists(conf.samples_path): + os.makedirs(conf.samples_path) + + train(conf, data) diff --git a/models.py b/models.py index e3e8e4a..9705161 100644 --- a/models.py +++ b/models.py @@ -1,71 +1,38 @@ import tensorflow as tf -import numpy as np - -def get_weights(shape, name, mask=None): - weights_initializer = tf.contrib.layers.xavier_initializer() - W = tf.get_variable(name, shape, tf.float32, weights_initializer) - - if mask: - filter_mid_x = shape[0]//2 - filter_mid_y = shape[1]//2 - mask_filter = np.ones(shape, dtype=np.float32) - mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0. - mask_filter[filter_mid_x+1:, :, :, :] = 0. - - if mask == 'a': - mask_filter[filter_mid_x, filter_mid_y, :, :] = 0. - - W *= mask_filter - return W - -def get_bias(shape, name): - return tf.get_variable(name, shape, tf.float32, tf.zeros_initializer) - -def conv_op(x, W): - return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') +from layers import GatedCNN class PixelCNN(): - def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True): - self.fan_in = fan_in - in_dim = self.fan_in.get_shape()[-1] - self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]] - self.b_shape = W_shape[2] - - self.payload = payload - self.mask = mask - self.activation = activation - - - if gated: - self.gated_conv() - else: - self.simple_conv() - - def gated_conv(self): - W_f = get_weights(self.W_shape, "v_W", mask=self.mask) - b_f = get_bias(self.b_shape, "v_b") - W_g = get_weights(self.W_shape, "h_W", mask=self.mask) - b_g = get_bias(self.b_shape, "h_b") - - conv_f = conv_op(self.fan_in, W_f) - conv_g = conv_op(self.fan_in, W_g) - - if self.payload is not None: - conv_f += self.payload - conv_g += self.payload - - self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g)) - - def simple_conv(self): - W = get_weights(self.W_shape, "W", mask=self.mask) - b = get_bias(self.b_shape, "b") - conv = conv_op(self.fan_in, W) - if self.activation: - self.fan_out = tf.nn.relu(tf.add(conv, b)) - else: - self.fan_out = tf.add(conv, b) - - def output(self): - return self.fan_out - - + def __init__(self, conf): + + self.X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) + v_stack_in, h_stack_in = self.X, self.X + + for i in range(conf.layers): + filter_size = 3 if i > 0 else conf.filter_size + in_dim = conf.f_map if i > 0 else conf.channel + mask = 'b' if i > 0 else 'a' + residual = True if i > 0 else False + i = str(i) + with tf.variable_scope("v_stack"+i): + v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, mask=mask).output() + v_stack_in = v_stack + + with tf.variable_scope("v_stack_1"+i): + v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, gated=False, mask=mask).output() + + with tf.variable_scope("h_stack"+i): + h_stack = GatedCNN([1, filter_size, conf.f_map], h_stack_in, payload=v_stack_1, mask=mask).output() + + with tf.variable_scope("h_stack_1"+i): + h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, gated=False, mask=mask).output() + if residual: + h_stack_1 += h_stack_in # Residual connection + h_stack_in = h_stack_1 + + with tf.variable_scope("fc_1"): + fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, gated=False, mask='b').output() + + # handle Imagenet differently + with tf.variable_scope("fc_2"): + self.fc2 = GatedCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output() + self.pred = tf.nn.sigmoid(self.fc2) diff --git a/utils.py b/utils.py index 34299f3..c47de82 100644 --- a/utils.py +++ b/utils.py @@ -7,24 +7,20 @@ def binarize(images): return (np.random.uniform(size=images.shape) < images).astype(np.float32) -def generate_and_save(sess, X, pred, img_height, img_width, epoch, samples_dir): - sample_save_dir = samples_dir.rstrip("/")+"_"+str(epoch) - if not os.path.isdir(sample_save_dir): - os.makedirs(sample_save_dir) - +def generate_and_save(sess, X, pred, conf): n_row, n_col = 5, 5 - samples = np.zeros((n_row*n_col, img_height, img_width, 1), dtype=np.float32) - for i in xrange(img_height): - for j in xrange(img_width): + samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, 1), dtype=np.float32) + for i in xrange(conf.img_height): + for j in xrange(conf.img_width): for k in xrange(1): next_sample = binarize(sess.run(pred, {X:samples})) samples[:, i, j, k] = next_sample[:, i, j, k] images = samples - images = images.reshape((n_row, n_col, img_height, img_width)) + images = images.reshape((n_row, n_col, conf.img_height, conf.img_width)) images = images.transpose(1, 2, 0, 3) - images = images.reshape((img_height * n_row, img_width * n_col)) + images = images.reshape((conf.img_height * n_row, conf.img_width * n_col)) filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+".jpg" - scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(sample_save_dir, filename)) + scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(conf.samples_path, filename))