-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
125 lines (106 loc) · 3.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import logging
import math
import os
import torch
import shutil
import time
logger = logging.getLogger()
def init_logger(log_file=None, log_file_level=logging.NOTSET):
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_format)
logger.handlers = [console_handler]
if log_file and log_file != '':
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(log_file_level)
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
return logger
def get_perplexity(loss, n_words):
""" compute perplexity """
return math.exp(min(loss / n_words, 100))
# def save_checkpoint(model, step):
# model_state_dict = model.state_dict()
#
# checkpoint = {
# 'model': model_state_dict,
# 'opt': self.args,
# 'optims': self.optims,
# }
#
# checkpoint_dir = './checkpoints/'
# checkpoint_path = os.path.join(checkpoint_dir, 'model_step_%d.pt' % step)
# logger.info("Saving checkpoint %s" % checkpoint_path)
#
# if not os.path.exists(checkpoint_path):
# torch.save(model_state_dict, checkpoint_path)
def test_rouge(temp_dir, cand, ref):
candidates = [line.strip() for line in open(cand, encoding='utf-8')]
references = [line.strip() for line in open(ref, encoding='utf-8')]
print(len(candidates))
print(len(references))
assert len(candidates) == len(references)
cnt = len(candidates)
current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time))
if not os.path.isdir(tmp_dir):
os.mkdir(tmp_dir)
os.mkdir(tmp_dir + "/candidate")
os.mkdir(tmp_dir + "/reference")
try:
for i in range(cnt):
if len(references[i]) < 1:
continue
with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w",
encoding="utf-8") as f:
f.write(candidates[i])
with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w",
encoding="utf-8") as f:
f.write(references[i])
r = pyrouge.Rouge155(temp_dir=temp_dir)
r.model_dir = tmp_dir + "/reference/"
r.system_dir = tmp_dir + "/candidate/"
r.model_filename_pattern = 'ref.#ID#.txt'
r.system_filename_pattern = r'cand.(\d+).txt'
rouge_results = r.convert_and_evaluate()
print(rouge_results)
results_dict = r.output_to_dict(rouge_results)
finally:
pass
if os.path.isdir(tmp_dir):
shutil.rmtree(tmp_dir)
return results_dict
def tile(x, count, dim=0):
"""
Tiles x on dimension dim count times.
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.view(batch, -1) \
.transpose(0, 1) \
.repeat(count, 1) \
.transpose(0, 1) \
.contiguous() \
.view(*out_size)
if dim != 0:
x = x.permute(perm).contiguous()
return x
def rouge_results_to_str(results_dict):
return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
results_dict["rouge_1_f_score"] * 100,
results_dict["rouge_2_f_score"] * 100,
# results_dict["rouge_3_f_score"] * 100,
results_dict["rouge_l_f_score"] * 100,
results_dict["rouge_1_recall"] * 100,
results_dict["rouge_2_recall"] * 100,
# results_dict["rouge_3_f_score"] * 100,
results_dict["rouge_l_recall"] * 100
# ,results_dict["rouge_su*_f_score"] * 100
)