Skip to content

Commit

Permalink
restored QuantConvTranspose3d to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev committed Jan 31, 2024
1 parent 22fe5c2 commit ee01b84
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 9 deletions.
3 changes: 2 additions & 1 deletion tests/brevitas/export/quant_module_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
QuantConv2d,
QuantConv3d,
QuantConvTranspose1d,
QuantConvTranspose2d, #QuantConvTranspose3d,
QuantConvTranspose2d,
QuantConvTranspose3d,
]
BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8
BIAS_BIT_WIDTHS = [8, 16, 32]
Expand Down
3 changes: 2 additions & 1 deletion tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@
QuantConv2d,
QuantConv3d,
QuantConvTranspose1d,
QuantConvTranspose2d, #QuantConvTranspose3d,
QuantConvTranspose2d,
QuantConvTranspose3d,
]

ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32]
Expand Down
6 changes: 3 additions & 3 deletions tests/brevitas/nn/test_a2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def test_quant_wbiol_a2q(model_input, current_cases):
elif kwargs[
'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_size)
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3))
#elif kwargs[
# 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size)
# quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4))
elif kwargs[
'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size)
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4))
else:
raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.")

Expand Down
3 changes: 2 additions & 1 deletion tests/brevitas/nn/test_wbiol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
QuantConv2d,
QuantConv3d,
QuantConvTranspose1d,
QuantConvTranspose2d, #QuantConvTranspose3d,
QuantConvTranspose2d,
QuantConvTranspose3d,
]


Expand Down
3 changes: 2 additions & 1 deletion tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant):
QuantConv2d,
QuantConv3d,
QuantConvTranspose1d,
QuantConvTranspose2d, #QuantConvTranspose3d,
QuantConvTranspose2d,
QuantConvTranspose3d,
]


Expand Down
3 changes: 1 addition & 2 deletions tests/brevitas_ort/test_quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def test_ort_wbiol(model, export_type, current_cases):
o_bit_width = case_id.split('-')[-5]
i_bit_width = case_id.split('-')[-3]

#if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d','QuantConvTranspose3d') and export_type == 'qop':
if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop':
if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', 'QuantConvTranspose3d') and export_type == 'qop':
pytest.skip('Export of ConvTranspose is not supported for QOperation')
if 'per_channel' in quantizer and 'asymmetric' in quantizer:
pytest.skip('Per-channel zero-point is not well supported in ORT.')
Expand Down

0 comments on commit ee01b84

Please sign in to comment.