-
Notifications
You must be signed in to change notification settings - Fork 0
/
loader.py
93 lines (85 loc) · 2.59 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# tools to load data
import numpy as np
import pandas as pd
import re
field_names = [
'sentence_ix',
'word',
'stem',
'pos',
'_pos_',
'meaning',
'parse_ix',
'parse_role',
'blank0',
'blank1']
lang_map = {
'en': 'English',
'fi': 'Finnish',
'fr': 'French'}
number_re = re.compile('\d')
website_re = re.compile('^http')
def clean(word):
if pd.isnull(word):
word = 'NAN'
elif number_re.search(word):
word = 'NUMBER'
elif website_re.match(word):
word = 'WEBSITE'
return word
def ud_load(lang_str):
ud_format = 'ud-treebanks-v2.0/UD_{}/{}-ud-{}.conllu'
lang_train = pd.read_csv(
ud_format.format(lang_map[lang_str], lang_str, 'train'),
names=field_names,
delimiter='\t')
lang_dev = pd.read_csv(
ud_format.format(lang_map[lang_str], lang_str, 'dev'),
names=field_names,
delimiter='\t')
lang = pd.concat(
[lang_train, lang_dev],
ignore_index=True)
if not np.issubdtype(lang.sentence_ix.dtype, np.number):
lang = lang[~lang.sentence_ix.str.startswith('#')]
lang = lang[~lang.sentence_ix.str.contains('-')]
lang['sentence_ix'] = pd.to_numeric(lang['sentence_ix'])
lang = lang.reset_index(drop=True)
lang['sent_mark'] = lang.sentence_ix.diff()
lang_wd = [[]]
lang_pos = [[]]
counter = 0
for row_ix in range(lang.shape[0]):
row = lang.iloc[row_ix]
if (not np.isnan(row.sent_mark)) and (row.sent_mark < 0):
lang_wd.append([])
lang_pos.append([])
counter = counter + 1
word = clean(row.word)
pos = row.pos
lang_wd[counter].append(word)
lang_pos[counter].append(pos)
return lang_wd, lang_pos
def ep_load(src_str, tgt_str):
ep_format = 'europarl/tok.lower.{}-{}.{}'
src_file = open(ep_format.format(tgt_str, src_str, src_str), 'r')
src_data = []
for line in src_file.readlines():
sent = line.strip()
words = []
for word in sent.split(' '):
word = clean(word.strip())
words.append(word)
src_data.append(words)
src_file.close()
tgt_file = open(ep_format.format(tgt_str, src_str, tgt_str), 'r')
tgt_data = []
for line in tgt_file.readlines():
sent = line.strip()
words = []
for word in sent.split(' '):
word = clean(word.strip())
words.append(word)
tgt_data.append(words)
tgt_file.close()
return src_data, tgt_data