diff --git a/download_ffhq.py b/download_ffhq.py index 681f7919f..6e0b67c44 100755 --- a/download_ffhq.py +++ b/download_ffhq.py @@ -9,6 +9,7 @@ """Download Flickr-Faces-HQ (FFHQ) dataset to current working directory.""" import os +import re import sys import requests import html @@ -27,6 +28,8 @@ import itertools import shutil from collections import OrderedDict, defaultdict +from pydrive2.auth import GoogleAuth +from pydrive2.drive import GoogleDrive PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error @@ -130,6 +133,50 @@ def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10): except: pass +def pydrive_create_drive_manager(cmd_auth): + gAuth = GoogleAuth() + + if cmd_auth: + gAuth.CommandLineAuth() + else: + gAuth.LocalWebserverAuth() + + gAuth.Authorize() + print("authorized access to google drive API!") + + drive: GoogleDrive = GoogleDrive(gAuth) + return drive + + +def pydrive_extract_files_id(drive, link): + try: + fileID = re.search(r"(?<=/d/|id=|rs/).+?(?=/|$)", link)[0] # extract the fileID + return fileID + except Exception as error: + print("error : " + str(error)) + print("Link is probably invalid") + print(link) + + +def pydrive_download_file(drive, spec, stats, chunk_size=128, num_attempts=10): + link = spec['file_url'] + save_path = spec['file_path'] + id = pydrive_extract_files_id(drive, link) + file_dir = os.path.dirname(save_path) + if file_dir: + os.makedirs(file_dir, exist_ok=True) + + pydrive_file = drive.CreateFile({'id': id}) + for attempts_left in reversed(range(num_attempts)): + try: + pydrive_file.GetContentFile(save_path) + break + except: + if not attempts_left: + raise + stats['files_done'] += 1 + stats['bytes_done'] += os.stat(save_path).st_size + #---------------------------------------------------------------------------- def choose_bytes_unit(num_bytes): @@ -152,7 +199,7 @@ def format_time(seconds): #---------------------------------------------------------------------------- -def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs): +def download_files(file_specs, drive=None, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs): # Determine which files to download. done_specs = {spec['file_path']: spec for spec in file_specs if os.path.isfile(spec['file_path'])} @@ -169,7 +216,7 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5 exception_queue = queue.Queue() for spec in missing_specs: spec_queue.put(spec) - thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, download_kwargs=download_kwargs) + thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, drive=drive, download_kwargs=download_kwargs) for _thread_idx in range(min(num_threads, len(missing_specs))): threading.Thread(target=_download_thread, kwargs=thread_kwargs, daemon=True).start() @@ -206,12 +253,15 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5 except queue.Empty: pass -def _download_thread(spec_queue, exception_queue, stats, download_kwargs): +def _download_thread(spec_queue, exception_queue, stats, drive, download_kwargs): with requests.Session() as session: while not spec_queue.empty(): spec = spec_queue.get() try: - download_file(session, spec, stats, **download_kwargs) + if drive is not None: + pydrive_download_file(drive, spec, stats, **download_kwargs) + else: + download_file(session, spec, stats, **download_kwargs) except: exception_queue.put(sys.exc_info()) @@ -350,10 +400,15 @@ def recreate_aligned_images(json_data, dst_dir='realign1024x1024', output_size=1 #---------------------------------------------------------------------------- -def run(tasks, **download_kwargs): +def run(tasks, pydrive, cmd_auth, **download_kwargs): + if pydrive: + drive = pydrive_create_drive_manager(cmd_auth) + else: + drive = None + if not os.path.isfile(json_spec['file_path']) or not os.path.isfile('LICENSE.txt'): print('Downloading JSON metadata...') - download_files([json_spec, license_specs['json']], **download_kwargs) + download_files([json_spec, license_specs['json']], drive=drive, **download_kwargs) print('Parsing JSON metadata...') with open(json_spec['file_path'], 'rb') as f: @@ -375,7 +430,7 @@ def run(tasks, **download_kwargs): if len(specs): print('Downloading %d files...' % len(specs)) np.random.shuffle(specs) # to make the workload more homogeneous - download_files(specs, **download_kwargs) + download_files(specs, drive=drive, **download_kwargs) if 'align' in tasks: recreate_aligned_images(json_data) @@ -390,6 +445,8 @@ def run_cmdline(argv): parser.add_argument('-t', '--thumbs', help='download 128x128 thumbnails as PNG (1.95 GB)', dest='tasks', action='append_const', const='thumbs') parser.add_argument('-w', '--wilds', help='download in-the-wild images as PNG (955 GB)', dest='tasks', action='append_const', const='wilds') parser.add_argument('-r', '--tfrecords', help='download multi-resolution TFRecords (273 GB)', dest='tasks', action='append_const', const='tfrecords') + parser.add_argument('--pydrive', help='use pydrive interface to download files. it overrides google drive quota limitation this requires google credentials (default: False)', action='store_true') + parser.add_argument('--cmd_auth', help='use command line google authentication when using pydrive interface (default: False)', action='store_true') parser.add_argument('-a', '--align', help='recreate 1024x1024 images from in-the-wild images', dest='tasks', action='append_const', const='align') parser.add_argument('--num_threads', help='number of concurrent download threads (default: 32)', type=int, default=32, metavar='NUM') parser.add_argument('--status_delay', help='time between download status prints (default: 0.2)', type=float, default=0.2, metavar='SEC')