Arcane  4.1.12.0
Developer documentation
Loading...
Searching...
No Matches
NCCLVariableSynchronizeDispatcher.cc
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* NCCLVariableSynchronizeDispatcher.cc (C) 2000-2025 */
9/* */
10/* Specific management of variable synchronizations via NCCL. */
11/*---------------------------------------------------------------------------*/
12/*---------------------------------------------------------------------------*/
13
14#include "arcane/utils/FatalErrorException.h"
15#include "arcane/utils/MemoryView.h"
16#include "arcane/utils/ITraceMng.h"
17
18#include "arcane/core/IParallelMng.h"
19#include "arcane/core/internal/IParallelMngInternal.h"
20#include "arcane/core/parallel/IStat.h"
21
22#include "arcane/impl/IDataSynchronizeBuffer.h"
23#include "arcane/impl/IDataSynchronizeImplementation.h"
24
25#include "arccore/common/accelerator/RunQueue.h"
26#include "arccore/accelerator/AcceleratorUtils.h"
27
28#include <nccl.h>
29
30/*---------------------------------------------------------------------------*/
31/*---------------------------------------------------------------------------*/
32
33namespace Arcane
34{
35namespace Accelerator::Cuda
36{
37 void arcaneCheckNCCLErrors(const TraceInfo& ti, ncclResult_t e)
38 {
39 ARCCORE_FATAL_IF((e != ncclSuccess), "NCCL Error trace={0} e={1} str={2}", ti, e, ncclGetErrorString(e));
40 }
41} // namespace Accelerator::Cuda
42
43#define ARCCORE_CHECK_NCCL(result) \
44 Arcane::Accelerator::Cuda::arcaneCheckNCCLErrors(A_FUNCINFO, result)
45
46/*---------------------------------------------------------------------------*/
47/*---------------------------------------------------------------------------*/
48
54class NCCLVariableSynchronizeDispatcher
56{
57 public:
58
59 class Factory;
60 explicit NCCLVariableSynchronizeDispatcher(Factory* f);
61
62 protected:
63
64 void compute() override {}
65 void beginSynchronize(IDataSynchronizeBuffer* ds_buf) override;
66 void endSynchronize(IDataSynchronizeBuffer* ds_buf) override;
67
68 private:
69
70 IParallelMng* m_parallel_mng = nullptr;
71 ncclComm_t m_nccl_communicator;
72};
73
74/*---------------------------------------------------------------------------*/
75/*---------------------------------------------------------------------------*/
76
79{
80 public:
81
82 explicit Factory(IParallelMng* mpi_pm)
83 : m_parallel_mng(mpi_pm)
84 {}
85
86 Ref<IDataSynchronizeImplementation> createInstance() override
87 {
88 auto* x = new NCCLVariableSynchronizeDispatcher(this);
90 }
91
92 public:
93
94 IParallelMng* m_parallel_mng = nullptr;
95};
96
97/*---------------------------------------------------------------------------*/
98/*---------------------------------------------------------------------------*/
99
101arcaneCreateNCCLVariableSynchronizerFactory(IParallelMng* mpi_pm)
102{
103 auto* x = new NCCLVariableSynchronizeDispatcher::Factory(mpi_pm);
105}
106
107/*---------------------------------------------------------------------------*/
108/*---------------------------------------------------------------------------*/
109
110NCCLVariableSynchronizeDispatcher::
111NCCLVariableSynchronizeDispatcher(Factory* f)
112: m_parallel_mng(f->m_parallel_mng)
113{
114 IParallelMng* pm = m_parallel_mng;
115 Int32 my_rank = pm->commRank();
116 Int32 nb_rank = pm->commSize();
117
118 // TODO: We should verify that there is exactly one MPI rank per GPU
119 // because NCCL does not support multiple ranks on the same GPU.
120
121 ncclUniqueId my_id;
122 ARCCORE_CHECK_NCCL(ncclGetUniqueId(&my_id));
123 ArrayView<char> id_as_bytes(NCCL_UNIQUE_ID_BYTES, reinterpret_cast<char*>(&my_id));
124 pm->broadcast(id_as_bytes, 0);
125
126 ARCCORE_CHECK_NCCL(ncclCommInitRank(&m_nccl_communicator, nb_rank, my_id, my_rank));
127}
128
129/*---------------------------------------------------------------------------*/
130/*---------------------------------------------------------------------------*/
131
132void NCCLVariableSynchronizeDispatcher::
133beginSynchronize(IDataSynchronizeBuffer* ds_buf)
134{
135 Integer nb_message = ds_buf->nbRank();
136
137 IParallelMng* pm = m_parallel_mng;
138 ITraceMng* tm = pm->traceMng();
139 tm->info() << "Doing NCCL Sync";
140
141 double prepare_time = 0.0;
142 cudaStream_t stream = 0;
143
144 // If IParallelMng has a CUDA RunQueue, we use it.
145 RunQueue pm_queue = pm->_internalApi()->queue();
146 if (pm_queue.executionPolicy() == Accelerator::eExecutionPolicy::CUDA)
147 stream = Accelerator::AcceleratorUtils::toCudaNativeStream(pm_queue);
148 ;
149 ARCCORE_CHECK_NCCL(ncclGroupStart());
150 {
151 // Recopy the send buffers into \a var_values
152 ds_buf->copyAllSend();
153
154 // Post the receive messages
155 for (Integer i = 0; i < nb_message; ++i) {
156 Int32 target_rank = ds_buf->targetRank(i);
157 auto buf = ds_buf->receiveBuffer(i).bytes();
158 if (!buf.empty()) {
159 ARCCORE_CHECK_NCCL(ncclRecv(buf.data(), buf.size(), ncclInt8, target_rank, m_nccl_communicator, stream));
160 }
161 }
162
163 // Post the send messages
164 for (Integer i = 0; i < nb_message; ++i) {
165 auto buf = ds_buf->sendBuffer(i).bytes();
166 Int32 target_rank = ds_buf->targetRank(i);
167 if (!buf.empty()) {
168 ARCCORE_CHECK_NCCL(ncclSend(buf.data(), buf.size(), ncclInt8, target_rank, m_nccl_communicator, stream));
169 }
170 }
171 }
172 // Blocks until all messages are finished
173 ARCCORE_CHECK_NCCL(ncclGroupEnd());
174
175 tm->info() << "End begin synchronize";
176 pm->stat()->add("SyncPrepare", prepare_time, ds_buf->totalSendSize());
177}
178
179/*---------------------------------------------------------------------------*/
180/*---------------------------------------------------------------------------*/
181
182void NCCLVariableSynchronizeDispatcher::
183endSynchronize(IDataSynchronizeBuffer* ds_buf)
184{
185 IParallelMng* pm = m_parallel_mng;
186
187 double copy_time = 0.0;
188 double wait_time = 0.0;
189 ds_buf->copyAllReceive();
190
191 // Ensures that buffer copies are properly finished
192 ds_buf->barrier();
193
194 Int64 total_ghost_size = ds_buf->totalReceiveSize();
195 Int64 total_share_size = ds_buf->totalSendSize();
196 Int64 total_size = total_ghost_size + total_share_size;
197 pm->stat()->add("SyncCopy", copy_time, total_ghost_size);
198 pm->stat()->add("SyncWait", wait_time, total_size);
199}
200
201/*---------------------------------------------------------------------------*/
202/*---------------------------------------------------------------------------*/
203
204} // End namespace Arcane
205
206/*---------------------------------------------------------------------------*/
207/*---------------------------------------------------------------------------*/
#define ARCCORE_FATAL_IF(cond,...)
Macro throwing a FatalErrorException if cond is true.
Modifiable view of an array of type T.
Generic buffer for data synchronization.
Interface for a generic dispatcher factory.
Interface of the parallelism manager for a subdomain.
virtual ITraceMng * traceMng() const =0
Trace manager.
virtual Int32 commRank() const =0
Rank of this instance in the communicator.
virtual Int32 commSize() const =0
Number of instances in the communicator.
virtual TraceMessage info()=0
Stream for an information message.
Reference to an instance.
@ CUDA
Execution policy using the CUDA environment.
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
std::int64_t Int64
Signed integer type of 64 bits.
Int32 Integer
Type representing an integer.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Creates a reference on a pointer.
std::int32_t Int32
Signed integer type of 32 bits.