Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (activation_calibration): speed-up by skipping quantization #1029

Merged
merged 8 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ def __init__(
if dtype is None:
dtype = torch.get_default_dtype()
self.eps = torch.finfo(dtype).tiny
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x)

def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
Expand All @@ -86,10 +85,15 @@ def dequantize(self, y, scale):

@brevitas.jit.script_method
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
scale = self.scaling_impl(x)
if self.observer_only:
y = x
saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values
else:
y, scale = self.quantize(x, scale)
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values
17 changes: 14 additions & 3 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -153,7 +154,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.int_quant(scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width


Expand All @@ -176,6 +180,7 @@ def __init__(
self.pre_zero_point_impl = pre_zero_point_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -187,7 +192,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point


Expand Down Expand Up @@ -253,5 +261,8 @@ def forward(self, x: Tensor, input_bit_width: Tensor,
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
34 changes: 34 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,19 @@ def _set_local_loss_mode(module, enabled):
m.local_loss_mode = enabled


def _set_observer_mode(module, enabled, previous_observer_mode):
for m in module.modules():
if hasattr(m, 'observer_only'):
previous_observer_mode[m] = m.observer_only
m.observer_only = enabled


def _restore_observer_mode(module, previous_observer_mode):
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = previous_observer_mode[m]


class MSE(torch.nn.Module):
# References:
# https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py
Expand All @@ -459,7 +472,12 @@ def __init__(
self.mse_init_op = mse_init_op
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_observer_mode = lambda enabled: _set_observer_mode(
proxy_module, enabled, self.previous_observer_mode)
self.restore_observer_mode = lambda: _restore_observer_mode(
proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.num = mse_iters
self.search_method = mse_search_method
Expand All @@ -480,10 +498,12 @@ def evaluate_loss(self, x, candidate):
self.internal_candidate = candidate
# Set to local_loss_mode before calling the proxy
self.set_local_loss_mode(True)
self.set_observer_mode(False)
quant_value = self.proxy_forward(x)
quant_value = _unpack_quant_tensor(quant_value)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
self.restore_observer_mode()
return loss

def mse_grid_search(self, xl, x):
Expand Down Expand Up @@ -571,7 +591,12 @@ def __init__(
self.hqo_init_op = hqo_init_op_scale
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_observer_mode = lambda enabled: _set_observer_mode(
proxy_module, enabled, self.previous_observer_mode)
self.restore_observer_mode = lambda: _restore_observer_mode(
proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.hqo_iters = hqo_iters_scale
self.stats_reduce_dim = stats_reduce_dim
Expand All @@ -598,8 +623,10 @@ def parameter_search(self, xl, x):
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
self.set_observer_mode(False)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
self.restore_observer_mode()
loss = torch.abs(quant_tensor.value - x).mean()

best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
Expand Down Expand Up @@ -670,7 +697,12 @@ def __init__(
self.hqo_init_op_zp = hqo_init_op_zp
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_observer_mode = lambda enabled: _set_observer_mode(
proxy_module, enabled, self.previous_observer_mode)
self.restore_observer_mode = lambda: _restore_observer_mode(
proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False
Expand All @@ -688,8 +720,10 @@ def parameter_search(self, xl, x):
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
self.set_observer_mode(False)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
self.restore_observer_mode()
qt_value = self.input_view_shape_impl(quant_tensor.value)
qt_scale = self.input_view_shape_impl(quant_tensor.scale)
qt_zp = self.input_view_shape_impl(quant_tensor.zero_point)
Expand Down
11 changes: 6 additions & 5 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ def disable_act_quantization(self, model, is_training):
if isinstance(module, ActQuantProxyFromInjectorBase):
module.train(is_training)
if self.call_act_quantizer_impl:
hook = module.register_forward_hook(self.disable_act_quant_hook)
self.disable_act_quant_hooks.append(hook)
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = True
else:
module.disable_quant = True
elif isinstance(module, _ACC_PROXIES):
Expand All @@ -228,9 +229,9 @@ def enable_act_quantization(self, model, is_training):
elif isinstance(module, ActQuantProxyFromInjectorBase):
module.disable_quant = False
module.train(is_training)
for hook in self.disable_act_quant_hooks:
hook.remove()
self.disable_act_quant_hooks = []
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = False

def enable_param_quantization(self, model, is_training):
for module in model.modules():
Expand Down
7 changes: 4 additions & 3 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_float_to_quant_float(inp, minifloat_format):
signed=signed,
float_clamp_impl=float_clamp)
expected_out, *_ = float_quant(inp)

out_quant, scale = float_quant.quantize(inp)
scale = float_quant.scaling_impl(inp)
out_quant, scale = float_quant.quantize(inp, scale)
exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float)
out_quant, *_ = float_quant.float_clamp_impl(
out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias)
Expand Down Expand Up @@ -142,7 +142,8 @@ def test_scaling_impls_called_once(inp, minifloat_format):
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl,
float_clamp_impl=float_clamp)
_ = float_quant.quantize(inp)
scale = float_quant.scaling_impl(inp)
_ = float_quant.quantize(inp, scale)
# scaling implementations should be called exaclty once on the input
float_scaling_impl.assert_called_once_with(
torch.tensor(exponent_bit_width),
Expand Down
Loading