From 831c463397925ba60e920e7aa53d1c0975c95c18 Mon Sep 17 00:00:00 2001 From: ilia-kats Date: Sun, 28 May 2023 17:04:37 +0200 Subject: [PATCH] make ProvenanceTensor behave more like a Tensor (closes #3218) (#3220) * make ProvenanceTensor behave more like a Tensor (closes #3218) the data is now stored in the Tensor itself instead of an attribute. This fixes torch.to_tensor returning empty tensors when called with a ProvenanceTensor and and a device as arguments * fix compatibility with PyTorch 1.11 * make detach_provenance always return the exact same object this is important when using Tensors as keys in a dict, e.g. the Pyro param store * preserve .unconstrained attribute in detach_provenance * simplify * add unit test * simplify further * use Tensor.as_subclass instead of modifying __class__ also remove unnecessary check in __init__ --- pyro/ops/provenance.py | 11 +++++---- tests/ops/test_provenance.py | 45 ++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) create mode 100644 tests/ops/test_provenance.py diff --git a/pyro/ops/provenance.py b/pyro/ops/provenance.py index 02a515c7e5..5058e66414 100644 --- a/pyro/ops/provenance.py +++ b/pyro/ops/provenance.py @@ -46,14 +46,15 @@ def __new__(cls, data: torch.Tensor, provenance=frozenset(), **kwargs): assert not isinstance(data, ProvenanceTensor) if not provenance: return data - return super().__new__(cls) + ret = data.as_subclass(cls) + ret._t = data # this makes sure that detach_provenance always + # returns the same object. This is important when + # using the tensor as key in a dict, e.g. the global + # param store + return ret def __init__(self, data, provenance=frozenset()): assert isinstance(provenance, frozenset) - if isinstance(data, ProvenanceTensor): - provenance |= data._provenance - data = data._t - self._t = data self._provenance = provenance def __repr__(self): diff --git a/tests/ops/test_provenance.py b/tests/ops/test_provenance.py new file mode 100644 index 0000000000..818d9d83ca --- /dev/null +++ b/tests/ops/test_provenance.py @@ -0,0 +1,45 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pyro.ops.provenance import ProvenanceTensor +from tests.common import assert_equal, requires_cuda + + +@requires_cuda +@pytest.mark.parametrize( + "dtype1", + [ + torch.float16, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ], +) +@pytest.mark.parametrize( + "dtype2", + [ + torch.float16, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ], +) +def test_provenance_tensor(dtype1, dtype2): + device = torch.device("cuda") + x = torch.tensor([1, 2, 3], dtype=dtype1) + y = ProvenanceTensor(x, frozenset(["x"])) + z = torch.as_tensor(y, device=device, dtype=dtype2) + + assert x.shape == y.shape == z.shape + assert_equal(x, z.cpu())