-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
69 lines (60 loc) · 3.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import argparse
import importlib
from torch_trainer import GraphTrainer
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# Parse argments for updating
parser = argparse.ArgumentParser(description="SEAL")
parser.add_argument("-n", "--name", type=str, default='no_name')
parser.add_argument("-g", "--graph_path", type=str, default=None)
parser.add_argument("--pre_trained_model_path", type=str, default=None)
# Subgraphs in SEAL
parser.add_argument("--num_hops", type=int, default=1)
parser.add_argument("--ratio_per_hop", type=float, default=1.0)
parser.add_argument("--max_nodes_per_hop", type=int, default=800)
parser.add_argument("--node_label", type=str, default='drnl')
# Datasplits
parser.add_argument("--neg_pos_ratio_test", type=float, default=1.0)
parser.add_argument("--neg_pos_ratio", type=float, default=1.0)
parser.add_argument("--train_fraction", type=float, default=1.0)
parser.add_argument("--splitting", type=str, default='random')
parser.add_argument("--valid_fold", type=int, default=1)
parser.add_argument("--fraction_dist_neg", type=float, default=0.5)
parser.add_argument("--seed", type=int, default=100)
parser.add_argument("--include_in_train", type=str, default=None)
parser.add_argument("--mode", type=str, default='normal')
# Dataloaders of size and number of threads
parser.add_argument("-bs", "--batch_size", type=int, default=256)
parser.add_argument("--num_workers", type=int, default=6)
# NN hyperparameters
parser.add_argument("--model", type=str, default='DGCNN')
parser.add_argument("--n_epochs", type=int, default=20)
parser.add_argument("-lr", "--learning_rate", type=float, default=0.0005)
parser.add_argument("--decay", type=float, default=0.855)
parser.add_argument("--dropout", type=float, default=0.517)
parser.add_argument("--n_runs", type=int, default=1)
# SEAL training hyperparameters
parser.add_argument("--hidden_channels", type=int, default=128)
parser.add_argument("--num_layers", type=int, default=6)
parser.add_argument("--max_z", type=int, default=1000)
parser.add_argument("--sortpool_k", type=float, default=879)
parser.add_argument("--graph_norm", action="store_true")
parser.add_argument("--batch_norm", action="store_true")
# GAE training hyperparameters
parser.add_argument("--variational", action="store_true")
parser.add_argument("--linear", action="store_true")
parser.add_argument("--out_channels", type=int, default=None)
# Graph options
parser.add_argument("--use_attribute", type=str, default='fingerprint')
parser.add_argument("--use_embedding", action="store_true")
# Classification parameters
parser.add_argument("--p_threshold", type=float, default=0.9)
parser.add_argument("--pos_weight_loss", type=float, default=1.0)
args = parser.parse_args()
settings = vars(args)
if settings["model"] == "DGCNN" and "max_nodes_per_hop" not in settings:
settings["max_nodes_per_hop"] = None
# Determine path of graph based on above settings
assert settings["graph_path"] is not None, "-g --graph_path not provided as input or in settings file"
trainer = GraphTrainer(settings)
_ = trainer.run(running_test=False, final_test=True)