From 0fd67ad776f02699c2189389880f0b75ccd3a4dd Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 6 Dec 2023 21:19:45 +0000 Subject: [PATCH] Fix test --- src/brevitas/graph/equalize.py | 19 +++++++++---------- tests/brevitas/graph/equalization_fixtures.py | 6 ++++-- tests/brevitas/graph/test_equalization.py | 4 ++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index ce58a6d91..60d0f28fd 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -174,10 +174,13 @@ def dict_name_to_module(model, regions): name_set = set() for region in regions: for name in region.srcs: + name = name.split("$")[0] name_set.add(name) for name in region.sinks: + name = name.split("$")[0] name_set.add(name) for name in region.acts: + name = name.split("$")[0] name_set.add(name) for name, module in model.named_modules(): if name in name_set: @@ -776,7 +779,7 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, weight = get_weight_sink([module]) eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) # It is not possible to equalize through LayerNorm as sink - if isinstance(module, (nn.LayerNorm,)): + if isinstance(module, (nn.LayerNorm,) + _batch_norm): # state.sinks.add((_UNSUPPORTED_OP, _UNSUPPORTED_OP)) state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP else: @@ -1049,10 +1052,12 @@ def setup(self): batch_dim = 0 region_to_search = region.sinks if len(region.acts) == 0 else region.acts for name in list(region.srcs.keys()) + list(region.sinks.keys()): + name = name.split("$")[0] module = name_to_module[name] if hasattr(module, 'batch_first'): batch_dim = 0 if module.batch_first == True else 1 for name in region_to_search: + name = name.split("$")[0] act_module = name_to_module[name] if not isinstance(get_module(self.graph_model, name), KwargsForwardHook): use_inp = True if region_to_search == region.sinks else False @@ -1074,16 +1079,10 @@ def apply(self, alpha): name_to_module = dict_name_to_module(self.graph_model, self.regions) for region in self.regions: region_to_search = region.sinks if len(region.acts) == 0 else region.acts - if any([self.float_act_map[name] is None for name in region_to_search]): + if any([self.float_act_map[name.split("$")[0]] is None for name in region_to_search]): continue - act_module = [name_to_module[act_name] for act_name in region.acts] - list_of_act_val = [self.float_act_map[name] for name in region_to_search] - sinks = [name_to_module[sink] for sink in region.sinks] - # Filter out scale_varying activations from the srcs - srcs = [ - name_to_module[src] - for src in region.srcs - if not isinstance(name_to_module[src], _scale_varying_activations)] + act_module = [name_to_module[act_name.split("$")[0]] for act_name in region.acts] + list_of_act_val = [self.float_act_map[name.split("$")[0]] for name in region_to_search] list_of_insert_mul_node_fn = None if self.add_mul_node and any([ diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index bab0d363e..0c873d061 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -21,8 +21,8 @@ 'shufflenet_v2_x0_5': [0.318, 0.649], 'mobilenet_v2': [0.161, 0.320], 'resnet18': [0.487, 0.952], - 'googlenet': [0.1826, 0.413], - 'inception_v3': [0.264, 0.6], + 'googlenet': [0.495, 0.982], + 'inception_v3': [0.582, 0.989], 'alexnet': [0.875, 0.875],} IN_SIZE_CONV = (1, 3, 224, 224) @@ -34,8 +34,10 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_ name_set = set() for region in regions: for name in region.srcs: + name = name.split('$')[0] name_set.add(name) for name in region.sinks: + name = name.split('$')[0] name_set.add(name) scale_factors_regions = [] for name, module in model.named_modules(): diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 3b3955d01..1e31fda93 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -86,8 +86,8 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool srcs = set() sinks = set() for r in regions: - srcs.update(list(r.srcs)) - sinks.update(list(r.sinks)) + srcs.update([x.split("$")[0] for x in list(r.srcs)]) + sinks.update([x.split("$")[0] for x in list(r.sinks)]) count_region_srcs = 0 count_region_sinks = 0