diff --git a/gan.py b/gan.py index 2ac25f2..7cc6215 100644 --- a/gan.py +++ b/gan.py @@ -14,17 +14,19 @@ from generator import Generator def concat_elu(inputs): - return tf.nn.elu(tf.concat(3, [-inputs, inputs])) + return tf.nn.elu(tf.concat(axis=3, values=[-inputs, inputs])) class GAN(Generator): def __init__(self, hidden_size, batch_size, learning_rate): self.input_tensor = tf.placeholder(tf.float32, [None, 28 * 28]) + self.is_training = tf.placeholder_with_default(True, []) with arg_scope([layers.conv2d, layers.conv2d_transpose], activation_fn=concat_elu, normalizer_fn=layers.batch_norm, - normalizer_params={'scale': True}): + normalizer_params={'scale': True, + 'is_training': self.is_training}): with tf.variable_scope("model"): D1 = discriminator(self.input_tensor) # positive examples D_params_num = len(tf.trainable_variables()) @@ -40,13 +42,18 @@ def __init__(self, hidden_size, batch_size, learning_rate): params = tf.trainable_variables() D_params = params[:D_params_num] G_params = params[D_params_num:] + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + g_update_ops = [op for op in update_ops if op.name.startswith('model_1/')] + d_update_ops = [op for op in update_ops if op not in g_update_ops] # train_discrimator = optimizer.minimize(loss=D_loss, var_list=D_params) # train_generator = optimizer.minimize(loss=G_loss, var_list=G_params) global_step = tf.contrib.framework.get_or_create_global_step() - self.train_discrimator = layers.optimize_loss( - D_loss, global_step, learning_rate / 10, 'Adam', variables=D_params, update_ops=[]) - self.train_generator = layers.optimize_loss( - G_loss, global_step, learning_rate, 'Adam', variables=G_params, update_ops=[]) + with tf.control_dependencies(d_update_ops): + self.train_discrimator = layers.optimize_loss( + D_loss, global_step, learning_rate / 10, 'Adam', variables=D_params, update_ops=[]) + with tf.control_dependencies(g_update_ops): + self.train_generator = layers.optimize_loss( + G_loss, global_step, learning_rate, 'Adam', variables=G_params, update_ops=[]) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) diff --git a/generator.py b/generator.py index dffc8b9..6bd114f 100644 --- a/generator.py +++ b/generator.py @@ -21,7 +21,7 @@ def generate_and_save_images(self, num_samples, directory): num_samples: number of samples to generate directory: a directory to save the images ''' - imgs = self.sess.run(self.sampled_tensor) + imgs = self.sess.run(self.sampled_tensor, feed_dict={self.is_training:False}) for k in range(imgs.shape[0]): imgs_folder = os.path.join(directory, 'imgs') if not os.path.exists(imgs_folder): diff --git a/vae.py b/vae.py index 207e710..0c7ee07 100644 --- a/vae.py +++ b/vae.py @@ -19,11 +19,13 @@ class VAE(Generator): def __init__(self, hidden_size, batch_size, learning_rate): self.input_tensor = tf.placeholder( tf.float32, [None, 28 * 28]) + self.is_training = tf.placeholder_with_default(True, []) with arg_scope([layers.conv2d, layers.conv2d_transpose], activation_fn=tf.nn.elu, normalizer_fn=layers.batch_norm, - normalizer_params={'scale': True}): + normalizer_params={'scale': True, + 'is_training': self.is_training}): with tf.variable_scope("model") as scope: encoded = encoder(self.input_tensor, hidden_size * 2) @@ -44,8 +46,10 @@ def __init__(self, hidden_size, batch_size, learning_rate): output_tensor, self.input_tensor) loss = vae_loss + rec_loss - self.train = layers.optimize_loss(loss, tf.contrib.framework.get_or_create_global_step( - ), learning_rate=learning_rate, optimizer='Adam', update_ops=[]) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + self.train = layers.optimize_loss(loss, tf.contrib.framework.get_or_create_global_step( + ), learning_rate=learning_rate, optimizer='Adam', update_ops=[]) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer())