Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 22, 2025
1 parent bc0ca5d commit 2e8485b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def generate_quantizers(
device=None,
weight_kwargs=None,
input_kwargs=None,
quant_attn_mode=None,
scaling_min_val=1e-4):
"""
Replace float layers with quant layers in the target model
Expand Down Expand Up @@ -361,6 +362,13 @@ def generate_quantizers(
if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group':
weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint)

if quant_attn_mode == 'sdpa':
k_permute_dims = (0, 1, 3, 2)
k_broadcastable_shape_lambda = lambda x, shape: x.view(shape[0], 1, shape[-2], shape[-1])
elif quant_attn_mode == 'mha':
k_permute_dims = (0, 2, 1)
k_broadcastable_shape_lambda = lambda x, shape: x.view(shape[0], 1, shape[-1])

# Modify the input quantizers based on the arguments passed in
if input_quant is not None:
if input_scale_type == 'dynamic':
Expand All @@ -369,6 +377,7 @@ def generate_quantizers(
**{
'dynamic_scaling_broadcastable_fn': lambda x,
shape: x.view(*shape[:-1], 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
input_quant = input_quant.let(**{'group_dim': 2, 'group_size': input_group_size})
Expand All @@ -389,12 +398,9 @@ def generate_quantizers(
'stats_reduce_dim': 1})
k_transposed_quant = k_transposed_quant.let(
**{
'dynamic_scaling_broadcastable_fn':
lambda x,
shape: x.view(shape[0], 1, shape[-1]),
'permute_dims': (0, 2, 1),
'stats_reduce_dim':
1})
'dynamic_scaling_broadcastable_fn': k_broadcastable_shape_lambda,
'permute_dims': k_permute_dims,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
q_scaled_quant = q_scaled_quant.let(
**{
Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def quantize_llm(args):
input_group_size=args.input_group_size,
quantize_input_zero_point=args.quantize_input_zero_point,
scale_rounding_func_type=args.scale_rounding_func_type,
quant_attn_mode='sdpa' if (quant_sdpa_fx or args.functional_sdpa_quant) else 'mha',
device=device,
scaling_min_val=args.scaling_min_val)
layer_map = generate_quant_maps(
Expand Down

0 comments on commit 2e8485b

Please sign in to comment.