Skip to content

Commit

Permalink
Merge pull request #112 from esc/refactor_transforms
Browse files Browse the repository at this point in the history
enable `scfg.restructure()`
  • Loading branch information
esc authored Apr 16, 2024
2 parents e54eeec + 510c8fd commit 52033dd
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 129 deletions.
102 changes: 1 addition & 101 deletions numba_rvsdg/core/datastructures/byte_flow.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import dis
from copy import deepcopy
from dataclasses import dataclass
from typing import Generator, Callable
from typing import Callable

from numba_rvsdg.core.datastructures.scfg import SCFG
from numba_rvsdg.core.datastructures.basic_block import RegionBlock
from numba_rvsdg.core.datastructures.flow_info import FlowInfo
from numba_rvsdg.core.utils import _logger, _LogWrap

from numba_rvsdg.core.transformations import (
restructure_loop,
restructure_branch,
)


@dataclass(frozen=True)
class ByteFlow:
Expand Down Expand Up @@ -60,96 +53,3 @@ def from_bytecode(code: Callable) -> "ByteFlow": # type: ignore
flowinfo = FlowInfo.from_bytecode(bc)
scfg = flowinfo.build_basicblocks()
return ByteFlow(bc=bc, scfg=scfg)

def _join_returns(self) -> "ByteFlow":
"""Joins the return blocks within the corresponding SCFG.
This method creates a deep copy of the SCFG and performs
operation to join return blocks within the control flow.
It returns a new ByteFlow object with the updated SCFG.
Returns
-------
byteflow: ByteFlow
The new ByteFlow object with updated SCFG.
"""
scfg = deepcopy(self.scfg)
scfg.join_returns()
return ByteFlow(bc=self.bc, scfg=scfg)

def _restructure_loop(self) -> "ByteFlow":
"""Restructures the loops within the corresponding SCFG.
Creates a deep copy of the SCFG and performs the operation to
restructure loop constructs within the control flow using
the algorithm LOOP RESTRUCTURING from section 4.1 of Bahmann2015.
It applies the restructuring operation to both the main SCFG
and any subregions within it. It returns a new ByteFlow object
with the updated SCFG.
Returns
-------
byteflow: ByteFlow
The new ByteFlow object with updated SCFG.
"""
scfg = deepcopy(self.scfg)
restructure_loop(scfg.region)
for region in _iter_subregions(scfg):
restructure_loop(region)
return ByteFlow(bc=self.bc, scfg=scfg)

def _restructure_branch(self) -> "ByteFlow":
"""Restructures the branches within the corresponding SCFG.
Creates a deep copy of the SCFG and performs the operation to
restructure branch constructs within the control flow. It applies
the restructuring operation to both the main SCFG and any
subregions within it. It returns a new ByteFlow object with
the updated SCFG.
Returns
-------
byteflow: ByteFlow
The new ByteFlow object with updated SCFG.
"""
scfg = deepcopy(self.scfg)
restructure_branch(scfg.region)
for region in _iter_subregions(scfg):
restructure_branch(region)
return ByteFlow(bc=self.bc, scfg=scfg)

def restructure(self) -> "ByteFlow":
"""Applies join_returns, restructure_loop and restructure_branch
in the respective order on the SCFG.
Creates a deep copy of the SCFG and applies a series of
restructuring operations to it. The operations include
joining return blocks, restructuring loop constructs, and
restructuring branch constructs. It returns a new ByteFlow
object with the updated SCFG.
Returns
-------
byteflow: ByteFlow
The new ByteFlow object with updated SCFG.
"""
scfg = deepcopy(self.scfg)
# close
scfg.join_returns()
# handle loop
restructure_loop(scfg.region)
for region in _iter_subregions(scfg):
restructure_loop(region)
# handle branch
restructure_branch(scfg.region)
for region in _iter_subregions(scfg):
restructure_branch(region)
return ByteFlow(bc=self.bc, scfg=scfg)


