Skip to content

Commit

Permalink
modular and experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent 856b99c commit d2c5abc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
10 changes: 6 additions & 4 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ def conv_op(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

class PixelCNN():
def __init__(self, W_shape, b_shape, fan_in, gated=True, payload=None, mask=None, activation=True):
self.W_shape = W_shape
self.b_shape = b_shape
def __init__(self, W_shape, fan_in, gated=True, payload=None, mask=None, activation=True):
self.fan_in = fan_in
in_dim = self.fan_in.get_shape()[-1]
self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]]
self.b_shape = W_shape[2]

self.payload = payload
self.mask = mask
self.activation = activation


if gated:
self.gated_conv()
Expand Down Expand Up @@ -62,7 +65,6 @@ def simple_conv(self):
else:
self.fan_out = tf.add(conv, b)


def output(self):
return self.fan_out

Expand Down
32 changes: 17 additions & 15 deletions pixelcnn_tf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# TODO
# try changing 0.0 to np.random
# kundan: concat payload instead of add
# : arch: without 1X1 in 1st layer and last 2 layers
# : replaces masking with n/2 filter
# check network arch
# upscale-downscale-q_level
# autoencoder
# cost on test set
# make for imagenet data
# make for imagenet data: upscale-downscale-q_level, mean pixel value if not mnist
# stats
# logger

import tensorflow as tf
import numpy as np
Expand All @@ -15,7 +18,7 @@
from models import PixelCNN

mnist = input_data.read_data_sets("data/")
epochs = 50
epochs = 1
batch_size = 50
grad_clip = 1

Expand All @@ -28,36 +31,35 @@
FILTER_SIZE = 7

X = tf.placeholder(tf.float32, shape=[None, img_height, img_width, channel])
# TODO mean pixel value if not mnist
v_stack_in, h_stack_in = X, X

# TODO encapsulate
for i in range(LAYERS):
FILTER_SIZE = 3 if i > 0 else FILTER_SIZE
in_dim = F_MAP if i > 0 else channel
mask = 'b' if i > 0 else 'a'
i = str(i)

with tf.variable_scope("v_stack"+i):
v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, in_dim, F_MAP], [F_MAP], v_stack_in, mask='a').output()
v_stack = PixelCNN([FILTER_SIZE, FILTER_SIZE, F_MAP], v_stack_in, mask=mask).output()
v_stack_in = v_stack

with tf.variable_scope("v_stack_1"+i):
v_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], v_stack_in, gated=False, mask=mask).output()
v_stack_1 = PixelCNN([1, 1, F_MAP], v_stack_in, gated=False, mask=mask).output()

with tf.variable_scope("h_stack"+i):
h_stack = PixelCNN([1, FILTER_SIZE, in_dim, F_MAP], [F_MAP], h_stack_in, gated=True, payload=v_stack_1, mask=mask).output()
h_stack = PixelCNN([1, FILTER_SIZE, F_MAP], h_stack_in, payload=v_stack_1, mask=mask).output()

with tf.variable_scope("h_stack_1"+i):
h_stack_1 = PixelCNN([1, 1, F_MAP, F_MAP], [F_MAP], h_stack, gated=False, mask=mask).output()
h_stack_1 += h_stack_in
h_stack_1 = PixelCNN([1, 1, F_MAP], h_stack, gated=False, mask=mask).output()
#h_stack_1 += h_stack_in # Residual connection
h_stack_in = h_stack_1

with tf.variable_scope("fc_1"):
fc1 = PixelCNN([1, 1, F_MAP, F_MAP],[F_MAP], h_stack_in, gated=False, mask='b').output()
fc1 = PixelCNN([1, 1, F_MAP], h_stack_in, gated=False, mask='b').output()

# handle Imagenet differently
with tf.variable_scope("fc_2"):
fc2 = PixelCNN([1, 1, F_MAP, 1],[1], fc1, gated=False, mask='b', activation=False).output()
fc2 = PixelCNN([1, 1, 1], fc1, gated=False, mask='b', activation=False).output()
pred = tf.nn.sigmoid(fc2)

loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fc2, X))
Expand Down Expand Up @@ -97,7 +99,7 @@ def generate_and_save(sess):
.reshape([batch_size, img_height, img_width, 1]))
_, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})

print "Epoch: %d, Cost: %f"%(i, cost)
print "Epoch: %d, Cost: %f"%(i, cost)
generate_and_save(sess)

generate_and_save(sess)
Expand Down

0 comments on commit d2c5abc

Please sign in to comment.