diff --git a/paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc index 33471068ecbcb4..365dfd5e216f33 100644 --- a/paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc @@ -25,32 +25,6 @@ #include "paddle/pir/include/pass/pass_registry.h" #include "paddle/pir/include/pattern_rewrite/pattern_match.h" -/* -fuse malmul + add to fc_xpu -For example: -graph: - - x w - \ / - | - mul - | - | - bias --- add - | - | - output ------------------------------------------------------- -After the pass is applied: - x w - \ / - | - bias--- fc_xpu - | - | - Output -*/ - namespace { int ConvertActivationType(const std::string &act_type) { @@ -76,6 +50,8 @@ int ConvertActivationType(const std::string &act_type) { return static_cast(xpu::Activation_t::SWISH); } else if (act_type == "relu6") { return static_cast(xpu::Activation_t::RELU6); + } else if (act_type == "swish_glu") { + return static_cast(xpu::Activation_t::SWISH_GLU); } else { PADDLE_THROW(common::errors::Unimplemented( "Not support convert activation_type(%s).", act_type)); @@ -83,13 +59,38 @@ int ConvertActivationType(const std::string &act_type) { return -1; } -class FCXpuFusePattern : public paddle::drr::DrrPatternBase { +/* +fuse malmul + add to fc_xpu +For example: +graph: + + x w + \ / + | + mul + | + | + bias --- add + | + | + output +------------------------------------------------------ +After the pass is applied: + x w + \ / + | + bias--- fc_xpu + | + | + Output +*/ +class FcXpuFuseAddPattern : public paddle::drr::DrrPatternBase { private: bool transpose_w_; public: - explicit FCXpuFusePattern(bool transpose_w) : transpose_w_(transpose_w) {} - std::string name() const override { return "FCXpuFusePattern"; } + explicit FcXpuFuseAddPattern(bool transpose_w) : transpose_w_(transpose_w) {} + std::string name() const override { return "FcXpuFuseAddPattern"; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -185,14 +186,274 @@ class FCXpuFusePattern : public paddle::drr::DrrPatternBase { } }; +/* +fuse malmul + add + act to fc_xpu +For example: +graph: + + x w + \ / + | + mul + | + | + bias --- add + | + | + act + | + | + output +------------------------------------------------------ +After the pass is applied: + x w + \ / + | + bias--- fc_xpu + | + | + Output +*/ +class FcXpuFuseAddActPattern : public paddle::drr::DrrPatternBase { + private: + bool transpose_w_; + + public: + explicit FcXpuFuseAddActPattern(bool transpose_w) + : transpose_w_(transpose_w) {} + std::string name() const override { return "FcXpuFuseAddActPattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &mul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + mul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("mul_out")}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + add({&pat.Tensor("mul_out"), &pat.Tensor("bias")}, + {&pat.Tensor("add_out")}); + const auto &swiglu = pat.Op(paddle::dialect::SwigluOp::name()); + swiglu({&pat.Tensor("add_out"), &pat.InputNoneTensor()}, + {&pat.Tensor("act_out")}); + + // Constraints + pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { + auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto bias_shape = pir::GetShapeFromValue(match_ctx.Tensor("bias")); + if (transpose_w_ != match_ctx.Attr("transpose_y")) { + return false; + } + return (w_shape.size() == 2 && x_shape.size() >= 2 && + bias_shape.size() == 1); + }); + + // Result pattern + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &in_num_col_dims_attr = + res.ComputeAttr([&](const paddle::drr::MatchContext &match_ctx) -> int { + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + return x_shape.size() - 1; + }); + + if (!transpose_w_) { + // prepare weight, transpose it if necessary + const auto &perm_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w")); + if (w_shape.size() == 2) { + return {1, 0}; + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Not support convert w_shape.size()(%d).", w_shape.size())); + } + }); + const auto &transpose_op = + res.Op(paddle::dialect::TransposeOp::name(), {{"perm", perm_attr}}); + res.Tensor("w_trans") = transpose_op(res.Tensor("w")); + VLOG(3) << "transpose weight for fc_xpu op"; + } + + const auto &out_dtype_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> phi::DataType { + auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x")); + // 目前仅支持以下几种非量化的情况 + if (x_dtype.isa()) { + return phi::DataType::FLOAT32; + } else if (x_dtype.isa()) { + return phi::DataType::FLOAT16; + } else if (x_dtype.isa()) { + return phi::DataType::BFLOAT16; + } else { + return phi::DataType::UNDEFINED; + } + }); + // only support float32 bias now + const auto &cast_op = res.Op(paddle::dialect::CastOp::name(), + {{"dtype", res.DataTypeAttr("float32")}}); + res.Tensor("bias_fp32") = cast_op(res.Tensor("bias")); + + const auto &fc_xpu = res.Op( + paddle::dialect::FcXpuOp::name(), + {{ + {"in_num_col_dims", in_num_col_dims_attr}, + {"transpose_x", pat.Attr("transpose_x")}, + {"alpha", res.Float32Attr(1.0f)}, + {"beta", res.Float32Attr(0.f)}, + {"act_type", res.Int32Attr(ConvertActivationType("swish_glu"))}, + {"act_alpha", res.Float32Attr(0.0f)}, + {"out_dtype", out_dtype_attr}, + }}); + fc_xpu( + { + &res.Tensor("x"), + &res.InputNoneTensor(), + transpose_w_ ? &res.Tensor("w") : &res.Tensor("w_trans"), + &res.InputNoneTensor(), + &res.Tensor("bias_fp32"), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + }, + {&res.Tensor("act_out"), &res.Tensor("out_max")}); + } +}; + +/* +fuse malmul + act to fc_xpu +For example: +graph: + + x w + \ / + | + mul + | + | + act + | + | + output +------------------------------------------------------ +After the pass is applied: + x w + \ / + | + bias--- fc_xpu + | + | + Output +*/ + +class FcXpuFuseActPattern : public paddle::drr::DrrPatternBase { + private: + bool transpose_w_; + + public: + explicit FcXpuFuseActPattern(bool transpose_w) : transpose_w_(transpose_w) {} + std::string name() const override { return "FcXpuFuseActPattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &mul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + mul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("mul_out")}); + + const auto &swiglu = pat.Op(paddle::dialect::SwigluOp::name()); + swiglu({&pat.Tensor("mul_out"), &pat.InputNoneTensor()}, + {&pat.Tensor("act_out")}); + + // Constraints + pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { + auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + if (transpose_w_ != match_ctx.Attr("transpose_y")) { + return false; + } + return (w_shape.size() == 2 && x_shape.size() >= 2); + }); + + // Result pattern + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &in_num_col_dims_attr = + res.ComputeAttr([&](const paddle::drr::MatchContext &match_ctx) -> int { + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + return x_shape.size() - 1; + }); + + if (!transpose_w_) { + // prepare weight, transpose it if necessary + const auto &perm_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w")); + if (w_shape.size() == 2) { + return {1, 0}; + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Not support convert w_shape.size()(%d).", w_shape.size())); + } + }); + const auto &transpose_op = + res.Op(paddle::dialect::TransposeOp::name(), {{"perm", perm_attr}}); + res.Tensor("w_trans") = transpose_op(res.Tensor("w")); + VLOG(3) << "transpose weight for fc_xpu op"; + } + + const auto &out_dtype_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> phi::DataType { + auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x")); + // 目前仅支持以下几种非量化的情况 + if (x_dtype.isa()) { + return phi::DataType::FLOAT32; + } else if (x_dtype.isa()) { + return phi::DataType::FLOAT16; + } else if (x_dtype.isa()) { + return phi::DataType::BFLOAT16; + } else { + return phi::DataType::UNDEFINED; + } + }); + + const auto &fc_xpu = res.Op( + paddle::dialect::FcXpuOp::name(), + {{ + {"in_num_col_dims", in_num_col_dims_attr}, + {"transpose_x", pat.Attr("transpose_x")}, + {"alpha", res.Float32Attr(1.0f)}, + {"beta", res.Float32Attr(0.f)}, + {"act_type", res.Int32Attr(ConvertActivationType("swish_glu"))}, + {"act_alpha", res.Float32Attr(0.0f)}, + {"out_dtype", out_dtype_attr}, + }}); + fc_xpu( + { + &res.Tensor("x"), + &res.InputNoneTensor(), + transpose_w_ ? &res.Tensor("w") : &res.Tensor("w_trans"), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + }, + {&res.Tensor("act_out"), &res.Tensor("out_max")}); + } +}; + class FCXpuFusePass : public pir::PatternRewritePass { public: FCXpuFusePass() : pir::PatternRewritePass("fc_xpu_fuse_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); - ps.Add(paddle::drr::Create(context, false)); - ps.Add(paddle::drr::Create(context, true)); + ps.Add(paddle::drr::Create(context, false)); + ps.Add(paddle::drr::Create(context, true)); + ps.Add(paddle::drr::Create(context, false)); + ps.Add(paddle::drr::Create(context, true)); + ps.Add(paddle::drr::Create(context, false)); + ps.Add(paddle::drr::Create(context, true)); return ps; } }; diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 4ce91b1d918643..f062ad0f1624cc 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -842,6 +842,9 @@ void FcXPUInferMeta(const MetaTensor& x, out_shape[i] = static_cast(x.dims()[i]); } out_shape[in_num_col_dims] = static_cast(w.dims()[0]); + if (act_type == 23 /*phi::backends::xpu::Activation_t::SWISH_GLU*/) { + out_shape[in_num_col_dims] = out_shape[in_num_col_dims] / 2; + } out->set_dims(DDim(out_shape.data(), static_cast(out_shape.size()))); out->set_dtype(out_dtype); out->set_layout(x.layout()); diff --git a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc index b2047d6ec99c7e..92ca3d648060e7 100644 --- a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc @@ -156,58 +156,132 @@ void FcXPUKernelImpl(const Context& ctx, w_len); PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te"); } - int r = - xblas::fc_fusion( - ctx.x_context(), - x_data_fp16, - w_data_fp16, - out_data, - m, - n, - k, - transpose_x, - true, - x_max_data ? x_max_data : xte_x_maxptr, - w_max_data ? w_max_data : xte_w_maxptr, - out_max_data, - transpose_x ? m : k, - k, - n, - alpha, - beta, - bias_data, - act, - xte_scale_x, - xte_scale_w); + baidu::xpu::xblas::FcFusionTensor tensor_a1{ + x_data_fp16, + x_max_data ? x_max_data : xte_x_maxptr, + transpose_x ? k : m, + transpose_x ? m : k, + transpose_x ? m : k, + transpose_x}; + baidu::xpu::xblas::FcFusionTensor tensor_b1{ + w_data_fp16, w_max_data ? w_max_data : xte_w_maxptr, n, k, k, true}; + baidu::xpu::xblas::FcFusionTensor tensor_c1{ + out_data, nullptr, m, n, n, false}; + baidu::xpu::xblas::FcFusionTensor tensor_d1{ + out_data, nullptr, m, n, n, false}; + baidu::xpu::xblas::FcFusionDesc desc{alpha, + beta}; + + baidu::xpu::xblas::FcFusionEpilogue epilogue1{ + act, bias_data, xte_scale_x, xte_scale_w, 0, 0, out_max_data}; + + if (act_type == xpu::Activation_t::SWISH_GLU) { + tensor_d1 = baidu::xpu::xblas::FcFusionTensor{ + out_data, nullptr, m, n / 2, n / 2, false}; + } else { + tensor_d1 = baidu::xpu::xblas::FcFusionTensor{ + out_data, nullptr, m, n, n, false}; + } + + int r = baidu::xpu::xblas::fc_fusion(ctx.x_context(), + tensor_a1, + tensor_b1, + tensor_c1, + tensor_d1, + desc, + epilogue1); PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion"); } } if (std::getenv("XPU_PADDLE_FC_BFLOAT16_XTE") == nullptr) { - int r = xpu:: - fc_fusion( // TX/TW/TY/TGEMM - ctx.x_context(), // ctx - x_data, // x - w_data, // w - out_data, // y - m, // m - n, // n - k, // k - transpose_x, // x_trans - true, // w_trans - x_max_data, // x_maxptr - w_max_data, // w_maxptr - out_max_data, // y_maxptr - transpose_x ? m : k, // ldx - k, // ldw - n, // ldy - alpha, // alpha - beta, // beta - bias_data, // bias - act, // act - scale_max_data); // scale + if constexpr (((std::is_same::value && + std::is_same::value && + std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value && + std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value && + std::is_same::value && + std::is_same::value))) { + int r = xpu:: + fc_fusion( // TX/TW/TY/TGEMM + ctx.x_context(), // ctx + x_data, // x + w_data, // w + out_data, // y + m, // m + n, // n + k, // k + transpose_x, // x_trans + true, // w_trans + x_max_data, // x_maxptr + w_max_data, // w_maxptr + out_max_data, // y_maxptr + transpose_x ? m : k, // ldx + k, // ldw + n, // ldy + alpha, // alpha + beta, // beta + bias_data, // bias + act, // act + scale_max_data); // scale + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu"); + } else { + baidu::xpu::xblas::FcFusionTensor tensor_a1{ + x_data, + x_max_data, + transpose_x ? k : m, + transpose_x ? m : k, + transpose_x ? m : k, + transpose_x}; + baidu::xpu::xblas::FcFusionTensor tensor_b1{ + w_data, w_max_data, n, k, k, true}; + baidu::xpu::xblas::FcFusionTensor tensor_c1{ + out_data, nullptr, m, n, n, false}; + baidu::xpu::xblas::FcFusionTensor tensor_d1{ + out_data, nullptr, m, n, n, false}; + baidu::xpu::xblas::FcFusionDesc desc{alpha, + beta}; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu"); + baidu::xpu::xblas::FcFusionEpilogue epilogue1{ + act, bias_data, scale_max_data, nullptr, 0, 0, out_max_data}; + + if (act_type == xpu::Activation_t::SWISH_GLU) { + tensor_d1 = baidu::xpu::xblas::FcFusionTensor{ + out_data, nullptr, m, n / 2, n / 2, false}; + } else { + tensor_d1 = baidu::xpu::xblas::FcFusionTensor{ + out_data, nullptr, m, n, n, false}; + } + int r = baidu::xpu::xblas::fc_fusion(ctx.x_context(), + tensor_a1, + tensor_b1, + tensor_c1, + tensor_d1, + desc, + epilogue1); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion"); + } } #else int r = diff --git a/paddle/phi/kernels/xpu/xpu_api_wrapper.h b/paddle/phi/kernels/xpu/xpu_api_wrapper.h index 8ecace50e79bce..0db82f20704d03 100644 --- a/paddle/phi/kernels/xpu/xpu_api_wrapper.h +++ b/paddle/phi/kernels/xpu/xpu_api_wrapper.h @@ -292,110 +292,31 @@ static void xblas_fc_wrapper(xpu::Context* ctx, #endif } else { #ifdef PADDLE_WITH_XPU_XRE5 - bool is_xte = false; - if constexpr (std::is_same::value) { - if (std::getenv("XPU_PADDLE_FC_BFLOAT16_XTE") != nullptr) { - is_xte = true; - - const int MAXPTR_N = ctx->max_ptr_size(); - int x_len = m * k; - XPUTypeFP16* x_fp16 = nullptr; - x_fp16 = RAII_GUARD.alloc_l3_or_gm(x_len); - PADDLE_ENFORCE_XDNN_NOT_NULL(x_fp16); - int w_len = k * n; - XPUTypeFP16* w_fp16 = nullptr; - w_fp16 = RAII_GUARD.alloc_l3_or_gm(w_len); - PADDLE_ENFORCE_XDNN_NOT_NULL(w_fp16); - - float* xte_scale_x = nullptr; - float* xte_scale_w = nullptr; - xte_scale_x = RAII_GUARD.alloc_l3_or_gm(1); - PADDLE_ENFORCE_XDNN_NOT_NULL(xte_scale_x); - xte_scale_w = RAII_GUARD.alloc_l3_or_gm(1); - PADDLE_ENFORCE_XDNN_NOT_NULL(xte_scale_w); - - float* xte_x_maxptr = nullptr; - float* xte_w_maxptr = nullptr; - if (x_maxptr == nullptr) { - xte_x_maxptr = RAII_GUARD.alloc_l3_or_gm(MAXPTR_N); - PADDLE_ENFORCE_XDNN_NOT_NULL(xte_x_maxptr); - int r = xpu::findmax(ctx, x, xte_x_maxptr, x_len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_findmax"); - r = xpu::cast_te(ctx, x, xte_x_maxptr, x_fp16, xte_scale_x, x_len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te"); - } else { - r = xpu::cast_te(ctx, x, x_maxptr, x_fp16, xte_scale_x, x_len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te"); - } - if (w_maxptr == nullptr) { - xte_w_maxptr = RAII_GUARD.alloc_l3_or_gm(MAXPTR_N); - PADDLE_ENFORCE_XDNN_NOT_NULL(xte_w_maxptr); - r = xpu::findmax(ctx, w, xte_w_maxptr, w_len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_findmax"); - r = xpu::cast_te(ctx, w, xte_w_maxptr, w_fp16, xte_scale_w, w_len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te"); - } else { - r = xpu::cast_te(ctx, w, w_maxptr, w_fp16, xte_scale_w, w_len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te"); - } - - r = xblas:: - fc_fusion( - ctx, - x_fp16, - w_fp16, - y, - m, - n, - k, - x_trans, - w_trans, - x_maxptr ? x_maxptr : xte_x_maxptr, - w_maxptr ? w_maxptr : xte_w_maxptr, - y_maxptr, - ldx, - ldw, - ldy, - alpha, - beta, - bias, - act, - xte_scale_x, - xte_scale_w, - scale_x_mode, - scale_w_mode); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion"); - } - } + r = xblas::fc_fusion(ctx, + x, + w, + y, + m, + n, + k, + x_trans, + w_trans, + x_maxptr, + w_maxptr, + y_maxptr, + ldx, + ldw, + ldy, + alpha, + beta, + bias, + act, + scale_x, + scale_w, + scale_x_mode, + scale_w_mode); - if (!is_xte) { - r = xblas::fc_fusion(ctx, - x, - w, - y, - m, - n, - k, - x_trans, - w_trans, - x_maxptr, - w_maxptr, - y_maxptr, - ldx, - ldw, - ldy, - alpha, - beta, - bias, - act, - scale_x, - scale_w, - scale_x_mode, - scale_w_mode); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion"); - } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion"); #else r = xpu::fc_fusion(ctx, x, diff --git a/test/ir/pir/fused_pass/xpu/test_fc_xpu_fuse_pass.py b/test/ir/pir/fused_pass/xpu/test_fc_xpu_fuse_pass.py index ca1f6e6df4f920..54b93cbe1938b6 100644 --- a/test/ir/pir/fused_pass/xpu/test_fc_xpu_fuse_pass.py +++ b/test/ir/pir/fused_pass/xpu/test_fc_xpu_fuse_pass.py @@ -23,7 +23,7 @@ paddle.enable_static() -class TestFCXpuFusePattern(PassTest): +class TestFcXpuFuseAddPattern(PassTest): r""" x w \ /