Skip to content

Commit

Permalink
sample run with gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Oct 30, 2016
1 parent 11b7bc0 commit 4d20c8a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 14 deletions.
70 changes: 70 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import tensorflow as tf
import numpy as np

def get_weights(shape, mask=None):
# TODO set init bounds
W = tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1))

if mask:
filter_mid_x = shape[0]//2
filter_mid_y = shape[1]//2
mask_filter = np.ones(shape, dtype=np.float32)
mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0.
mask_filter[filter_mid_x+1:, :, :, :] = 0.

if mask == 'a':
mask_filter[filter_mid_x, filter_mid_y, :, :] = 0.

W *= mask_filter
return W


def get_bias(shape):
return tf.Variable(tf.constant(shape=shape, value=0.1, dtype=tf.float32))

def conv_op(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

class PixelCNN():
def __init__(self, W_shape, b_shape, fan_in, gated=True, payload=None, mask=None, activation=True):
self.W_shape = W_shape
self.b_shape = b_shape
self.fan_in = fan_in
self.payload = payload
self.mask = mask
self.activation = activation

if gated:
self.gated_conv()
else:
self.simple_conv()

def gated_conv(self):
W_f = get_weights(self.W_shape, mask=self.mask)
b_f = get_bias(self.b_shape)
W_g = get_weights(self.W_shape, mask=self.mask)
b_g = get_bias(self.b_shape)

conv_f = conv_op(self.fan_in, W_f)
conv_g = conv_op(self.fan_in, W_g)

if self.payload is not None:
conv_f += self.payload
conv_g += self.payload

self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))

def simple_conv(self):
W = get_weights(self.W_shape, mask=self.mask)
b = get_bias(self.b_shape)
conv = conv_op(self.fan_in, W)
if self.activation:
self.fan_out = tf.nn.relu(conv + b)
else:
self.fan_out = conv + b


def output(self):
return self.fan_out


43 changes: 29 additions & 14 deletions pixelcnn_tf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# TODO sanity check pred,
# generate and plot,
# upscale-downscale-q_level
# validation

import tensorflow as tf
import numpy as np
from models import PixelCNN
from tensorflow.examples.tutorials.mnist import input_data

# TODO get mean if pixel value >1
mnist = input_data.read_data_sets("data/")
epochs = 10
batch_size = 50
grad_clip = 1

LAYERS = 3
F_MAP = 32
Expand All @@ -15,6 +20,7 @@

X = tf.placeholder(tf.float32, shape=[None, 784])
X_image = tf.reshape(X, [-1, 28, 28, CHANNEL])
# TODO mean pixel value if not mnist
v_stack_in, h_stack_in = X_image, X_image

for i in range(LAYERS):
Expand All @@ -38,24 +44,33 @@
h_stack_1 += h_stack_in
h_stack_in = h_stack_1

with tf.name_scope("f_layer"):
pred = PixelCNN([1, 1, F_MAP, 1],[1], h_stack_in, gated=False, mask='b', activation='sigmoid').output()
pred = tf.reshape(pred, [batch_size, 784])
with tf.name_scope("fc_1"):
fc1 = PixelCNN([1, 1, F_MAP, F_MAP],[F_MAP], h_stack_in, gated=False, mask='b').output()
# handle Imagenet differently
with tf.name_scope("fc_2"):
fc2 = PixelCNN([1, 1, F_MAP, 1],[1], fc1, gated=False, mask='b', activation=False).output()
pred = tf.nn.sigmoid(fc2)

loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X_image))

trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(loss)

clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients]
optimizer = trainer.apply_gradients(clipped_gradients)

cross_entropy = tf.reduce_mean(-tf.reduce_sum(X * tf.log(pred), reduction_indices=[1]))
#TODO gradient clipping
trainer = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy)
correct_preds = tf.equal(tf.argmax(X,1), tf.argmax(pred, 1))
accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
#correct_preds = tf.equal(tf.argmax(X,1), tf.argmax(pred, 1))
#accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))

#summary = tf.train.SummaryWriter('logs', sess.graph)

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for i in range(epochs):
batch_X, batch_y = mnist.train.next_batch(batch_size)
sess.run(trainer, feed_dict={X:batch_X})
batch_X = mnist.train.next_batch(batch_size)[0]
_, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})
print cost

if i%1 == 0:
print accuracy.eval(feed_dict={X:batch_X})
print accuracy.eval(feed_dict={X:mnist.test.images})
#if i%1 == 0:
#print accuracy.eval(feed_dict={X:batch_X})
#print accuracy.eval(feed_dict={X:mnist.test.images})

0 comments on commit 4d20c8a

Please sign in to comment.