Skip to content

Commit

Permalink
Fix and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 22, 2025
1 parent ad43760 commit 18c5104
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 69 deletions.
16 changes: 8 additions & 8 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def prepare_for_export(self, module):
self.group_dim = module.group_dim

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
inp_shape = x.shape
x, scale, zero_point, *other = self.module_forward(x)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
x = groupwise_dequant_expand(x, scale, zero_point, self.group_dim, inp_shape)[0]
output_args = tuple([x, scale, zero_point] + list(other))
return output_args


Expand Down Expand Up @@ -274,13 +274,13 @@ def prepare_for_export(self, module: nn.Module):
self.group_dim = module.group_dim

def forward(self, x: Tensor) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
inp_shape = x.shape
x, scale, zero_point, *other = self.module_forward(x)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
x = groupwise_dequant_expand(x, scale, zero_point, self.group_dim, inp_shape)[0]
output_args = tuple([x, scale, zero_point] + list(other))
return output_args


Expand Down
125 changes: 64 additions & 61 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat

Expand Down Expand Up @@ -181,7 +182,8 @@
'sym': Int8DynamicActPerRowFloat,
'asym': ShiftedUint8DynamicActPerRowFloat},
'per_group': {
'sym': Int8DynamicActPerGroupFloat}}},
'sym': Int8DynamicActPerGroupFloat,
'asym': ShiftedUint8DynamicActPerGroupFloat}}},
'po2_scale': {
'stats': {
'per_row': {
Expand Down Expand Up @@ -300,18 +302,19 @@ def generate_quantizers(
linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method][input_quant_granularity][input_quant_type]

kv_quant_type = kv_quant_type if kv_quant_type is not None else input_quant_type
kv_quant_granularity = kv_quant_granularity if kv_quant_granularity is not None else input_quant_granularity

v_quant = k_transposed_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method][kv_quant_granularity][kv_quant_type]

if kv_quant_type is not None:
q_scaled_quant = attn_output_weights_quant = None

else:
q_scaled_quant = attn_output_weights_quant = sym_input_quant

kv_quant_type = kv_quant_type if kv_quant_type is not None else input_quant_type
kv_quant_granularity = kv_quant_granularity if kv_quant_granularity is not None else input_quant_granularity
print(kv_quant_granularity)

v_quant = k_transposed_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method][kv_quant_granularity][kv_quant_type]

extra_kwargs = {
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
Expand Down Expand Up @@ -375,60 +378,60 @@ def generate_quantizers(
kv_broadcastable_shape_lambda = lambda x, shape: x.view(shape[0], 1, shape[-1])

# Modify the input quantizers based on the arguments passed in
if input_quant is not None:
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_row':
input_quant = input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
input_quant = input_quant.let(**{'group_dim': 2, 'group_size': input_group_size})
if sym_input_quant is not None:
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_row':
q_scaled_quant = q_scaled_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
v_quant = v_quant.let(
**{
'dynamic_scaling_broadcastable_fn': kv_broadcastable_shape_lambda,
'permute_dims': kv_permute_dims,
'stats_reduce_dim': 1})
k_transposed_quant = k_transposed_quant.let(
**{
'dynamic_scaling_broadcastable_fn': kv_broadcastable_shape_lambda,
'permute_dims': kv_permute_dims,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
q_scaled_quant = q_scaled_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size})
v_quant = v_quant.let(**{'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = k_transposed_quant.let(
**{
'group_dim': -2, 'group_size': input_group_size})
v_quant = k_transposed_quant
attn_output_weights_quant = q_scaled_quant

if linear_input_quant is not None:
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_row':
linear_input_quant = linear_input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
linear_input_quant = linear_input_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size})
if input_bit_width is not None:
# Input Quant
if input_quant_granularity == 'per_row':
input_quant = input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
input_quant = input_quant.let(**{'group_dim': 2, 'group_size': input_group_size})

# QKV/Softmax Quant
if kv_quant_granularity == 'per_row':
q_scaled_quant = q_scaled_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1}) if q_scaled_quant is not None else None
v_quant = v_quant.let(
**{
'dynamic_scaling_broadcastable_fn': kv_broadcastable_shape_lambda,
'permute_dims': kv_permute_dims,
'stats_reduce_dim': 1})
k_transposed_quant = k_transposed_quant.let(
**{
'dynamic_scaling_broadcastable_fn': kv_broadcastable_shape_lambda,
'permute_dims': kv_permute_dims,
'stats_reduce_dim': 1})
elif kv_quant_granularity == 'per_group':
q_scaled_quant = q_scaled_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size
}) if q_scaled_quant is not None else None
v_quant = v_quant.let(**{'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = k_transposed_quant.let(
**{
'group_dim': -2, 'group_size': input_group_size})
v_quant = k_transposed_quant
attn_output_weights_quant = q_scaled_quant

# Input to Linear Layer Quant
if input_quant_granularity == 'per_row':
linear_input_quant = linear_input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
linear_input_quant = linear_input_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size})
return linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant


