Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (gpfq): updating input processing and L1-norm constraints for GPFA2Q #852

Merged
merged 4 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
from typing import List, Optional
from typing import Callable, List, Optional

import numpy as np
import torch
import torch.nn as nn
import unfoldNd

from brevitas.function import get_upper_bound_on_l1_norm
Expand All @@ -22,9 +23,21 @@ class gpfq_mode(gpxq_mode):

Args:
model (Module): The model to quantize with GPFQ
group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
of layer names that can be optimized in parallel. Default: None
inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
create_weight_orig (bool): If True, store the original floating point weights before applying
gpfq. These weights will be used anytime quantization is disabled. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPFQ. Default: False
p (float): The percentage of processed inputs to use. Default: 1.0
return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
forward call inside the context manager returns None. Default: False
act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
use_gpfa2q (bool): Whether to use accumulator-aware GPFQ. Default: False
accumulator_bit_width (Optional, int): The target accumulator bit width. Default: None
a2q_layer_filter_fnc (Optional, callable): An optional lambda function to filter layers for
accumulator cosntraints. Should return True for layers to constrain. Default: `lambda x: True`

Example:
>>> with torch.no_grad():
Expand All @@ -39,7 +52,7 @@ class gpfq_mode(gpxq_mode):

def __init__(
self,
model,
model: nn.Module,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
Expand All @@ -48,7 +61,8 @@ def __init__(
return_forward_output: bool = False,
act_order: bool = False,
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None) -> None:
accumulator_bit_width: Optional[int] = None,
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -65,6 +79,7 @@ def __init__(
# GPFA2Q params
self.use_gpfa2q = use_gpfa2q
self.accumulator_bit_width = accumulator_bit_width
self.a2q_layer_filter_fnc = a2q_layer_filter_fnc # returns true when to use GPFA2Q

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
Expand Down Expand Up @@ -101,7 +116,7 @@ def catch_stopfwd(self, *args, **kwargs):

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
if not self.use_gpfa2q:
if (not self.a2q_layer_filter_fnc(layer)) or (not self.use_gpfa2q):
return GPFQ(
layer=layer,
name=name,
Expand Down Expand Up @@ -287,12 +302,13 @@ def __init__(
p=p)
self.accumulator_bit_width = accumulator_bit_width
assert self.accumulator_bit_width is not None
self.requires_quant_input = True # force true

def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_input is None:
raise ValueError(
'Expected quant input to calculate Upper Bound on L1 norm, but received None')
'Expected quant input to calculate L1-norm upper bound, but received None')
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
Expand All @@ -314,7 +330,8 @@ def single_layer_update(self):
s = self.layer.quant_weight_scale()
s = s.view(self.groups, -1) # [Groups, OC/Groups]

l1_norm = torch.zeros(weight.shape[:-1], device=dev)
# initialize cumulative l1-norm
z = torch.zeros(weight.shape[:-1], device=dev)

# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
Expand Down Expand Up @@ -345,18 +362,13 @@ def single_layer_update(self):
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

max_q_arg = s[group_index, :] * torch.clamp_min(T - z[group_index, :], 0.)
q_arg = q_arg.sign() * torch.clamp_max(q_arg.abs(), max_q_arg)
weight[group_index, :, permutation_list[group_index][t]] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)
z += q.abs() / s # increment cumulative l1-norm

for group_index in range(self.groups):
candidate_l1 = l1_norm[group_index] + torch.abs(q[group_index])
candidate_l1_mask = candidate_l1 > T * s[group_index]
if torch.any(candidate_l1_mask):
# set all values to 0 that are exceeding T * s
weight[group_index, :, permutation_list[group_index][t]][candidate_l1_mask] = 0
q[group_index][candidate_l1_mask] = 0
else:
l1_norm[group_index] = candidate_l1
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
self.quantized_input[group_index, :,
Expand Down
8 changes: 8 additions & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,17 @@ class gptq_mode(gpxq_mode):

Args:
model (Module): The model to quantize with GPTQ
group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
of layer names that can be optimized in parallel. Default: None
inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True
create_weight_orig (bool): If True, store the original floating point weights before applying
gptq. These weights will be used anytime quantization is disabled. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPTQ. Default: False
num_blocks (int): The number of sub-blocks to use to speed-up GPTQ computation. Default: 100
act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
forward call inside the context manager returns None. Default: False

Example:
>>> with torch.no_grad():
Expand Down
87 changes: 70 additions & 17 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,32 @@ class LayerHandler:


class gpxq_mode(ABC):
"""
Apply GPxQ algorithm.

Args:
model (Module): The model to quantize with GPxQ
group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
of layer names that can be optimized in parallel. Default: None
inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
create_weight_orig (bool): If True, store the original floating point weights before applying
gpxq. These weights will be used anytime quantization is disabled. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPxQ. Default: False
act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
forward call inside the context manager returns None. Default: False

Example:
>>> with torch.no_grad():
>>> with gpxq_mode(model) as gpxq:
>>> gpxq_mode = gpxq.model
>>> for i in tqdm(range(gpxq.num_layers)):
>>> for img, t in calib_loader:
>>> img = img.cuda()
>>> gpxq_mode(img)
>>> gpxq.update()
"""

def __init__(
self,
Expand Down Expand Up @@ -181,32 +207,59 @@ def __init__(
self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
self.quant_input = None
self.requires_quant_input = False # For GPFA2Q

@property
def layer_requires_input_quant(self):
# some weight quantizers require a quant input (e.g., A2Q)
check_1 = self.layer.weight_quant_requires_quant_input
# if input_quant is enabled, then we will store its information
check_2 = self.layer.is_input_quant_enabled
# GPFA2Q requires the quantized input to be stored
check_3 = self.requires_quant_input
requires_input_quant = check_1 or check_2 or check_3
return requires_input_quant

def process_input(self, inp):
# Input is a tuple, so we take first element
inp = inp[0]
# If using Quant Activations, inp could be QuantTensor

# if the quant_input is not already cached, then get
# metadata from QuantWBIOL module
if self.quant_input is None:
inp_scale = self.layer.quant_input_scale()
inp_zero_point = self.layer.quant_input_zero_point()
inp_bit_width = self.layer.quant_input_bit_width()
inp_signed = self.layer.is_quant_input_signed
inp_training = self.layer.training

# If using quantized activations, inp could be QuantTensor. In
# this case, we overwrite the metadata if it is specified.
if isinstance(inp, QuantTensor):
if self.layer.weight_quant_requires_quant_input:
# Can minimize memory allocation by not storing actual values
self.quant_input = QuantTensor(
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=inp.scale,
zero_point=inp.zero_point,
bit_width=inp.bit_width,
signed=inp.signed,
training=inp.training)
if self.layer_requires_input_quant and (self.quant_input is None):
if inp.scale is not None:
inp_scale = inp.scale
if inp.zero_point is not None:
inp_zero_point = inp.zero_point
if inp.bit_width is not None:
inp_bit_width = inp.bit_width
if inp.signed is not None:
inp_signed = inp.signed
if inp.training is not None:
inp_training = inp.training
inp = inp.value
elif self.layer.is_input_quant_enabled:

# if the layer requires an input quant and the quant input cache has
# yet to be populated, then populate with the collected metadata
if self.layer_requires_input_quant and (self.quant_input is None):
self.quant_input = QuantTensor(
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=self.layer.quant_input_scale(),
zero_point=self.layer.quant_input_zero_point(),
bit_width=self.layer.quant_input_bit_width(),
signed=self.layer.is_quant_input_signed,
training=self.layer.training)
scale=inp_scale,
zero_point=inp_zero_point,
bit_width=inp_bit_width,
signed=inp_signed,
training=inp_training)

# If input is unbatched, add batch_size = 1
if len(inp.shape) == 1:
Expand Down
Loading