Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 6, 2023
1 parent 61bcc70 commit 0fd67ad
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
19 changes: 9 additions & 10 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,13 @@ def dict_name_to_module(model, regions):
name_set = set()
for region in regions:
for name in region.srcs:
name = name.split("$")[0]
name_set.add(name)
for name in region.sinks:
name = name.split("$")[0]
name_set.add(name)
for name in region.acts:
name = name.split("$")[0]
name_set.add(name)
for name, module in model.named_modules():
if name in name_set:
Expand Down Expand Up @@ -776,7 +779,7 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
weight = get_weight_sink([module])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset)
# It is not possible to equalize through LayerNorm as sink
if isinstance(module, (nn.LayerNorm,)):
if isinstance(module, (nn.LayerNorm,) + _batch_norm):
# state.sinks.add((_UNSUPPORTED_OP, _UNSUPPORTED_OP))
state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP
else:
Expand Down Expand Up @@ -1049,10 +1052,12 @@ def setup(self):
batch_dim = 0
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
for name in list(region.srcs.keys()) + list(region.sinks.keys()):
name = name.split("$")[0]
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first == True else 1
for name in region_to_search:
name = name.split("$")[0]
act_module = name_to_module[name]
if not isinstance(get_module(self.graph_model, name), KwargsForwardHook):
use_inp = True if region_to_search == region.sinks else False
Expand All @@ -1074,16 +1079,10 @@ def apply(self, alpha):
name_to_module = dict_name_to_module(self.graph_model, self.regions)
for region in self.regions:
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
if any([self.float_act_map[name] is None for name in region_to_search]):
if any([self.float_act_map[name.split("$")[0]] is None for name in region_to_search]):
continue
act_module = [name_to_module[act_name] for act_name in region.acts]
list_of_act_val = [self.float_act_map[name] for name in region_to_search]
sinks = [name_to_module[sink] for sink in region.sinks]
# Filter out scale_varying activations from the srcs
srcs = [
name_to_module[src]
for src in region.srcs
if not isinstance(name_to_module[src], _scale_varying_activations)]
act_module = [name_to_module[act_name.split("$")[0]] for act_name in region.acts]
list_of_act_val = [self.float_act_map[name.split("$")[0]] for name in region_to_search]

list_of_insert_mul_node_fn = None
if self.add_mul_node and any([
Expand Down
6 changes: 4 additions & 2 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
'shufflenet_v2_x0_5': [0.318, 0.649],
'mobilenet_v2': [0.161, 0.320],
'resnet18': [0.487, 0.952],
'googlenet': [0.1826, 0.413],
'inception_v3': [0.264, 0.6],
'googlenet': [0.495, 0.982],
'inception_v3': [0.582, 0.989],
'alexnet': [0.875, 0.875],}

IN_SIZE_CONV = (1, 3, 224, 224)
Expand All @@ -34,8 +34,10 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_
name_set = set()
for region in regions:
for name in region.srcs:
name = name.split('$')[0]
name_set.add(name)
for name in region.sinks:
name = name.split('$')[0]
name_set.add(name)
scale_factors_regions = []
for name, module in model.named_modules():
Expand Down
4 changes: 2 additions & 2 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool
srcs = set()
sinks = set()
for r in regions:
srcs.update(list(r.srcs))
sinks.update(list(r.sinks))
srcs.update([x.split("$")[0] for x in list(r.srcs)])
sinks.update([x.split("$")[0] for x in list(r.sinks)])

count_region_srcs = 0
count_region_sinks = 0
Expand Down

0 comments on commit 0fd67ad

Please sign in to comment.