diff --git a/numba_rvsdg/core/datastructures/byte_flow.py b/numba_rvsdg/core/datastructures/byte_flow.py index fb08a79..c15ea06 100644 --- a/numba_rvsdg/core/datastructures/byte_flow.py +++ b/numba_rvsdg/core/datastructures/byte_flow.py @@ -1,18 +1,11 @@ import dis -from copy import deepcopy from dataclasses import dataclass -from typing import Generator, Callable +from typing import Callable from numba_rvsdg.core.datastructures.scfg import SCFG -from numba_rvsdg.core.datastructures.basic_block import RegionBlock from numba_rvsdg.core.datastructures.flow_info import FlowInfo from numba_rvsdg.core.utils import _logger, _LogWrap -from numba_rvsdg.core.transformations import ( - restructure_loop, - restructure_branch, -) - @dataclass(frozen=True) class ByteFlow: @@ -60,96 +53,3 @@ def from_bytecode(code: Callable) -> "ByteFlow": # type: ignore flowinfo = FlowInfo.from_bytecode(bc) scfg = flowinfo.build_basicblocks() return ByteFlow(bc=bc, scfg=scfg) - - def _join_returns(self) -> "ByteFlow": - """Joins the return blocks within the corresponding SCFG. - - This method creates a deep copy of the SCFG and performs - operation to join return blocks within the control flow. - It returns a new ByteFlow object with the updated SCFG. - - Returns - ------- - byteflow: ByteFlow - The new ByteFlow object with updated SCFG. - """ - scfg = deepcopy(self.scfg) - scfg.join_returns() - return ByteFlow(bc=self.bc, scfg=scfg) - - def _restructure_loop(self) -> "ByteFlow": - """Restructures the loops within the corresponding SCFG. - - Creates a deep copy of the SCFG and performs the operation to - restructure loop constructs within the control flow using - the algorithm LOOP RESTRUCTURING from section 4.1 of Bahmann2015. - It applies the restructuring operation to both the main SCFG - and any subregions within it. It returns a new ByteFlow object - with the updated SCFG. - - Returns - ------- - byteflow: ByteFlow - The new ByteFlow object with updated SCFG. - """ - scfg = deepcopy(self.scfg) - restructure_loop(scfg.region) - for region in _iter_subregions(scfg): - restructure_loop(region) - return ByteFlow(bc=self.bc, scfg=scfg) - - def _restructure_branch(self) -> "ByteFlow": - """Restructures the branches within the corresponding SCFG. - - Creates a deep copy of the SCFG and performs the operation to - restructure branch constructs within the control flow. It applies - the restructuring operation to both the main SCFG and any - subregions within it. It returns a new ByteFlow object with - the updated SCFG. - - Returns - ------- - byteflow: ByteFlow - The new ByteFlow object with updated SCFG. - """ - scfg = deepcopy(self.scfg) - restructure_branch(scfg.region) - for region in _iter_subregions(scfg): - restructure_branch(region) - return ByteFlow(bc=self.bc, scfg=scfg) - - def restructure(self) -> "ByteFlow": - """Applies join_returns, restructure_loop and restructure_branch - in the respective order on the SCFG. - - Creates a deep copy of the SCFG and applies a series of - restructuring operations to it. The operations include - joining return blocks, restructuring loop constructs, and - restructuring branch constructs. It returns a new ByteFlow - object with the updated SCFG. - - Returns - ------- - byteflow: ByteFlow - The new ByteFlow object with updated SCFG. - """ - scfg = deepcopy(self.scfg) - # close - scfg.join_returns() - # handle loop - restructure_loop(scfg.region) - for region in _iter_subregions(scfg): - restructure_loop(region) - # handle branch - restructure_branch(scfg.region) - for region in _iter_subregions(scfg): - restructure_branch(region) - return ByteFlow(bc=self.bc, scfg=scfg) - - -def _iter_subregions(scfg: SCFG) -> Generator[RegionBlock, SCFG, None]: - for node in scfg.graph.values(): - if isinstance(node, RegionBlock): - yield node - assert node.subregion is not None - yield from _iter_subregions(node.subregion) diff --git a/numba_rvsdg/core/datastructures/scfg.py b/numba_rvsdg/core/datastructures/scfg.py index 2bf4a9b..5e818b9 100644 --- a/numba_rvsdg/core/datastructures/scfg.py +++ b/numba_rvsdg/core/datastructures/scfg.py @@ -672,7 +672,8 @@ def join_returns(self) -> None: """Close the CFG. A closed CFG is a CFG with a unique entry and exit node that have no - predescessors and no successors respectively. + predescessors and no successors respectively. Transformation is applied + in-place. """ # for all nodes that contain a return return_nodes = [ @@ -683,6 +684,51 @@ def join_returns(self) -> None: return_solo_name = self.name_gen.new_block_name(SYNTH_RETURN) self.insert_SyntheticReturn(return_solo_name, return_nodes, []) + def iter_subregions(self) -> Generator[RegionBlock, "SCFG", None]: + """Iterate over all subregions of this CFG.""" + for node in self.graph.values(): + if isinstance(node, RegionBlock): + yield node + assert node.subregion is not None + yield from node.subregion.iter_subregions() + + def restructure_loop(self) -> None: + """Apply LOOP RESTRUCTURING transform. + + Performs the operation to restructure loop constructs using the + algorithm LOOP RESTRUCTURING from section 4.1 of Bahmann2015. It + applies an in-place restructuring operation to both the main SCFG and + any subregions within it. + + """ + # Avoid cyclic imports + from numba_rvsdg.core.transformations import restructure_loop + + restructure_loop(self.region) + for region in self.iter_subregions(): + restructure_loop(region) + + def restructure_branch(self) -> None: + """Apply BRANCH RESTRUCTURING transform. + + Performs the operation to restructure branch constructs using the + algorithm BRANCH RESTRUCTURING from section 4.2 of Bahmann2015. It + applies an in-place restructuring operation to both the main SCFG and + any subregions within it. + + """ + # Avoid cyclic imports + from numba_rvsdg.core.transformations import restructure_branch + + restructure_branch(self.region) + for region in self.iter_subregions(): + restructure_branch(region) + + def restructure(self) -> None: + self.join_returns() + self.restructure_loop() + self.restructure_branch() + def join_tails_and_exits( self, tails: List[str], exits: List[str] ) -> Tuple[str, str]: diff --git a/numba_rvsdg/rendering/rendering.py b/numba_rvsdg/rendering/rendering.py index c4fad36..0ef79bb 100644 --- a/numba_rvsdg/rendering/rendering.py +++ b/numba_rvsdg/rendering/rendering.py @@ -365,14 +365,14 @@ def render_flow(flow: ByteFlow) -> None: """ ByteFlowRenderer().render_byteflow(flow).view("before") - cflow = flow._join_returns() - ByteFlowRenderer().render_byteflow(cflow).view("closed") + flow.scfg.join_returns() + ByteFlowRenderer().render_byteflow(flow).view("closed") - lflow = cflow._restructure_loop() - ByteFlowRenderer().render_byteflow(lflow).view("loop restructured") + flow.scfg.restructure_loop() + ByteFlowRenderer().render_byteflow(flow).view("loop restructured") - bflow = lflow._restructure_branch() - ByteFlowRenderer().render_byteflow(bflow).view("branch restructured") + flow.scfg.restructure_branch() + ByteFlowRenderer().render_byteflow(flow).view("branch restructured") def render_scfg(scfg: SCFG) -> None: diff --git a/numba_rvsdg/tests/test_figures.py b/numba_rvsdg/tests/test_figures.py index 3309c04..9476634 100644 --- a/numba_rvsdg/tests/test_figures.py +++ b/numba_rvsdg/tests/test_figures.py @@ -1,6 +1,5 @@ # mypy: ignore-errors -from numba_rvsdg.core.datastructures.byte_flow import ByteFlow from numba_rvsdg.core.datastructures.flow_info import FlowInfo from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.tests.test_utils import SCFGComparator @@ -441,11 +440,10 @@ def test_figure_3(self): ] flow = FlowInfo.from_bytecode(bc) scfg = flow.build_basicblocks() - byteflow = ByteFlow(bc=bc, scfg=scfg) - byteflow = byteflow.restructure() + scfg.restructure() x, _ = SCFG.from_yaml(fig_3_yaml) - self.assertSCFGEqual(x, byteflow.scfg) + self.assertSCFGEqual(x, scfg) def test_figure_4(self): # Figure 4 of the paper @@ -474,8 +472,7 @@ def test_figure_4(self): ] flow = FlowInfo.from_bytecode(bc) scfg = flow.build_basicblocks() - byteflow = ByteFlow(bc=bc, scfg=scfg) - byteflow = byteflow.restructure() + scfg.restructure() x, _ = SCFG.from_yaml(fig_4_yaml) - self.assertSCFGEqual(x, byteflow.scfg) + self.assertSCFGEqual(x, scfg) diff --git a/numba_rvsdg/tests/test_scc.py b/numba_rvsdg/tests/test_scc.py index f0de80d..974b587 100644 --- a/numba_rvsdg/tests/test_scc.py +++ b/numba_rvsdg/tests/test_scc.py @@ -54,7 +54,7 @@ def make_flow(func): def test_scc(): f = make_flow(scc) - f.restructure() + f.scfg.restructure() if __name__ == "__main__": diff --git a/numba_rvsdg/tests/test_scfg.py b/numba_rvsdg/tests/test_scfg.py index 71e167e..370550b 100644 --- a/numba_rvsdg/tests/test_scfg.py +++ b/numba_rvsdg/tests/test_scfg.py @@ -194,17 +194,14 @@ def foo(n): def test_concealed_region_view_iter(self): flow = ByteFlow.from_bytecode(self.foo) - restructured = flow._restructure_loop() + flow.scfg.restructure_loop() expected = [ ("python_bytecode_block_0", PythonBytecodeBlock), ("loop_region_0", RegionBlock), ("python_bytecode_block_3", PythonBytecodeBlock), ] received = list( - ( - (k, type(v)) - for k, v in restructured.scfg.concealed_region_view.items() - ) + ((k, type(v)) for k, v in flow.scfg.concealed_region_view.items()) ) self.assertEqual(expected, received) diff --git a/numba_rvsdg/tests/test_simulate.py b/numba_rvsdg/tests/test_simulate.py index 9986e71..5a0dbae 100644 --- a/numba_rvsdg/tests/test_simulate.py +++ b/numba_rvsdg/tests/test_simulate.py @@ -44,7 +44,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + flow.scfg.restructure() # if case self._run(foo, flow, {"x": 1}) @@ -59,7 +59,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + flow.scfg.restructure() # loop bypass case self._run(foo, flow, {"x": 0}) @@ -78,7 +78,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + flow.scfg.restructure() # loop bypass case self._run(foo, flow, {"x": 0}) @@ -97,7 +97,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + flow.scfg.restructure() # loop bypass case self._run(foo, flow, {"x": 0}) @@ -121,7 +121,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + flow.scfg.restructure() # no loop self._run(foo, flow, {"x": 0}) @@ -145,7 +145,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + flow.scfg.restructure() # loop bypass self._run(foo, flow, {"x": 0}) @@ -161,7 +161,7 @@ def foo(x, y): return (x > 0 and x < 10) or (y > 0 and y < 10) flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + flow.scfg.restructure() self._run(foo, flow, {"x": 5, "y": 5}) @@ -175,7 +175,7 @@ def foo(s, e): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + flow.scfg.restructure() # no looping self._run(foo, flow, {"s": 0, "e": 0})