Skip to content

Commit

Permalink
fix issue for repetitive inputs (#70749)
Browse files Browse the repository at this point in the history
  • Loading branch information
LLee233 authored Jan 10, 2025
1 parent 266e3cd commit 38bdf53
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Quanter {
void AddQuantOps() {
if (IsNotPermittedOpType()) return;

std::vector<std::string> linked_xputs;
std::unordered_map<std::string, std::string> linked_xputs;

for (const auto& logical_xput : op_xputs) {
std::vector<std::string> quant_xput_names;
Expand All @@ -39,7 +39,12 @@ class Quanter {

const auto& physical_xputs_names = logical_xput.second;
for (const auto& physical_xput_name : physical_xputs_names) {
if (IsAlreadyLinked(linked_xputs, physical_xput_name)) continue;
// In case the input is repetitively used, where the input should be
// still added.
if (IsAlreadyLinked(linked_xputs, physical_xput_name)) {
quant_xput_names.emplace_back(linked_xputs[physical_xput_name]);
continue;
}

VarDesc quant_x_desc(
patterns::PDNodeName(get_op_type(), get_op_edge()));
Expand All @@ -52,7 +57,7 @@ class Quanter {
auto physical_xput_node = xputs_map[physical_xput_name];
link_nodes(physical_xput_node, quant_op, quant_x_node);
counter++;
linked_xputs.push_back(physical_xput_name);
linked_xputs[physical_xput_name] = xput_name;
}

set_edge(logical_xput_name, quant_xput_names);
Expand Down Expand Up @@ -87,10 +92,10 @@ class Quanter {
virtual void set_edge(const std::string& logical_xput_name,
const std::vector<std::string>& quant_xput_names) = 0;

bool IsAlreadyLinked(const std::vector<std::string>& node_names,
const std::string& node_name) const {
return std::find(node_names.begin(), node_names.end(), node_name) !=
node_names.end();
bool IsAlreadyLinked(
const std::unordered_map<std::string, std::string>& node_names_map,
const std::string& node_name) const {
return node_names_map.find(node_name) != node_names_map.end();
}

virtual ir::Node* create_quant_op(const std::string& input_name,
Expand Down

0 comments on commit 38bdf53

Please sign in to comment.