14#include "arcane/utils/FatalErrorException.h"
15#include "arcane/utils/MemoryView.h"
16#include "arcane/utils/ITraceMng.h"
18#include "arcane/core/IParallelMng.h"
19#include "arcane/core/internal/IParallelMngInternal.h"
20#include "arcane/core/parallel/IStat.h"
22#include "arcane/impl/IDataSynchronizeBuffer.h"
23#include "arcane/impl/IDataSynchronizeImplementation.h"
25#include "arccore/common/accelerator/RunQueue.h"
26#include "arccore/accelerator/AcceleratorUtils.h"
35namespace Accelerator::Cuda
37 void arcaneCheckNCCLErrors(
const TraceInfo& ti, ncclResult_t e)
39 ARCCORE_FATAL_IF((e != ncclSuccess),
"NCCL Error trace={0} e={1} str={2}", ti, e, ncclGetErrorString(e));
43#define ARCCORE_CHECK_NCCL(result) \
44 Arcane::Accelerator::Cuda::arcaneCheckNCCLErrors(A_FUNCINFO, result)
53class NCCLVariableSynchronizeDispatcher
59 explicit NCCLVariableSynchronizeDispatcher(
Factory* f);
63 void compute()
override {}
70 ncclComm_t m_nccl_communicator;
82 : m_parallel_mng(mpi_pm)
87 auto* x =
new NCCLVariableSynchronizeDispatcher(
this);
100arcaneCreateNCCLVariableSynchronizerFactory(
IParallelMng* mpi_pm)
109NCCLVariableSynchronizeDispatcher::
110NCCLVariableSynchronizeDispatcher(Factory* f)
111: m_parallel_mng(f->m_parallel_mng)
121 ARCCORE_CHECK_NCCL(ncclGetUniqueId(&my_id));
122 ArrayView<char> id_as_bytes(NCCL_UNIQUE_ID_BYTES,
reinterpret_cast<char*
>(&my_id));
123 pm->broadcast(id_as_bytes, 0);
125 ARCCORE_CHECK_NCCL(ncclCommInitRank(&m_nccl_communicator, nb_rank, my_id, my_rank));
131void NCCLVariableSynchronizeDispatcher::
132beginSynchronize(IDataSynchronizeBuffer* ds_buf)
134 Integer nb_message = ds_buf->nbRank();
136 IParallelMng* pm = m_parallel_mng;
138 tm->
info() <<
"Doing NCCL Sync";
140 double prepare_time = 0.0;
141 cudaStream_t stream = 0;
144 RunQueue pm_queue = pm->_internalApi()->queue();
146 stream = Accelerator::AcceleratorUtils::toCudaNativeStream(pm_queue);
148 ARCCORE_CHECK_NCCL(ncclGroupStart());
151 ds_buf->copyAllSend();
154 for (
Integer i = 0; i < nb_message; ++i) {
155 Int32 target_rank = ds_buf->targetRank(i);
156 auto buf = ds_buf->receiveBuffer(i).bytes();
158 ARCCORE_CHECK_NCCL(ncclRecv(buf.data(), buf.size(), ncclInt8, target_rank, m_nccl_communicator, stream));
163 for (
Integer i = 0; i < nb_message; ++i) {
164 auto buf = ds_buf->sendBuffer(i).bytes();
165 Int32 target_rank = ds_buf->targetRank(i);
167 ARCCORE_CHECK_NCCL(ncclSend(buf.data(), buf.size(), ncclInt8, target_rank, m_nccl_communicator, stream));
172 ARCCORE_CHECK_NCCL(ncclGroupEnd());
174 tm->info() <<
"End begin synchronize";
175 pm->stat()->add(
"SyncPrepare", prepare_time, ds_buf->totalSendSize());
181void NCCLVariableSynchronizeDispatcher::
182endSynchronize(IDataSynchronizeBuffer* ds_buf)
184 IParallelMng* pm = m_parallel_mng;
186 double copy_time = 0.0;
187 double wait_time = 0.0;
188 ds_buf->copyAllReceive();
193 Int64 total_ghost_size = ds_buf->totalReceiveSize();
194 Int64 total_share_size = ds_buf->totalSendSize();
195 Int64 total_size = total_ghost_size + total_share_size;
196 pm->stat()->add(
"SyncCopy", copy_time, total_ghost_size);
197 pm->stat()->add(
"SyncWait", wait_time, total_size);
#define ARCCORE_FATAL_IF(cond,...)
Macro envoyant une exception FatalErrorException si cond est vrai.
Vue modifiable d'un tableau d'un type T.
Buffer générique pour la synchronisation de données.
Interface d'une fabrique dispatcher générique.
Interface du gestionnaire de parallélisme pour un sous-domaine.
virtual ITraceMng * traceMng() const =0
Gestionnaire de traces.
virtual Int32 commRank() const =0
Rang de cette instance dans le communicateur.
virtual Int32 commSize() const =0
Nombre d'instance dans le communicateur.
virtual TraceMessage info()=0
Flot pour un message d'information.
Référence à une instance.
@ CUDA
Politique d'exécution utilisant l'environnement CUDA.
-*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
std::int64_t Int64
Type entier signé sur 64 bits.
Int32 Integer
Type représentant un entier.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Créé une référence sur un pointeur.
std::int32_t Int32
Type entier signé sur 32 bits.