Skip to content

Commit

Permalink
Rewrite the rfactor scheduling directive (#8490)
Browse files Browse the repository at this point in the history
The old implementation suffered from several serious issues. It
duplicated substantial amounts of the logic in ApplySplit.cpp, the
way it handled adapting the predicate to the reducing func was
unprincipled, and it confused dims and vars in a way that could
segfault. It also left the order of pure dimensions unspecified.
The new implementation chooses to follow the existing dims list.

We now disallow rfactor() on funcs with RVar+Var fused schedules.
The implementation also relies on a new or_condition_over_domain
helper function and drops the purify() scheduling directive.

Fixes #7854
  • Loading branch information
alexreinking authored Dec 27, 2024
1 parent 50354c3 commit a9f82db
Show file tree
Hide file tree
Showing 19 changed files with 463 additions and 556 deletions.
2 changes: 1 addition & 1 deletion python_bindings/src/halide/halide_/PyStage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void define_stage(py::module &m) {
.def("dump_argument_list", &Stage::dump_argument_list)
.def("name", &Stage::name)

.def("rfactor", (Func(Stage::*)(std::vector<std::pair<RVar, Var>>)) & Stage::rfactor,
.def("rfactor", (Func(Stage::*)(const std::vector<std::pair<RVar, Var>> &)) & Stage::rfactor,
py::arg("preserved"))
.def("rfactor", (Func(Stage::*)(const RVar &, const Var &)) & Stage::rfactor,
py::arg("r"), py::arg("v"))
Expand Down
9 changes: 1 addition & 8 deletions src/ApplySplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,
}
} break;
case Split::RenameVar:
case Split::PurifyRVar:
result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::Substitution);
result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::LetStmt);
break;
Expand All @@ -167,10 +166,7 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,
}

vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &split, const string &prefix) {
// Define the bounds on the split dimensions using the bounds
// on the function args. If it is a purify, we should use the bounds
// from the dims instead.

// Define the bounds on the split dimensions using the bounds on the function args.
vector<std::pair<string, Expr>> let_stmts;

Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent");
Expand Down Expand Up @@ -201,9 +197,6 @@ vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &spl
let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max);
let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent);
break;
case Split::PurifyRVar:
// Do nothing for purify
break;
}

return let_stmts;
Expand Down
24 changes: 8 additions & 16 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <algorithm>
#include <iterator>
#include <numeric>

namespace Halide {
namespace Internal {
Expand Down Expand Up @@ -297,7 +298,6 @@ class BoundsInference : public IRMutator {
}

// Default case (no specialization)
vector<Expr> predicates = def.split_predicate();
for (const ReductionVariable &rv : def.schedule().rvars()) {
rvars.insert(rv);
}
Expand All @@ -308,23 +308,15 @@ class BoundsInference : public IRMutator {
}
vecs[1] = def.values();

vector<Expr> predicates = def.split_predicate();
for (size_t i = 0; i < result.size(); ++i) {
for (const Expr &val : vecs[i]) {
if (!predicates.empty()) {
Expr cond_val = Call::make(val.type(),
Internal::Call::if_then_else,
{likely(predicates[0]), val},
Internal::Call::PureIntrinsic);
for (size_t i = 1; i < predicates.size(); ++i) {
cond_val = Call::make(cond_val.type(),
Internal::Call::if_then_else,
{likely(predicates[i]), cond_val},
Internal::Call::PureIntrinsic);
}
result[i].emplace_back(const_true(), cond_val);
} else {
result[i].emplace_back(const_true(), val);
}
Expr cond_val = std::accumulate(
predicates.begin(), predicates.end(), val,
[](const auto &acc, const auto &pred) {
return Call::make(acc.type(), Call::if_then_else, {likely(pred), acc}, Call::PureIntrinsic);
});
result[i].emplace_back(const_true(), cond_val);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/Derivative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ void ReverseAccumulationVisitor::propagate_halide_function_call(
// f(r.x) = ... && r is associative
// => f(x) = ...
if (var != nullptr && var->reduction_domain.defined() &&
var->reduction_domain.split_predicate().empty()) {
is_const_one(var->reduction_domain.predicate())) {
ReductionDomain rdom = var->reduction_domain;
int rvar_id = -1;
for (int rid = 0; rid < (int)rdom.domain().size(); rid++) {
Expand Down
2 changes: 0 additions & 2 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,6 @@ Split::SplitType Deserializer::deserialize_split_type(Serialize::SplitType split
return Split::SplitType::RenameVar;
case Serialize::SplitType::FuseVars:
return Split::SplitType::FuseVars;
case Serialize::SplitType::PurifyRVar:
return Split::SplitType::PurifyRVar;
default:
user_error << "unknown split type " << (int)split_type << "\n";
return Split::SplitType::SplitVar;
Expand Down
Loading

0 comments on commit a9f82db

Please sign in to comment.