Expand Down
13 changes: 13 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
from brevitas.core.zero_point import RuntimeDynamicGroupZeroPoint
from brevitas.core.zero_point import StatsFromParameterZeroPoint
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
Expand Down Expand Up @@ -99,6 +100,18 @@ class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
scaling_per_output_type = ScalingPerOutputType.GROUP


class ShiftedUint8DynamicActPerGroupFloat(DynamicActProxyMixin, ShiftedUint8ActPerTensorFloat):
"""
Symmetric quantizer with per group scale.
"""
proxy_class = GroupwiseActQuantProxyFromInjector
scaling_impl = RuntimeDynamicGroupStatsScaling
scaling_stats_op = 'min_max'
scaling_per_output_type = ScalingPerOutputType.GROUP
zero_point_impl = RuntimeDynamicGroupZeroPoint
zero_point_stats_impl = NegativeMinOrZero


class ShiftedUint8DynamicActPerTensorFloat(DynamicActProxyMixin, ShiftedUint8ActPerTensorFloat):
"""
Symmetric quantizer with per tensor dynamic scale.
Expand Down
66 changes: 66 additions & 0 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,72 @@ def test_small_models_quant_layer_types_count(caplog, layer_args_types_count):
assert_layer_types_count(model, exp_layer_types_count)


@pytest_cases.fixture(
ids=["mistral-kv-quant-fx-sdpa", "mistral-kv-quant-functional-sdpa"],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"act_calibration": False,
"input_quant_granularity": "per_row",
"kv_quant_granularity": "per_group",
"input_group_size": 32,
"input_scale_type": "dynamic",
"input_quant_type": "sym",
"quant_sdpa": True,
"functional_sdpa_quant": False,
"kv_quant_type": "asym"},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"act_calibration": False,
"input_quant_granularity": "per_row",
"kv_quant_granularity": "per_group",
"input_group_size": 32,
"input_scale_type": "dynamic",
"input_quant_type": "sym",
"quant_sdpa": False,
"functional_sdpa_quant": True,
"kv_quant_type": "asym"},])
def layer_args_hyperparam(default_run_args, request):
args = default_run_args
layer_dict = request.param
args.update(**layer_dict)
yield args


@pytest.mark.llm
@requires_pt_ge('2.2')
def test_small_models_quant_layer_hyperparam(caplog, layer_args_hyperparam):
from brevitas.nn import QuantScaledDotProductAttention as QuantSDPA
from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector
caplog.set_level(logging.INFO)
args = layer_args_hyperparam

float_ppl, quant_ppl, model = validate_args_and_run_main(args)
quant_sdpa = []
for m in model.modules():
if isinstance(m, QuantSDPA):
quant_sdpa.append(m)

first_sdpa = quant_sdpa[0]

# Check that Q/Softmax quantization is disabled
assert first_sdpa.q_scaled_quant.act_quant.fused_activation_quant_proxy is None
assert first_sdpa.attn_output_weights_quant.act_quant.fused_activation_quant_proxy is None
# NOTE: We assume that asym == unsigned. This might change in the future.
assert not first_sdpa.v_quant.act_quant.is_signed
assert not first_sdpa.k_transposed_quant.act_quant.is_signed
# Check for groupwise activation quantization
assert isinstance(first_sdpa.v_quant.act_quant, GroupwiseActQuantProxyFromInjector)
assert isinstance(first_sdpa.k_transposed_quant.act_quant, GroupwiseActQuantProxyFromInjector)
assert first_sdpa.v_quant.act_quant.group_size == args.input_group_size
assert first_sdpa.k_transposed_quant.act_quant.group_size == args.input_group_size
# Functional quantization uses one shared quant block for everything
if args.quant_sdpa:
assert len(quant_sdpa) > 1
elif args.functional_sdpa_quant:
assert len(quant_sdpa) == 1


@pytest_cases.fixture(
ids=[
"opt-replace-mha",
Expand Down

0 comments on commit 18c5104

Please sign in to comment.