forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_data.py
317 lines (263 loc) · 10.4 KB
/
build_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Utilities for downloading and building data.
These can be replaced if your particular file system does not support them.
"""
import importlib
import time
import datetime
import os
import requests
import shutil
import tqdm
def built(path, version_string=None):
"""
Check if '.built' flag has been set for that task.
If a version_string is provided, this has to match, or the version
is regarded as not built.
"""
if version_string:
fname = os.path.join(path, '.built')
if not os.path.isfile(fname):
return False
else:
with open(fname, 'r') as read:
text = read.read().split('\n')
return len(text) > 1 and text[1] == version_string
else:
return os.path.isfile(os.path.join(path, '.built'))
def mark_done(path, version_string=None):
"""
Mark this path as prebuilt.
Marks the path as done by adding a '.built' file with the current timestamp
plus a version description string if specified.
:param str path:
The file path to mark as built.
:param str version_string:
The version of this dataset.
"""
with open(os.path.join(path, '.built'), 'w') as write:
write.write(str(datetime.datetime.today()))
if version_string:
write.write('\n' + version_string)
def download(url, path, fname, redownload=False):
"""
Download file using `requests`.
If ``redownload`` is set to false, then
will not download tar file again if it is present (default ``True``).
"""
outfile = os.path.join(path, fname)
download = not os.path.isfile(outfile) or redownload
print("[ downloading: " + url + " to " + outfile + " ]")
retry = 5
exp_backoff = [2 ** r for r in reversed(range(retry))]
pbar = tqdm.tqdm(unit='B', unit_scale=True, desc='Downloading {}'.format(fname))
while download and retry >= 0:
resume_file = outfile + '.part'
resume = os.path.isfile(resume_file)
if resume:
resume_pos = os.path.getsize(resume_file)
mode = 'ab'
else:
resume_pos = 0
mode = 'wb'
response = None
with requests.Session() as session:
try:
header = (
{'Range': 'bytes=%d-' % resume_pos, 'Accept-Encoding': 'identity'}
if resume
else {}
)
response = session.get(url, stream=True, timeout=5, headers=header)
# negative reply could be 'none' or just missing
if resume and response.headers.get('Accept-Ranges', 'none') == 'none':
resume_pos = 0
mode = 'wb'
CHUNK_SIZE = 32768
total_size = int(response.headers.get('Content-Length', -1))
# server returns remaining size if resuming, so adjust total
total_size += resume_pos
pbar.total = total_size
done = resume_pos
with open(resume_file, mode) as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if total_size > 0:
done += len(chunk)
if total_size < done:
# don't freak out if content-length was too small
total_size = done
pbar.total = total_size
pbar.update(len(chunk))
break
except requests.exceptions.ConnectionError:
retry -= 1
pbar.clear()
if retry >= 0:
print('Connection error, retrying. (%d retries left)' % retry)
time.sleep(exp_backoff[retry])
else:
print('Retried too many times, stopped retrying.')
finally:
if response:
response.close()
if retry < 0:
raise RuntimeWarning('Connection broken too many times. Stopped retrying.')
if download and retry > 0:
pbar.update(done - pbar.n)
if done < total_size:
raise RuntimeWarning(
'Received less data than specified in '
+ 'Content-Length header for '
+ url
+ '.'
+ ' There may be a download problem.'
)
move(resume_file, outfile)
pbar.close()
def make_dir(path):
"""Make the directory and any nonexistent parent directories (`mkdir -p`)."""
# the current working directory is a fine path
if path != '':
os.makedirs(path, exist_ok=True)
def move(path1, path2):
"""Rename the given file."""
shutil.move(path1, path2)
def remove_dir(path):
"""Remove the given directory, if it exists."""
shutil.rmtree(path, ignore_errors=True)
def untar(path, fname, deleteTar=True):
"""
Unpack the given archive file to the same directory.
:param str path:
The folder containing the archive. Will contain the contents.
:param str fname:
The filename of the archive file.
:param bool deleteTar:
If true, the archive will be deleted after extraction.
"""
print('unpacking ' + fname)
fullpath = os.path.join(path, fname)
shutil.unpack_archive(fullpath, path)
if deleteTar:
os.remove(fullpath)
def cat(file1, file2, outfile, deleteFiles=True):
"""Concatenate two files to an outfile, possibly deleting the originals."""
with open(outfile, 'wb') as wfd:
for f in [file1, file2]:
with open(f, 'rb') as fd:
shutil.copyfileobj(fd, wfd, 1024 * 1024 * 10)
# 10MB per writing chunk to avoid reading big file into memory.
if deleteFiles:
os.remove(file1)
os.remove(file2)
def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def download_from_google_drive(gd_id, destination):
"""Use the requests package to download a file from Google Drive."""
URL = 'https://docs.google.com/uc?export=download'
with requests.Session() as session:
response = session.get(URL, params={'id': gd_id}, stream=True)
token = _get_confirm_token(response)
if token:
response.close()
params = {'id': gd_id, 'confirm': token}
response = session.get(URL, params=params, stream=True)
CHUNK_SIZE = 32768
with open(destination, 'wb') as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
response.close()
def get_model_dir(datapath):
return os.path.join(datapath, 'models')
def download_models(
opt, fnames, model_folder, version='v1.0', path='aws', use_model_type=False
):
"""
Download models into the ParlAI model zoo from a url.
:param fnames: list of filenames to download
:param model_folder: models will be downloaded into models/model_folder/model_type
:param path: url for downloading models; defaults to downloading from AWS
:param use_model_type: whether models are categorized by type in AWS
"""
model_type = opt.get('model_type', None)
if model_type is not None:
dpath = os.path.join(opt['datapath'], 'models', model_folder, model_type)
else:
dpath = os.path.join(opt['datapath'], 'models', model_folder)
if not built(dpath, version):
for fname in fnames:
print('[building data: ' + dpath + '/' + fname + ']')
if built(dpath):
# An older version exists, so remove these outdated files.
remove_dir(dpath)
make_dir(dpath)
# Download the data.
for fname in fnames:
if path == 'aws':
url = 'http://parl.ai/downloads/_models/'
url += model_folder + '/'
if use_model_type:
url += model_type + '/'
url += fname
else:
url = path + '/' + fname
download(url, dpath, fname)
if '.tgz' in fname or '.gz' in fname or '.zip' in fname:
untar(dpath, fname)
# Mark the data as built.
mark_done(dpath, version)
def modelzoo_path(datapath, path):
"""
Map pretrain models filenames to their path on disk.
If path starts with 'models:', then we remap it to the model zoo path
within the data directory (default is ParlAI/data/models).
We download models from the model zoo if they are not here yet.
"""
if path is None:
return None
if (
not path.startswith('models:')
and not path.startswith('zoo:')
and not path.startswith('izoo:')
):
return path
elif path.startswith('models:') or path.startswith('zoo:'):
zoo = path.split(':')[0]
zoo_len = len(zoo) + 1
# Check if we need to download the model
animal = path[zoo_len : path.rfind('/')].replace('/', '.')
if '.' not in animal:
animal += '.build'
module_name = 'parlai.zoo.{}'.format(animal)
try:
my_module = importlib.import_module(module_name)
my_module.download(datapath)
except (ImportError, AttributeError):
pass
return os.path.join(datapath, 'models', path[zoo_len:])
else:
# Internal path (starts with "izoo:") -- useful for non-public
# projects. Save the path to your internal model zoo in
# parlai_internal/.internal_zoo_path
# TODO: test the internal zoo.
zoo_path = 'parlai_internal/zoo/.internal_zoo_path'
if not os.path.isfile('parlai_internal/zoo/.internal_zoo_path'):
raise RuntimeError(
'Please specify the path to your internal zoo in the '
'file parlai_internal/zoo/.internal_zoo_path in your '
'internal repository.'
)
else:
with open(zoo_path, 'r') as f:
zoo = f.read().split('\n')[0]
return os.path.join(zoo, path[5:])