diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 21cb46a19..31ab57361 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -141,7 +141,8 @@ def quantize_model( input_quant_granularity=None, input_group_size=None, quantize_input_zero_point=False, - quantize_embedding=False): + quantize_embedding=False, + device=None): """ Replace float layers with quant layers in the target model """ @@ -208,7 +209,8 @@ def quantize_model( **{ 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, - 'dtype': dtype,}, + 'dtype': dtype, + 'device': device}, **input_float_format) if input_scale_type == 'dynamic': if input_quant_granularity == 'per_row': @@ -224,7 +226,8 @@ def quantize_model( **{ 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, - 'dtype': dtype}, + 'dtype': dtype, + 'device': device}, **input_float_format) if input_scale_type == 'dynamic': if input_quant_granularity == 'per_tensor': @@ -263,7 +266,8 @@ def quantize_model( **{ 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, - 'dtype': dtype}, + 'dtype': dtype, + 'device': device}, **input_float_format) if input_scale_type == 'dynamic': if input_quant_granularity == 'per_row': @@ -279,8 +283,12 @@ def quantize_model( 'group_dim': -1, 'group_size': input_group_size}) quant_linear_kwargs = { - 'input_quant': linear_input_quant, 'weight_quant': weight_quant, 'dtype': dtype} - quant_conv_kwargs = {'input_quant': input_quant, 'weight_quant': weight_quant, 'dtype': dtype} + 'input_quant': linear_input_quant, + 'weight_quant': weight_quant, + 'dtype': dtype, + 'device': device} + quant_conv_kwargs = { + 'input_quant': input_quant, 'weight_quant': weight_quant, 'dtype': dtype, 'device': device} quant_mha_kwargs = { 'in_proj_input_quant': input_quant, @@ -300,7 +308,8 @@ def quantize_model( # activation equalization requires packed_in_proj # since it supports only self-attention 'packed_in_proj': True, - 'dtype': dtype} + 'dtype': dtype, + 'device': device} layer_map = { nn.Linear: (qnn.QuantLinear, quant_linear_kwargs), @@ -311,7 +320,7 @@ def quantize_model( nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)} if quantize_embedding: - quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype} + quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype, 'device': device} layer_map[nn.Embedding] = (qnn.QuantEmbedding, quant_embedding_kwargs) model = layerwise_quantize(