Skip to content

Commit

Permalink
Fix tfds builders that try to access gcs even though the data is local.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666405348
  • Loading branch information
The TensorFlow Datasets Authors committed Aug 28, 2024
1 parent 858fbe5 commit e6948a6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
25 changes: 20 additions & 5 deletions tensorflow_datasets/datasets/pg19/pg19_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

from collections.abc import Mapping
import os

import numpy as np
Expand Down Expand Up @@ -44,13 +45,27 @@ def _info(self):
homepage='https://github.com/deepmind/pg19',
)

def _get_paths(self, data_dir: str) -> Mapping[str, str]:
return {
'metadata': os.path.join(data_dir, 'metadata.csv'),
'train': os.path.join(data_dir, 'train'),
'validation': os.path.join(data_dir, 'validation'),
'test': os.path.join(data_dir, 'test'),
}

def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
del dl_manager # Unused

metadata_dict = dict()
metadata_path = os.path.join(_DATA_DIR, 'metadata.csv')
metadata = tf.io.gfile.GFile(metadata_path).read().splitlines()
if self.data_dir and all(
map(os.path.exists, self._get_paths(self.data_dir).values())
):
data_dir = self._data_dir
else:
data_dir = _DATA_DIR
paths = self._get_paths(data_dir)
metadata = tf.io.gfile.GFile(paths['metadata']).read().splitlines()

for row in metadata:
row_split = row.split(',')
Expand All @@ -62,21 +77,21 @@ def _split_generators(self, dl_manager):
name=tfds.Split.TRAIN,
gen_kwargs={
'metadata': metadata_dict,
'filepath': os.path.join(_DATA_DIR, 'train'),
'filepath': paths['train'],
},
),
tfds.core.SplitGenerator(
name=tfds.Split.VALIDATION,
gen_kwargs={
'metadata': metadata_dict,
'filepath': os.path.join(_DATA_DIR, 'validation'),
'filepath': paths['validation'],
},
),
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
gen_kwargs={
'metadata': metadata_dict,
'filepath': os.path.join(_DATA_DIR, 'test'),
'filepath': paths['test'],
},
),
]
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_datasets/robotics/dataset_importer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def get_relative_dataset_location(self):
pass

def get_dataset_location(self):
if self._data_dir:
return self._data_dir
return os.path.join(
str(self._GCS_BUCKET), self.get_relative_dataset_location()
)
Expand Down

0 comments on commit e6948a6

Please sign in to comment.