Skip to content

Commit

Permalink
[Snippets] Added type support to LoopPort (#28310)
Browse files Browse the repository at this point in the history
### Details:
- *Moved fields of the class `LoopPort` to private section and added
getters/setters*
- *Implemented `Type` to `LoopPort` to distinguish not incremented ports
due to double ptr increment and not incremented ports of Brgemm.*
- *These changes fix inefficient calculation of Buffer allocation size
in dynamic MHA Subgraphs. More details are described in the ticket
159913*

### Tickets:
 - *157326*
 - *159913*
  • Loading branch information
a-sidorova authored Jan 10, 2025
1 parent af72b13 commit e1357f1
Show file tree
Hide file tree
Showing 31 changed files with 483 additions and 344 deletions.
4 changes: 1 addition & 3 deletions src/common/snippets/include/snippets/lowered/loop_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ using LoopInfoPtr = std::shared_ptr<LoopInfo>;
*/
class LoopInfo : public std::enable_shared_from_this<LoopInfo> {
public:
enum {UNDEFINED_DIM_IDX = std::numeric_limits<size_t>::max()};

LoopInfo() = default;
LoopInfo(size_t work_amount, size_t increment, const std::vector<LoopPort>& entries, const std::vector<LoopPort>& exits);
LoopInfo(size_t work_amount, size_t increment, const std::vector<ExpressionPort>& entries, const std::vector<ExpressionPort>& exits);
Expand Down Expand Up @@ -66,7 +64,7 @@ class LoopInfo : public std::enable_shared_from_this<LoopInfo> {

/**
* @brief Returns dimension index if dimension indices for all input and output ports are equal.
* Otherwise returns UNDEFINED_DIM_IDX.
* Otherwise returns LoopPort::UNDEFINED_DIM_IDX.
* @return index
*/
size_t get_dim_idx() const;
Expand Down
56 changes: 48 additions & 8 deletions src/common/snippets/include/snippets/lowered/loop_port.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,71 @@

#include "snippets/lowered/expression_port.hpp"
#include "snippets/lowered/expression.hpp"
#include "snippets/utils/utils.hpp"


namespace ov {
namespace snippets {
namespace lowered {

/* The structure describes port of Loop: expression port that connected to Expressions from other Loops.
/* The class describes port of Loop: expression port that connected to Expressions from other Loops.
*/
struct LoopPort {
class LoopPort {
public:
enum {UNDEFINED_DIM_IDX = std::numeric_limits<size_t>::max()};
enum class Type {
Incremented, // Loop port which data ptr should be incremented after each Loop iteration
NotIncremented, // Loop port which data ptr should not be incremented (for example, to avoid double increment)
NotProcessed, // LoopPort which doesn't process the dim by `dim_idx` (UNDEFINED_DIM_IDX) and is used only for Loop bound definition
};

LoopPort() = default;
LoopPort(const ExpressionPort& port, bool is_incremented = true, size_t dim_idx = 0);

template<LoopPort::Type T,
typename std::enable_if<T == Type::Incremented || T == Type::NotIncremented, bool>::type = true>
static LoopPort create(const ExpressionPort& port, size_t dim_idx = 0) {
return LoopPort(port, dim_idx, T);
}

template<LoopPort::Type T,
typename std::enable_if<T == Type::NotProcessed, bool>::type = true>
static LoopPort create(const ExpressionPort& port) {
return LoopPort(port, UNDEFINED_DIM_IDX, Type::NotProcessed);
}

std::shared_ptr<LoopPort> clone_with_new_expr(const ExpressionPtr& new_expr) const;

friend bool operator==(const LoopPort& lhs, const LoopPort& rhs);
friend bool operator!=(const LoopPort& lhs, const LoopPort& rhs);
friend bool operator<(const LoopPort& lhs, const LoopPort& rhs);

std::shared_ptr<ExpressionPort> expr_port = {};
// True if after each Loop iteration the corresponding data pointer should be incremented.
// Otherwise, the data pointer shift is skipped
bool is_incremented = true;
size_t dim_idx = 0; // The numeration starts from the end (dim_idx = 0 -> is the most inner dimension)
const std::shared_ptr<ExpressionPort>& get_expr_port() const { return m_expr_port; }
Type get_type() const { return m_type; }
size_t get_dim_idx() const;

void set_expr_port(std::shared_ptr<ExpressionPort> p);
void set_dim_idx(size_t idx);

template<LoopPort::Type T,
typename std::enable_if<T == Type::Incremented || T == Type::NotIncremented, bool>::type = true>
void convert_to_type() {
OPENVINO_ASSERT(is_processed(), "NotProcessed LoopPort cannot change type!");
m_type = T;
}

bool is_processed() const;
bool is_incremented() const;

private:
LoopPort(const ExpressionPort& port, size_t dim_idx, Type type);

std::shared_ptr<ExpressionPort> m_expr_port = {};
size_t m_dim_idx = 0; // The numeration starts from the end (dim_idx = 0 -> is the most inner dimension)
Type m_type = Type::Incremented;
};

std::ostream& operator<<(std::ostream& out, const LoopPort::Type& type);

} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class InsertBuffers : public RangedPass {
const LinearIR::constExprIt& begin_it,
const LinearIR::constExprIt& end_it,
const LoopManagerPtr& loop_manager,
const std::vector<LoopPort>& loop_entries,
const std::vector<LoopPort>& loop_exits) const;
const std::vector<ExpressionPort>& loop_entries,
const std::vector<ExpressionPort>& loop_exits) const;

static LinearIR::constExprIt insertion_position(const LinearIR& linear_ir,
const LoopManagerPtr& loop_manager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ void BufferExpression::init_allocation_size(const std::shared_ptr<LoopManager>&
const auto& subtensor = ov::snippets::utils::get_projected_subtensor(parent_port);

auto hard_equal = [&parent_port](const LoopPort& port) {
return *port.expr_port == parent_port;
return *port.get_expr_port() == parent_port;
};
auto soft_equal = [&](const LoopPort& loop_port) {
const auto& port = *loop_port.expr_port;
const auto& port = *loop_port.get_expr_port();
// Check semantic of LoopPort
if (parent_port.get_index() != port.get_index() ||
port.get_expr()->get_node()->get_type_info() != parent_port.get_expr()->get_node()->get_type_info())
Expand Down Expand Up @@ -109,8 +109,10 @@ void BufferExpression::init_allocation_size(const std::shared_ptr<LoopManager>&
OPENVINO_ASSERT(it != output_ports.end(), "compute_allocation_shape: output port of parent loop can not be found");
}
const auto& loop_port = *it;
const auto& dim_idx = loop_port.dim_idx;
if (loop_port.is_incremented && dim_idx < rank) {
if (!loop_port.is_processed())
continue;
const auto& dim_idx = loop_port.get_dim_idx();
if (dim_idx < rank) {
if (const auto& unified_loop_info = ov::as_type_ptr<UnifiedLoopInfo>(loop_info))
m_allocation_size = utils::dynamic_safe_mul(m_allocation_size, unified_loop_info->get_work_amount());
else if (const auto& expanded_loop_info = ov::as_type_ptr<ExpandedLoopInfo>(loop_info))
Expand Down
48 changes: 28 additions & 20 deletions src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ LoopInfo::LoopInfo(size_t work_amount, size_t increment, const std::vector<Expre
m_input_ports.reserve(entries.size());
m_output_ports.reserve(exits.size());
for (const auto& port : entries)
m_input_ports.emplace_back(port);
m_input_ports.push_back(LoopPort::create<LoopPort::Type::Incremented>(port));
for (const auto& port : exits)
m_output_ports.emplace_back(port);
m_output_ports.push_back(LoopPort::create<LoopPort::Type::Incremented>(port));
}

bool LoopInfo::is_dynamic() const {
Expand All @@ -30,14 +30,22 @@ bool LoopInfo::is_dynamic() const {

size_t LoopInfo::get_dim_idx() const {
OPENVINO_ASSERT(!m_input_ports.empty(), "Loop info must have at least one input port");
auto equal_dim_idxes = [&](const LoopPort& p) {
return !p.is_incremented || p.dim_idx == m_input_ports[0].dim_idx;
};

auto is_processed = [](const LoopPort& p) { return p.is_processed(); };
auto is_processed_it = std::find_if(m_input_ports.begin(), m_input_ports.end(), is_processed);
if (is_processed_it == m_input_ports.end()) {
is_processed_it = std::find_if(m_output_ports.begin(), m_output_ports.end(), is_processed);
if (is_processed_it == m_output_ports.end())
return LoopPort::UNDEFINED_DIM_IDX;
}
const auto dim_idx = is_processed_it->get_dim_idx();

auto equal_dim_idxes = [&](const LoopPort& p) { return !p.is_processed() || p.get_dim_idx() == dim_idx; };
if (std::all_of(m_input_ports.begin(), m_input_ports.end(), equal_dim_idxes) &&
std::all_of(m_output_ports.begin(), m_output_ports.end(), equal_dim_idxes)) {
return m_input_ports[0].dim_idx;
return dim_idx;
} else {
return UNDEFINED_DIM_IDX;
return LoopPort::UNDEFINED_DIM_IDX;
}
}

Expand All @@ -60,7 +68,7 @@ size_t LoopInfo::get_increment() const {
std::vector<bool> LoopInfo::get_is_incremented() const {
std::vector<bool> values;
values.reserve(get_input_count() + get_output_count());
iterate_through_ports([&values](const LoopPort& port) { values.push_back(port.is_incremented); });
iterate_through_ports([&values](const LoopPort& port) { values.push_back(port.is_incremented()); });
return values;
}

Expand All @@ -81,14 +89,14 @@ void LoopInfo::set_increment(size_t increment) {
}

void LoopInfo::set_dim_idx(size_t dim_idx) {
auto setter = [dim_idx](LoopPort& port) { port.dim_idx = dim_idx; };
auto setter = [dim_idx](LoopPort& port) { if (port.is_processed()) port.set_dim_idx(dim_idx); };
std::for_each(m_input_ports.begin(), m_input_ports.end(), setter);
std::for_each(m_output_ports.begin(), m_output_ports.end(), setter);
}

template<>
std::vector<LoopPort>::iterator LoopInfo::find_loop_port(const LoopPort& loop_port) {
auto& ports = loop_port.expr_port->get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
auto& ports = loop_port.get_expr_port()->get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
const auto it = std::find_if(ports.begin(), ports.end(),
[&loop_port](const LoopPort& port) { return port == loop_port; });
OPENVINO_ASSERT(it != ports.end(), "Failed find_loop_port: existing loop port has not been found");
Expand All @@ -99,7 +107,7 @@ template<>
std::vector<LoopPort>::iterator LoopInfo::find_loop_port(const ExpressionPort& expr_port) {
auto& ports = expr_port.get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
const auto it = std::find_if(ports.begin(), ports.end(),
[&expr_port](const LoopPort& port) { return *port.expr_port == expr_port; });
[&expr_port](const LoopPort& port) { return *port.get_expr_port() == expr_port; });
return it;
}

Expand All @@ -118,7 +126,7 @@ namespace {
void validate_new_target_ports(const std::vector<LoopPort>& target_ports, ExpressionPort::Type target_type) {
OPENVINO_ASSERT(target_ports.empty() ||
std::all_of(target_ports.cbegin(), target_ports.cend(),
[&target_type](const LoopPort& target_port) { return target_type == target_port.expr_port->get_type(); }));
[&target_type](const LoopPort& target_port) { return target_type == target_port.get_expr_port()->get_type(); }));
}
void validate_new_target_ports(const std::vector<ExpressionPort>& target_ports, ExpressionPort::Type target_type) {
OPENVINO_ASSERT(target_ports.empty() ||
Expand All @@ -128,7 +136,7 @@ void validate_new_target_ports(const std::vector<ExpressionPort>& target_ports,
} // namespace

void LoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector<LoopPort>& target_ports) {
const auto& actual_port_type = actual_port.expr_port->get_type();
const auto& actual_port_type = actual_port.get_expr_port()->get_type();
validate_new_target_ports(target_ports, actual_port_type);

auto& ports = actual_port_type == ExpressionPort::Input ? m_input_ports : m_output_ports;
Expand All @@ -153,7 +161,7 @@ void LoopInfo::replace_with_new_ports(const ExpressionPort& actual_port, const s
std::transform(target_loop_ports.begin(), target_loop_ports.end(), target_ports.begin(), target_loop_ports.begin(),
[](LoopPort loop_port, const ExpressionPort& expr_port) {
LoopPort copy = std::move(loop_port); // to save loop port parameters
copy.expr_port = std::make_shared<ExpressionPort>(expr_port);
copy.set_expr_port(std::make_shared<ExpressionPort>(expr_port));
return copy;
});
port_it = ports.erase(port_it);
Expand All @@ -164,7 +172,7 @@ std::vector<LoopPort> LoopInfo::clone_loop_ports(const ExpressionMap& expr_map,
std::vector<LoopPort> cloned_port_points;
cloned_port_points.reserve(loop_ports.size());
for (const auto& p : loop_ports) {
const auto& expr = p.expr_port->get_expr().get();
const auto& expr = p.get_expr_port()->get_expr().get();
OPENVINO_ASSERT(expr_map.count(expr), "Can't clone LoopInfo: old expression is not in the map");
const auto& new_expr = expr_map.at(expr);
cloned_port_points.emplace_back(*p.clone_with_new_expr(new_expr));
Expand Down Expand Up @@ -309,8 +317,8 @@ std::vector<size_t> get_port_index_order(const std::vector<LoopPort>& ports) {
std::iota(new_indexes.begin(), new_indexes.end(), 0);
std::sort(new_indexes.begin(), new_indexes.end(),
[ports](size_t l, size_t r) {
const auto& expr_port_l = ports[l].expr_port;
const auto& expr_port_r = ports[r].expr_port;
const auto& expr_port_l = ports[l].get_expr_port();
const auto& expr_port_r = ports[r].get_expr_port();
if (expr_port_l->get_expr() == expr_port_r->get_expr())
return expr_port_l->get_index() < expr_port_r->get_index();
return expr_port_l->get_expr()->get_exec_num() < expr_port_r->get_expr()->get_exec_num();
Expand Down Expand Up @@ -340,7 +348,7 @@ UnifiedLoopInfo::LoopPortInfo UnifiedLoopInfo::get_loop_port_info(const Expressi
const auto& ports = is_input ? m_input_ports : m_output_ports;
const auto& descs = is_input ? m_input_port_descs : m_output_port_descs;
const auto it = std::find_if(ports.begin(), ports.end(),
[&expr_port](const LoopPort& port) { return *port.expr_port == expr_port; });
[&expr_port](const LoopPort& port) { return *port.get_expr_port() == expr_port; });
const auto index = static_cast<size_t>(std::distance(ports.cbegin(), it));
OPENVINO_ASSERT(index < ports.size() && index < descs.size(), "LoopPortInfo has not been found!");
return {ports[index], descs[index]};
Expand All @@ -354,10 +362,10 @@ void UnifiedLoopInfo::replace_with_cloned_descs(size_t actual_port_idx, size_t n
}

void UnifiedLoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector<LoopPort>& target_ports) {
const auto& actual_port_type = actual_port.expr_port->get_type();
const auto& actual_port_type = actual_port.get_expr_port()->get_type();
validate_new_target_ports(target_ports, actual_port_type);

const auto is_input = actual_port.expr_port->get_type() == ExpressionPort::Input;
const auto is_input = actual_port.get_expr_port()->get_type() == ExpressionPort::Input;
auto& ports = is_input ? m_input_ports : m_output_ports;
auto port_it = find_loop_port(actual_port);

Expand Down
8 changes: 4 additions & 4 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ std::pair<LinearIR::constExprIt, LinearIR::constExprIt> LoopManager::get_loop_bo
OPENVINO_ASSERT(!entries.empty(), "Loop must have input ports");
OPENVINO_ASSERT(!exits.empty(), "Loop must have output ports");

const auto& entry_expr = entries.front().expr_port->get_expr();
const auto& entry_expr = entries.front().get_expr_port()->get_expr();
auto loop_begin_pos = linear_ir.find(entry_expr);
// Some operations in Loop can be before first input ports: Scalars, VectorBuffer.
// We should iterate by them till the expr is in the corresponding Loop
Expand All @@ -103,7 +103,7 @@ std::pair<LinearIR::constExprIt, LinearIR::constExprIt> LoopManager::get_loop_bo
prev_loop_ids = (*std::prev(loop_begin_pos))->get_loop_ids();
}

const auto& exit_expr = exits.back().expr_port->get_expr();
const auto& exit_expr = exits.back().get_expr_port()->get_expr();
auto loop_end_pos = std::next(linear_ir.find_after(loop_begin_pos, exit_expr));
// There might be LoopEnd with another `loop_id` but in the target Loop as well.
auto current_loop_ids = (*loop_end_pos)->get_loop_ids();
Expand Down Expand Up @@ -312,14 +312,14 @@ void LoopManager::fuse_loop_ports(std::vector<LoopPort>& output_ports,

std::vector<LoopPort> new_output_ports;
for (const auto& output_port : output_ports) {
const auto consumers_inputs = output_port.expr_port->get_connected_ports();
const auto consumers_inputs = output_port.get_expr_port()->get_connected_ports();

std::set<LoopPort> mapped_input_ports;
std::set<ExpressionPtr> outside_consumers;
for (const auto& consumer_input : consumers_inputs) {
const auto input_port_it = std::find_if(input_ports.begin(), input_ports.end(),
[&consumer_input](const LoopPort& port) {
return *port.expr_port.get() == consumer_input;
return *port.get_expr_port().get() == consumer_input;
});
if (input_port_it != input_ports.end()) {
mapped_input_ports.insert(*input_port_it);
Expand Down
Loading

0 comments on commit e1357f1

Please sign in to comment.