Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (brevitas_examples/llm): separate KV Cache quantization #1165

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def prepare_for_export(self, module):
self.group_dim = module.group_dim

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
inp_shape = x.shape
x, scale, zero_point, *other = self.module_forward(x)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
x = groupwise_dequant_expand(x, scale, zero_point, self.group_dim, inp_shape)[0]
output_args = tuple([x, scale, zero_point] + list(other))
return output_args


Expand Down Expand Up @@ -274,13 +274,13 @@ def prepare_for_export(self, module: nn.Module):
self.group_dim = module.group_dim

def forward(self, x: Tensor) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
inp_shape = x.shape
x, scale, zero_point, *other = self.module_forward(x)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim >= 0 else self.group_dim - 1
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
x = groupwise_dequant_expand(x, scale, zero_point, self.group_dim, inp_shape)[0]
output_args = tuple([x, scale, zero_point] + list(other))
return output_args


Expand Down
198 changes: 104 additions & 94 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat

Expand Down Expand Up @@ -181,7 +182,8 @@
'sym': Int8DynamicActPerRowFloat,
'asym': ShiftedUint8DynamicActPerRowFloat},
'per_group': {
'sym': Int8DynamicActPerGroupFloat}}},
'sym': Int8DynamicActPerGroupFloat,
'asym': ShiftedUint8DynamicActPerGroupFloat}}},
'po2_scale': {
'stats': {
'per_row': {
Expand Down Expand Up @@ -245,11 +247,14 @@ def generate_quantizers(
input_quant_type=None,
input_quant_granularity=None,
input_group_size=None,
kv_quant_type=None,
kv_quant_granularity=None,
quantize_input_zero_point=False,
scale_rounding_func_type=None,
device=None,
weight_kwargs=None,
input_kwargs=None,
quant_attn_mode=None,
scaling_min_val=1e-4):
"""
Replace float layers with quant layers in the target model
Expand All @@ -274,6 +279,17 @@ def generate_quantizers(
weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]

if input_kwargs is None:
input_kwargs = dict()

if scale_rounding_func_type is not None:
scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte}
scale_type = scale_rounding_func_dict[scale_rounding_func_type]
input_kwargs = {**input_kwargs, **{'restrict_value_float_to_int_impl': scale_type}}

if scaling_min_val is not None:
input_kwargs = {**input_kwargs, **{'scaling_min_val': scaling_min_val}}

if input_bit_width is not None and input_scale_type == 'no_scale':
input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][
input_scale_type][input_quant_type]
Expand All @@ -286,17 +302,40 @@ def generate_quantizers(
linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method][input_quant_granularity][input_quant_type]

if input_kwargs is None:
input_kwargs = dict()
if kv_quant_type is not None:
q_scaled_quant = attn_output_weights_quant = None

else:
q_scaled_quant = attn_output_weights_quant = sym_input_quant

kv_quant_type = kv_quant_type if kv_quant_type is not None else input_quant_type
kv_quant_granularity = kv_quant_granularity if kv_quant_granularity is not None else input_quant_granularity
print(kv_quant_granularity)

v_quant = k_transposed_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method][kv_quant_granularity][kv_quant_type]

extra_kwargs = {
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype,
'device': device}
input_kwargs = {**input_kwargs, **extra_kwargs, **input_float_format}

input_quant = input_quant.let(**input_kwargs)
sym_input_quant = sym_input_quant.let(**input_kwargs)
linear_input_quant = linear_input_quant.let(**input_kwargs)
v_quant = v_quant.let(**input_kwargs)
k_transposed_quant = k_transposed_quant.let(**input_kwargs)
q_scaled_quant = q_scaled_quant.let(**input_kwargs) if q_scaled_quant is not None else None
attn_output_weights_quant = attn_output_weights_quant.let(
**input_kwargs) if attn_output_weights_quant is not None else None

else:
input_quant = None
sym_input_quant = None
linear_input_quant = None
q_scaled_quant = attn_output_weights_quant = v_quant = k_transposed_quant = None

# Modify the weight quantizer based on the arguments passed in
weight_quant = weight_quant.let(
Expand All @@ -310,26 +349,13 @@ def generate_quantizers(
scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte}
scale_type = scale_rounding_func_dict[scale_rounding_func_type]
weight_quant = weight_quant.let(**{'restrict_value_float_to_int_impl': scale_type})
if input_quant is not None:
input_quant = input_quant.let(**{'restrict_value_float_to_int_impl': scale_type})
if sym_input_quant is not None:
sym_input_quant = sym_input_quant.let(
**{'restrict_value_float_to_int_impl': scale_type})
if linear_input_quant is not None:
linear_input_quant = linear_input_quant.let(
**{'restrict_value_float_to_int_impl': scale_type})

if weight_group_dim is not None:
weight_quant = weight_quant.let(**{'group_dim': weight_group_dim})

if scaling_min_val is not None:
weight_quant = weight_quant.let(**{'scaling_min_val': scaling_min_val})
input_quant = input_quant.let(
**{'scaling_min_val': scaling_min_val}) if input_quant is not None else None
linear_input_quant = linear_input_quant.let(
**{'scaling_min_val': scaling_min_val}) if linear_input_quant is not None else None
sym_input_quant = sym_input_quant.let(
**{'scaling_min_val': scaling_min_val}) if sym_input_quant is not None else None

if weight_kwargs is not None:
weight_quant = weight_quant.let(**weight_kwargs)

Expand All @@ -344,84 +370,68 @@ def generate_quantizers(
if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group':
weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint)

if quant_attn_mode == 'sdpa':
kv_permute_dims = (0, 1, 3, 2)
kv_broadcastable_shape_lambda = lambda x, shape: x.view(shape[0], 1, shape[-2], shape[-1])
elif quant_attn_mode == 'mha':
kv_permute_dims = (0, 2, 1)
kv_broadcastable_shape_lambda = lambda x, shape: x.view(shape[0], 1, shape[-1])

# Modify the input quantizers based on the arguments passed in
if input_quant is not None:
input_quant = input_quant.let(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype,
'device': device},
**input_float_format)
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_row':
input_quant = input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
input_quant = input_quant.let(**{'group_dim': 2, 'group_size': input_group_size})
if sym_input_quant is not None:
sym_input_quant = sym_input_quant.let(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype,
'device': device},
**input_float_format)
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_tensor':
q_scaled_quant = sym_input_quant
k_transposed_quant = sym_input_quant
elif input_quant_granularity == 'per_row':
q_scaled_quant = sym_input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
k_transposed_quant = sym_input_quant.let(
**{
'dynamic_scaling_broadcastable_fn':
lambda x,
shape: x.view(shape[0], 1, shape[-1]),
'permute_dims': (0, 2, 1),
'stats_reduce_dim':
1})
elif input_quant_granularity == 'per_group':
q_scaled_quant = sym_input_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = sym_input_quant.let(
**{
'group_dim': -2, 'group_size': input_group_size})
v_quant = k_transposed_quant
attn_output_weights_quant = q_scaled_quant
else:
q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = sym_input_quant
else:
q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = None
if linear_input_quant is not None:
linear_input_quant = linear_input_quant.let(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype,
'device': device},
**input_float_format)
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_row':
linear_input_quant = linear_input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
linear_input_quant = linear_input_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size})
if input_bit_width is not None:
# Input Quant
if input_quant_granularity == 'per_row':
input_quant = input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
input_quant = input_quant.let(**{'group_dim': 2, 'group_size': input_group_size})

# QKV/Softmax Quant
if kv_quant_granularity == 'per_row':
q_scaled_quant = q_scaled_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1}) if q_scaled_quant is not None else None
v_quant = v_quant.let(
**{
'dynamic_scaling_broadcastable_fn': kv_broadcastable_shape_lambda,
'permute_dims': kv_permute_dims,
'stats_reduce_dim': 1})
k_transposed_quant = k_transposed_quant.let(
**{
'dynamic_scaling_broadcastable_fn': kv_broadcastable_shape_lambda,
'permute_dims': kv_permute_dims,
'stats_reduce_dim': 1})
elif kv_quant_granularity == 'per_group':
q_scaled_quant = q_scaled_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size
}) if q_scaled_quant is not None else None
v_quant = v_quant.let(**{'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = k_transposed_quant.let(
**{
'group_dim': -2, 'group_size': input_group_size})
v_quant = k_transposed_quant
attn_output_weights_quant = q_scaled_quant

# Input to Linear Layer Quant
if input_quant_granularity == 'per_row':
linear_input_quant = linear_input_quant.let(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
linear_input_quant = linear_input_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size})
return linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant


Expand Down
13 changes: 13 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
from brevitas.core.zero_point import RuntimeDynamicGroupZeroPoint
from brevitas.core.zero_point import StatsFromParameterZeroPoint
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
Expand Down Expand Up @@ -99,6 +100,18 @@ class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
scaling_per_output_type = ScalingPerOutputType.GROUP


class ShiftedUint8DynamicActPerGroupFloat(DynamicActProxyMixin, ShiftedUint8ActPerTensorFloat):
"""
Symmetric quantizer with per group scale.
"""
proxy_class = GroupwiseActQuantProxyFromInjector
scaling_impl = RuntimeDynamicGroupStatsScaling
scaling_stats_op = 'min_max'
scaling_per_output_type = ScalingPerOutputType.GROUP
zero_point_impl = RuntimeDynamicGroupZeroPoint
zero_point_stats_impl = NegativeMinOrZero


class ShiftedUint8DynamicActPerTensorFloat(DynamicActProxyMixin, ShiftedUint8ActPerTensorFloat):
"""
Symmetric quantizer with per tensor dynamic scale.
Expand Down
10 changes: 9 additions & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--input-param-method {stats,mse}]
[--input-scale-precision {float_scale,po2_scale}]
[--input-scale-type {static,dynamic,no_scale}]
[--input-quant-type {sym,asym}]
[--input-quant-type {sym,asym}] [--kv-quant-type {sym,asym}]
[--input-quant-granularity {per_tensor,per_row,per_group}]
[--kv-quant-granularity {per_tensor,per_row,per_group}]
[--input-group-size INPUT_GROUP_SIZE]
[--learned-round-lr LEARNED_ROUND_LR]
[--learned-round-scale-lr LEARNED_ROUND_SCALE_LR]
Expand Down Expand Up @@ -125,9 +126,16 @@ options:
value.
--input-quant-type {sym,asym}
Input quantization type. Default: asym.
--kv-quant-type {sym,asym}
KV quantization type. If None, it will follow input
quant type. If set, will perform only KV cache
quantization. Default: None
--input-quant-granularity {per_tensor,per_row,per_group}
Granularity for scales/zero-point of inputs. Default:
per_tensor.
--kv-quant-granularity {per_tensor,per_row,per_group}
Granularity for scales/zero-point of inputs. Default:
per_tensor.
--input-group-size INPUT_GROUP_SIZE
Group size for per_group input quantization. Default:
64.
Expand Down
3 changes: 3 additions & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ input_quant_granularity: per_tensor
input_quant_type: asym
input_scale_precision: float_scale
input_scale_type: static
kv_quant_granularity: null
kv_quant_type: null
learned_round: null
learned_round_fast_update: false
learned_round_iters: 200
Expand All @@ -59,6 +61,7 @@ replace_rmsnorm: false
rotation: null
rotation_mode: had
rotation_orphan_sink: false
rotation_sdpa_regions: false
scale_rounding_func_type: null
scaling_min_val: 0.0001
seed: 0
Expand Down
Loading
Loading