Skip to content

Commit

Permalink
More fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 24, 2024
1 parent 1e45caf commit 9a3f25d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ def forward(self, x):

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list[self.group_dim] = (
tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size
block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list)
tensor_shape_list.insert(block_dim, self.group_size)
x = x.view(tensor_shape_list)
return x
Expand Down
5 changes: 3 additions & 2 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ def forward(
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
threshold = self.restrict_clamp_scaling(self.restrict_module(threshold))
out = self.scaling_stats_impl(stats_input_reshaped)
# Apply log scaling
out = self.restrict_module(out)
# Scaling min val
out = self.restrict_clamp_scaling(out)
out = self.restrict_clamp_scaling(out) / threshold
return out
5 changes: 3 additions & 2 deletions src/brevitas/quant/experimental/mx_quant_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dependencies import value

from brevitas.core.function_wrapper.ops_ste import CeilSte
from brevitas.core.function_wrapper.ops_ste import FloorSte
from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling
from brevitas.inject import ExtendedInjector
from brevitas.inject.enum import RestrictValueType
Expand Down Expand Up @@ -46,14 +47,14 @@ class GroupwiseActProxyMixin(ExtendedInjector):
class MXWeightMixin(ExtendedInjector):
group_size = 32
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = CeilSte
restrict_value_float_to_int_impl = FloorSte
scaling_per_output_type = ScalingPerOutputType.GROUP


class MXActMixin(ExtendedInjector):
group_size = 32
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = CeilSte
restrict_value_float_to_int_impl = FloorSte
scaling_impl = RuntimeDynamicGroupStatsScaling
scaling_per_output_type = ScalingPerOutputType.GROUP

Expand Down

0 comments on commit 9a3f25d

Please sign in to comment.