Skip to content

Commit

Permalink
notes:image generation for cifar
Browse files Browse the repository at this point in the history
restricted dataset
preprocessing on hold
experiment with gen - sample, argmax
cifar labels 1-hot to be generated
  • Loading branch information
anantzoid committed Nov 19, 2016
1 parent a4632ea commit e4c72ec
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 31 deletions.
2 changes: 1 addition & 1 deletion autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ class Conf(object):
else:
from keras.datasets import cifar10
data = cifar10.load_data()
# TODO normalize pixel values
data = data[0][0]
data /= 255.0
data = np.transpose(data, (0, 2, 3, 1))
conf.img_height = 32
conf.img_width = 32
Expand Down
30 changes: 22 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from models import PixelCNN
from utils import *

tf.set_random_seed(100)
def train(conf, data):
X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
model = PixelCNN(conf, X)
model = PixelCNN(X, conf)

trainer = tf.train.RMSPropOptimizer(1e-3)
gradients = trainer.compute_gradients(model.loss)
Expand All @@ -31,14 +32,20 @@ def train(conf, data):
conf.img_height, conf.img_width, conf.channel]))
batch_y = one_hot(batch_y, conf.num_classes)
else:
pointer = 0
batch_X, pointer = get_batch(data, pointer, conf.batch_size)

_, cost = sess.run([optimizer, model.loss], feed_dict={X:batch_X, model.h:batch_y})

#batch_X, batch_y = next(data)
data_dict = {X:batch_X}
if conf.conditional is True:
#TODO extract one-hot classes
data_dict[model.h] = batch_y
_, cost,_f = sess.run([optimizer, model.loss, model.fc2], feed_dict=data_dict)
print _f[0]
print "Epoch: %d, Cost: %f"%(i, cost)

saver.save(sess, conf.ckpt_file)
generate_samples(sess, X, model.h, model.pred, conf)
generate_samples(sess, X, model.h, model.pred_sample, conf, "sample")
generate_samples(sess, X, model.h, model.pred_argmax, conf, "argmax")


if __name__ == "__main__":
Expand Down Expand Up @@ -69,15 +76,22 @@ def train(conf, data):
else:
from keras.datasets import cifar10
data = cifar10.load_data()
data = data[0][0]
labels = data[0][1]
data = data[0][0] / 255.0
data = np.transpose(data, (0, 2, 3, 1))
conf.img_height = 32
conf.img_width = 32
conf.channel = 3
raise ValueError("Specify num_classes")
conf.num_classes = 10
conf.num_batches = data.shape[0] // conf.batch_size

'''
# TODO debug shape
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(featurewise_center=True,
featurewise_std_normalization=True)
datagen.fit(data)
data = datagen.flow(data, labels, batch_size=conf.batch_size)
'''

conf = makepaths(conf)
train(conf, data)
20 changes: 8 additions & 12 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
class PixelCNN():
def __init__(self, X, conf, h=None):
self.X = X
self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0)
if conf.data == "mnist":
self.X_norm = X
else:
self.X_norm = X
#self.X_norm = tf.image.per_image_whitening(X)
v_stack_in, h_stack_in = self.X_norm, self.X_norm
# TODO norm for multichannel: dubtract mean and divide by std feature-wise

if conf.conditional is True:
if h is not None:
self.h = h
Expand Down Expand Up @@ -48,18 +52,10 @@ def __init__(self, X, conf, h=None):
color_dim = 256
with tf.variable_scope("fc_2"):
self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, gated=False, mask='b', activation=False).output()
#fc2_shape = self.fc2.get_shape()
#self.fc2 = tf.reshape(self.fc2, (int(fc2_shape[0]), int(fc2_shape[1]), int(fc2_shape[2]), conf.channel, -1))
#fc2_shape = self.fc2.get_shape()
#self.fc2 = tf.nn.softmax(tf.reshape(self.fc2, (-1, int(fc2_shape[-1]))))
self.fc2 = tf.reshape(self.fc2, (-1, color_dim))

#self.loss = self.categorical_crossentropy(self.fc2, self.X)
#self.X_flat = tf.reshape(self.X, [-1])
#self.fc2_flat = tf.cast(tf.argmax(self.fc2, dimension=tf.rank(self.fc2) - 1), dtype=tf.float32)
self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32)))
#self.loss = tf.reduce_mean(-tf.reduce_sum(X_ohehot * tf.log(self.fc2), reduction_indices=[1]))

# NOTE or check without argmax
self.pred = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X))
#self.pred = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X))
self.pred = tf.reshape(tf.multinomial(tf.nn.softmax(self.fc2), num_samples=1, seed=100), tf.shape(self.X))

30 changes: 20 additions & 10 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@
def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype(np.float32)

def generate_samples(sess, X, h, pred, conf):
def generate_samples(sess, X, h, pred, conf, suff):
n_row, n_col = 5, 5
samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
# TODO make it generic
labels = one_hot([1,2,3,4,5]*5, conf.num_classes)
labels = one_hot(np.array([1,2,3,4,5]*5), conf.num_classes)

for i in xrange(conf.img_height):
for j in xrange(conf.img_width):
for k in xrange(conf.channel):
next_sample = sess.run(pred, {X:samples, h: labels})
data_dict = {X:samples}
if conf.conditional is True:
data_dict[h] = labels
next_sample = sess.run(pred, feed_dict=data_dict)
if conf.data == "mnist":
next_sample = binarize(next_sample)
samples[:, i, j, k] = next_sample[:, i, j, k]
save_images(samples, n_row, n_col, conf)
np.save('preds_'+suff+'.npy', samples)
print "height:%d"%i
save_images(samples, n_row, n_col, conf, suff)


def generate_ae(sess, encoder_X, decoder_X, y, data, conf):
Expand All @@ -39,15 +44,20 @@ def generate_ae(sess, encoder_X, decoder_X, y, data, conf):
next_sample = binarize(next_sample)
samples[:, i, j, k] = next_sample[:, i, j, k]

save_images(samples, n_row, n_col, conf)
save_images(samples, n_row, n_col, conf, '')

def save_images(samples, n_row, n_col, conf):
def save_images(samples, n_row, n_col, conf, suff):
images = samples
images = images.reshape((n_row, n_col, conf.img_height, conf.img_width))
images = images.transpose(1, 2, 0, 3)
images = images.reshape((conf.img_height * n_row, conf.img_width * n_col))
if conf.data == "mnist":
images = images.reshape((n_row, n_col, conf.img_height, conf.img_width))
images = images.transpose(1, 2, 0, 3)
images = images.reshape((conf.img_height * n_row, conf.img_width * n_col))
else:
images = images.reshape((n_row, n_col, conf.img_height, conf.img_width, conf.channel))
images = images.transpose(1, 2, 0, 3, 4)
images = images.reshape((conf.img_height * n_row, conf.img_width * n_col, conf.channel))

filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+".jpg"
filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+suff+".jpg"
scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(conf.samples_path, filename))


Expand Down

0 comments on commit e4c72ec

Please sign in to comment.