-
Notifications
You must be signed in to change notification settings - Fork 7
/
creation.py
157 lines (103 loc) · 3.36 KB
/
creation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python
# coding: utf-8
#
# # Creation of the NICE dataset
# In[1]:
# define paths
PATH_TO_LABELS = 'source/labels.xml'
PATH_TO_TEXTS = 'source/texts.xml'
PATH_TO_BINARY_RESULTS = 'NICE_binary/NICE'
PATH_TO_RESULTS = 'NICE/NICE'
# In[2]:
# define parameters
RANDOM_STATE = 42
TRAIN_SPLIT = 0.7
# ### Load the labels and texts
# In[3]:
# read labels.xml file
import xml.etree.ElementTree as ET
tree = ET.parse(PATH_TO_LABELS)
root = tree.getroot()
labels = dict()
for cl in root:
gs = cl.attrib['isGoodOrService']
cl_number = cl.attrib['classNumber']
for elem in cl:
labels[elem.attrib['id']] = gs, cl_number
# In[4]:
# read texts.xml file
import xml.etree.ElementTree as ET
tree = ET.parse(PATH_TO_TEXTS)
root = tree.getroot()
texts = dict()
# skip ClassesTexts and iterate over GoodsAndServicesTexts
for text in root[1]:
# indication > labels
texts[text.attrib['idRef']] = text[0][0].text
# ### Write the results
# In[5]:
# preprocessing
import re
import string
import unicodedata
def preprocess(str):
# lowercase
str = str.lower()
# remove text inside [] brackets
str = re.sub(r'\[.*?\]', '', str)
# remove punctuation
str = str.translate(str.maketrans('', '', string.punctuation))
# remove accents
str = ''.join(c for c in unicodedata.normalize('NFD', str) if unicodedata.category(c) != 'Mn')
return str
# In[6]:
# create splits with fixed seed
import random
random.seed(RANDOM_STATE)
length = len(texts)
split_length = int(length * TRAIN_SPLIT)
train_indices = random.sample(range(length), split_length)
test_indices = [i for i in range(length) if i not in train_indices]
print('Train:', len(train_indices), 'Test:', len(test_indices))
keys = list(texts.keys())
test_keys = [keys[i] for i in test_indices]
train_keys = [keys[i] for i in train_indices]
# In[7]:
# create binary dataset
assert labels.keys() == texts.keys()
# train split
with open(PATH_TO_BINARY_RESULTS + '_train.txt' , 'w', encoding="utf-8") as f:
# write all except the last line
for key in train_keys[:-1]:
f.write(labels[key][0] + '\t' + preprocess(texts[key]) + '\n')
# write the last line
last_key = train_keys[-1]
f.write(labels[last_key][0] + '\t' + preprocess(texts[last_key]))
# test split
with open(PATH_TO_BINARY_RESULTS + '_test.txt' , 'w', encoding="utf-8") as f:
# write all except the last line
for key in test_keys[:-1]:
f.write(labels[key][0] + '\t' + preprocess(texts[key]) + '\n')
# write the last line
last_key = test_keys[-1]
f.write(labels[last_key][0] + '\t' + preprocess(texts[last_key]))
# In[8]:
# create dataset
assert labels.keys() == texts.keys()
# train split
with open(PATH_TO_RESULTS + '_train.txt' , 'w', encoding="utf-8") as f:
# write all except the last line
for key in train_keys[:-1]:
f.write(labels[key][1] + '\t' + preprocess(texts[key]) + '\n')
# write the last line
last_key = train_keys[-1]
f.write(labels[last_key][1] + '\t' + preprocess(texts[last_key]))
# test split
with open(PATH_TO_RESULTS + '_test.txt' , 'w', encoding="utf-8") as f:
# write all except the last line
for key in test_keys[:-1]:
f.write(labels[key][1] + '\t' + preprocess(texts[key]) + '\n')
# write the last line
last_key = test_keys[-1]
f.write(labels[last_key][1] + '\t' + preprocess(texts[last_key]))
# In[ ]: