From c2cc8f9857bd2de85225438774d8626ca91122fb Mon Sep 17 00:00:00 2001 From: yzeng22 Date: Sun, 8 Jan 2023 15:10:00 -0500 Subject: [PATCH] recursiv dir --- src/pytorch_fid/fid_score.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/pytorch_fid/fid_score.py b/src/pytorch_fid/fid_score.py index ab83bcd..186d397 100755 --- a/src/pytorch_fid/fid_score.py +++ b/src/pytorch_fid/fid_score.py @@ -230,6 +230,17 @@ def calculate_activation_statistics(files, model, batch_size=50, dims=2048, sigma = np.cov(act, rowvar=False) return mu, sigma +def _list_image_files_recursively(data_dir): + results = [] + for entry in sorted(os.listdir(data_dir)): + full_path = os.path.join(data_dir, entry) + ext = entry.split(".")[-1] + if "." in entry and ext.lower() in IMAGE_EXTENSIONS: + results.append(full_path) + elif os.path.isdir(full_path): + results.extend(_list_image_files_recursively(full_path)) + return results + def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1): @@ -237,9 +248,10 @@ def compute_statistics_of_path(path, model, batch_size, dims, device, with np.load(path) as f: m, s = f['mu'][:], f['sigma'][:] else: - path = pathlib.Path(path) - files = sorted([file for ext in IMAGE_EXTENSIONS - for file in path.glob('*.{}'.format(ext))]) + #path = pathlib.Path(path) + files = _list_image_files_recursively(path) + #files = sorted([file for ext in IMAGE_EXTENSIONS + # for file in path.glob('*.{}'.format(ext))]) m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers)