This repository has been archived by the owner on Jul 25, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
112 lines (86 loc) · 4.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import tensorflow as tf
import numpy as np
from Unet import UNet,UNet1
from discriminator import Discriminator
import os
from PIL import Image
import random
import time
class Train():
def __init__(self):
#realA RGB
self.realA = tf.placeholder(tf.float32, shape=[None,512,512,3])
#realB 線画
self.realB = tf.placeholder(tf.float32, shape=[None,512,512,3])
#batch_size
batch_size = self.realA.get_shape().as_list()[0]
#Generated by UNet used realB
#fakeA 着色
self.fakeA = UNet1(self.realB).dec_dc0
#concat
#positive
#realAB
realAB = tf.concat([self.realA, self.realB], 3)
#negative
#fakeAB
fakeAB = tf.concat([self.fakeA, self.realB], 3)
#discriminator
dis_r = Discriminator(realAB, False)
real_logits = dis_r.last_h
real_out = dis_r.out
dis_f = Discriminator(fakeAB, True)
fake_logits = dis_f.last_h
fake_out = dis_f.out
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_out)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_out)))
self.UNet_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_out)))
self.d_loss = self.d_loss_fake + self.d_loss_real
self.g_loss = self.UNet_loss + 100*tf.reduce_mean(tf.abs(self.realA-self.fakeA))
training_var = tf.trainable_variables()
d_var = [var for var in training_var if 'd_' in var.name]
g_var = [var for var in training_var if 'g_' in var.name]
self.opt_d = tf.train.AdamOptimizer(0.0002,beta1=0.5).minimize(self.d_loss, var_list=d_var)
self.opt_g = tf.train.AdamOptimizer(0.0002,beta1=0.5).minimize(self.g_loss, var_list=g_var)
if not os.path.exists('./saved/'):
os.mkdir('./saved/')
if not os.path.exists('./visualized/'):
os.mkdir('./visualized/')
def sample(size, channel, path, batch_files):
imgs = np.empty((0,size,size,channel), dtype=np.float32)
for file_name in batch_files:
img = np.array(Image.open(path+file_name)).astype(np.float32)
imgs = np.append(imgs, np.array([img]), axis=0)
imgs = imgs.reshape((-1,size,size,channel))
return imgs
def visualize_g(size, g_img, x_img, t_img,batch_size, epoch, i):
for n in range(batch_size):
img = np.concatenate((g_img[n], x_img[n], t_img[n]),axis=1)
img = Image.fromarray(np.uint8(img))
img.save('./visualized/epoch{}batch_num{}batch{}.jpg'.format(epoch,n,i))
def main():
batch_size = 5
epochs = 3000
filenames = [random.choice(os.listdir('./data/rgb512/')) for _ in range(1000)]
data_size = len(filenames)
train = Train()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
graph = tf.summary.FileWriter('./logas', sess.graph)
for epoch in range(epochs):
new_time = time.time()
for i in range(0, data_size, batch_size):
batch_files = [random.choice(filenames) for _ in range(batch_size)]
rgb512 = sample(512, 3, './data/rgb512/', batch_files)
linedraw512 = sample(512, 3, './data/linedraw512/', batch_files)
batch_time = time.time()
d_loss, _ = sess.run([train.d_loss,train.opt_d],{train.realA:rgb512,train.realB:linedraw512})
g_img, g_loss, _ = sess.run([train.fakeA,train.g_loss,train.opt_g],{train.realA:rgb512,train.realB:linedraw512})
visualize_g(512, g_img, linedraw512, rgb512, batch_size, epoch, i)
print(' g_loss:',g_loss,' d_loss:',d_loss,' speed:',time.time()-batch_time," batches / s")
print('--------------------------------')
print('epoch_num:',epoch,' epoch_time:',time.time()-new_time)
print('--------------------------------')
saver.save(sess, "saved/model.ckpt")
if __name__ == "__main__":
main()