Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (jit): remove patcher #752

Merged
merged 3 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/brevitas/core/quant/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, quant_delay_steps):
super(_DelayQuant, self).__init__()
self.quant_delay_steps: int = brevitas.jit.Attribute(quant_delay_steps, int)

@brevitas.jit.script_method_110_disabled
@brevitas.jit.script_method
def forward(self, x: Tensor, y: Tensor) -> Tensor:
if self.quant_delay_steps > 0:
self.quant_delay_steps = self.quant_delay_steps - 1
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method_110_disabled
@brevitas.jit.script_method
def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor:
y = x / scale
y = y + zero_point
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method_110_disabled
@brevitas.jit.script_method
def to_int(
self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor,
x: Tensor) -> Tensor:
Expand Down
7 changes: 1 addition & 6 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from brevitas.proxy.quant_proxy import QuantProxyProtocol
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.jit_utils import clear_class_registry
from brevitas.utils.jit_utils import jit_patches_generator
from brevitas.utils.python_utils import patch


Expand Down Expand Up @@ -162,7 +161,6 @@ class BaseManager(ABC):

target_name = None
handlers = []
_base_trace_patches_generator = jit_patches_generator
_fn_to_cache = []
_fn_cache = []
_cached_io_handler_map = {}
Expand All @@ -183,10 +181,7 @@ def _gen_patches(cls, fn_dispatcher):

@classmethod
def _trace_patches(cls):
patches = []
if cls._base_trace_patches_generator is not None:
patches += cls._base_trace_patches_generator()
patches += cls._gen_patches(cls._trace_fn_dispatcher)
patches = cls._gen_patches(cls._trace_fn_dispatcher)
return patches

@classmethod
Expand Down
9 changes: 0 additions & 9 deletions src/brevitas/jit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from packaging import version
import torch

from brevitas.config import JIT_ENABLED

IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0')


def _disabled(fn):
return fn
Expand All @@ -20,15 +17,9 @@ def _disabled(fn):
ScriptModule = torch.jit.ScriptModule
Attribute = torch.jit.Attribute

if not IS_ABOVE_110:
script_method_110_disabled = _disabled
else:
script_method_110_disabled = script_method

else:

script_method = _disabled
script = _disabled
script_method_110_disabled = _disabled
ScriptModule = torch.nn.Module
Attribute = lambda val, type: val
30 changes: 1 addition & 29 deletions src/brevitas/utils/jit_utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import inspect

import torch

try:
from torch._jit_internal import get_torchscript_modifier
except:
get_torchscript_modifier = None

from dependencies import Injector
from packaging import version
import torch

from brevitas import torch_version
from brevitas.inject import ExtendedInjector
from brevitas.jit import IS_ABOVE_110

from .python_utils import patch


def _get_modifier_wrapper(fn):
if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)):
return None
else:
return get_torchscript_modifier(fn)


if IS_ABOVE_110:

def jit_patches_generator():
return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)]
else:
jit_patches_generator = None


def clear_class_registry():
Expand Down
7 changes: 0 additions & 7 deletions tests/brevitas_examples/test_jit_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import torch

from brevitas.utils.jit_utils import jit_patches_generator
from brevitas_examples.bnn_pynq.models import model_with_cfg

FC_INPUT_SIZE = (1, 1, 28, 28)
Expand All @@ -28,9 +27,6 @@ def test_brevitas_fc_jit_trace(size, wbits, abits):
fc, _ = model_with_cfg(nname.lower(), pretrained=False)
fc.train(False)
input_tensor = torch.randn(FC_INPUT_SIZE)
with ExitStack() as stack:
for mgr in jit_patches_generator():
stack.enter_context(mgr)
traced_model = torch.jit.trace(fc, input_tensor)
out_traced = traced_model(input_tensor)
out = fc(input_tensor)
Expand All @@ -46,9 +42,6 @@ def test_brevitas_cnv_jit_trace(wbits, abits):
cnv, _ = model_with_cfg(nname.lower(), pretrained=False)
cnv.train(False)
input_tensor = torch.randn(CNV_INPUT_SIZE)
with ExitStack() as stack:
for mgr in jit_patches_generator():
stack.enter_context(mgr)
traced_model = torch.jit.trace(cnv, input_tensor)
out_traced = traced_model(input_tensor)
out = cnv(input_tensor)
Expand Down
Loading