Skip to content

Commit

Permalink
FID pytorch Old version and Windows Fix (#2061)
Browse files Browse the repository at this point in the history
* Old version error fix

* Changed to 6 digit precision checking

* Update pytorch-version-tests.yml

* Changed rel_err to 1e-5

* Removed manual trigger

Co-authored-by: vfdev <[email protected]>
  • Loading branch information
gucifer and vfdev-5 authored Jun 18, 2021
1 parent 0007e86 commit f240e6e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
5 changes: 4 additions & 1 deletion ignite/metrics/gan/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def _get_covariance(self, sigma: torch.Tensor, total: torch.Tensor) -> torch.Ten
r"""
Calculates covariance from mean and sum of products of variables
"""
sub_matrix = torch.outer(total, total)
if LooseVersion(torch.__version__) <= LooseVersion("1.7.0"):
sub_matrix = torch.ger(total, total)
else:
sub_matrix = torch.outer(total, total)
sub_matrix = sub_matrix / self._num_examples
return (sigma - sub_matrix) / (self._num_examples - 1)

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pytest-xdist
dill
# Test contrib dependencies
scipy
pytorch_fid
pytorch_fid==0.1.1
tqdm
scikit-learn
matplotlib
Expand Down
9 changes: 6 additions & 3 deletions tests/ignite/metrics/gan/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_fid_function():

sigma1 = torch.tensor(sigma1, dtype=torch.float64)
sigma2 = torch.tensor(sigma2, dtype=torch.float64)
assert pytest.approx(fid_score(mu1, mu2, sigma1, sigma2)) == pytorch_fid_score.calculate_frechet_distance(
assert pytest.approx(fid_score(mu1, mu2, sigma1, sigma2), rel=1e-5) == pytorch_fid_score.calculate_frechet_distance(
mu1, sigma1, mu2, sigma2
)

Expand All @@ -86,7 +86,10 @@ def test_compute_fid_from_features():
mu1, sigma1 = train_samples.mean(axis=0), cov(train_samples, rowvar=False)
mu2, sigma2 = test_samples.mean(axis=0), cov(test_samples, rowvar=False)

assert pytest.approx(pytorch_fid_score.calculate_frechet_distance(mu1, sigma1, mu2, sigma2)) == fid_scorer.compute()
assert (
pytest.approx(pytorch_fid_score.calculate_frechet_distance(mu1, sigma1, mu2, sigma2), rel=1e-5)
== fid_scorer.compute()
)


def test_compute_fid_sqrtm():
Expand Down Expand Up @@ -187,7 +190,7 @@ def update(_, i):
evaluator = pytorch_fid_score.calculate_frechet_distance
mu1, sigma1 = y_pred.mean(axis=0).to("cpu"), cov(y_pred.to("cpu"), rowvar=False)
mu2, sigma2 = y_true.mean(axis=0).to("cpu"), cov(y_true.to("cpu"), rowvar=False)
assert pytest.approx(evaluator(mu1, sigma1, mu2, sigma2)) == m.compute()
assert pytest.approx(evaluator(mu1, sigma1, mu2, sigma2), rel=1e-5) == m.compute()

metric_devices = [torch.device("cpu")]
if device.type != "xla":
Expand Down

0 comments on commit f240e6e

Please sign in to comment.