Skip to content

Commit

Permalink
Indicate vectorize() outer dimensions
Browse files Browse the repository at this point in the history
When the outer dimension "x_vo" is assigned as gpu_block(), inform the
main Grouping algorithm of such an outer dimension. Avoid compute_at
"y_o" the outer gpu_block() level. Use the inner gpu_block() level.

When 4 or more GPU block levels are requested, schedule only the
innermost levels.
  • Loading branch information
antonysigma committed Aug 24, 2023
1 parent 67ae68f commit 9bd065d
Showing 1 changed file with 39 additions and 15 deletions.
54 changes: 39 additions & 15 deletions src/autoschedulers/mullapudi2016/AutoSchedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,10 +945,10 @@ class GPUTileHelper {
return;
}

internal_assert(vars.size() <= 3);

std::stringstream oss;
switch (vars.size()) {
case 0:
return;
case 1: {
const auto &[v, outer, inner, factor, strategy] = vars.front();
f.split(v, outer, inner, factor, strategy);
Expand Down Expand Up @@ -996,7 +996,7 @@ class GPUTileHelper {

break;
}
case 3: {
default: {
const auto &x = vars[0];
const auto &y = vars[1];
const auto &z = vars[2];
Expand Down Expand Up @@ -1086,6 +1086,7 @@ class GPUTilingDedup {
bool is_initial_order = true;
std::vector<VarOrRVar> ordering;

std::set<std::string> is_split;
std::set<std::string> outer_vars;
std::set<std::string> inner_vars;

Expand All @@ -1099,6 +1100,11 @@ class GPUTilingDedup {
return inner_vars.find(variable_name) != inner_vars.end();
}

/** True if Func::parallel(v_o) is pending. */
bool is_parallel(const std::string &variable_name) const {
return parallelize.find(variable_name) != parallelize.end();
}

bool is_update() const {
return f.name().find("update") != std::string::npos;
}
Expand All @@ -1114,12 +1120,16 @@ class GPUTilingDedup {
continue;
}

if (is_compute_at) continue;

// Skip all gpu_blocks if the current Stage is "compute_at" another
// stage, in which the gpu_blocks are already specified.
if (!is_compute_at && is_outer(v_name)) {
// if (is_outer(v_name) || is_parallel(v_name)) {
if (is_outer(v_name)) {
// Mark as gpu blocks;
f.gpu_blocks(v);
sched.push_schedule(f.name(), stage_num, "gpu_blocks(" + v_name + ")", {v_name});
continue;
}
}
}
Expand Down Expand Up @@ -1193,6 +1203,7 @@ class GPUTilingDedup {
*/
void has_split(const VarOrRVar &v, const VarOrRVar &vo, const VarOrRVar &vi, const Expr &factor, TailStrategy strategy) {
debug(2) << f.name() << ".split(" << v.name() << "," << factor << ")\n";
is_split.emplace(v.name());
outer_vars.emplace(vo.name());
inner_vars.emplace(vi.name());

Expand Down Expand Up @@ -1247,7 +1258,7 @@ class GPUTilingDedup {
sched.push_schedule(f.name(), stage_num, oss.str(), var_list);
}

const bool is_already_split = (inner_vars.size() + outer_vars.size() > 0);
const bool is_already_split = (!is_split.empty());
if (is_already_split) {
// If the Mullapudi's auto-splitting algorithm already computes the
// tile size, we simply mark the inner dims as gpu_threads();
Expand Down Expand Up @@ -1634,7 +1645,7 @@ struct Partitioner {

// Loop over the dimensions of function stage 'f_handle' starting from innermost
// and vectorize the first pure dimension encountered.
void vectorize_stage(
std::optional<pair<VarOrRVar, VarOrRVar>> vectorize_stage(
const Group &g, Stage f_handle, int stage_num, Definition def,
const Function &func, bool is_group_output, const Target &t, set<string> &rvars,
map<string, Expr> &estimates, AutoSchedule &sched, GPUTilingDedup &gpu_tiling);
Expand Down Expand Up @@ -2824,11 +2835,11 @@ pair<VarOrRVar, VarOrRVar> Partitioner::split_dim(
return make_pair(inner, outer);
}

void Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
Definition def, const Function &func, bool is_group_output,
const Target &t, set<string> &rvars,
map<string, Expr> &estimates, AutoSchedule &sched,
GPUTilingDedup &gpu_tiling) {
std::optional<pair<VarOrRVar, VarOrRVar>> Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
Definition def, const Function &func, bool is_group_output,
const Target &t, set<string> &rvars,
map<string, Expr> &estimates, AutoSchedule &sched,
GPUTilingDedup &gpu_tiling) {
vector<Dim> &dims = def.schedule().dims();
int vec_dim_index = -1;

Expand Down Expand Up @@ -2902,7 +2913,11 @@ void Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
debug(1) << "Outer dim vectorization of var \"" << vec_dim_name
<< "\" in function \"" << f_handle.name() << "\"\n";
}

return make_pair(inner, outer);
}

return std::nullopt;
}

// Return true if the vars/rvars in 'ordering' are in the same order as the
Expand Down Expand Up @@ -3184,8 +3199,16 @@ void Partitioner::generate_group_cpu_schedule(
}
}

vectorize_stage(g, f_handle, g.output.stage_num, def, g_out, true, t,
rvars, stg_estimates, sched, gpu_tiling);
{
auto vectorized_split = vectorize_stage(g, f_handle, g.output.stage_num, def, g_out, true, t,
rvars, stg_estimates, sched, gpu_tiling);

if (t.has_gpu_feature() && vectorized_split) {
auto [v_i, v_o] = *vectorized_split;
inner_dims.emplace_back(std::move(v_i));
outer_dims.emplace_back(std::move(v_o));
}
}

// Parallelize definition
Expr def_par = 1;
Expand Down Expand Up @@ -3296,14 +3319,15 @@ void Partitioner::generate_group_cpu_schedule(
mem_handle = Func(mem.func).update(mem.stage_num - 1);
} else {
if (!outer_dims.empty()) {
string sanitized_g_out = get_sanitized_name(g_out.name());
if (tile_inner_var.is_rvar) {
Func(mem.func).compute_at(Func(g_out), tile_inner_var.rvar);
debug(2) << mem_handle.name() << ".compute_at(" << sanitized_g_out << ", " << tile_inner_var.rvar << ")\n";
} else {
Func(mem.func).compute_at(Func(g_out), tile_inner_var.var);
debug(2) << mem_handle.name() << ".compute_at(" << sanitized_g_out << ", " << tile_inner_var.var << ")\n";
}

string sanitized_g_out = get_sanitized_name(g_out.name());
debug(2) << mem_handle.name() << ".compute_at(" << sanitized_g_out << ")\n";
sched.push_schedule(mem_handle.name(), mem.stage_num,
"compute_at(" + sanitized_g_out + ", " + tile_inner_var.name() + ")",
{sanitized_g_out, tile_inner_var.name()});
Expand Down

0 comments on commit 9bd065d

Please sign in to comment.