diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 390582e05..985986789 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -94,26 +94,6 @@ def forward(self, x): return BNConvModel -@pytest_cases.fixture -def quant_conv_with_input_quant_model(): - - class QuantConvModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.conv_0 = qnn.QuantConv2d( - 3, 16, kernel_size=3) # gpxq tests assume no quant on first layer - self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=Int8ActPerTensorFloat) - - def forward(self, x): - x = self.conv_0(x) - x = torch.relu(x) - x = self.conv_1(x) - return x - - return QuantConvModel - - @pytest_cases.fixture @pytest_cases.parametrize('bias', [True, False]) @pytest_cases.parametrize('add_bias_kv', [True, False]) @@ -227,26 +207,6 @@ def forward(self, x): return ConvDepthConvModel -@pytest_cases.fixture -def quant_convdepthconv_model(): - - class QuantConvDepthConvModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.conv = qnn.QuantConv2d(3, 16, kernel_size=3) - self.conv_0 = qnn.QuantConv2d(16, 16, kernel_size=1, groups=16) - self.relu = qnn.QuantReLU(return_quant_tensor=True) - - def forward(self, x): - x = self.conv(x) - x = self.relu(x) - x = self.conv_0(x) - return x - - return QuantConvDepthConvModel - - @pytest_cases.fixture def convbn_model(): @@ -295,28 +255,6 @@ def forward(self, x): return ResidualModel -@pytest_cases.fixture -def quant_residual_model(): - - class QuantResidualModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.conv = qnn.QuantConv2d(3, 16, kernel_size=1) - self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1) - self.relu = qnn.QuantReLU(return_quant_tensor=True) - - def forward(self, x): - start = x - x = self.conv(x) - x = self.relu(x) - x = self.conv_0(x) - x = start + x - return x - - return QuantResidualModel - - @pytest_cases.fixture def srcsinkconflict_model(): """ @@ -409,26 +347,6 @@ def forward(self, x): return ConvTransposeModel -@pytest_cases.fixture -def quant_convtranspose_model(): - - class QuantConvTransposeModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.relu = qnn.QuantReLU(return_quant_tensor=True) - self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3) - self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3) - - def forward(self, x): - x = self.conv_0(x) - x = self.relu(x) - x = self.conv_1(x) - return x - - return QuantConvTransposeModel - - list_of_fixtures = [ 'residual_model', 'srcsinkconflict_model', @@ -441,17 +359,8 @@ def forward(self, x): 'convgroupconv_model', 'convtranspose_model'] -list_of_quant_fixtures = [ - 'quant_conv_with_input_quant_model', - 'quant_convdepthconv_model', - 'quant_residual_model', - 'quant_convtranspose_model'] - toy_model = fixture_union('toy_model', list_of_fixtures, ids=list_of_fixtures) -toy_quant_model = fixture_union( - 'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures) - RESNET_18_REGIONS = [ [('layer3.0.bn1',), ('layer3.0.conv2',)], [('layer4.1.bn1',), ('layer4.1.conv2',)], @@ -468,3 +377,95 @@ def forward(self, x): ('layer1.0.conv1', 'layer1.1.conv1', 'layer2.0.conv1', 'layer2.0.downsample.0')], [('layer2.0.bn1',), ('layer2.0.conv2',)], [('layer4.0.bn2', 'layer4.0.downsample.1', 'layer4.1.bn2'), ('fc', 'layer4.1.conv1')],] + + +@pytest_cases.fixture +def quant_conv_with_input_quant_model(): + + class QuantConvModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv_0 = qnn.QuantConv2d( + 3, 16, kernel_size=3) # gpxq tests assume no quant on first layer + self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=Int8ActPerTensorFloat) + + def forward(self, x): + x = self.conv_0(x) + x = torch.relu(x) + x = self.conv_1(x) + return x + + return QuantConvModel + + +@pytest_cases.fixture +def quant_convdepthconv_model(): + + class QuantConvDepthConvModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv = qnn.QuantConv2d(3, 16, kernel_size=3) + self.conv_0 = qnn.QuantConv2d(16, 16, kernel_size=1, groups=16) + self.relu = qnn.QuantReLU(return_quant_tensor=True) + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + x = self.conv_0(x) + return x + + return QuantConvDepthConvModel + + +@pytest_cases.fixture +def quant_residual_model(): + + class QuantResidualModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv = qnn.QuantConv2d(3, 16, kernel_size=1) + self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1) + self.relu = qnn.QuantReLU(return_quant_tensor=True) + + def forward(self, x): + start = x + x = self.conv(x) + x = self.relu(x) + x = self.conv_0(x) + x = start + x + return x + + return QuantResidualModel + + +@pytest_cases.fixture +def quant_convtranspose_model(): + + class QuantConvTransposeModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.relu = qnn.QuantReLU(return_quant_tensor=True) + self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3) + self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3) + + def forward(self, x): + x = self.conv_0(x) + x = self.relu(x) + x = self.conv_1(x) + return x + + return QuantConvTransposeModel + + +list_of_quant_fixtures = [ + 'quant_conv_with_input_quant_model', + 'quant_convdepthconv_model', + 'quant_residual_model', + 'quant_convtranspose_model'] + +toy_quant_model = fixture_union( + 'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures)