From bd46f89e1b53c6f0c26d85da5669d1d2eda496a6 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo <1934033+volcacius@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:25:49 +0100 Subject: [PATCH] Fix (backport): op decomp in make_fx backport (#763) Signed-off-by: Alessandro Pappalardo --- src/brevitas/backport/fx/experimental/proxy_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/backport/fx/experimental/proxy_tensor.py b/src/brevitas/backport/fx/experimental/proxy_tensor.py index 430623e9d..a5118878a 100644 --- a/src/brevitas/backport/fx/experimental/proxy_tensor.py +++ b/src/brevitas/backport/fx/experimental/proxy_tensor.py @@ -317,7 +317,7 @@ def proxy_call(proxy_mode, func, args, kwargs): # `__torch_dispatch__` is only called on torch ops, which must subclass `OpOverload` # We treat all other functions as an `external_call`, for instance, a function decorated # with `@torch.fx.wrap` - external_call = not isinstance(func, backport._ops.OpOverload) + external_call = not isinstance(func, (backport._ops.OpOverload, torch._ops.OpOverload)) def can_handle_tensor(x): return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)