From ec089712c45f205b527ee7f2dc74f674ab274823 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 10 Nov 2023 13:45:38 +0000 Subject: [PATCH] Fix Test --- tests/brevitas_examples/test_jit_trace.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/brevitas_examples/test_jit_trace.py b/tests/brevitas_examples/test_jit_trace.py index 52a4c6b1d..4c16bf800 100644 --- a/tests/brevitas_examples/test_jit_trace.py +++ b/tests/brevitas_examples/test_jit_trace.py @@ -6,7 +6,6 @@ import pytest import torch -from brevitas.utils.jit_utils import jit_patches_generator from brevitas_examples.bnn_pynq.models import model_with_cfg FC_INPUT_SIZE = (1, 1, 28, 28) @@ -28,9 +27,6 @@ def test_brevitas_fc_jit_trace(size, wbits, abits): fc, _ = model_with_cfg(nname.lower(), pretrained=False) fc.train(False) input_tensor = torch.randn(FC_INPUT_SIZE) - with ExitStack() as stack: - for mgr in jit_patches_generator(): - stack.enter_context(mgr) traced_model = torch.jit.trace(fc, input_tensor) out_traced = traced_model(input_tensor) out = fc(input_tensor) @@ -46,9 +42,6 @@ def test_brevitas_cnv_jit_trace(wbits, abits): cnv, _ = model_with_cfg(nname.lower(), pretrained=False) cnv.train(False) input_tensor = torch.randn(CNV_INPUT_SIZE) - with ExitStack() as stack: - for mgr in jit_patches_generator(): - stack.enter_context(mgr) traced_model = torch.jit.trace(cnv, input_tensor) out_traced = traced_model(input_tensor) out = cnv(input_tensor)