Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example: How to use merge_bn correctly #542

Open
g12bftd opened this issue Feb 27, 2023 · 8 comments
Open

Example: How to use merge_bn correctly #542

g12bftd opened this issue Feb 27, 2023 · 8 comments

Comments

@g12bftd
Copy link

g12bftd commented Feb 27, 2023

There is an architecture I would like to quantise and retrain from its floating point counterpart. I would like to incorporate the merge_bn operation supported by Brevitas. How exactly would I do this here. An overview is good but some code would be better. Note I only want to merge/fuse the Conv + BN + ReLU components. Here is my architecture:

class QuantizedModel(nn.Module):
def init(self, config):
super(QuantizedVGG, self).init()

    self.weight_config = config

    k = 1
    self.quant_inp = qnn.QuantIdentity(
        bit_width=16, return_quant_tensor=True)
    self.conv1 = qnn.QuantConv2d(in_channels=3, out_channels=int(k * 128), kernel_size=3, padding=1, weight_bit_width=self.weight_config[0], return_quant_tensor=True, bias=True)
    self.bn1 = nn.BatchNorm2d(int(k * 128))
    self.relu1 = qnn.QuantReLU(bit_width=self.weight_config[0], return_quant_tensor=True)
    self.conv2 = qnn.QuantConv2d(int(k * 128), int(k * 128), kernel_size=3, padding=1, weight_bit_width=self.weight_config[1], return_quant_tensor=True, bias=True)
    self.bn2 = nn.BatchNorm2d(int(k * 128))
    self.relu2 = qnn.QuantReLU(bit_width=self.weight_config[1], return_quant_tensor=True)
    self.max_pool1 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)
    self.conv3 = qnn.QuantConv2d(int(k * 128), int(k * 256), kernel_size=3, padding=1, weight_bit_width=self.weight_config[2], return_quant_tensor=True, bias=True)
    self.bn3 = nn.BatchNorm2d(int(k * 256))
    self.relu3 = qnn.QuantReLU(bit_width=self.weight_config[2], return_quant_tensor=True)
    self.conv4 = qnn.QuantConv2d(int(k * 256), int(k * 256), kernel_size=3, padding=1, weight_bit_width=self.weight_config[3], return_quant_tensor=True, bias=True)
    self.bn4 = nn.BatchNorm2d(int(k * 256))
    self.relu4 = qnn.QuantReLU(bit_width=self.weight_config[3], return_quant_tensor=True)
    self.max_pool2 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)
    self.conv5 = qnn.QuantConv2d(int(k * 256), int(k * 512), kernel_size=3, padding=1, weight_bit_width=self.weight_config[4], return_quant_tensor=True, bias=True)
    self.bn5 = nn.BatchNorm2d(int(k * 512))
    self.relu5 = qnn.QuantReLU(bit_width=self.weight_config[4], return_quant_tensor=True)
    self.conv6 = qnn.QuantConv2d(int(k * 512), int(k * 512), kernel_size=3, padding=1, weight_bit_width=self.weight_config[5], return_quant_tensor=True, bias=True)
    self.bn6 = nn.BatchNorm2d(int(k * 512))
    self.relu6 = qnn.QuantReLU(bit_width=self.weight_config[5], return_quant_tensor=True)
    self.max_pool3 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)

    input_feats = 8192
   
    self.fc1 = qnn.QuantLinear(input_feats, int(k * 1024), weight_bit_width=self.weight_config[6], return_quant_tensor=True, bias=True)
    self.fc2 = qnn.QuantLinear(int(k * 1024), 10, weight_bit_width=self.weight_config[7], bias=True)
    
  
def forward(self, x):
    out = self.relu1(self.bn1(self.conv1(x)))
    out = self.relu2(self.bn2(self.conv2(out)))
    out = self.max_pool1(out)
    out = self.relu3(self.bn3(self.conv3(out)))
    out = self.relu4(self.bn4(self.conv4(out)))
    out = self.max_pool2(out)
    out = self.relu5(self.bn5(self.conv5(out)))
    out = self.relu6(self.bn6(self.conv6(out)))
    out = self.max_pool3(out)
    out = out.reshape(out.shape[0], -1)
    out = self.fc1(out)
    out = self.fc2(out)
    return out
@MohamedA95
Copy link

MohamedA95 commented Feb 27, 2023

Hi @g12bftd,
As far as I understand merging batch normalization layers is usually a post-training optimization so I would train the model and then create a script that defines two objects of the same model one with the batch norm and one without. Then I would loop over the model with batch norm merging conv & batch_norm layers and then saving the results in the new model --the one with no batch norm layer-- so the code should look roughly like follows:

bn_model = QuantizedModel(bn=True)
model = QuantizedModel(bn=False)
for l in bn_model:
    if l isinstanceof(qnn.quantconv):
        merge_bn(l,nextlayer)
        model[index_of_corresponding_layer].copy_state_dict(l)
torch.save(model, fused_QuantizedModel.pth)

Also, I would recommend defining the model using nn.sequential to make it easier to loop over the model

@wilfredkisku
Copy link

@MohamedA95 I am new to brevitas so is it that we need to train with the classical BN layers? If you could elaborate to a newbie such as me, as all the models that I am trying to export need to have intermediate BN layers.

