diff --git a/autoencoder.py b/autoencoder.py index ff02194..791c2dd 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -114,8 +114,8 @@ class Conf(object): else: from keras.datasets import cifar10 data = cifar10.load_data() - # TODO normalize pixel values data = data[0][0] + data /= 255.0 data = np.transpose(data, (0, 2, 3, 1)) conf.img_height = 32 conf.img_width = 32 diff --git a/main.py b/main.py index e95b4ce..f3d0fb1 100644 --- a/main.py +++ b/main.py @@ -4,9 +4,10 @@ from models import PixelCNN from utils import * +tf.set_random_seed(100) def train(conf, data): X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) - model = PixelCNN(conf, X) + model = PixelCNN(X, conf) trainer = tf.train.RMSPropOptimizer(1e-3) gradients = trainer.compute_gradients(model.loss) @@ -31,14 +32,20 @@ def train(conf, data): conf.img_height, conf.img_width, conf.channel])) batch_y = one_hot(batch_y, conf.num_classes) else: + pointer = 0 batch_X, pointer = get_batch(data, pointer, conf.batch_size) - - _, cost = sess.run([optimizer, model.loss], feed_dict={X:batch_X, model.h:batch_y}) - + #batch_X, batch_y = next(data) + data_dict = {X:batch_X} + if conf.conditional is True: + #TODO extract one-hot classes + data_dict[model.h] = batch_y + _, cost,_f = sess.run([optimizer, model.loss, model.fc2], feed_dict=data_dict) + print _f[0] print "Epoch: %d, Cost: %f"%(i, cost) saver.save(sess, conf.ckpt_file) - generate_samples(sess, X, model.h, model.pred, conf) + generate_samples(sess, X, model.h, model.pred_sample, conf, "sample") + generate_samples(sess, X, model.h, model.pred_argmax, conf, "argmax") if __name__ == "__main__": @@ -69,15 +76,22 @@ def train(conf, data): else: from keras.datasets import cifar10 data = cifar10.load_data() - data = data[0][0] + labels = data[0][1] + data = data[0][0] / 255.0 data = np.transpose(data, (0, 2, 3, 1)) conf.img_height = 32 conf.img_width = 32 conf.channel = 3 - raise ValueError("Specify num_classes") conf.num_classes = 10 conf.num_batches = data.shape[0] // conf.batch_size - + ''' + # TODO debug shape + from keras.preprocessing.image import ImageDataGenerator + datagen = ImageDataGenerator(featurewise_center=True, + featurewise_std_normalization=True) + datagen.fit(data) + data = datagen.flow(data, labels, batch_size=conf.batch_size) + ''' conf = makepaths(conf) train(conf, data) diff --git a/models.py b/models.py index 354ec68..13177f6 100644 --- a/models.py +++ b/models.py @@ -4,9 +4,13 @@ class PixelCNN(): def __init__(self, X, conf, h=None): self.X = X - self.X_norm = self.X if conf.data == "mnist" else tf.div(self.X, 255.0) + if conf.data == "mnist": + self.X_norm = X + else: + self.X_norm = X + #self.X_norm = tf.image.per_image_whitening(X) v_stack_in, h_stack_in = self.X_norm, self.X_norm - # TODO norm for multichannel: dubtract mean and divide by std feature-wise + if conf.conditional is True: if h is not None: self.h = h @@ -48,18 +52,10 @@ def __init__(self, X, conf, h=None): color_dim = 256 with tf.variable_scope("fc_2"): self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, gated=False, mask='b', activation=False).output() - #fc2_shape = self.fc2.get_shape() - #self.fc2 = tf.reshape(self.fc2, (int(fc2_shape[0]), int(fc2_shape[1]), int(fc2_shape[2]), conf.channel, -1)) - #fc2_shape = self.fc2.get_shape() - #self.fc2 = tf.nn.softmax(tf.reshape(self.fc2, (-1, int(fc2_shape[-1])))) self.fc2 = tf.reshape(self.fc2, (-1, color_dim)) - #self.loss = self.categorical_crossentropy(self.fc2, self.X) - #self.X_flat = tf.reshape(self.X, [-1]) - #self.fc2_flat = tf.cast(tf.argmax(self.fc2, dimension=tf.rank(self.fc2) - 1), dtype=tf.float32) self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32))) - #self.loss = tf.reduce_mean(-tf.reduce_sum(X_ohehot * tf.log(self.fc2), reduction_indices=[1])) - # NOTE or check without argmax - self.pred = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X)) + #self.pred = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X)) + self.pred = tf.reshape(tf.multinomial(tf.nn.softmax(self.fc2), num_samples=1, seed=100), tf.shape(self.X)) diff --git a/utils.py b/utils.py index 71af225..8f558d5 100644 --- a/utils.py +++ b/utils.py @@ -7,20 +7,25 @@ def binarize(images): return (np.random.uniform(size=images.shape) < images).astype(np.float32) -def generate_samples(sess, X, h, pred, conf): +def generate_samples(sess, X, h, pred, conf, suff): n_row, n_col = 5, 5 samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) # TODO make it generic - labels = one_hot([1,2,3,4,5]*5, conf.num_classes) + labels = one_hot(np.array([1,2,3,4,5]*5), conf.num_classes) for i in xrange(conf.img_height): for j in xrange(conf.img_width): for k in xrange(conf.channel): - next_sample = sess.run(pred, {X:samples, h: labels}) + data_dict = {X:samples} + if conf.conditional is True: + data_dict[h] = labels + next_sample = sess.run(pred, feed_dict=data_dict) if conf.data == "mnist": next_sample = binarize(next_sample) samples[:, i, j, k] = next_sample[:, i, j, k] - save_images(samples, n_row, n_col, conf) + np.save('preds_'+suff+'.npy', samples) + print "height:%d"%i + save_images(samples, n_row, n_col, conf, suff) def generate_ae(sess, encoder_X, decoder_X, y, data, conf): @@ -39,15 +44,20 @@ def generate_ae(sess, encoder_X, decoder_X, y, data, conf): next_sample = binarize(next_sample) samples[:, i, j, k] = next_sample[:, i, j, k] - save_images(samples, n_row, n_col, conf) + save_images(samples, n_row, n_col, conf, '') -def save_images(samples, n_row, n_col, conf): +def save_images(samples, n_row, n_col, conf, suff): images = samples - images = images.reshape((n_row, n_col, conf.img_height, conf.img_width)) - images = images.transpose(1, 2, 0, 3) - images = images.reshape((conf.img_height * n_row, conf.img_width * n_col)) + if conf.data == "mnist": + images = images.reshape((n_row, n_col, conf.img_height, conf.img_width)) + images = images.transpose(1, 2, 0, 3) + images = images.reshape((conf.img_height * n_row, conf.img_width * n_col)) + else: + images = images.reshape((n_row, n_col, conf.img_height, conf.img_width, conf.channel)) + images = images.transpose(1, 2, 0, 3, 4) + images = images.reshape((conf.img_height * n_row, conf.img_width * n_col, conf.channel)) - filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+".jpg" + filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+suff+".jpg" scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(conf.samples_path, filename))