Skip to content

Commit

Permalink
Fix (gpfa2q): using callback instead of ignored layer names
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored and Giuseppe5 committed Feb 15, 2024
1 parent bbd169d commit 35069dc
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
from typing import List, Optional
from typing import Callable, List, Optional

import numpy as np
import torch
import torch.nn as nn
import unfoldNd

from brevitas.function import get_upper_bound_on_l1_norm
Expand All @@ -25,6 +26,14 @@ class gpfq_mode(gpxq_mode):
inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPFQ. Default: False
p (float): The percentage of processed inputs to use. Default: 1.0
return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
forward call inside the context manager returns None. Default: False
act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
use_gpfa2q (bool): Whether to use accumulator-aware GPFQ. Default: False
accumulator_bit_width (Optional, int): The target accumulator bit width. Default: None
a2q_layer_filter_fnc (Optional, callable): An optional lambda function to filter layers for
accumulator cosntraints. Should return True for layers to constrain. Default: `lambda x: True`
Example:
>>> with torch.no_grad():
Expand All @@ -39,7 +48,7 @@ class gpfq_mode(gpxq_mode):

def __init__(
self,
model,
model: nn.Module,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
Expand All @@ -49,7 +58,7 @@ def __init__(
act_order: bool = False,
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None,
ignore_layers: Optional[List[str]] = None) -> None:
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -66,7 +75,7 @@ def __init__(
# GPFA2Q params
self.use_gpfa2q = use_gpfa2q
self.accumulator_bit_width = accumulator_bit_width
self.ignore_layers = ignore_layers
self.a2q_layer_filter_fnc = a2q_layer_filter_fnc # returns true when to use GPFA2Q

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
Expand Down Expand Up @@ -103,7 +112,7 @@ def catch_stopfwd(self, *args, **kwargs):

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
if (name in self.ignore_layers) or (not self.use_gpfa2q):
if (not self.a2q_layer_filter_fnc(layer)) or (not self.use_gpfa2q):
return GPFQ(
layer=layer,
name=name,
Expand Down

0 comments on commit 35069dc

Please sign in to comment.