Skip to content

Commit

Permalink
Fix (Channel-Splitting): clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 26, 2024
1 parent 331bebc commit 3e0f0e3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
9 changes: 7 additions & 2 deletions src/brevitas/graph/channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _channels_to_split(
return torch.unique(channels_to_split)


# decorator is needed to modify the weights in-place using a view
@torch.no_grad()
def _split_channels(
module: nn.Module,
Expand Down Expand Up @@ -253,7 +254,11 @@ def _clean_regions(regions: List[Region]) -> List[Region]:

class GraphChannelSplitting(GraphTransform):

def __init__(self, split_ratio=0.02, split_criterion='maxabs', split_input=True):
def __init__(
self,
split_ratio: float = 0.02,
split_criterion: str = 'maxabs',
split_input: bool = True):
super(GraphChannelSplitting, self).__init__()

self.split_ratio = split_ratio
Expand All @@ -262,7 +267,7 @@ def __init__(self, split_ratio=0.02, split_criterion='maxabs', split_input=True)

def apply(
self,
model,
model: GraphModule,
return_regions: bool = False
) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
regions = _extract_regions(model)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def preprocess_for_quantize(
merge_bn=True,
equalize_bias_shrinkage: str = 'vaiq',
equalize_scale_computation: str = 'maxabs',
channel_splitting_ratio=0.0,
channel_splitting_split_input=True,
channel_splitting_ratio: float = 0.0,
channel_splitting_split_input: bool = True,
channel_splitting_criterion: str = 'maxabs'):

training_state = model.training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def unique(sequence):
'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q
'act_quant_percentile': [99.999], # Activation Quantization Percentile
'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible
'channel_splitting_ratio': [0.0
], # Channel Splitting ratio, 0.0 means no splitting is performed
'channel_splitting_ratio': [0.0], # Channel Splitting ratio, 0.0 means no splitting
'split_input': [True], # Whether to split the input channels when applying channel splitting
'merge_bn': [True]} # Whether to merge BN layers

Expand Down

0 comments on commit 3e0f0e3

Please sign in to comment.