Skip to content

Commit

Permalink
Fix (quantizer): fix shapes for mse/hqo optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 13, 2024
1 parent 482531c commit 9596911
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def evaluate_loss(self, x, candidate):
self.set_observer_mode(False)
quant_value = self.proxy_forward(x)
quant_value = _unpack_quant_tensor(quant_value)
quant_value = quant_value.view(x.shape)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
self.restore_observer_mode()
Expand Down
1 change: 1 addition & 0 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase
from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
Expand Down
3 changes: 0 additions & 3 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,6 @@ class MSEAsymmetricScale(ExtendedInjector):

mse_scale = MSEAsymmetricScaleSubInjector
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
Expand All @@ -561,7 +560,6 @@ class MSESymmetricScale(ExtendedInjector):

mse_scale = MSESymmetricScaleSubInjector
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
Expand All @@ -583,7 +581,6 @@ class MSEZeroPoint(ExtendedInjector):
"""

mse_zero_point = MSEZeroPointSubInjector
zero_point_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
Expand Down
1 change: 0 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ def main(args):

if args.bias_corr:
model = add_zero_bias_to_linear(model)

model = offload_model(model)

with torch.no_grad():
Expand Down

0 comments on commit 9596911

Please sign in to comment.