Skip to content

Commit

Permalink
[OneDNN][PIR] add permitted input name filter for orphaned op (#70628)
Browse files Browse the repository at this point in the history
* add permitted input name fliter for orphaned op

* modify headers

* add ut for CI Coverage

* format

* expand benefit of placementPattern
  • Loading branch information
zhanglirong1999 authored Jan 10, 2025
1 parent 38bdf53 commit 08506e2
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
22 changes: 20 additions & 2 deletions paddle/fluid/pir/transforms/onednn/cpu_bfloat16_placement_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class OneDNNBf16PlacementPattern : public pir::RewritePattern {
public:
explicit OneDNNBf16PlacementPattern(pir::IrContext* context)
: pir::RewritePattern(MatchAnyOpTypeTag(),
1 /*benefit*/,
5 /*benefit*/,
context,
{} /*generated_names*/) {}

Expand Down Expand Up @@ -280,13 +280,31 @@ class RemoveOrphanedPattern : public pir::RewritePattern {
"pd_op.fetch",
"pd_op.assign"});

const std::vector<std::string> permitted_input_names = {
"x", "y", "input", "residual_param", "residual_data"};
auto op_name = op->name();
auto op_info = pir::IrContext::Instance()->GetRegisteredOpInfo(op_name);
if (!op_info) return false;
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
->get_op_info_(op_name),
paddle::dialect::IsLegacyOp(op_name));
auto input_names = yaml_parser.InputNames();

if (op->num_operands()) {
for (uint32_t i = 0; i < op->num_operands(); i++) {
if (!op->operand_source(i) || !op->operand_source(i).type()) {
continue;
}
std::string input_name = input_names[i];
auto iter = std::find(permitted_input_names.begin(),
permitted_input_names.end(),
input_name);
if (iter == permitted_input_names.end()) {
// The input in permitted_input, it must be bf16, others can be fp32
continue;
}
auto* prev_op = pir::GetDefiningOpForInput(op, i);
// if (!prev_op) continue;
// Some ops do not need to be processed
std::string prev_name = prev_op->name();
if (constant_op.count(prev_name)) {
Expand Down
59 changes: 59 additions & 0 deletions test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,5 +1169,64 @@ def test_check_output(self):
self.check_pass_correct(rtol=1e-02, atol=1e-02)


class TestConv2dBf16PlacementPass(PassTest):
def is_program_valid(self, program=None):
return True

def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[5, 5, 5, 5], dtype='float32'
)
w_attr = paddle.ParamAttr(
learning_rate=0.0,
initializer=paddle.nn.initializer.Normal(mean=0.0, std=2.0),
)
conv2d = paddle.nn.Conv2D(
in_channels=5,
out_channels=1,
kernel_size=[1, 1],
groups=1,
stride=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
data_format='NCHW',
bias_attr=False,
weight_attr=w_attr,
)

out = conv2d(x)
out = paddle.assign(out)
self.pass_attr_list = [
{'onednn_placement_pass': {}},
{'cpu_bfloat16_placement_pass': {}},
{'cpu_bfloat16_pass': {}},
{'cpu_bfloat16_type_placement_pass': {}},
]
self.feeds = {
"x": np.random.random((5, 5, 5, 5)).astype("float32"),
"bias": np.random.random(1).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"onednn_op.conv2d": 1,
"pd_op.conv2d": 0,
}
return [main_prog, start_prog]

def sample_program(self):
yield self.build_ir_program(), False

def setUp(self):
self.places.append(paddle.CPUPlace())
self.skip_accuracy_verification = True

def test_check_output(self):
self.check_pass_correct()


if __name__ == "__main__":
unittest.main()

0 comments on commit 08506e2

Please sign in to comment.