Skip to content

Commit

Permalink
Review, notebook missing
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 20, 2024
1 parent c1de55e commit 2c26cd1
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 65 deletions.
6 changes: 0 additions & 6 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from brevitas.core.scaling.standalone import ParameterScaling
from brevitas.fx.brevitas_tracer import symbolic_trace
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.channel_splitting import GraphChannelSplitting
from brevitas.graph.equalize import EqualizeGraph
from brevitas.graph.fixed_point import CollapseConsecutiveConcats
from brevitas.graph.fixed_point import MergeBatchNorm
from brevitas.graph.fixed_point import MoveSplitBatchNormBeforeCat
from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool
from brevitas.graph.quantize_impl import act_handler
from brevitas.graph.quantize_impl import add_output_quant_handler
from brevitas.graph.quantize_impl import inp_placeholder_handler
Expand All @@ -26,11 +24,7 @@
from brevitas.graph.standardize import MeanMethodToAdaptiveAvgPool2d
from brevitas.graph.standardize import RemoveStochasticModules
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.nn import quant_layer
import brevitas.nn as qnn
from brevitas.proxy.groupwise_float_parameter_quant import \
GroupwiseWeightFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8ActPerTensorFloatMinMaxInit
from brevitas.quant import Int8WeightPerTensorFloat
Expand Down
5 changes: 4 additions & 1 deletion src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.quant_tensor.groupwise_float_quant_tensor import GroupwiseFloatQuantTensor
from brevitas.quant_tensor.groupwise_int_quant_tensor import GroupwiseIntQuantTensor

from .utils import filter_kwargs

Expand Down Expand Up @@ -71,7 +73,8 @@ def _set_global_is_quant_layer(self, value):
config._IS_INSIDE_QUANT_LAYER = value

def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]):
quant_tensor_classes = [IntQuantTensor, FloatQuantTensor]
quant_tensor_classes = [
IntQuantTensor, FloatQuantTensor, GroupwiseIntQuantTensor, GroupwiseFloatQuantTensor]
for qt_class in quant_tensor_classes:
if len(inp) == len(qt_class._fields):
return qt_class
Expand Down
14 changes: 7 additions & 7 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTens
y,
x.scale,
x.zero_point,
x.mantissa_bit_width,
x.exponent_bit_width,
x.mantissa_bit_width,
x.exponent_bias,
x.signed,
self.training,
x.saturating,
x.inf_values,
x.nan_values)
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
Expand Down Expand Up @@ -143,11 +143,11 @@ def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuan
x.mantissa_bit_width,
x.exponent_bit_width,
x.exponent_bias,
x.signed,
self.training,
x.saturating,
x.inf_values,
x.nan_values)
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union

import torch
from torch import Tensor
Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat
y,
x.scale,
x.zero_point,
self.group_dim,
self.group_size,
x.mantissa_bit_width,
self.group_dim,
x.exponent_bit_width,
x.mantissa_bit_width,
x.exponent_bias,
x.signed,
self.training,
x.saturating,
x.inf_values,
x.nan_values)
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]:
out,
scale,
zero_point,
bit_width,
self.group_dim,
self.group_size,
self.group_dim,
bit_width,
self.is_signed,
self.training)
else: # quantization disabled
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseIntQu
y,
x.scale,
x.zero_point,
self.group_dim,
self.group_size,
self.group_dim,
x.bit_width,
x.signed,
self.training)
Expand Down
21 changes: 11 additions & 10 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ def cat(tensors, dim, out=None):
exponent_bit_width=output_exponent_bit_width,
mantissa_bit_width=output_mantissa_bit_width,
exponent_bias=output_exponent_bias,
signed=output_signed,
training=output_training,
saturating=output_saturating,
inf_values=output_inf_values,
nan_values=output_nan_values)
nan_values=output_nan_values,
signed=output_signed,
training=output_training)
else:
tensors = [_unpack_quant_tensor(qt) for qt in tensors]
output_value = torch.cat(tensors, dim=dim)
Expand All @@ -280,11 +280,11 @@ def __neg__(self):
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
exponent_bias=self.exponent_bias,
signed=self.signed,
training=self.training,
saturating=self.saturating,
inf_values=self.inf_values,
nan_values=self.nan_values)
nan_values=self.nan_values,
signed=self.signed,
training=self.training)
else:
# TODO: implement
raise NotImplementedError
Expand All @@ -304,7 +304,7 @@ def __mul__(self, other):
return output

def __str__(self):
return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})"
return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed_t={self.signed_t}, training_t={self.training_t})"

def __truediv__(self, other):
if isinstance(other, QuantTensor):
Expand All @@ -325,10 +325,11 @@ def __abs__(self):
zero_point=self.zero_point,
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
signed=False,
training=self.training,
exponent_bias=self.exponent_bias,
saturating=self.saturating,
inf_values=self.inf_values,
nan_values=self.nan_values)
nan_values=self.nan_values,
signed=False,
training=self.training)
else:
return self
6 changes: 3 additions & 3 deletions src/brevitas/quant_tensor/float_torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def embedding_handler(input, quant_weight, *args, **kwargs):
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
signed,
training,
saturating,
inf_values,
nan_values)
nan_values,
signed,
training)
return out


