Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent 32e5221 commit 856b99c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
23 changes: 11 additions & 12 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import tensorflow as tf
import numpy as np

def get_weights(shape, mask=None):
def get_weights(shape, name, mask=None):
weights_initializer = tf.contrib.layers.xavier_initializer()
W = tf.get_variable("weights", shape, tf.float32, weights_initializer)
W = tf.get_variable(name, shape, tf.float32, weights_initializer)

if mask:
filter_mid_x = shape[0]//2
Expand All @@ -18,9 +18,8 @@ def get_weights(shape, mask=None):
W *= mask_filter
return W


def get_bias(shape):
return tf.get_variable("biases", shape, tf.float32, tf.zeros_initializer)
def get_bias(shape, name):
return tf.get_variable(name, shape, tf.float32, tf.zeros_initializer)

def conv_op(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
Expand All @@ -40,11 +39,11 @@ def __init__(self, W_shape, b_shape, fan_in, gated=True, payload=None, mask=None
self.simple_conv()

def gated_conv(self):
W_f = get_weights(self.W_shape, mask=self.mask)
b_f = get_bias(self.b_shape)
W_g = get_weights(self.W_shape, mask=self.mask)
b_g = get_bias(self.b_shape)
W_f = get_weights(self.W_shape, "v_W", mask=self.mask)
b_f = get_bias(self.b_shape, "v_b")
W_g = get_weights(self.W_shape, "h_W", mask=self.mask)
b_g = get_bias(self.b_shape, "h_b")

conv_f = conv_op(self.fan_in, W_f)
conv_g = conv_op(self.fan_in, W_g)

Expand All @@ -55,8 +54,8 @@ def gated_conv(self):
self.fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))

def simple_conv(self):
W = get_weights(self.W_shape, mask=self.mask)
b = get_bias(self.b_shape)
W = get_weights(self.W_shape, "W", mask=self.mask)
b = get_bias(self.b_shape, "b")
conv = conv_op(self.fan_in, W)
if self.activation:
self.fan_out = tf.nn.relu(tf.add(conv, b))
Expand Down
10 changes: 5 additions & 5 deletions pixelcnn_tf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# TODO
# try changing 0.0 to np.random
# make for imagenet data
# check network arch
# upscale-downscale-q_level
# cost on test set
# autoencoder
# cost on test set
# make for imagenet data

import tensorflow as tf
import numpy as np
Expand Down Expand Up @@ -38,7 +38,7 @@
i = str(i)

with tf.variable_scope("v_stack"+i):
v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, in_dim, F_MAP], [F_MAP], v_stack_in, mask=mask).output()
v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, in_dim, F_MAP], [F_MAP], v_stack_in, mask='a').output()
v_stack_in = v_stack

with tf.variable_scope("v_stack_1"+i):
Expand Down Expand Up @@ -97,8 +97,8 @@ def generate_and_save(sess):
.reshape([batch_size, img_height, img_width, 1]))
_, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})

print "Epoch: %d, Cost: %.2f"%(i, cost)
print "Epoch: %d, Cost: %f"%(i, cost)
generate_and_save(sess)

generate_and_save(sess)
generate_and_save(sess)

0 comments on commit 856b99c

Please sign in to comment.