Skip to content

Commit

Permalink
pCNN AE
Browse files Browse the repository at this point in the history
sample generation changes
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent e56e36c commit e4119ea
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
20 changes: 10 additions & 10 deletions autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@

class AE(object):
def __init__(self, X):
self.fmap_out = 32
self.num_layers = 2
self.fmap_out = [8, 32]
self.fmap_in = conf.channel
self.fan_in = X
self.num_layers = 3
self.filter_size = 3
self.filter_size = 4
self.W = []
self.strides = [1, 1, 1, 1]

for i in range(self.num_layers):
if i == self.num_layers -1 :
self.fmap_out = 10
self.W.append(tf.Variable(tf.truncated_normal(shape=[self.filter_size, self.filter_size, self.fmap_in, self.fmap_out], stddev=0.1), name="W_%d"%i))
b = tf.Variable(tf.ones(shape=[self.fmap_out], dtype=tf.float32), name="encoder_b_%d"%i)
self.W.append(tf.Variable(tf.truncated_normal(shape=[self.filter_size, self.filter_size, self.fmap_in, self.fmap_out[i]], stddev=0.1), name="W_%d"%i))
b = tf.Variable(tf.ones(shape=[self.fmap_out[i]], dtype=tf.float32), name="encoder_b_%d"%i)
en_conv = tf.nn.conv2d(self.fan_in, self.W[i], self.strides, padding='SAME', name="encoder_conv_%d"%i)
en_pool = tf.nn.max_pool(en_conv, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME', name="encoder_pool_%d"%i)

self.fan_in = tf.tanh(tf.add(en_conv, b))
self.fmap_in = self.fmap_out
self.fan_in = tf.tanh(tf.add(en_pool, b))
self.fmap_in = self.fmap_out[i]

self.fan_in = tf.reshape(self.fan_in, (-1, conf.img_width*conf.img_height*self.fmap_out))
op_shape = self.fan_in.get_shape()
self.fan_in = tf.reshape(self.fan_in, (-1, int(op_shape[1])*int(op_shape[2])*int(op_shape[3])))

def decoder(self):
self.W.reverse()
Expand Down
9 changes: 4 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,16 @@ def train(conf, data):
data_dict[model.h] = batch_y
_, cost,_f = sess.run([optimizer, model.loss, model.fc2], feed_dict=data_dict)
print "Epoch: %d, Cost: %f"%(i, cost)

saver.save(sess, conf.ckpt_file)
generate_samples(sess, X, model.h, model.pred_sample, conf, "sample")
generate_samples(sess, X, model.h, model.pred_argmax, conf, "argmax")
if (i+1)%10 == 0:
saver.save(sess, conf.ckpt_file)
generate_samples(sess, X, model.h, model.pred, conf, "argmax")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, default='train')
parser.add_argument('--data', type=str, default='mnist')
parser.add_argument('--layers', type=int, default=5)
parser.add_argument('--layers', type=int, default=12)
parser.add_argument('--f_map', type=int, default=32)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=100)
Expand Down
10 changes: 6 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype(np.float32)

def generate_samples(sess, X, h, pred, conf, suff):
n_row, n_col = 5, 5
n_row, n_col = 10,10
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
# TODO make it generic
labels = one_hot(np.array([1,2,3,4,5]*5), conf.num_classes)
Expand All @@ -28,8 +28,8 @@ def generate_samples(sess, X, h, pred, conf, suff):
save_images(samples, n_row, n_col, conf, suff)


def generate_ae(sess, encoder_X, decoder_X, y, data, conf):
n_row, n_col = 5, 5
def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''):
n_row, n_col = 1,1
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
if conf.data == 'mnist':
labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel))
Expand All @@ -44,7 +44,9 @@ def generate_ae(sess, encoder_X, decoder_X, y, data, conf):
next_sample = binarize(next_sample)
samples[:, i, j, k] = next_sample[:, i, j, k]

save_images(samples, n_row, n_col, conf, '')
np.save('preds_'+suff+'.npy', samples)
print "height:%d"%i
save_images(samples, n_row, n_col, conf, suff)

def save_images(samples, n_row, n_col, conf, suff):
images = samples
Expand Down

0 comments on commit e4119ea

Please sign in to comment.