From 0428c3413c696a6add648b949ad392e3791ba700 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 1 May 2024 14:03:39 +0100 Subject: [PATCH] added code to turn value into tensor if not already --- src/brevitas/nn/mixin/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 6e48339ab..bad417939 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -79,7 +79,9 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe if not torch._C._get_tracing_state(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) - elif isinstance(inp, Tensor): + else: + if isinstance(inp, Tensor) is False: + inp = torch.tensor(inp) inp = inp.rename(None) return inp