-
Notifications
You must be signed in to change notification settings - Fork 1
/
expand_train.py
81 lines (68 loc) · 3 KB
/
expand_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0
import json
from newspaper import Article
from tqdm import tqdm
from copy import deepcopy
from utils import process_text, get_offsets
from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer
def expand_data(data, offset):
tokenizer = SpacyTokenizer()
all_urls = list()
for item in data:
for doc in item["documents"]:
all_urls.append(doc["url"])
all_urls = list(set(all_urls))
mapping = dict()
for w_url in tqdm(all_urls):
counter = 0
found = False
while not found and counter < 5:
try:
cc_article = Article(w_url)
cc_article.download()
cc_article.parse()
if cc_article.text != "":
text = process_text(cc_article.text.strip())
mapping[w_url] = text
found = True
counter += 1
except:
counter += 1
continue
print("Total documents present: ", len(all_urls))
print("Total documents mapped: ", len(mapping))
new_data = list()
skipped = 0
for item_index, item in tqdm(enumerate(data)):
segments = dict()
to_add = True
# If we were unable to download the URL, we skip it
for doc_index, doc in enumerate(item["documents"]):
if doc["url"] not in mapping:
to_add = False
break
if to_add:
for doc_index, doc in enumerate(item["documents"]):
data[item_index]["documents"][doc_index]["text"] = mapping[doc["url"]]
data[item_index]["documents"][doc_index]["sentences"] = get_offsets(mapping[doc["url"]], tokenizer, doc["text_id"])
for segment in data[item_index]["documents"][doc_index]["sentences"]:
segments[segment["segment_id"]] = segment["text"]
if offset == "relative":
for statement_index, statement in enumerate(item["statements"]):
data[item_index]["statements"][statement_index]["text"] = segments[statement["segment_id"]][statement["start_char"]:statement["end_char"]]
else:
for statement_index, statement in enumerate(item["statements"]):
data[item_index]["statements"][statement_index]["text"] = mapping[doc["url"]][statement["start_char"]:statement["end_char"]]
new_data.append(deepcopy(data[item_index]))
else:
skipped += 1
continue
print("Total examples originally: ", len(data))
print("Examples skipped: ", skipped)
return deepcopy(new_data)
if __name__ == '__main__':
print("Expanding Gold train set")
gold_data = json.load(open("./data/train.json"))
gold_data = expand_data(gold_data, "relative")
json.dump(gold_data, open("./data/expanded_train.json", "w"), indent=4)