Skip to content

Commit

Permalink
Fix (tests): packing quant fixtures together
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Mar 7, 2024
1 parent c06da79 commit 4c19b46
Showing 1 changed file with 92 additions and 91 deletions.
183 changes: 92 additions & 91 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,6 @@ def forward(self, x):
return BNConvModel


@pytest_cases.fixture
def quant_conv_with_input_quant_model():

class QuantConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv_0 = qnn.QuantConv2d(
3, 16, kernel_size=3) # gpxq tests assume no quant on first layer
self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=Int8ActPerTensorFloat)

def forward(self, x):
x = self.conv_0(x)
x = torch.relu(x)
x = self.conv_1(x)
return x

return QuantConvModel


@pytest_cases.fixture
@pytest_cases.parametrize('bias', [True, False])
@pytest_cases.parametrize('add_bias_kv', [True, False])
Expand Down Expand Up @@ -227,26 +207,6 @@ def forward(self, x):
return ConvDepthConvModel


@pytest_cases.fixture
def quant_convdepthconv_model():

class QuantConvDepthConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = qnn.QuantConv2d(3, 16, kernel_size=3)
self.conv_0 = qnn.QuantConv2d(16, 16, kernel_size=1, groups=16)
self.relu = qnn.QuantReLU(return_quant_tensor=True)

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
return x

return QuantConvDepthConvModel


@pytest_cases.fixture
def convbn_model():

Expand Down Expand Up @@ -295,28 +255,6 @@ def forward(self, x):
return ResidualModel


@pytest_cases.fixture
def quant_residual_model():

class QuantResidualModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = qnn.QuantConv2d(3, 16, kernel_size=1)
self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1)
self.relu = qnn.QuantReLU(return_quant_tensor=True)

def forward(self, x):
start = x
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
x = start + x
return x

return QuantResidualModel


@pytest_cases.fixture
def srcsinkconflict_model():
"""
Expand Down Expand Up @@ -409,26 +347,6 @@ def forward(self, x):
return ConvTransposeModel


@pytest_cases.fixture
def quant_convtranspose_model():

class QuantConvTransposeModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.relu = qnn.QuantReLU(return_quant_tensor=True)
self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3)
self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3)

def forward(self, x):
x = self.conv_0(x)
x = self.relu(x)
x = self.conv_1(x)
return x

return QuantConvTransposeModel


list_of_fixtures = [
'residual_model',
'srcsinkconflict_model',
Expand All @@ -441,17 +359,8 @@ def forward(self, x):
'convgroupconv_model',
'convtranspose_model']

list_of_quant_fixtures = [
'quant_conv_with_input_quant_model',
'quant_convdepthconv_model',
'quant_residual_model',
'quant_convtranspose_model']

toy_model = fixture_union('toy_model', list_of_fixtures, ids=list_of_fixtures)

toy_quant_model = fixture_union(
'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures)

RESNET_18_REGIONS = [
[('layer3.0.bn1',), ('layer3.0.conv2',)],
[('layer4.1.bn1',), ('layer4.1.conv2',)],
Expand All @@ -468,3 +377,95 @@ def forward(self, x):
('layer1.0.conv1', 'layer1.1.conv1', 'layer2.0.conv1', 'layer2.0.downsample.0')],
[('layer2.0.bn1',), ('layer2.0.conv2',)],
[('layer4.0.bn2', 'layer4.0.downsample.1', 'layer4.1.bn2'), ('fc', 'layer4.1.conv1')],]


@pytest_cases.fixture
def quant_conv_with_input_quant_model():

class QuantConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv_0 = qnn.QuantConv2d(
3, 16, kernel_size=3) # gpxq tests assume no quant on first layer
self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=Int8ActPerTensorFloat)

def forward(self, x):
x = self.conv_0(x)
x = torch.relu(x)
x = self.conv_1(x)
return x

return QuantConvModel


@pytest_cases.fixture
def quant_convdepthconv_model():

class QuantConvDepthConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = qnn.QuantConv2d(3, 16, kernel_size=3)
self.conv_0 = qnn.QuantConv2d(16, 16, kernel_size=1, groups=16)
self.relu = qnn.QuantReLU(return_quant_tensor=True)

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
return x

return QuantConvDepthConvModel


@pytest_cases.fixture
def quant_residual_model():

class QuantResidualModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = qnn.QuantConv2d(3, 16, kernel_size=1)
self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1)
self.relu = qnn.QuantReLU(return_quant_tensor=True)

def forward(self, x):
start = x
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
x = start + x
return x

return QuantResidualModel


@pytest_cases.fixture
def quant_convtranspose_model():

class QuantConvTransposeModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.relu = qnn.QuantReLU(return_quant_tensor=True)
self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3)
self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3)

def forward(self, x):
x = self.conv_0(x)
x = self.relu(x)
x = self.conv_1(x)
return x

return QuantConvTransposeModel


list_of_quant_fixtures = [
'quant_conv_with_input_quant_model',
'quant_convdepthconv_model',
'quant_residual_model',
'quant_convtranspose_model']

toy_quant_model = fixture_union(
'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures)

0 comments on commit 4c19b46

Please sign in to comment.