-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_buffer.py
166 lines (121 loc) · 5.6 KB
/
test_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import argparse
import torch
import random
import numpy as np
import os
from tqdm import tqdm
from torch_geometric.nn import DataParallel
from torch_geometric.loader import DataListLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score, accuracy_score
from data_loader import RumorDataset
from model import get_model
from experiment import get_experiment
import time
import torch_geometric.transforms as T
import warnings
os.environ["CUDA_VISIBLE_DEVICES"]="0"
warnings.filterwarnings("ignore")
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(
description='Tree Rumor Detection and Verification')
parser.add_argument('--batch_size', type=int, default=1, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--hidden_dim', type=int, default=768, metavar='N',
help='hidden dimension (default: 768)')
parser.add_argument('--max_len', type=int, default=64, metavar='N',
help='maximum length of the conversation (default: 50)')
parser.add_argument('--experiment', type=str, metavar='N',
help='experiment name')
parser.add_argument('--model', type=str, default="CDGTNB", metavar='N',
help='model name')
parser.add_argument('--fold', type=int, default=0, metavar='N',
help='experiment name')
parser.add_argument('--seed', type=int, default=0, metavar='N',
help='experiment name')
parser.add_argument('--aug', type=bool, default=True, metavar='N',
help='experiment name')
parser.add_argument('--buffer', type=bool, default=True, metavar='N',
help='experiment name')
args = parser.parse_args()
def test():
RANDOM_SEED = args.seed
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
experiment = get_experiment(args.experiment)
root_dir = os.path.join(experiment["root_dir"], str(args.fold))
language = experiment["language"]
classes = experiment["classes"]
test_dataset = RumorDataset(
root=root_dir,
classes=classes,
split='test',
language=language,
max_length=args.max_len,
aug=args.aug
)
test_loader = DataListLoader(
test_dataset, batch_size=args.batch_size, shuffle=False)
print('num of training / testing samples : {} / {} '.format(len(test_dataset), len(test_dataset)))
model = get_model(args.model,args.hidden_dim, len(classes),0.0 , language=language)
model = DataParallel(model).to(device)
model.eval()
comment = f'{args.model}_{args.experiment}_{args.fold}_{args.seed}'
writer = SummaryWriter(log_dir="runs/{}_{}".format(str(int(time.time())),"time_" + comment))
MAX_NODES = [25, 50, 75, 100]
for MAX_NODE in MAX_NODES:
total_times = 0.0
total_count = 0.0
for _, batch in enumerate(tqdm(test_loader)):
labels = torch.cat([data.y for data in batch]).to(device).long()
for idx, data in enumerate(batch):
buffer_atten = None
buffer_weight = None
num_nodes = int(data.num_nodes)
if num_nodes == 1:
continue
if num_nodes < MAX_NODE:
continue
for step in range(1,num_nodes+1):
if step > MAX_NODE:
continue
num_true = int(step)
num_false = int(num_nodes - num_true)
tensor_false = torch.zeros(num_false, dtype=torch.bool)
tensor_true = torch.ones(num_true, dtype=torch.bool)
subset = torch.cat([tensor_true,tensor_false])
batch[idx] = data.subgraph(subset)
buffer_size = int(batch[idx].num_nodes)
if step > 1 and args.buffer:
edge_index = batch[idx].edge_index
last_id = edge_index[1][-1]
edge_index = edge_index[:,edge_index[1] == last_id]
edge_index = torch.flatten(edge_index)
edge_index = [i in edge_index for i in range(buffer_size)]
subset = torch.tensor(edge_index,dtype=torch.bool)
batch[idx] = batch[idx].subgraph(subset)
if buffer_atten is not None:
if args.buffer:
batch[idx].__setattr__("buffer_size" , torch.tensor(buffer_size))
batch[idx].__setattr__("buffer_atten" , buffer_atten)
batch[idx].__setattr__("buffer_weight" , buffer_weight)
start = time.time()
outputs, buffer_atten, buffer_weight = model(batch)
_, preds = torch.max(outputs, 1)
end = time.time()
else:
start = time.time()
outputs, buffer_atten, buffer_weight = model(batch)
_, preds = torch.max(outputs, 1)
end = time.time()
total_time = end - start
total_times += total_time
total_count += 1
times = total_times/ total_count
writer.add_scalar("Time(s)", times, MAX_NODE)
if __name__ == "__main__":
test()