Skip to content

Commit

Permalink
img generation, tf sanitization
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Oct 30, 2016
1 parent 4d20c8a commit 32e5221
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 34 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pyc
10 changes: 5 additions & 5 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import numpy as np

def get_weights(shape, mask=None):
# TODO set init bounds
W = tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1))
weights_initializer = tf.contrib.layers.xavier_initializer()
W = tf.get_variable("weights", shape, tf.float32, weights_initializer)

if mask:
filter_mid_x = shape[0]//2
Expand All @@ -20,7 +20,7 @@ def get_weights(shape, mask=None):


def get_bias(shape):
return tf.Variable(tf.constant(shape=shape, value=0.1, dtype=tf.float32))
return tf.get_variable("biases", shape, tf.float32, tf.zeros_initializer)

def conv_op(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
Expand Down Expand Up @@ -59,9 +59,9 @@ def simple_conv(self):
b = get_bias(self.b_shape)
conv = conv_op(self.fan_in, W)
if self.activation:
self.fan_out = tf.nn.relu(conv + b)
self.fan_out = tf.nn.relu(tf.add(conv, b))
else:
self.fan_out = conv + b
self.fan_out = tf.add(conv, b)


def output(self):
Expand Down
86 changes: 57 additions & 29 deletions pixelcnn_tf.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,104 @@
# TODO sanity check pred,
# generate and plot,
# TODO
# try changing 0.0 to np.random
# make for imagenet data
# check network arch
# upscale-downscale-q_level
# validation
# cost on test set
# autoencoder

import tensorflow as tf
import numpy as np
from models import PixelCNN
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime
import scipy.misc
import os
from models import PixelCNN

mnist = input_data.read_data_sets("data/")
epochs = 10
epochs = 50
batch_size = 50
grad_clip = 1

img_height = 28
img_width = 28
channel = 1

LAYERS = 3
F_MAP = 32
FILTER_SIZE = 7
CHANNEL = 1

X = tf.placeholder(tf.float32, shape=[None, 784])
X_image = tf.reshape(X, [-1, 28, 28, CHANNEL])
X = tf.placeholder(tf.float32, shape=[None, img_height, img_width, channel])
# TODO mean pixel value if not mnist
v_stack_in, h_stack_in = X_image, X_image
v_stack_in, h_stack_in = X, X

for i in range(LAYERS):
FILTER_SIZE = 3 if i > 0 else FILTER_SIZE
CHANNEL = F_MAP if i > 0 else CHANNEL
in_dim = F_MAP if i > 0 else channel
mask = 'b' if i > 0 else 'a'
i = str(i)

with tf.name_scope("v_stack"+i):
v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, CHANNEL, F_MAP], [F_MAP], v_stack_in, mask=mask).output()
with tf.variable_scope("v_stack"+i):
v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, in_dim, F_MAP], [F_MAP], v_stack_in, mask=mask).output()
v_stack_in = v_stack

with tf.name_scope("v_stack_1"+i):
with tf.variable_scope("v_stack_1"+i):
v_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], v_stack_in, gated=False, mask=mask).output()

with tf.name_scope("h_stack"+i):
h_stack = PixelCNN([1, FILTER_SIZE, CHANNEL, F_MAP], [F_MAP], h_stack_in, gated=True, payload=v_stack_1, mask=mask).output()
with tf.variable_scope("h_stack"+i):
h_stack = PixelCNN([1, FILTER_SIZE, in_dim, F_MAP], [F_MAP], h_stack_in, gated=True, payload=v_stack_1, mask=mask).output()

with tf.name_scope("h_stack_1"+i):
with tf.variable_scope("h_stack_1"+i):
h_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], h_stack, gated=False, mask=mask).output()
h_stack_1 += h_stack_in
h_stack_in = h_stack_1

with tf.name_scope("fc_1"):
with tf.variable_scope("fc_1"):
fc1 = PixelCNN([1, 1, F_MAP, F_MAP],[F_MAP], h_stack_in, gated=False, mask='b').output()

# handle Imagenet differently
with tf.name_scope("fc_2"):
with tf.variable_scope("fc_2"):
fc2 = PixelCNN([1, 1, F_MAP, 1],[1], fc1, gated=False, mask='b', activation=False).output()
pred = tf.nn.sigmoid(fc2)

loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X_image))
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X))

trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(loss)

clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients]
optimizer = trainer.apply_gradients(clipped_gradients)

#correct_preds = tf.equal(tf.argmax(X,1), tf.argmax(pred, 1))
#accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))

#summary = tf.train.SummaryWriter('logs', sess.graph)
def binarize(images):
return (0.0 < images).astype(np.float32)

def generate_and_save(sess):
n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, img_height, img_width, 1), dtype=np.float32)
for i in xrange(img_height):
for j in xrange(img_width):
for k in xrange(1):
next_sample = binarize(sess.run(pred, {X:samples}))
samples[:, i, j, k] = next_sample[:, i, j, k]
images = samples
images = images.reshape((n_row, n_col, img_height, img_width))
images = images.transpose(1, 2, 0, 3)
images = images.reshape((img_height * n_row, img_width * n_col))

filename = '%s_%s.jpg' % ("sample", str(datetime.now()))
scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join("samples", filename))

num_batches = mnist.train.num_examples // batch_size
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for i in range(epochs):
batch_X = mnist.train.next_batch(batch_size)[0]
_, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})
print cost

#if i%1 == 0:
#print accuracy.eval(feed_dict={X:batch_X})
#print accuracy.eval(feed_dict={X:mnist.test.images})
for j in range(num_batches):
batch_X = binarize(mnist.train.next_batch(batch_size)[0] \
.reshape([batch_size, img_height, img_width, 1]))
_, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})

print "Epoch: %d, Cost: %.2f"%(i, cost)
generate_and_save(sess)

generate_and_save(sess)

0 comments on commit 32e5221

Please sign in to comment.