def _iter_subregions(scfg: SCFG) -> Generator[RegionBlock, SCFG, None]:
for node in scfg.graph.values():
if isinstance(node, RegionBlock):
yield node
assert node.subregion is not None
yield from _iter_subregions(node.subregion)
48 changes: 47 additions & 1 deletion numba_rvsdg/core/datastructures/scfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,8 @@ def join_returns(self) -> None:
"""Close the CFG.
A closed CFG is a CFG with a unique entry and exit node that have no
predescessors and no successors respectively.
predescessors and no successors respectively. Transformation is applied
in-place.
"""
# for all nodes that contain a return
return_nodes = [
Expand All @@ -683,6 +684,51 @@ def join_returns(self) -> None:
return_solo_name = self.name_gen.new_block_name(SYNTH_RETURN)
self.insert_SyntheticReturn(return_solo_name, return_nodes, [])

def iter_subregions(self) -> Generator[RegionBlock, "SCFG", None]:
"""Iterate over all subregions of this CFG."""
for node in self.graph.values():
if isinstance(node, RegionBlock):
yield node
assert node.subregion is not None
yield from node.subregion.iter_subregions()

def restructure_loop(self) -> None:
"""Apply LOOP RESTRUCTURING transform.
Performs the operation to restructure loop constructs using the
algorithm LOOP RESTRUCTURING from section 4.1 of Bahmann2015. It
applies an in-place restructuring operation to both the main SCFG and
any subregions within it.
"""
# Avoid cyclic imports
from numba_rvsdg.core.transformations import restructure_loop

restructure_loop(self.region)
for region in self.iter_subregions():
restructure_loop(region)

def restructure_branch(self) -> None:
"""Apply BRANCH RESTRUCTURING transform.
Performs the operation to restructure branch constructs using the
algorithm BRANCH RESTRUCTURING from section 4.2 of Bahmann2015. It
applies an in-place restructuring operation to both the main SCFG and
any subregions within it.
"""
# Avoid cyclic imports
from numba_rvsdg.core.transformations import restructure_branch

restructure_branch(self.region)
for region in self.iter_subregions():
restructure_branch(region)

def restructure(self) -> None:
self.join_returns()
self.restructure_loop()
self.restructure_branch()

def join_tails_and_exits(
self, tails: List[str], exits: List[str]
) -> Tuple[str, str]:
Expand Down
12 changes: 6 additions & 6 deletions numba_rvsdg/rendering/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,14 @@ def render_flow(flow: ByteFlow) -> None:
"""
ByteFlowRenderer().render_byteflow(flow).view("before")

cflow = flow._join_returns()
ByteFlowRenderer().render_byteflow(cflow).view("closed")
flow.scfg.join_returns()
ByteFlowRenderer().render_byteflow(flow).view("closed")

lflow = cflow._restructure_loop()
ByteFlowRenderer().render_byteflow(lflow).view("loop restructured")
flow.scfg.restructure_loop()
ByteFlowRenderer().render_byteflow(flow).view("loop restructured")

bflow = lflow._restructure_branch()
ByteFlowRenderer().render_byteflow(bflow).view("branch restructured")
flow.scfg.restructure_branch()
ByteFlowRenderer().render_byteflow(flow).view("branch restructured")


def render_scfg(scfg: SCFG) -> None:
Expand Down
11 changes: 4 additions & 7 deletions numba_rvsdg/tests/test_figures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# mypy: ignore-errors

from numba_rvsdg.core.datastructures.byte_flow import ByteFlow
from numba_rvsdg.core.datastructures.flow_info import FlowInfo
from numba_rvsdg.core.datastructures.scfg import SCFG
from numba_rvsdg.tests.test_utils import SCFGComparator
Expand Down Expand Up @@ -441,11 +440,10 @@ def test_figure_3(self):
]
flow = FlowInfo.from_bytecode(bc)
scfg = flow.build_basicblocks()
byteflow = ByteFlow(bc=bc, scfg=scfg)
byteflow = byteflow.restructure()
scfg.restructure()

x, _ = SCFG.from_yaml(fig_3_yaml)
self.assertSCFGEqual(x, byteflow.scfg)
self.assertSCFGEqual(x, scfg)

def test_figure_4(self):
# Figure 4 of the paper
Expand Down Expand Up @@ -474,8 +472,7 @@ def test_figure_4(self):
]
flow = FlowInfo.from_bytecode(bc)
scfg = flow.build_basicblocks()
byteflow = ByteFlow(bc=bc, scfg=scfg)
byteflow = byteflow.restructure()
scfg.restructure()

x, _ = SCFG.from_yaml(fig_4_yaml)
self.assertSCFGEqual(x, byteflow.scfg)
self.assertSCFGEqual(x, scfg)
2 changes: 1 addition & 1 deletion numba_rvsdg/tests/test_scc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def make_flow(func):

def test_scc():
f = make_flow(scc)
f.restructure()
f.scfg.restructure()


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions numba_rvsdg/tests/test_scfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,14 @@ def foo(n):

def test_concealed_region_view_iter(self):
flow = ByteFlow.from_bytecode(self.foo)
restructured = flow._restructure_loop()
flow.scfg.restructure_loop()
expected = [
("python_bytecode_block_0", PythonBytecodeBlock),
("loop_region_0", RegionBlock),
("python_bytecode_block_3", PythonBytecodeBlock),
]
received = list(
(
(k, type(v))
for k, v in restructured.scfg.concealed_region_view.items()
)
((k, type(v)) for k, v in flow.scfg.concealed_region_view.items())
)
self.assertEqual(expected, received)

Expand Down
16 changes: 8 additions & 8 deletions numba_rvsdg/tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def foo(x):
return c

flow = ByteFlow.from_bytecode(foo)
flow = flow.restructure()
flow.scfg.restructure()

# if case
self._run(foo, flow, {"x": 1})
Expand All @@ -59,7 +59,7 @@ def foo(x):
return c

flow = ByteFlow.from_bytecode(foo)
flow = flow.restructure()
flow.scfg.restructure()

# loop bypass case
self._run(foo, flow, {"x": 0})
Expand All @@ -78,7 +78,7 @@ def foo(x):
return c

flow = ByteFlow.from_bytecode(foo)
flow = flow.restructure()
flow.scfg.restructure()

# loop bypass case
self._run(foo, flow, {"x": 0})
Expand All @@ -97,7 +97,7 @@ def foo(x):
return c

flow = ByteFlow.from_bytecode(foo)
flow = flow.restructure()
flow.scfg.restructure()

# loop bypass case
self._run(foo, flow, {"x": 0})
Expand All @@ -121,7 +121,7 @@ def foo(x):
return c

flow = ByteFlow.from_bytecode(foo)
flow = flow.restructure()
flow.scfg.restructure()

# no loop
self._run(foo, flow, {"x": 0})
Expand All @@ -145,7 +145,7 @@ def foo(x):
return c

flow = ByteFlow.from_bytecode(foo)
flow = flow.restructure()
flow.scfg.restructure()

# loop bypass
self._run(foo, flow, {"x": 0})
Expand All @@ -161,7 +161,7 @@ def foo(x, y):
return (x > 0 and x < 10) or (y > 0 and y < 10)

flow = ByteFlow.from_bytecode(foo)
flow = flow.restructure()
flow.scfg.restructure()

self._run(foo, flow, {"x": 5, "y": 5})

Expand All @@ -175,7 +175,7 @@ def foo(s, e):
return c

flow = ByteFlow.from_bytecode(foo)
flow = flow.restructure()
flow.scfg.restructure()

# no looping
self._run(foo, flow, {"s": 0, "e": 0})
Expand Down

0 comments on commit 52033dd

Please sign in to comment.