Expand Down
29 changes: 17 additions & 12 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def expand(self):
return new_value, new_scale, new_zp

@staticmethod
def from_expanded(value, group_dim, group_size, compress=False):
def from_expanded(value, group_size, group_dim, compress=False):
size = list(value.shape)
assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size'
if compress:
Expand Down Expand Up @@ -253,22 +253,24 @@ def __neg__(self):
# In case the dtype of self.minifloat is different from the one of the scale
neg_value = neg_value.type(scale.dtype)
neg_value = GroupwiseFloatQuantTensor.from_expanded(
neg_value, self.group_dim, self.group_size, compress=False)
neg_value, self.group_size, self.group_dim, compress=False)
scale = GroupwiseFloatQuantTensor.from_expanded(
scale, self.group_dim, self.group_size, compress=True)
scale, self.group_size, self.group_dim, compress=True)
if self.signed:
return GroupwiseFloatQuantTensor(
value=neg_value,
scale=scale,
zero_point=self.zero_point,
group_size=self.group_size,
group_dim=self.group_dim,
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
exponent_bias=self.exponent_bias,
signed=self.signed,
training=self.training,
saturating=self.saturating,
inf_values=self.inf_values,
nan_values=self.nan_values)
nan_values=self.nan_values,
signed=self.signed,
training=self.training)
else:
# TODO: implement
raise NotImplementedError
Expand All @@ -288,7 +290,7 @@ def __mul__(self, other):
return output

def __str__(self):
return f"GroupwiseFloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})"
return f"GroupwiseFloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed_t={self.signed_t}, training_t={self.training_t})"

def __truediv__(self, other):
if isinstance(other, QuantTensor):
Expand All @@ -307,19 +309,22 @@ def __abs__(self):
# In case the dtype of self.minifloat is different from the one of the scale
abs_value = abs_value.type(scale.dtype)
abs_value = GroupwiseFloatQuantTensor.from_expanded(
abs_value, self.group_dim, self.group_size, compress=False)
abs_value, self.group_size, self.group_dim, compress=False)
scale = GroupwiseFloatQuantTensor.from_expanded(
scale, self.group_dim, self.group_size, compress=True)
scale, self.group_size, self.group_dim, compress=True)
return GroupwiseFloatQuantTensor(
value=abs_value,
scale=self.scale,
zero_point=self.zero_point,
group_size=self.group_size,
group_dim=self.group_dim,
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
signed=False,
training=self.training,
exponent_bias=self.exponent_bias,
saturating=self.saturating,
inf_values=self.inf_values,
nan_values=self.nan_values)
nan_values=self.nan_values,
signed=False,
training=self.training)
else:
return self
31 changes: 14 additions & 17 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def expand(self):
return new_value, new_scale, new_zp

@staticmethod
def from_expanded(value, group_dim, group_size, compress=False):
def from_expanded(value, group_size, group_dim, compress=False):
size = list(value.shape)
assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size'
if compress:
Expand Down Expand Up @@ -243,22 +243,20 @@ def __neg__(self):
# In case the dtype of self.minifloat is different from the one of the scale
neg_value = neg_value.type(scale.dtype)
neg_value = GroupwiseIntQuantTensor.from_expanded(
neg_value, self.group_dim, self.group_size, compress=False)
neg_value, self.group_size, self.group_dim, compress=False)
scale = GroupwiseIntQuantTensor.from_expanded(
scale, self.group_dim, self.group_size, compress=True)
scale, self.group_size, self.group_dim, compress=True)
if self.signed:
return GroupwiseIntQuantTensor(
value=neg_value,
scale=scale,
zero_point=self.zero_point,
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
exponent_bias=self.exponent_bias,
group_size=self.group_size,
group_dim=self.group_dim,
bit_width=self.bit_width,
signed=self.signed,
training=self.training,
saturating=self.saturating,
inf_values=self.inf_values,
nan_values=self.nan_values)
saturating=self.saturating)
else:
# TODO: implement
raise NotImplementedError
Expand All @@ -278,7 +276,7 @@ def __mul__(self, other):
return output

def __str__(self):
return f"GroupwiseIntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})"
return f"GroupwiseIntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})"

def __truediv__(self, other):
if isinstance(other, QuantTensor):
Expand All @@ -297,19 +295,18 @@ def __abs__(self):
# In case the dtype of self.minifloat is different from the one of the scale
abs_value = abs_value.type(scale.dtype)
abs_value = GroupwiseIntQuantTensor.from_expanded(
abs_value, self.group_dim, self.group_size, compress=False)
abs_value, self.group_size, self.group_dim, compress=False)
scale = GroupwiseIntQuantTensor.from_expanded(
scale, self.group_dim, self.group_size, compress=True)
scale, self.group_size, self.group_dim, compress=True)
return GroupwiseIntQuantTensor(
value=abs_value,
scale=self.scale,
zero_point=self.zero_point,
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
group_size=self.group_size,
group_dim=self.group_dim,
bit_width=self.bit_width,
signed=False,
training=self.training,
saturating=self.saturating,
inf_values=self.inf_values,
nan_values=self.nan_values)
saturating=self.saturating)
else:
return self

0 comments on commit 2c26cd1

Please sign in to comment.