-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy patharchitectures.py
445 lines (361 loc) · 23.3 KB
/
architectures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
# -*- coding: utf-8 -*-
#!/usr/bin/env python2
'''
Based on code by kyubyong park at https://www.github.com/kyubyong/dc_tts
'''
from data_load import get_batch, load_vocab
from networks import TextEnc, AudioEnc, AudioDec, Attention, SSRN, FixedAttention, LinearTransformLabels, MerlinTextEnc
import tensorflow as tf
from utils import get_global_attention_guide, learning_rate_decay
class Graph(object):
def __init__(self, hp, mode="train", reuse=None):
assert mode in ['train', 'synthesize', 'generate_attention']
self.mode = mode
self.training = True if mode=="train" else False
self.reuse = reuse
self.hp = hp
self.add_data(reuse=reuse) ## TODO: reuse??
self.build_model()
if self.training:
self.build_loss()
self.build_training_scheme()
def add_data(self, reuse=None):
'''
Add either variables (for training) or placeholders (for synthesis) to the graph
'''
# Data Feeding
## L: Text. (B, N), int32
## mels: Reduced melspectrogram. (B, T/r, n_mels) float32
## mags: Magnitude. (B, T, n_fft//2+1) float32
hp = self.hp
if self.mode is 'train':
batchdict = get_batch(hp, self.get_batchsize())
if 0: print (batchdict) ; print (batchdict.keys()) ; sys.exit('vsfbd')
self.L, self.mels, self.mags, self.fnames, self.num_batch = \
batchdict['text'], batchdict['mel'], batchdict['mag'], batchdict['fname'], batchdict['num_batch']
if hp.multispeaker:
## check multispeaker config is valid:- TODO: to config validation?
for position in hp.multispeaker:
assert position in ['text_encoder_input', 'text_encoder_towards_end', \
'audio_decoder_input', 'ssrn_input', 'audio_encoder_input',\
'learn_channel_contributions', 'speaker_dependent_phones']
self.speakers = batchdict['speaker']
else:
self.speakers = None
if hp.attention_guide_dir:
self.gts = batchdict['attention_guide']
else:
self.gts = tf.convert_to_tensor(get_global_attention_guide(hp))
if hp.use_external_durations:
self.durations = batchdict['duration']
if hp.merlin_label_dir:
self.merlin_label = batchdict['merlin_label']
if 'position_in_phone' in hp.history_type:
self.position_in_phone = batchdict['position_in_phone']
batchsize = self.get_batchsize()
self.prev_max_attentions = tf.ones(shape=(batchsize,), dtype=tf.int32)
## TODO refactor to remove redundancy between the next 2 branches?
elif self.mode is 'synthesize': # synthesis
self.L = tf.placeholder(tf.int32, shape=(None, None))
self.speakers = None
if hp.multispeaker:
self.speakers = tf.placeholder(tf.int32, shape=(None, None))
if hp.use_external_durations:
self.durations = tf.placeholder(tf.float32, shape=(None, None, None))
if hp.merlin_label_dir:
self.merlin_label = tf.placeholder(tf.float32, shape=(None, None, hp.merlin_lab_dim))
if 'position_in_phone' in hp.history_type:
self.position_in_phone = tf.placeholder(tf.float32, shape=(None, None, 1))
self.mels = tf.placeholder(tf.float32, shape=(None, None, hp.n_mels))
self.prev_max_attentions = tf.placeholder(tf.int32, shape=(None,))
elif self.mode is 'generate_attention':
self.L = tf.placeholder(tf.int32, shape=(None, None))
self.speakers = None
if hp.multispeaker:
self.speakers = tf.placeholder(tf.int32, shape=(None, None))
if hp.use_external_durations:
self.durations = tf.placeholder(tf.float32, shape=(None, None, None))
if hp.merlin_label_dir:
self.merlin_label = tf.placeholder(tf.float32, shape=(None, None, hp.merlin_lab_dim))
if 'position_in_phone' in hp.history_type:
self.position_in_phone = tf.placeholder(tf.float32, shape=(None, None, 1))
self.mels = tf.placeholder(tf.float32, shape=(None, None, hp.n_mels))
def build_training_scheme(self):
'''
hp.update_weights: list of strings of regular expressions used to match
scope prefixes of variables with tf.get_collection. Only these will be updated
by the graph's train_op: others will be frozen in training. TODO: this comment is now out of place...
'''
hp = self.hp
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if hp.decay_lr:
self.lr = learning_rate_decay(hp.lr, self.global_step)
else:
self.lr = hp.lr
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=hp.beta1, beta2=hp.beta2, epsilon=hp.epsilon)
tf.summary.scalar("lr", self.lr)
if self.hp.update_weights:
train_variables = filter_variables_for_update(self.hp.update_weights)
print ('Subset of trainable variables chosen for finetuning.') ## TODO: add to logging!
print ('Variables not in this list will remain frozen:')
for variable in train_variables:
print (variable.name)
else:
train_variables = None ## default value -- everything included in compute_gradients
## gradient clipping
self.gvs = self.optimizer.compute_gradients(self.loss, var_list=train_variables) ## var_list: Optional list or tuple of tf.Variable to update to minimize loss
self.clipped = []
for grad, var in self.gvs:
grad = tf.clip_by_value(grad, -1., 1.)
self.clipped.append((grad, var))
self.train_op = self.optimizer.apply_gradients(self.clipped, global_step=self.global_step)
# Summary
self.merged = tf.summary.merge_all()
class SSRNGraph(Graph):
def get_batchsize(self):
return self.hp.batchsize['ssrn'] ## TODO: naming?
def build_model(self):
with tf.variable_scope("SSRN"):
## OSW: use 'mels' for input both in training and synthesis -- can be either variable or placeholder
self.Z_logits, self.Z = SSRN(self.hp, self.mels, training=self.training, speaker_codes=self.speakers, reuse=self.reuse)
def build_loss(self):
## L2 loss (new)
self.loss_l2 = tf.reduce_mean(tf.squared_difference(self.Z, self.mags))
# mag L1 loss
self.loss_mags = tf.reduce_mean(tf.abs(self.Z - self.mags))
# mag binary divergence loss
self.loss_bd2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.Z_logits, labels=self.mags))
if not self.hp.squash_output_ssrn:
self.loss_bd2 = tf.zeros_like(self.loss_bd2)
print("binary divergence loss disabled because squash_output_ssrn==False")
# total loss
try: ## new way to configure loss weights:- TODO: ensure all configs use new pattern, and remove 'except' branch
# total loss, with 2 terms combined with loss weights:
self.loss = (self.hp.loss_weights['ssrn']['L1'] * self.loss_mags) + \
(self.hp.loss_weights['ssrn']['binary_divergence'] * self.loss_bd2) +\
(self.hp.loss_weights['ssrn']['L2'] * self.loss_l2)
print("New loss weight format used!")
except:
self.lw_mag = self.hp.lw_mag
self.lw_bd2 = self.hp.lw_bd2
self.lw_ssrn_l2 = self.hp.lw_ssrn_l2
self.loss = (self.lw_mag * self.loss_mags) + (self.lw_bd2 * self.loss_bd2) + (self.lw_ssrn_l2 * self.loss_l2)
# loss_components attribute is used for reporting to log (osw)
self.loss_components = [self.loss, self.loss_mags, self.loss_bd2, self.loss_l2]
# summary used for reporting to tensorboard (kp)
tf.summary.scalar('train/loss_mags', self.loss_mags)
tf.summary.scalar('train/loss_bd2', self.loss_bd2)
tf.summary.image('train/mag_gt', tf.expand_dims(tf.transpose(self.mags[:1], [0, 2, 1]), -1))
tf.summary.image('train/mag_hat', tf.expand_dims(tf.transpose(self.Z[:1], [0, 2, 1]), -1))
class Text2MelGraph(Graph):
def get_batchsize(self):
return self.hp.batchsize['t2m'] ## TODO: naming?
def build_model(self):
with tf.variable_scope("Text2Mel"):
# Get S or decoder inputs. (B, T//r, n_mels). This is audio shifted 1 frame to the right.
self.S = tf.concat((tf.zeros_like(self.mels[:, :1, :]), self.mels[:, :-1, :]), 1)
# Networks
if self.hp.text_encoder_type=='none':
assert self.hp.merlin_label_dir
self.K = self.V = self.merlin_label
elif self.hp.text_encoder_type=='minimal_feedforward':
assert self.hp.merlin_label_dir
#sys.exit('Not implemented: hp.text_encoder_type=="minimal_feedforward"')
self.K = self.V = LinearTransformLabels(self.hp, self.merlin_label, training=self.training, reuse=self.reuse)
elif self.hp.text_encoder_type=='MerlinTextEnc':
assert self.hp.merlin_label_dir
with tf.variable_scope("MerlinTextEnc"):
self.K, self.V = MerlinTextEnc(self.hp, self.L, self.merlin_label, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (N, Tx, e)
else: ## default DCTTS text encoder
with tf.variable_scope("TextEnc"):
self.K, self.V = TextEnc(self.hp, self.L, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (N, Tx, e)
with tf.variable_scope("AudioEnc"):
if self.hp.history_type in ['fractional_position_in_phone', 'absolute_position_in_phone']:
self.Q = self.position_in_phone
elif self.hp.history_type == 'minimal_history':
sys.exit('Not implemented: hp.history_type=="minimal_history"')
else:
assert self.hp.history_type == 'DCTTS_standard'
self.Q = AudioEnc(self.hp, self.S, training=self.training, speaker_codes=self.speakers, reuse=self.reuse)
with tf.variable_scope("Attention"):
# R: (B, T/r, 2d)
# alignments: (B, N, T/r)
# max_attentions: (B,)
if self.hp.use_external_durations:
self.R, self.alignments, self.max_attentions = FixedAttention(self.hp, self.durations, self.Q, self.V)
elif self.mode is 'synthesize':
self.R, self.alignments, self.max_attentions = Attention(self.hp, self.Q, self.K, self.V,
monotonic_attention=True,
prev_max_attentions=self.prev_max_attentions)
elif self.mode is 'train':
self.R, self.alignments, self.max_attentions = Attention(self.hp, self.Q, self.K, self.V,
monotonic_attention=False,
prev_max_attentions=self.prev_max_attentions)
elif self.mode is 'generate_attention':
self.R, self.alignments, self.max_attentions = Attention(self.hp, self.Q, self.K, self.V,
monotonic_attention=False,
prev_max_attentions=None)
with tf.variable_scope("AudioDec"):
self.Y_logits, self.Y = AudioDec(self.hp, self.R, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (B, T/r, n_mels)
def build_loss(self):
hp = self.hp
## L2 loss (new)
self.loss_l2 = tf.reduce_mean(tf.squared_difference(self.Y, self.mels))
# mel L1 loss
self.loss_mels = tf.reduce_mean(tf.abs(self.Y - self.mels))
# mel binary divergence loss
self.loss_bd1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.Y_logits, labels=self.mels))
if not hp.squash_output_t2m:
self.loss_bd1 = tf.zeros_like(self.loss_bd1)
print("binary divergence loss disabled because squash_output_t2m==False")
if hp.attention_guide_fa is False: # Use guided attention loss
print("Guided attention loss!!")
# guided_attention loss
## padding happens bcs alignemnts dimension are not max_N and max_T but batch dependents
self.A = tf.pad(self.alignments, [(0, 0), (0, hp.max_N), (0, hp.max_T)], mode="CONSTANT", constant_values=-1.)[:, :hp.max_N, :hp.max_T]
## att guides in dir needs padding to max_T and max_N bcs attention matrix sizes in guide dir are sentence dependent
if hp.attention_guide_dir:
self.gts = tf.pad(self.gts, [(0, 0), (0, hp.max_N), (0, hp.max_T)], mode="CONSTANT", constant_values=1.0)[:, :hp.max_N, :hp.max_T] ## TODO: check adding penalty here (1.0 is the right thing)
self.attention_masks = tf.to_float(tf.not_equal(self.A, -1)) # casts the True and False values to 1 and 0.
self.loss_att = tf.reduce_sum(tf.abs(self.A * self.gts) * self.attention_masks) ## (B, Letters, Frames) * (Letters, Frames) -- Broadcasting first adds singleton dimensions to the left until rank is matched.
self.mask_sum = tf.reduce_sum(self.attention_masks)
self.loss_att /= self.mask_sum
# this means that attention loss is calculated over the NxT space of the particular batch (not over max_N and max_T, the padding happens to bring both alignment and gts to same dimension for matrix multiplication)
else: ## Use MSE attention loss - treat guide as target
print("MSE attention loss!!")
self.A = tf.pad(self.alignments, [(0, 0), (0, hp.max_N), (0, hp.max_T)], mode="CONSTANT", constant_values=-1.)[:, :hp.max_N, :hp.max_T]
self.attention_masks = tf.to_float(tf.not_equal(self.A, -1))
self.mask_sum = tf.reduce_sum(self.attention_masks)
self.A = tf.pad(self.alignments, [(0, 0), (0, hp.max_N), (0, hp.max_T)], mode="CONSTANT", constant_values=0.0)[:, :hp.max_N, :hp.max_T]
if hp.attention_guide_dir:
self.gts = tf.pad(self.gts, [(0, 0), (0, hp.max_N), (0, hp.max_T)], mode="CONSTANT", constant_values=0.0)[:, :hp.max_N, :hp.max_T] ## TODO: check adding penalty here (1.0 is the right thing) - CVB: correct bcs guides are 1 in the bounderies
self.loss_att = tf.reduce_sum( (self.A - self.gts) * (self.A - self.gts) )
self.loss_att /= self.mask_sum
#### Other attention losses - CDP, Ain and Aout (see "Confidence through attention", http://arxiv.org/abs/1710.03743)
if hp.lw_cdp!=0.0 or hp.lw_ain!=0.0 or hp.lw_aout!=0.0 :
att_per_input = tf.reduce_sum(self.alignments,axis=2) # keepdims=True
# CDP
ones = tf.ones_like(att_per_input)
self.loss_cdp = tf.reduce_sum( tf.log ( ones + ( ones - att_per_input) ** 2 ) )
self.loss_cdp /= tf.reduce_sum( ones )
# Ain entropy
Aaxis = 2
Amat = self.alignments # B x Nbatch x Tbatch
num_phones = tf.reduce_sum( tf.ones_like ( tf.reduce_sum(Amat,axis=2) ) ) # remove frame dim
num_frames = tf.reduce_sum(Amat,axis=0) # remove batch dim
num_frames = tf.reduce_sum( tf.ones_like ( tf.reduce_sum(num_frames,axis=0))) # remove phone dim
norm_per_token = tf.reduce_sum(Amat,axis=Aaxis,keepdims=True)
Amat = Amat / norm_per_token
#Amat = tf.where( tf.not_equal(norm_per_token, 0) , Amat / norm_per_token , tf.zeros_like(Amat))
Amat = tf.where( tf.is_nan(Amat) , tf.zeros_like(Amat), Amat)
Entropy = tf.where( tf.not_equal(Amat, 0) , Amat * tf.log(Amat), tf.zeros_like(Amat))
self.loss_Ain = tf.reduce_sum(Entropy)
self.loss_Ain /= num_phones
self.loss_Ain /= tf.log(num_frames)
self.loss_Ain *= -1.0
# Aout
Aaxis = 1
Amat = self.alignments # B x Nbatch x Tbatch
num_frames = tf.reduce_sum( tf.ones_like ( tf.reduce_sum(Amat,axis=1) ) ) # remove phone dim
num_phones = tf.reduce_sum(Amat,axis=0) # remove batch dim
num_phones = tf.reduce_sum( tf.ones_like ( tf.reduce_sum(num_phones,axis=1))) # remove frame dim
norm_per_token = tf.reduce_sum(Amat,axis=Aaxis,keepdims=True)
Amat = Amat / norm_per_token
Entropy = tf.where( tf.not_equal(Amat, 0) , Amat * tf.log(Amat), tf.zeros_like(Amat))
self.loss_Aout = tf.reduce_sum(Entropy)
self.loss_Aout /= num_frames
self.loss_Aout /= tf.log(num_phones)
self.loss_Aout *= -1.0
######
# total loss
try: ## new way to configure loss weights:- TODO: ensure all configs use new pattern, and remove 'except' branch
# total loss, with 2 terms combined with loss weights:
self.loss = (hp.loss_weights['t2m']['L1'] * self.loss_mels) + \
(hp.loss_weights['t2m']['binary_divergence'] * self.loss_bd1) +\
(hp.loss_weights['t2m']['attention'] * self.loss_att) +\
(hp.loss_weights['t2m']['L2'] * self.loss_l2)
except:
self.lw_mel = hp.lw_mel
self.lw_bd1 = hp.lw_bd1
self.lw_att = hp.lw_att
self.lw_t2m_l2 = self.hp.lw_t2m_l2
self.lw_cdp = hp.lw_cdp
self.lw_ain = hp.lw_ain
self.lw_aout = hp.lw_aout
self.loss = (self.lw_mel * self.loss_mels) + (self.lw_bd1 * self.loss_bd1) + (self.lw_att * self.loss_att) + (self.lw_t2m_l2 * self.loss_l2)
if ( self.lw_cdp != 0.0 ):
self.loss += ( self.loss_cdp * self.lw_cdp )
if ( self.lw_ain != 0.0 ):
self.loss += ( self.loss_Ain * self.lw_ain )
if ( self.lw_aout != 0.0 ):
self.loss += ( self.loss_Aout * self.lw_aout )
# loss_components attribute is used for reporting to log (osw)
if hp.lw_cdp!=0.0 or hp.lw_ain!=0.0 or hp.lw_aout!=0.0 :
self.loss_components = [self.loss, self.loss_mels, self.loss_bd1, self.loss_att, self.loss_l2, self.loss_cdp, self.loss_Ain, self.loss_Aout]
else:
self.loss_components = [self.loss, self.loss_mels, self.loss_bd1, self.loss_att, self.loss_l2]
# summary used for reporting to tensorboard (kp)
tf.summary.scalar('train/loss_mels', self.loss_mels)
tf.summary.scalar('train/loss_bd1', self.loss_bd1)
tf.summary.scalar('train/loss_att', self.loss_att)
tf.summary.image('train/mel_gt', tf.expand_dims(tf.transpose(self.mels[:1], [0, 2, 1]), -1))
tf.summary.image('train/mel_hat', tf.expand_dims(tf.transpose(self.Y[:1], [0, 2, 1]), -1))
class TextEncGraph(Graph): ## partial graph for deployment only
def build_model(self):
with tf.variable_scope("Text2Mel"):
# Get S or decoder inputs. (B, T//r, n_mels)
self.S = tf.concat((tf.zeros_like(self.mels[:, :1, :]), self.mels[:, :-1, :]), 1)
# Networks
with tf.variable_scope("TextEnc"):
self.K, self.V = TextEnc(self.hp, self.L, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (N, Tx, e)
class BabblerGraph(Graph):
'''
A model which simply predicts the next audio step given an audio history. Can be used
by itself to babble at synthesis time, given some initial seed (e.g. some frames of
silence, or the beginning of a sentence to be completed). Alternatively, its weights can
be used to initialise the corresponding weights of a text2mel model. As in the paper
"Semi-Supervised Training for Improving Data Efficiency in End-to-End Speech Synthesis" by
Yu-An Chung et al. (2018: https://arxiv.org/abs/1808.10128), dummy textencoder outputs
consisting of all zeros are supplied in training.
'''
def get_batchsize(self):
return self.hp.batchsize.get('babbler', 32) ## default = 32
def build_model(self):
with tf.variable_scope("Text2Mel"): ## keep scope names consistent with full Text2Mel
## to allow parameters to be reused more easily later
# Get S or decoder inputs. (B, T//r, n_mels). This is audio shifted 1 frame to the right.
self.S = tf.concat((tf.zeros_like(self.mels[:, :1, :]), self.mels[:, :-1, :]), 1)
## Babbler has no TextEnc
with tf.variable_scope("AudioEnc"):
self.Q = AudioEnc(self.hp, self.S, training=self.training, reuse=self.reuse)
with tf.variable_scope("Attention"):
## Babbler has no real attention. Dummy (all 0) text encoder outputs are supplied instead.
# R: concat Q with zero vector (dummy text encoder outputs)
dummy_R_prime = tf.zeros_like(self.Q) ## R_prime shares shape of audio encoder output
self.R = tf.concat((dummy_R_prime, self.Q), -1)
with tf.variable_scope("AudioDec"):
self.Y_logits, self.Y = AudioDec(self.hp, self.R, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (B, T/r, n_mels)
def build_loss(self):
hp = self.hp
# mel L1 loss
self.loss_mels = tf.reduce_mean(tf.abs(self.Y - self.mels))
# mel binary divergence loss
self.loss_bd = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.Y_logits, labels=self.mels))
# total loss, with 2 terms combined with loss weights:
self.loss = (hp.loss_weights['babbler']['L1'] * self.loss_mels) + \
(hp.loss_weights['babbler']['binary_divergence'] * self.loss_bd)
# loss_components attribute is used for reporting to log (osw)
self.loss_components = [self.loss, self.loss_mels, self.loss_bd]
# summary used for reporting to tensorboard (kp)
tf.summary.scalar('train/loss_mels', self.loss_mels)
tf.summary.scalar('train/loss_bd', self.loss_bd)
tf.summary.image('train/mel_gt', tf.expand_dims(tf.transpose(self.mels[:1], [0, 2, 1]), -1))
tf.summary.image('train/mel_hat', tf.expand_dims(tf.transpose(self.Y[:1], [0, 2, 1]), -1))
def filter_variables_for_update(update_weights):
to_train = []
for pattern_string in update_weights:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, pattern_string)
for variable in variables:
if variable not in to_train:
to_train.append(variable)
return to_train