forked from revsic/tf-alae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_mlp.py
94 lines (76 loc) · 2.61 KB
/
mnist_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import os
import numpy as np
import tensorflow as tf
from utils.trainer import Trainer
from datasets.mnist import MNIST
from models.mlpalae import MlpAlae
PARSER = argparse.ArgumentParser()
PARSER.add_argument('option', type=str, default='train')
PARSER.add_argument('--name', default='mnist_mlp')
PARSER.add_argument('--summarydir', default='./summary')
PARSER.add_argument('--ckptdir', default='./ckpt')
PARSER.add_argument('--epochs', default=50, type=int)
PARSER.add_argument('--seed', default=1234, type=int)
PARSER.add_argument('--batch-size', default=128, type=int)
class MnistAlae(MlpAlae):
"""MLP-ALAE for MNIST dataset (flatten, conditioned).
"""
def __init__(self, settings=None):
if settings is None:
settings = MnistAlae.default_setting()
super(MnistAlae, self).__init__(settings)
def generate(self, z):
"""Generate output tensors from latent vectors.
+ denormalize and reshape to image.
Args:
_: tf.Tensor, [B, latent_dim], latent vectors.
Returns:
_: tf.Tensor, [B, ...], output tensors.
"""
x = super().generate(z)
x = tf.clip_by_value(x[:, :784], 0, 1)
return tf.reshape(x, [-1, 28, 28, 1])
@staticmethod
def default_setting(z_dim=128, latent_dim=50, output_dim=784 + 10):
"""Default settings.
"""
return {
'z_dim': z_dim,
'latent_dim': latent_dim,
'output_dim': output_dim,
'gamma': 10,
'f': [1024, latent_dim],
'g': [1024, output_dim],
'e': [1024, latent_dim],
'd': [1024, 1],
'lr': 0.002,
'beta1': 0.0,
'beta2': 0.99,
}
def train(args):
mnist = MNIST()
mlpalae = MnistAlae()
modelname = args.name
summary_path = os.path.join(args.summarydir, modelname)
if not os.path.exists(summary_path):
os.makedirs(summary_path)
ckpt_path = os.path.join(args.ckptdir, modelname)
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
trainer = Trainer(summary_path, ckpt_path)
trainer.train(
mlpalae,
args.epochs,
mnist.datasets(bsize=args.batch_size, flatten=True, condition=True),
mnist.datasets(bsize=args.batch_size, flatten=True, condition=True, train=False),
len(mnist.x_train) // args.batch_size)
return 0
def main(args):
# set random seed
tf.random.set_seed(args.seed)
np.random.seed(args.seed)
if args.option == 'train':
train(args)
if __name__ == '__main__':
main(PARSER.parse_args())