Arcane  v4.1.2.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
NCCLVariableSynchronizeDispatcher.cc
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2025 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* NCCLVariableSynchronizeDispatcher.cc (C) 2000-2025 */
9/* */
10/* Gestion spécifique des synchronisations des variables via NCCL. */
11/*---------------------------------------------------------------------------*/
12/*---------------------------------------------------------------------------*/
13
14#include "arcane/utils/FatalErrorException.h"
15#include "arcane/utils/MemoryView.h"
16#include "arcane/utils/ITraceMng.h"
17
18#include "arcane/core/IParallelMng.h"
19#include "arcane/core/internal/IParallelMngInternal.h"
20#include "arcane/core/parallel/IStat.h"
21
22#include "arcane/impl/IDataSynchronizeBuffer.h"
23#include "arcane/impl/IDataSynchronizeImplementation.h"
24
25#include "arccore/common/accelerator/RunQueue.h"
26#include "arccore/accelerator/AcceleratorUtils.h"
27
28#include <nccl.h>
29
30/*---------------------------------------------------------------------------*/
31/*---------------------------------------------------------------------------*/
32
33namespace Arcane
34{
35namespace Accelerator::Cuda
36{
37 void arcaneCheckNCCLErrors(const TraceInfo& ti, ncclResult_t e)
38 {
39 ARCCORE_FATAL_IF((e != ncclSuccess), "NCCL Error trace={0} e={1} str={2}", ti, e, ncclGetErrorString(e));
40 }
41} // namespace Accelerator::Cuda
42
43#define ARCCORE_CHECK_NCCL(result) \
44 Arcane::Accelerator::Cuda::arcaneCheckNCCLErrors(A_FUNCINFO, result)
45
46/*---------------------------------------------------------------------------*/
47/*---------------------------------------------------------------------------*/
53class NCCLVariableSynchronizeDispatcher
55{
56 public:
57
58 class Factory;
59 explicit NCCLVariableSynchronizeDispatcher(Factory* f);
60
61 protected:
62
63 void compute() override {}
64 void beginSynchronize(IDataSynchronizeBuffer* ds_buf) override;
65 void endSynchronize(IDataSynchronizeBuffer* ds_buf) override;
66
67 private:
68
69 IParallelMng* m_parallel_mng = nullptr;
70 ncclComm_t m_nccl_communicator;
71};
72
73/*---------------------------------------------------------------------------*/
74/*---------------------------------------------------------------------------*/
75
78{
79 public:
80
81 explicit Factory(IParallelMng* mpi_pm)
82 : m_parallel_mng(mpi_pm)
83 {}
84
85 Ref<IDataSynchronizeImplementation> createInstance() override
86 {
87 auto* x = new NCCLVariableSynchronizeDispatcher(this);
89 }
90
91 public:
92
93 IParallelMng* m_parallel_mng = nullptr;
94};
95
96/*---------------------------------------------------------------------------*/
97/*---------------------------------------------------------------------------*/
98
100arcaneCreateNCCLVariableSynchronizerFactory(IParallelMng* mpi_pm)
101{
102 auto* x = new NCCLVariableSynchronizeDispatcher::Factory(mpi_pm);
104}
105
106/*---------------------------------------------------------------------------*/
107/*---------------------------------------------------------------------------*/
108
109NCCLVariableSynchronizeDispatcher::
110NCCLVariableSynchronizeDispatcher(Factory* f)
111: m_parallel_mng(f->m_parallel_mng)
112{
113 IParallelMng* pm = m_parallel_mng;
114 Int32 my_rank = pm->commRank();
115 Int32 nb_rank = pm->commSize();
116
117 // TODO: Il faudrait vérifier qu'on a bien un GPU par rang MPI
118 // car NCCL ne supporte pas qu'il y ait plusieurs rangs sur le même GPU.
119
120 ncclUniqueId my_id;
121 ARCCORE_CHECK_NCCL(ncclGetUniqueId(&my_id));
122 ArrayView<char> id_as_bytes(NCCL_UNIQUE_ID_BYTES, reinterpret_cast<char*>(&my_id));
123 pm->broadcast(id_as_bytes, 0);
124
125 ARCCORE_CHECK_NCCL(ncclCommInitRank(&m_nccl_communicator, nb_rank, my_id, my_rank));
126}
127
128/*---------------------------------------------------------------------------*/
129/*---------------------------------------------------------------------------*/
130
131void NCCLVariableSynchronizeDispatcher::
132beginSynchronize(IDataSynchronizeBuffer* ds_buf)
133{
134 Integer nb_message = ds_buf->nbRank();
135
136 IParallelMng* pm = m_parallel_mng;
137 ITraceMng* tm = pm->traceMng();
138 tm->info() << "Doing NCCL Sync";
139
140 double prepare_time = 0.0;
141 cudaStream_t stream = 0;
142
143 // Si le IParallelMng a une RunQueue Cuda, on l'utilise.
144 RunQueue pm_queue = pm->_internalApi()->queue();
145 if (pm_queue.executionPolicy() == Accelerator::eExecutionPolicy::CUDA)
146 stream = Accelerator::AcceleratorUtils::toCudaNativeStream(pm_queue);
147;
148 ARCCORE_CHECK_NCCL(ncclGroupStart());
149 {
150 // Recopie les buffers d'envoi dans \a var_values
151 ds_buf->copyAllSend();
152
153 // Poste les messages de réception
154 for (Integer i = 0; i < nb_message; ++i) {
155 Int32 target_rank = ds_buf->targetRank(i);
156 auto buf = ds_buf->receiveBuffer(i).bytes();
157 if (!buf.empty()) {
158 ARCCORE_CHECK_NCCL(ncclRecv(buf.data(), buf.size(), ncclInt8, target_rank, m_nccl_communicator, stream));
159 }
160 }
161
162 // Poste les messages d'envoi
163 for (Integer i = 0; i < nb_message; ++i) {
164 auto buf = ds_buf->sendBuffer(i).bytes();
165 Int32 target_rank = ds_buf->targetRank(i);
166 if (!buf.empty()) {
167 ARCCORE_CHECK_NCCL(ncclSend(buf.data(), buf.size(), ncclInt8, target_rank, m_nccl_communicator, stream));
168 }
169 }
170 }
171 // Bloque jusqu'à ce que tous les messages soient terminés
172 ARCCORE_CHECK_NCCL(ncclGroupEnd());
173
174 tm->info() << "End begin synchronize";
175 pm->stat()->add("SyncPrepare", prepare_time, ds_buf->totalSendSize());
176}
177
178/*---------------------------------------------------------------------------*/
179/*---------------------------------------------------------------------------*/
180
181void NCCLVariableSynchronizeDispatcher::
182endSynchronize(IDataSynchronizeBuffer* ds_buf)
183{
184 IParallelMng* pm = m_parallel_mng;
185
186 double copy_time = 0.0;
187 double wait_time = 0.0;
188 ds_buf->copyAllReceive();
189
190 // S'assure que les copies des buffers sont bien terminées
191 ds_buf->barrier();
192
193 Int64 total_ghost_size = ds_buf->totalReceiveSize();
194 Int64 total_share_size = ds_buf->totalSendSize();
195 Int64 total_size = total_ghost_size + total_share_size;
196 pm->stat()->add("SyncCopy", copy_time, total_ghost_size);
197 pm->stat()->add("SyncWait", wait_time, total_size);
198}
199
200/*---------------------------------------------------------------------------*/
201/*---------------------------------------------------------------------------*/
202
203} // End namespace Arcane
204
205/*---------------------------------------------------------------------------*/
206/*---------------------------------------------------------------------------*/
#define ARCCORE_FATAL_IF(cond,...)
Macro envoyant une exception FatalErrorException si cond est vrai.
Vue modifiable d'un tableau d'un type T.
Buffer générique pour la synchronisation de données.
Interface d'une fabrique dispatcher générique.
Interface du gestionnaire de parallélisme pour un sous-domaine.
virtual ITraceMng * traceMng() const =0
Gestionnaire de traces.
virtual Int32 commRank() const =0
Rang de cette instance dans le communicateur.
virtual Int32 commSize() const =0
Nombre d'instance dans le communicateur.
virtual TraceMessage info()=0
Flot pour un message d'information.
Référence à une instance.
@ CUDA
Politique d'exécution utilisant l'environnement CUDA.
-*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
std::int64_t Int64
Type entier signé sur 64 bits.
Int32 Integer
Type représentant un entier.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Créé une référence sur un pointeur.
std::int32_t Int32
Type entier signé sur 32 bits.