Skip to content

Commit

Permalink
Last review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 21, 2023
1 parent 0f67b6b commit 704772e
Showing 1 changed file with 27 additions and 32 deletions.
59 changes: 27 additions & 32 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,29 +638,24 @@ def _is_reshaping_op(node: Node) -> bool:
return node.target in _reshaping_op


def get_weight_source(module_list):
transpose = lambda module, axis: module.weight if axis == 0 else module.weight.transpose(0, 1)
for i, module in enumerate(module_list):
if isinstance(module, nn.MultiheadAttention):
if hasattr(module, 'out_proj'):
module_list[i] = module.out_proj
else:
raise RuntimeError("Configuration for Multiheadattention not supported")
srcs_axes = {module: _get_output_axis(module) for module in module_list}
weight = [transpose(m, axis) for m, axis in srcs_axes.items()]
def get_weight_source(module):
transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1)
if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'out_proj'):
raise RuntimeError("Configuration for Multiheadattention not supported")
weight = module.out_proj.weight if isinstance(module, nn.MultiheadAttention) else module.weight
axis = _get_output_axis(module)
weight = transpose(weight, axis)
return weight


def get_weight_sink(module_list):
transpose = lambda module, axis: module.weight if axis == 0 else module.weight.transpose(0, 1)
for i, module in enumerate(module_list):
if isinstance(module, nn.MultiheadAttention):
if hasattr(module, 'in_proj_weight'):
module_list[i] = WeightBiasTuple(module.in_proj_weight)
else:
raise RuntimeError("Configuration for Multiheadattention not supported")
sinks_axes = {module: _get_input_axis(module) for module in module_list}
weight = [transpose(m, axis) for m, axis in sinks_axes.items()]
def get_weight_sink(module):
transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1)
if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'in_proj_weight'):
raise RuntimeError("Configuration for Multiheadattention not supported")
weight = WeightBiasTuple(module.in_proj_weight).weight if isinstance(
module, nn.MultiheadAttention) else module.weight
axis = _get_input_axis(module)
weight = transpose(weight, axis)
return weight


Expand All @@ -669,8 +664,8 @@ def find_srcs_channel_dim(model, inp_node):
# If we meet a supported module, determine the channel shape
module = get_module(model, inp_node.target)
# Since we are walking up, we consider the module as srcs
weight = get_weight_source([module])
channel = weight[0].shape[0]
weight = get_weight_source(module)
channel = weight.shape[0]
return channel
elif _is_add(inp_node):
all_channels = []
Expand Down Expand Up @@ -732,14 +727,14 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
continue
if _is_supported_module(graph_model, node):
module = get_module(graph_model, node.target)
weight = get_weight_source([module])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset)
weight = get_weight_source(module)
eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset)

# After we found a source, we need to check if it branches into multiple sinks
state.add_srcs(node.target, module, eq_indexes)
find_sinks(graph_model, node, state)
state.offset = state.offset if not state.update_offset else state.offset + weight[
0].shape[0]
state.offset = state.offset if not state.update_offset else state.offset + weight.shape[
0]
elif _is_scale_invariant_module(
graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node):
find_sinks(graph_model, node, state)
Expand Down Expand Up @@ -783,8 +778,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
continue
if _is_supported_module(graph_model, node):
module = get_module(graph_model, node.target)
weight = get_weight_sink([module])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset)
weight = get_weight_sink(module)
eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset)
# It is not possible to equalize through LayerNorm as sink
if isinstance(module, (nn.LayerNorm,) + _batch_norm):
state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP
Expand Down Expand Up @@ -850,8 +845,8 @@ def _extract_regions(
state.add_acts(node.target, module)
else:
module = get_module(graph_model, node.target)
weight = get_weight_source([module])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0)
weight = get_weight_source(module)
eq_indexes = EqualizationIndexes(0, weight.shape[0], 0)
state.add_srcs(node.target, module, eq_indexes)
find_sinks(graph_model, node, state)
if len(state.sinks) > 0 and _UNSUPPORTED_OP not in state.sinks.keys():
Expand Down Expand Up @@ -991,8 +986,8 @@ def find_module(self, model, regions: List):
"""
if isinstance(model,
_supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):
weight = get_weight_sink([model])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0)
weight = get_weight_sink(model)
eq_indexes = EqualizationIndexes(0, weight.shape[0], 0)
region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model})
regions.append(region)
else:
Expand Down

0 comments on commit 704772e

Please sign in to comment.