Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 15, 2025
1 parent 8004832 commit 9936c35
Show file tree
Hide file tree
Showing 9 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args
Expand Down Expand Up @@ -278,7 +278,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def group_size(self):

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def group_size(self):

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def group_size(self):

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def group_size(self):

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant/solver/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None):
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output == ScalingPerOutputType.GROUP:
reduce_dim = group_dim + 1 if group_dim > 0 else group_dim
reduce_dim = group_dim + 1 if group_dim >= 0 else group_dim
return reduce_dim

@value
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def expand(self):

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
group_dim = group_dim if group_dim > 0 else group_dim - 1
group_dim = group_dim if group_dim >= 0 else group_dim - 1
size = list(value.shape)
assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size'
if compress:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def expand(self):

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
group_dim = group_dim if group_dim > 0 else group_dim - 1
group_dim = group_dim if group_dim >= 0 else group_dim - 1
size = list(value.shape)
assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size'
if compress:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def float_to_int_impl_to_enum(module):

def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_shape):
curr_shape = value_.shape
start_dim = group_dim if group_dim > 0 else group_dim - 1
start_dim = group_dim if group_dim >= 0 else group_dim - 1
new_value = value_.flatten(start_dim, start_dim + 1)
if scale_.shape != ():
new_scale = scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
Expand Down

0 comments on commit 9936c35

Please sign in to comment.