14#include "arcane/utils/FatalErrorException.h"
15#include "arcane/utils/MemoryView.h"
16#include "arcane/utils/ITraceMng.h"
18#include "arcane/core/IParallelMng.h"
19#include "arcane/core/internal/IParallelMngInternal.h"
20#include "arcane/core/parallel/IStat.h"
22#include "arcane/impl/IDataSynchronizeBuffer.h"
23#include "arcane/impl/IDataSynchronizeImplementation.h"
25#include "arccore/common/accelerator/RunQueue.h"
26#include "arccore/accelerator/AcceleratorUtils.h"
35namespace Accelerator::Cuda
37 void arcaneCheckNCCLErrors(
const TraceInfo& ti, ncclResult_t e)
39 ARCCORE_FATAL_IF((e != ncclSuccess),
"NCCL Error trace={0} e={1} str={2}", ti, e, ncclGetErrorString(e));
43#define ARCCORE_CHECK_NCCL(result) \
44 Arcane::Accelerator::Cuda::arcaneCheckNCCLErrors(A_FUNCINFO, result)
54class NCCLVariableSynchronizeDispatcher
60 explicit NCCLVariableSynchronizeDispatcher(
Factory* f);
64 void compute()
override {}
71 ncclComm_t m_nccl_communicator;
83 : m_parallel_mng(mpi_pm)
88 auto* x =
new NCCLVariableSynchronizeDispatcher(
this);
101arcaneCreateNCCLVariableSynchronizerFactory(
IParallelMng* mpi_pm)
110NCCLVariableSynchronizeDispatcher::
111NCCLVariableSynchronizeDispatcher(Factory* f)
112: m_parallel_mng(f->m_parallel_mng)
122 ARCCORE_CHECK_NCCL(ncclGetUniqueId(&my_id));
123 ArrayView<char> id_as_bytes(NCCL_UNIQUE_ID_BYTES,
reinterpret_cast<char*
>(&my_id));
124 pm->broadcast(id_as_bytes, 0);
126 ARCCORE_CHECK_NCCL(ncclCommInitRank(&m_nccl_communicator, nb_rank, my_id, my_rank));
132void NCCLVariableSynchronizeDispatcher::
133beginSynchronize(IDataSynchronizeBuffer* ds_buf)
135 Integer nb_message = ds_buf->nbRank();
137 IParallelMng* pm = m_parallel_mng;
139 tm->
info() <<
"Doing NCCL Sync";
141 double prepare_time = 0.0;
142 cudaStream_t stream = 0;
145 RunQueue pm_queue = pm->_internalApi()->queue();
147 stream = Accelerator::AcceleratorUtils::toCudaNativeStream(pm_queue);
149 ARCCORE_CHECK_NCCL(ncclGroupStart());
152 ds_buf->copyAllSend();
155 for (
Integer i = 0; i < nb_message; ++i) {
156 Int32 target_rank = ds_buf->targetRank(i);
157 auto buf = ds_buf->receiveBuffer(i).bytes();
159 ARCCORE_CHECK_NCCL(ncclRecv(buf.data(), buf.size(), ncclInt8, target_rank, m_nccl_communicator, stream));
164 for (
Integer i = 0; i < nb_message; ++i) {
165 auto buf = ds_buf->sendBuffer(i).bytes();
166 Int32 target_rank = ds_buf->targetRank(i);
168 ARCCORE_CHECK_NCCL(ncclSend(buf.data(), buf.size(), ncclInt8, target_rank, m_nccl_communicator, stream));
173 ARCCORE_CHECK_NCCL(ncclGroupEnd());
175 tm->info() <<
"End begin synchronize";
176 pm->stat()->add(
"SyncPrepare", prepare_time, ds_buf->totalSendSize());
182void NCCLVariableSynchronizeDispatcher::
183endSynchronize(IDataSynchronizeBuffer* ds_buf)
185 IParallelMng* pm = m_parallel_mng;
187 double copy_time = 0.0;
188 double wait_time = 0.0;
189 ds_buf->copyAllReceive();
194 Int64 total_ghost_size = ds_buf->totalReceiveSize();
195 Int64 total_share_size = ds_buf->totalSendSize();
196 Int64 total_size = total_ghost_size + total_share_size;
197 pm->stat()->add(
"SyncCopy", copy_time, total_ghost_size);
198 pm->stat()->add(
"SyncWait", wait_time, total_size);
#define ARCCORE_FATAL_IF(cond,...)
Macro throwing a FatalErrorException if cond is true.
Modifiable view of an array of type T.
Generic buffer for data synchronization.
Interface for a generic dispatcher factory.
Interface of the parallelism manager for a subdomain.
virtual ITraceMng * traceMng() const =0
Trace manager.
virtual Int32 commRank() const =0
Rank of this instance in the communicator.
virtual Int32 commSize() const =0
Number of instances in the communicator.
virtual TraceMessage info()=0
Stream for an information message.
Reference to an instance.
@ CUDA
Execution policy using the CUDA environment.
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
std::int64_t Int64
Signed integer type of 64 bits.
Int32 Integer
Type representing an integer.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Creates a reference on a pointer.
std::int32_t Int32
Signed integer type of 32 bits.