-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
59 lines (38 loc) · 1.97 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import defaults
from model import get_model, get_tokenizer
from preprocessing import get_train_data, get_arguments
from random import randint
import tensorflow as tf
from os.path import join
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='test a T5 on random examples from the Webis-Framing-19 dataset.')
parser.add_argument('--pretrained-name', help="Name of the pretrained model to test. Huggingface models and local models can be used." ,type=str,default=defaults.model_savename)
parser.add_argument('--data-file', help="absolute or relative path to Webis-Framing-19 dataset csv file.", type=str, default=defaults.data_csv_file)
parser.add_argument('--no-tries', help="Test this many random examples from the dataset.", type=int, default=defaults.no_tries)
args = parser.parse_args()
MODELNAME = join(defaults.models_path,args.pretrained_name)
CSV_FILE = args.data_file
NO_TRIES = args.no_tries
seperator_line = "---------------------------------------------------------------------------------"
print("Loading Model.")
model = get_model(MODELNAME)
tokenizer = get_tokenizer()
print("Preprocessing data.")
args = get_arguments(CSV_FILE)
X , Y = get_train_data(args,return_text_labels=True)
print("Starting test.")
for k in range(NO_TRIES):
ix = randint(0,len(args)-1)
x_text = args[ix].x_text
x = tf.expand_dims(X.input_ids[ix],0)
y = Y[ix]
out_ids = tf.squeeze(model.generate(x))
tokens = tokenizer.convert_ids_to_tokens(out_ids,skip_special_tokens=True)
tokens = [t for t in tokens if t not in ['<pad>','</s>']]
pred = tokenizer.convert_tokens_to_string(tokens)
print(seperator_line)
print(f"input=\n{x_text}")
print(f"prediction =\t{pred:>15}")
print(f"label =\t{y:>15}")
print(seperator_line)