-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexBERT.py
120 lines (101 loc) · 4.69 KB
/
exBERT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import logging
import torch
import os
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from processors import HeadTailPredictionProcessor, RelationPredictionProcessor, TripleClassificationProcessor
from transformers.trainer_utils import IntervalStrategy
import cli
logger = logging.getLogger(__name__)
def write_metrics(metrics: dict, output_dir: str):
with open(os.path.join(output_dir, 'results.txt'), 'w') as f:
for key, value in metrics.items():
f.write(f'{key}:\t\t{value}\n')
def pre_flight_checks(args):
if args.task.lower() not in ['tc', 'rp', 'htp']:
raise ValueError('task should be on the defined values. Check help for more details')
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
def main():
args = cli.init_cli()
pre_flight_checks(args)
if args.task.lower().strip() == 'tc':
processor_class = TripleClassificationProcessor
elif args.task.lower().strip() == 'rp':
processor_class = RelationPredictionProcessor
else:
processor_class = HeadTailPredictionProcessor
kg = processor_class(
args.custom_model if args.custom_model is not None and os.path.exists(args.custom_model) else args.bert_model,
args.data_dir,
args.dataset_cache,
args.max_seq_length
)
logger.info("Initialized KG-Processor")
training_args = TrainingArguments(
output_dir=args.output_dir, # output directory
num_train_epochs=args.num_train_epochs, # total number of training epochs
per_device_train_batch_size=args.train_batch_size, # batch size per device during training
per_device_eval_batch_size=args.eval_batch_size, # batch size for evaluation
warmup_ratio=args.warmup_proportion, # ratio of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir='./logs', # directory for storing logs
logging_steps=1000,
learning_rate=args.learning_rate,
local_rank=args.local_rank,
seed=args.seed,
gradient_accumulation_steps=args.gradient_accumulation_steps,
fp16=args.fp16,
do_train=args.do_train,
do_eval=args.do_eval,
no_cuda=args.no_cuda,
save_strategy=IntervalStrategy.NO,
)
logger.info("Created training args")
model = AutoModelForSequenceClassification.from_pretrained(
args.custom_model if args.custom_model is not None and os.path.exists(args.custom_model) else args.bert_model,
num_labels=kg.get_labels_count()
)
logger.info("Loaded model from disk or downloaded it")
logger.info("Creating dataset objects")
train_ds, eval_ds, test_ds = kg.create_datasets(args.data_dir)
trainer = Trainer(
model=model, # the instantiated 🤗Transformers model to be trained
tokenizer=kg.tokenizer, # The tokenizer used by KG processor
args=training_args, # training arguments, defined above
train_dataset=train_ds, # training dataset
eval_dataset=eval_ds, # evaluation dataset
compute_metrics=kg.which_metrics(),
)
logger.info("Initialized trainer object")
if args.do_train or args.do_eval:
if args.do_train:
logger.info("Training")
trainer.train()
logger.info("Saving model to disk")
trainer.save_model(args.output_dir)
if args.do_eval:
logger.info("Evaluating")
trainer.evaluate()
if args.do_predict:
logger.info("Testing")
results = trainer.predict(test_ds)
print(results.metrics)
logger.info("Writing metrics to disk")
write_metrics(results.metrics, args.output_dir)
if __name__ == "__main__":
main()