@MohamedA95
Copy link

Hi @wilfredkisku, What do you mean by classical BN layers? do you mean torch.nn.BatchNorm2d if this is what you mean then if your model requires it you will have to use it. AFAIK brevitas does not have a quant bn yet. do you can use torch bn and then fuse it with the previous conv as both layers are linear transformations check this link

@wilfredkisku
Copy link

@MohamedA95 Thank you for the reply. Yes, the models that I am using requires torch.nn.BatchNorm2d layers. Can they also be fused with quantized layers? Thanks again.

@MohamedA95
Copy link

Yes they can be fused brevitas even has a function to do it under brevitas.nn.utils

@wilfredkisku
Copy link

@MohamedA95 Thanks for all the help. I have been able to understand the idea behind fusing the layers. What I have done now is create two models that are similar but one with CONV + BN and the other without BN.

###########################
#### MODEL 1 ##############
###########################

from torch.nn import Module
import torch.nn.functional as F

import torch.nn as nn

import brevitas.nn as qnn
from brevitas.quant import Int8Bias as BiasQuant


class QuantWeightActLeNet(Module):
    def __init__(self):
        super(QuantWeightActLeNet, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4)
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4)
        self.relu1 = qnn.QuantReLU(bit_width=4)
        self.bn = nn.BatchNorm2d(6)

    def forward(self, x):
        out = self.quant_inp(x)
        out = self.relu1(self.bn(self.conv1(out)))
        return out

###########################
#### MODEL 2 ##############
###########################
class QuantWeightActLeNet_wo(Module):
    def __init__(self):
        super(QuantWeightActLeNet_wo, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4)
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4)
        self.relu1 = qnn.QuantReLU(bit_width=4)

    def forward(self, x):
        out = self.quant_inp(x)
        out = self.relu1(self.conv1(out))
        return out

quant_weight_act_lenet_wo = QuantWeightActLeNet_wo()
quant_weight_act_lenet = QuantWeightActLeNet()

I am using the merge_bn functions to merge the CONV and BN layer:

#######################
###### MERGE ##########
#######################

def merge_bn(layer, bn, output_channel_dim=0):
    out = mul_add_from_bn(
        bn_mean=bn.running_mean,
        bn_var=bn.running_var,
        bn_eps=bn.eps,
        bn_weight=bn.weight.data.clone(),
        bn_bias=bn.bias.data.clone())
    mul_factor, add_factor = out

    #compute the shape of the channel
    out_ch_weight_shape = compute_channel_view_shape(layer.weight, output_channel_dim)

    #in-place operations multiply the layer weights with the mul_factor of the BN
    #without making a new copy of the Tensor
    layer.weight.data.mul_(mul_factor.view(out_ch_weight_shape))

    #handle if -> bias = True
    if layer.bias is not None:
        out_ch_bias_shape = compute_channel_view_shape(layer.bias, channel_dim=0)
        layer.bias.data.mul_(mul_factor.view(out_ch_bias_shape))
        layer.bias.data.add_(add_factor.view(out_ch_bias_shape))
    else:
        layer.bias = Parameter(add_factor)
    if (hasattr(layer, 'weight_quant') and
            isinstance(layer.weight_quant, WeightQuantProxyFromInjector)):
        layer.weight_quant.init_tensor_quant()
    if (hasattr(layer, 'bias_quant') and isinstance(layer.bias_quant, BiasQuantProxyFromInjector)):
        layer.bias_quant.init_tensor_quant()

But I am having issues while copying the trained weights and biases + additional quantization parameres that are present in the Quantization layers such as QuantConv2d. If I use a concize code like the one below for creating the dictionary of weights for only CONV and skipping BN (which has been fused with the CONV earlier.

for keys in pretrained_dict.keys():
  if keys.split('.')[0] != 'bn':
    processed_dict[keys] = pretrained_dict[keys]

quant_weight_act_lenet_wo.load_state_dict(processed_dict, strict=False)

I am able to copy the weights but a few parameters associated with the brevitas quantization library do not get copied. The error is given below:

_IncompatibleKeys(missing_keys=['quant_inp.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value', 'relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value'], unexpected_keys=[])

I would be thankful for any help in this regard. Thanks again.

@MohamedA95
Copy link

Hi @wilfredkisku,
I am not sure about your way of copying the state dict, I would do something like the following:
1-Define the two models one with the batch norm and one without
2-Loop over the model with bn fusing it with conv
3-Loop over the model without batch norm copying the state dict from the other model
quant_weight_act_lenet_wo.conv1.load_state_dict(quant_weight_act_lenet.conv1.state_dict())

@DDDDDY1
Copy link

DDDDDY1 commented Nov 25, 2023

Hi @wilfredkisku, What do you mean by classical BN layers? do you mean torch.nn.BatchNorm2d if this is what you mean then if your model requires it you will have to use it. AFAIK brevitas does not have a quant bn yet. do you can use torch bn and then fuse it with the previous conv as both layers are linear transformations check this link

Hi, this reply mentioned that we are able to train fixed-point batchnorm using BatchNorm2dToQuantScaleBias with power of two scale factors. I think it is supporting quant bn?

However it is not clear to me that how it is done for batchnorm? Does the scale and bias change during training? Or it is indeed doing post training for batch norm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants