Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify HLS loop dead node elimination #726

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
namespace jlm::hls
{

static bool
IsPassthroughLoopVar(const rvsdg::ThetaNode::LoopVar & loopvar)
{
return loopvar.pre->nusers() == 1 && loopvar.post->origin() == loopvar.pre;
}

static bool
IsPassthroughArgument(const rvsdg::output & argument)
{
Expand Down Expand Up @@ -115,22 +121,6 @@ RemoveUnusedStatesFromLambda(llvm::lambda::node & lambdaNode)
remove(&lambdaNode);
}

static void
RemovePassthroughArgument(const rvsdg::RegionArgument & argument)
{
auto origin = argument.input()->origin();
auto result = dynamic_cast<rvsdg::RegionResult *>(*argument.begin());
argument.region()->node()->output(result->output()->index())->divert_users(origin);

auto inputIndex = argument.input()->index();
auto outputIndex = result->output()->index();
auto region = argument.region();
region->RemoveResult(result->index());
region->RemoveArgument(argument.index());
region->node()->RemoveInput(inputIndex);
region->node()->RemoveOutput(outputIndex);
}

static void
RemoveUnusedStatesFromGammaNode(rvsdg::GammaNode & gammaNode)
{
Expand Down Expand Up @@ -178,15 +168,16 @@ RemoveUnusedStatesFromGammaNode(rvsdg::GammaNode & gammaNode)
static void
RemoveUnusedStatesFromThetaNode(rvsdg::ThetaNode & thetaNode)
{
auto thetaSubregion = thetaNode.subregion();
for (int i = thetaSubregion->narguments() - 1; i >= 0; --i)
std::vector<rvsdg::ThetaNode::LoopVar> loopvars;
for (const auto & loopvar : thetaNode.GetLoopVars())
{
auto & argument = *thetaSubregion->argument(i);
if (IsPassthroughArgument(argument))
if (IsPassthroughLoopVar(loopvar))
{
RemovePassthroughArgument(argument);
loopvar.output->divert_users(loopvar.input->origin());
loopvars.push_back(loopvar);
}
}
thetaNode.RemoveLoopVars(std::move(loopvars));
}

static void
Expand Down
6 changes: 2 additions & 4 deletions jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ distribute_constant(const rvsdg::SimpleOperation & op, rvsdg::simple_output * ou
loopvar.output->divert_users(
rvsdg::SimpleNode::create_normalized(out->region(), op, {})[0]);
distribute_constant(op, arg_replacement);
theta->subregion()->RemoveResult(loopvar.post->index());
theta->subregion()->RemoveArgument(loopvar.pre->index());
theta->RemoveInput(loopvar.input->index());
theta->RemoveOutput(loopvar.output->index());
loopvar.post->divert_to(loopvar.pre);
theta->RemoveLoopVars({ loopvar });
changed = true;
break;
}
Expand Down
7 changes: 1 addition & 6 deletions tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,7 @@ TestTheta()
jlm::hls::RemoveUnusedStates(*rvsdgModule);

// Assert
// This assert is only here so that we do not forget this test when we refactor the code
assert(thetaNode->ninputs() == 1);

// FIXME: This transformation is broken for theta nodes. For the setup above, it
// removes all inputs/outputs, except the predicate. However, the only
// input and output it should remove are input 1 and output 0, respectively.
assert(thetaNode->ninputs() == 3);
}

static void
Expand Down
Loading