-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathtest.py
executable file
·99 lines (81 loc) · 2.71 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
import os
import argparse
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from data import PersianLexicon
from model import Encoder, Decoder
from config import DataConfig, ModelConfig, TestConfig
def load_model(model_path, model):
model.load_state_dict(torch.load(
model_path,
map_location=lambda storage,
loc: storage
))
model.to(TestConfig.device)
model.eval()
return model
class G2P(object):
def __init__(self):
# data
self.ds = PersianLexicon(
DataConfig.graphemes_path,
DataConfig.phonemes_path,
DataConfig.lexicon_path
)
# model
self.encoder_model = Encoder(
ModelConfig.graphemes_size,
ModelConfig.hidden_size
)
load_model(TestConfig.encoder_model_path, self.encoder_model)
self.decoder_model = Decoder(
ModelConfig.phonemes_size,
ModelConfig.hidden_size
)
load_model(TestConfig.decoder_model_path, self.decoder_model)
def __call__(self, word, visualize):
x = [0] + [self.ds.g2idx[ch] for ch in word] + [1]
x = torch.tensor(x).long().unsqueeze(1)
with torch.no_grad():
enc = self.encoder_model(x)
phonemes, att_weights = [], []
x = torch.zeros(1, 1).long().to(TestConfig.device)
hidden = torch.ones(
1,
1,
ModelConfig.hidden_size
).to(TestConfig.device)
t = 0
while True:
with torch.no_grad():
out, hidden, att_weight = self.decoder_model(
x,
enc,
hidden
)
att_weights.append(att_weight.detach().cpu())
max_index = out[0, 0].argmax()
x = max_index.unsqueeze(0).unsqueeze(0)
t += 1
phonemes.append(self.ds.idx2p[max_index.item()])
if max_index.item() == 1:
break
if visualize:
att_weights = torch.cat(att_weights).squeeze(1).numpy().T
y, x = att_weights.shape
plt.imshow(att_weights, cmap='gray')
plt.yticks(range(y), ['<sos>'] + list(word) + ['<eos>'])
plt.xticks(range(x), phonemes)
plt.savefig(f'attention/{DataConfig.language}/{word}.png')
return phonemes
if __name__ == '__main__':
# get word
parser = argparse.ArgumentParser()
parser.add_argument('--word', type=str, default='پایتون')
parser.add_argument('--visualize', action='store_true')
args = parser.parse_args()
g2p = G2P()
result = g2p(args.word, args.visualize)
print('.'.join(result))