diff --git a/paddle/fluid/pir/transforms/onednn/cpu_bfloat16_placement_pass.cc b/paddle/fluid/pir/transforms/onednn/cpu_bfloat16_placement_pass.cc index ce0f873be31c74..620d77aad5ec2d 100644 --- a/paddle/fluid/pir/transforms/onednn/cpu_bfloat16_placement_pass.cc +++ b/paddle/fluid/pir/transforms/onednn/cpu_bfloat16_placement_pass.cc @@ -51,7 +51,7 @@ class OneDNNBf16PlacementPattern : public pir::RewritePattern { public: explicit OneDNNBf16PlacementPattern(pir::IrContext* context) : pir::RewritePattern(MatchAnyOpTypeTag(), - 1 /*benefit*/, + 5 /*benefit*/, context, {} /*generated_names*/) {} @@ -280,13 +280,31 @@ class RemoveOrphanedPattern : public pir::RewritePattern { "pd_op.fetch", "pd_op.assign"}); + const std::vector permitted_input_names = { + "x", "y", "input", "residual_param", "residual_data"}; + auto op_name = op->name(); + auto op_info = pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + if (!op_info) return false; + paddle::dialect::OpYamlInfoParser yaml_parser( + op_info.GetInterfaceImpl() + ->get_op_info_(op_name), + paddle::dialect::IsLegacyOp(op_name)); + auto input_names = yaml_parser.InputNames(); + if (op->num_operands()) { for (uint32_t i = 0; i < op->num_operands(); i++) { if (!op->operand_source(i) || !op->operand_source(i).type()) { continue; } + std::string input_name = input_names[i]; + auto iter = std::find(permitted_input_names.begin(), + permitted_input_names.end(), + input_name); + if (iter == permitted_input_names.end()) { + // The input in permitted_input, it must be bf16, others can be fp32 + continue; + } auto* prev_op = pir::GetDefiningOpForInput(op, i); - // if (!prev_op) continue; // Some ops do not need to be processed std::string prev_name = prev_op->name(); if (constant_op.count(prev_name)) { diff --git a/test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py b/test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py index d8de881c364980..0162a1d2ef50e1 100644 --- a/test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py +++ b/test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py @@ -1169,5 +1169,64 @@ def test_check_output(self): self.check_pass_correct(rtol=1e-02, atol=1e-02) +class TestConv2dBf16PlacementPass(PassTest): + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[5, 5, 5, 5], dtype='float32' + ) + w_attr = paddle.ParamAttr( + learning_rate=0.0, + initializer=paddle.nn.initializer.Normal(mean=0.0, std=2.0), + ) + conv2d = paddle.nn.Conv2D( + in_channels=5, + out_channels=1, + kernel_size=[1, 1], + groups=1, + stride=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + data_format='NCHW', + bias_attr=False, + weight_attr=w_attr, + ) + + out = conv2d(x) + out = paddle.assign(out) + self.pass_attr_list = [ + {'onednn_placement_pass': {}}, + {'cpu_bfloat16_placement_pass': {}}, + {'cpu_bfloat16_pass': {}}, + {'cpu_bfloat16_type_placement_pass': {}}, + ] + self.feeds = { + "x": np.random.random((5, 5, 5, 5)).astype("float32"), + "bias": np.random.random(1).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "onednn_op.conv2d": 1, + "pd_op.conv2d": 0, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + self.skip_accuracy_verification = True + + def test_check_output(self): + self.check_pass_correct() + + if __name__ == "__main__": unittest.main()