Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Operator Mechanism】Pr68945 _C_ops.c_concat in dynamic graph bug fix #70870

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ void PirInterpreter::UpdateNcclOpNum() {
"pd_op.barrier_grad",
"pd_op.alltoall_grad",
"pd_op.global_gather_grad",
"pd_op.c_concat_grad",
"pd_op.distributed_fused_lamb_grad",
"pd_op.margin_cross_entropy_grad",
"pd_op.sync_batch_norm_grad",
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,13 @@ def source_include(header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#endif
Expand Down
64 changes: 64 additions & 0 deletions paddle/phi/api/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,46 @@
}}
"""

NCCL_COMMCONTEXT_INIT = """
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
const auto & comm_context_manager_ = phi::distributed::CommContextManager::GetInstance();
if (nranks > 1 && !comm_context_manager_.Has(std::to_string(ring_id))) {{
auto store = phi::distributed::CreateOrGetGlobalTCPStore();
phi::distributed::CommContextManager::CreateNCCLCommContext(
store, std::to_string(ring_id), rank, nranks);
}}
#endif
"""

SET_NCCL_COMMCONTEXT = """
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
const auto & comm_context_manager = phi::distributed::CommContextManager::GetInstance();
phi::distributed::NCCLCommContext* comm_context = nullptr;
if (comm_context_manager.Has(std::to_string(ring_id))) {{
comm_context = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(
comm_context,
nullptr,
common::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id(%d) attr.",
std::to_string(ring_id)));
if (!comm_context->GetDevContext() || !comm_context->GetDevContext()->GetCommContext())
{{
auto kernel_res = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
if (FLAGS_low_precision_op_list) {{
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{}", kernel_data_type);
}}
Backend act_kernel_backend = kernel_res.has_fallback_cpu ? Backend::CPU : kernel_backend;
auto* dev_context = GetDeviceContextByBackend(act_kernel_backend);
dev_context->SetCommContext(comm_context);
}}
}}
#endif
"""

# 1. InferSPMD
SINGLE_DIST_META_IN_TEMPLATE = """
auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());"""
Expand Down Expand Up @@ -860,6 +900,24 @@ def process_data_type_args(args_item):
input_args=input_args, mesh=mesh, kernel_code=kernel_select_code
)

# Current initialization only consider the case where the parameters of op contain ring_id, nranks and rank.
# Other cases will be addressed in the future.
if 'ring_id' in self.attrs['names']:
if (
'rank' in self.attrs['names']
and 'nranks' in self.attrs['names']
):
if_condition_code = (
if_condition_code
+ '\n'
+ self.generate_nccl_commcontext_init_code()
)
if_condition_code = (
if_condition_code
+ '\n'
+ self.generate_set_nccl_commcontext_code()
)

return kernel_key_item_init + if_condition_code

def generate_specialized_infer_spmd_code(self) -> str:
Expand Down Expand Up @@ -1310,6 +1368,12 @@ def generate_kernel_selection_code(self) -> str:
self.api, self.kernel['func'][0], self.kernel['func'][0]
)

def generate_nccl_commcontext_init_code(self) -> str:
return NCCL_COMMCONTEXT_INIT.format(self.kernel['func'][0])

def generate_set_nccl_commcontext_code(self) -> str:
return SET_NCCL_COMMCONTEXT.format(self.kernel['func'][0], self.api)

def generate_reshard_input_code(self) -> str:
input_reshard_code = ""
if self.generate_infer_spmd is True:
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/generator/dist_bw_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,13 @@ def source_include(header_file_path, fw_header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#endif
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,12 @@
data_type : out_grad
no_need_buffer : input

- backward_op : c_concat_grad
forward : c_concat (Tensor x, int rank, int nranks, int ring_id, bool use_calc_stream, bool use_model_parallel) -> Tensor(out)
args : (Tensor out_grad, int rank = 0, int nranks = 1, int ring_id = 0, bool use_model_parallel = true)
output : Tensor(x_grad)
invoke: c_split(out_grad, rank, nranks, ring_id, use_model_parallel)

- backward_op : cast_grad
forward : cast (Tensor x, DataType dtype) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@
drop_empty_grad : [input_grad]

- op : c_concat
backward: c_concat_grad
inputs :
x : X
outputs :
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@
param : [x, nranks]
kernel :
func : c_concat
traits : paddle::dialect::ForwardOnlyTrait
backward: c_concat_grad

- op : c_identity
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
Expand Down
17 changes: 2 additions & 15 deletions python/paddle/distributed/fleet/layers/mpu/mp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LayerHelper,
_create_tensor,
in_dynamic_mode,
in_dynamic_or_pir_mode,
in_pir_mode,
)
from paddle.nn import Layer
Expand Down Expand Up @@ -139,21 +140,7 @@ def _c_concat(tensor, group=None):
rank = group.rank
nranks = group.nranks

if in_dynamic_mode():
return _legacy_C_ops.c_concat(
tensor,
'ring_id',
ring_id,
'use_calc_stream',
True,
'rank',
rank,
'nranks',
nranks,
'use_model_parallel',
True,
)
elif in_pir_mode():
if in_dynamic_or_pir_mode():
return _C_ops.c_concat(tensor, rank, nranks, ring_id, True, True)
else:
op_type = 'c_concat'
Expand Down
Loading