Skip to content

Commit

Permalink
Feat (examples/generative): add device flag to quantize_model (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser authored Feb 19, 2024
1 parent 2cde5ed commit 620cc70
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def quantize_model(
input_quant_granularity=None,
input_group_size=None,
quantize_input_zero_point=False,
quantize_embedding=False):
quantize_embedding=False,
device=None):
"""
Replace float layers with quant layers in the target model
"""
Expand Down Expand Up @@ -208,7 +209,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype,},
'dtype': dtype,
'device': device},
**input_float_format)
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_row':
Expand All @@ -224,7 +226,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype},
'dtype': dtype,
'device': device},
**input_float_format)
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_tensor':
Expand Down Expand Up @@ -263,7 +266,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype},
'dtype': dtype,
'device': device},
**input_float_format)
if input_scale_type == 'dynamic':
if input_quant_granularity == 'per_row':
Expand All @@ -279,8 +283,12 @@ def quantize_model(
'group_dim': -1, 'group_size': input_group_size})

quant_linear_kwargs = {
'input_quant': linear_input_quant, 'weight_quant': weight_quant, 'dtype': dtype}
quant_conv_kwargs = {'input_quant': input_quant, 'weight_quant': weight_quant, 'dtype': dtype}
'input_quant': linear_input_quant,
'weight_quant': weight_quant,
'dtype': dtype,
'device': device}
quant_conv_kwargs = {
'input_quant': input_quant, 'weight_quant': weight_quant, 'dtype': dtype, 'device': device}

quant_mha_kwargs = {
'in_proj_input_quant': input_quant,
Expand All @@ -300,7 +308,8 @@ def quantize_model(
# activation equalization requires packed_in_proj
# since it supports only self-attention
'packed_in_proj': True,
'dtype': dtype}
'dtype': dtype,
'device': device}

layer_map = {
nn.Linear: (qnn.QuantLinear, quant_linear_kwargs),
Expand All @@ -311,7 +320,7 @@ def quantize_model(
nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)}

if quantize_embedding:
quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype}
quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype, 'device': device}
layer_map[nn.Embedding] = (qnn.QuantEmbedding, quant_embedding_kwargs)

model = layerwise_quantize(
Expand Down

0 comments on commit 620cc70

Please sign in to comment.