Skip to content

Commit

Permalink
[CINN] Implement the new AlignIterSpaceTactic
Browse files Browse the repository at this point in the history
  • Loading branch information
lshpku committed Jan 7, 2025
1 parent 03f4220 commit 14e960e
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 102 deletions.
2 changes: 2 additions & 0 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/hlir/framework/pir/trivial_op_impl.h"
#include "paddle/cinn/ir/group_schedule/config/schedule_config_manager.h"
#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.h"
Expand All @@ -33,6 +34,7 @@ void DynamicShapeGroupScheduler::Init() {
VLOG(4) << "original group func body: \n"
<< ir_sch_->GetModule().GetExprs()[0];
InitBuckets();
tactics_.emplace_back(CreateAlignIterSpaceTactic());
tactics_.emplace_back(CreateTileBroadcastTactic());
tactics_.emplace_back(CreateTileFirstGeneralTactic());
tactics_.emplace_back(CreateComputeInlineTactic());
Expand Down
227 changes: 172 additions & 55 deletions paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
// Copyright (c) 2025 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -13,89 +13,206 @@
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/integer_set.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"

namespace cinn {
namespace ir {
namespace {

/**
* Reorder the loops according to the memory-consistent order of input or output
* to make memory access as coalesced as possible.
*
* This tactic uses different alignment policies for Reduce and Trivial:
* 1) Reduce: align with the input, because after reduction, the output data is
* significantly smaller than the input data, so it's more critical to make
* input coalesced.
* 2) Trivial: align with the output, because discrete writes incur higher costs
* than discrete reads for the same volume of data due to the hardware design
* of cache. Therefore, we should ensure coalesced writes in priority.
*
* Note: we reorder spatial and reduce loops seperately, because we need to
* maintain the relative order between spatial and reduce loops, so as for later
* tactics to work properly. Thus, we use two lists sp_loop_perm & rd_loop_perm
* to record the permutation of spatial and reduce loops respectively.
*
*
* Examples:
* 1. Reduce
* Input:
* for (i, 0, 8): # S
* for (j, 0, 32): # S
* for (k, 0, 128): # R
* for (a, 0, 256): # R
* var_1[i, j] += var_0[j, a, k, i]
* Analysis:
* We align Reduce to the input `var_0[j, a, k, i]`. In the indices of var_0,
* the mapping from each index to the loop index is:
* indices[0] = j => loops[1] # S
* indices[1] = a => loops[3] # R
* indices[2] = k => loops[2] # R
* indices[3] = i => loops[0] # S
* To make the indices of var_0 consistent with its original memory layout, we
* need to permute the loops in the order {1, 3, 2, 0}. However, as we reorder
* spatial and reduce loop seperately, we split the permutation into sp & rd,
* getting sp_loop_perm = {1, 0} and rd_loop_perm = {3, 2}.
* Output:
* for (j, 0, 32): # S
* for (i, 0, 8): # S
* for (a, 0, 256): # R
* for (k, 0, 128): # R
* var_1[i, j] += var_0[j, a, k, i]
*
* 2. Trivial
* Input:
* for (i, 0, 32):
* for (j, 0, 128):
* for (k, 0, 256):
* var_1[k, i, j] = exp(var_0[j, i, k])
* Analysis:
* We align Trivial to the output `var_1[k, i, j]`. In the indices of var_1,
* the mapping from each index to the loop index is:
* indices[0] = k => loops[2]
* indices[1] = i => loops[0]
* indices[2] = j => loops[1]
* Like example 1, we should permute the loops in the order {2, 0, 1}. As this
* graph doesn't contain reduce loops, all we get is sp_loop_perm = {2, 0, 1},
* and rd_loop_perm = {}.
* Output:
* for (k, 0, 256):
* for (i, 0, 32):
* for (j, 0, 128):
* var_1[k, i, j] = exp(var_0[j, i, k])
*/
class AlignIterSpaceTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;
void Init(ScheduleContext* context, ir::IRSchedule* sch) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "AlignIterSpaceTactic"; }

private:
/**
* Get the common memory-consistent order of loops according to the outputs.
* Returns null if not all outputs share the same order.
*/
std::vector<int> GetCommonOutputLoopPerm(ir::IRSchedule* sch);

private:
ScheduleContext* context_;

// The permutation of spatial and reduce loops, in other to achieve the
// memory-consistent alignment.
std::vector<int> sp_loop_perm_;
std::vector<int> rd_loop_perm_;
};

void AlignIterSpaceTactic::Init(ScheduleContext* context) {
void AlignIterSpaceTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
context_ = context;
}
sp_loop_perm_.clear();
rd_loop_perm_.clear();

void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
ir::Expr block = sch->GetBlock(block_id);
auto& loop_strides = context_->config.base_info->loop_strides;
auto& reduce_axis = context_->config.base_info->reduce_axis;
std::set<int> reduce_axis_set(reduce_axis.begin(), reduce_axis.end());

std::vector<ir::Expr> loops = sch->GetLoops(block_id);
ir::Expr src_total_extent{1};
for (const auto& loop : loops) {
src_total_extent = src_total_extent * loop.As<ir::For>()->extent;
}
ir::Expr target_sp_extent{1};
for (const auto& iter : context_->iter_space_info.sp_space) {
target_sp_extent = target_sp_extent * std::get<0>(iter);
if (!loop_strides.empty()) {
// If this is a Reduce, calculate the loop_perm by sorting the loops in the
// descending order of their strides according to the input, then split the
// loop_perm into sp_loop_perm & rd_loop_perm.
std::vector<int> loop_perm(loop_strides.size());
std::iota(loop_perm.begin(), loop_perm.end(), 0);
std::stable_sort(loop_perm.begin(), loop_perm.end(), [&](int a, int b) {
return loop_strides[a] > loop_strides[b];
});

for (int axis : loop_perm) {
if (reduce_axis_set.count(axis) > 0) {
rd_loop_perm_.push_back(axis);
} else if (loop_strides[axis] != 0) {
sp_loop_perm_.push_back(axis);
}
}
} else {
// If this is a Trvial, calculate the sp_loop_perm according to the output.
sp_loop_perm_ = GetCommonOutputLoopPerm(sch);
}
ir::Expr target_total_extent = ir_utils::IRCopy(target_sp_extent);
for (const auto& iter : context_->iter_space_info.rb_space) {
target_total_extent = target_total_extent * std::get<0>(iter);

VLOG(4) << "AlignIterSpaceTactic:\n"
<< "sp_loop_perm: " << utils::Join(sp_loop_perm_, ", ") << "\n"
<< "rd_loop_perm: " << utils::Join(rd_loop_perm_, ", ");
}

std::unordered_map<ir::Var, int> GetLoopVarToIndex(
const std::vector<ir::Expr>& loops) {
std::unordered_map<ir::Var, int> loop_var2index;
for (int i = 0; i < loops.size(); ++i) {
auto* node = loops[i].As<ir::For>();
loop_var2index[node->loop_var] = i;
}
return loop_var2index;
}

common::cas_intervals_t var_intervals;
common::SymbolicExprAnalyzer symbolic_expr_analyzer(var_intervals);
std::optional<bool> total_extent_eq =
symbolic_expr_analyzer.ProveEQ(src_total_extent, target_total_extent);
bool need_reorder = false;
for (int i = 0; i < context_->iter_space_info.rb_last_order.size(); ++i) {
if (context_->iter_space_info.rb_last_order[i] != i) {
need_reorder = true;
break;
}
/**
* Check whether this is an effective permutation.
* A permutation is ineffective if it's entirely in ascending order.
*/
bool IsPermutationEffective(const std::vector<int>& perm) {
for (int i = 1; i < perm.size(); ++i) {
if (perm[i - 1] > perm[i]) return true;
}
return false;
}

if (total_extent_eq.has_value() && total_extent_eq.value()) {
if (need_reorder) {
sch->Reorder(block_id, context_->iter_space_info.rb_last_order);
}
if (context_->iter_space_info.sp_space.size() < loops.size() - 1) {
loops = sch->GetLoops(block_id);

// Align the loop in the current block that needs to be aligned with the
// reduce loop in iter_space_info
std::vector<ir::Expr> rb_loops(
loops.end() - context_->iter_space_info.rb_space.size(), loops.end());
sch->Fuse(rb_loops);
std::vector<int> AlignIterSpaceTactic::GetCommonOutputLoopPerm(
ir::IRSchedule* sch) {
std::vector<int> common_loop_perm;

for (auto& block : sch->GetAllBlocks()) {
std::string block_id = ir::analyzer::GetBlockName(block);
if (context_->output_names.count(block_id) == 0) continue;

auto store = ir::analyzer::GetStoreOfSBlock(block);
auto& indices = store.As<ir::Store>()->indices;
std::unordered_map<ir::Var, ir::Expr> iter_var2iter_value =
ir::analyzer::GetIterVarToValueOfSBlock(block);
std::unordered_map<ir::Var, int> loop_var2index =
GetLoopVarToIndex(sch->GetLoops(block));

std::vector<int> loop_perm;
for (auto& index : indices) {
if (index.is_constant()) continue;
if (!index.is_var()) return {};
ir::Expr iter_value = iter_var2iter_value[index.as_var_ref()];
if (!iter_value.is_var()) return {};
ir::Expr loop_var = iter_value.as_var_ref();
loop_perm.push_back(loop_var2index[loop_var]);
}
if (context_->iter_space_info.sp_space.size() > 1) {
// Align the loop in the current block that needs to be aligned with the
// spatial loop in iter_space_info
loops = sch->GetLoops(block_id);
std::vector<ir::Expr> sp_loops(
loops.begin(),
loops.end() - context_->iter_space_info.rb_space.size());
sch->Fuse(sp_loops);

if (common_loop_perm.empty()) {
common_loop_perm = std::move(loop_perm);
} else if (common_loop_perm != loop_perm) {
return {};
}
} else {
sch->Fuse(loops);
}

return common_loop_perm;
}

void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
if (ir::IsReduceInitTensorName(block_id)) return;
if (IsPermutationEffective(sp_loop_perm_)) {
sch->Reorder(block_id, sp_loop_perm_);
}
if (IsPermutationEffective(rd_loop_perm_)) {
sch->Reorder(block_id, rd_loop_perm_);
}
}

} // namespace

std::unique_ptr<ScheduleTactic> CreateAlignIterSpaceTactic() {
return std::make_unique<AlignIterSpaceTactic>();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
// Copyright (c) 2025 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -11,10 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
Expand Down
44 changes: 0 additions & 44 deletions paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class TileFirstGeneralTactic final : public ScheduleTactic {
std::string TacticName() const override { return "TileFirstGeneralTactic"; }

private:
void AlignToReduceInput(ir::IRSchedule* sch, const std::string& block_id);
void MergeFlattenAxis(ir::IRSchedule* sch, const std::string& block_id);
void MergeDiscreteFlattenAxis(ir::IRSchedule* sch,
const std::string& block_id);
Expand Down Expand Up @@ -128,11 +127,6 @@ void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch,
if (!can_apply_) return;
if (ir::IsReduceInitTensorName(block_id)) return;

AlignToReduceInput(sch, block_id);
VLOG(6) << "After AlignToReduceInput on block: [" << block_id
<< "], loop nest:\n"
<< sch->GetLoops(block_id)[0];

if (UseContinuousDataTile(context_->config)) {
VLOG(4) << "Using ApplyContinuousDataTile";
ApplyContinuousDataTile(sch, block_id);
Expand Down Expand Up @@ -293,44 +287,6 @@ void TileFirstGeneralTactic::ApplyContinuousDataTile(
SetReduceType(sch, block_id);
}

void TileFirstGeneralTactic::AlignToReduceInput(ir::IRSchedule* sch,
const std::string& block_id) {
const auto& loop_strides = context_->config.base_info->loop_strides;
if (loop_strides.empty()) {
return;
}

std::vector<ir::Expr> loops = sch->GetLoops(block_id);
std::vector<int64_t> loop_perm(loops.size());
std::iota(loop_perm.begin(), loop_perm.end(), 0);

const auto IsReduce = [&](int64_t axis) {
auto& reduce_axis = context_->config.base_info->reduce_axis;
return std::find(reduce_axis.begin(), reduce_axis.end(), axis) !=
reduce_axis.end();
};

std::sort(loop_perm.begin(), loop_perm.end(), [&](int64_t a, int64_t b) {
if (IsReduce(a) == IsReduce(b)) {
return loop_strides[a] > loop_strides[b];
}
return IsReduce(b);
});
VLOG(4) << "loop_perm: " << utils::Join(loop_perm, ", ");

// Reorder S/R loops seperately, otherwise reduce_init will be de-inlined.
std::vector<Expr> sp_loops, rd_loops;
for (auto i : loop_perm) {
if (IsReduce(i)) {
rd_loops.push_back(loops[i]);
} else if (loop_strides[i] != 0) {
sp_loops.push_back(loops[i]);
}
}
sch->Reorder(sp_loops);
sch->Reorder(rd_loops);
}

void TileFirstGeneralTactic::MergeFlattenAxis(ir::IRSchedule* sch,
const std::string& block_id) {
if (vec_flatten_axis_.size() >= 2) {
Expand Down

0 comments on commit 14e960e

Please sign in to comment.