diff --git a/noxfile.py b/noxfile.py index ad4811b..8779177 100644 --- a/noxfile.py +++ b/noxfile.py @@ -16,8 +16,8 @@ def lint(session): @nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"]) def tests(session): session.install( - 'torch==2.2.1', - 'torchvision', + 'torch==2.2.1', + 'torchvision', '--index-url', 'https://download.pytorch.org/whl/cpu' ) session.install('.') diff --git a/src/pytorch_fid/fid_score.py b/src/pytorch_fid/fid_score.py index ac82b53..5102a4b 100755 --- a/src/pytorch_fid/fid_score.py +++ b/src/pytorch_fid/fid_score.py @@ -64,8 +64,9 @@ def tqdm(x): help=('Dimensionality of Inception features to use. ' 'By default, uses pool3 features')) parser.add_argument('--save-stats', action='store_true', - help=('Generate an npz archive from a directory of samples. ' - 'The first path is used as input and the second as output.')) + help=('Generate an npz archive from a directory of ' + 'samples. The first path is used as input and the ' + 'second as output.')) parser.add_argument('path', type=str, nargs=2, help=('Paths to the generated images or ' 'to .npz statistic files')) @@ -307,7 +308,11 @@ def main(): num_workers = args.num_workers if args.save_stats: - save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers) + save_fid_stats(args.path, + args.batch_size, + device, + args.dims, + num_workers) return fid_value = calculate_fid_given_paths(args.path,