Arcane  v3.16.0.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
MpiTypeDispatcherImpl.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2025 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* MpiTypeDispatcherImpl.h (C) 2000-2025 */
9/* */
10/* Implémentation de 'MpiTypeDispatcher'. */
11/*---------------------------------------------------------------------------*/
12#ifndef ARCCORE_MESSAGEPASSINGMPI_INTERNAL_MPITYPEDISPATCHERIMPL_H
13#define ARCCORE_MESSAGEPASSINGMPI_INTERNAL_MPITYPEDISPATCHERIMPL_H
14/*---------------------------------------------------------------------------*/
15/*---------------------------------------------------------------------------*/
16
17#include "arccore/message_passing_mpi/internal/MpiTypeDispatcher.h"
18
19#include "arccore/message_passing_mpi/MpiDatatype.h"
20#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
21#include "arccore/message_passing_mpi/internal/MpiLock.h"
22
24#include "arccore/message_passing/Request.h"
25#include "arccore/message_passing/GatherMessageInfo.h"
26
27#include "arccore/base/NotSupportedException.h"
28#include "arccore/base/NotImplementedException.h"
29
30#include "arccore/collections/Array.h"
31
32/*---------------------------------------------------------------------------*/
33/*---------------------------------------------------------------------------*/
34
35namespace Arcane::MessagePassing::Mpi
36{
37
38/*---------------------------------------------------------------------------*/
39/*---------------------------------------------------------------------------*/
40
41template<class Type> MpiTypeDispatcher<Type>::
42MpiTypeDispatcher(IMessagePassingMng* parallel_mng,MpiAdapter* adapter,MpiDatatype* datatype)
43: m_parallel_mng(parallel_mng)
44, m_adapter(adapter)
45, m_datatype(datatype)
46{
47}
48
49/*---------------------------------------------------------------------------*/
50/*---------------------------------------------------------------------------*/
51
52template<class Type> MpiTypeDispatcher<Type>::
53~MpiTypeDispatcher()
54{
55 if (m_is_destroy_datatype)
56 delete m_datatype;
57}
58
59/*---------------------------------------------------------------------------*/
60/*---------------------------------------------------------------------------*/
61
62template<class Type> void MpiTypeDispatcher<Type>::
63broadcast(Span<Type> send_buf,Int32 rank)
64{
65 MPI_Datatype type = m_datatype->datatype();
66 m_adapter->broadcast(send_buf.data(),send_buf.size(),rank,type);
67}
68
69/*---------------------------------------------------------------------------*/
70/*---------------------------------------------------------------------------*/
71
72template<class Type> void MpiTypeDispatcher<Type>::
73allGather(Span<const Type> send_buf,Span<Type> recv_buf)
74{
75 MPI_Datatype type = m_datatype->datatype();
76 m_adapter->allGather(send_buf.data(),recv_buf.data(),send_buf.size(),type);
77}
78
79/*---------------------------------------------------------------------------*/
80/*---------------------------------------------------------------------------*/
81
82template<class Type> void MpiTypeDispatcher<Type>::
83gather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 rank)
84{
85 MPI_Datatype type = m_datatype->datatype();
86 m_adapter->gather(send_buf.data(),recv_buf.data(),send_buf.size(),rank,type);
87}
88
89/*---------------------------------------------------------------------------*/
90/*---------------------------------------------------------------------------*/
91
92template<class Type> void MpiTypeDispatcher<Type>::
93allGatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf)
94{
95 _gatherVariable2(send_buf,recv_buf,-1);
96}
97
98/*---------------------------------------------------------------------------*/
99/*---------------------------------------------------------------------------*/
100
101template<class Type> void MpiTypeDispatcher<Type>::
102gatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 rank)
103{
104 _gatherVariable2(send_buf,recv_buf,rank);
105}
106
107/*---------------------------------------------------------------------------*/
108/*---------------------------------------------------------------------------*/
109
110template<class Type> void MpiTypeDispatcher<Type>::
111_gatherVariable2(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 rank)
112{
113 Int32 comm_size = m_parallel_mng->commSize();
114 UniqueArray<int> send_counts(comm_size);
115 UniqueArray<int> send_indexes(comm_size);
116
117 Int64 nb_elem = send_buf.size();
118 int my_buf_count = (int)nb_elem;
119 Span<const int> count_r(&my_buf_count,1);
120
121 // Récupère le nombre d'éléments de chaque processeur
122 if (rank!=A_NULL_RANK)
123 mpGather(m_parallel_mng,count_r,send_counts,rank);
124 else
125 mpAllGather(m_parallel_mng,count_r,send_counts);
126
127 // Remplit le tableau des index
128 if (rank==A_NULL_RANK || rank==m_adapter->commRank()){
129 Int64 index = 0;
130 for( Integer i=0, is=comm_size; i<is; ++i ){
131 send_indexes[i] = (int)index;
132 index += send_counts[i];
133 //info() << " SEND " << i << " index=" << send_indexes[i] << " count=" << send_counts[i];
134 }
135 Int64 i64_total_elem = index;
136 Int64 max_size = ARCCORE_INT64_MAX;
137 if (i64_total_elem>max_size){
138 ARCCORE_FATAL("Invalid size '{0}'",i64_total_elem);
139 }
140 Int64 total_elem = i64_total_elem;
141 recv_buf.resize(total_elem);
142 }
143 gatherVariable(send_buf,recv_buf,send_counts,send_indexes,rank);
144}
145
146/*---------------------------------------------------------------------------*/
147/*---------------------------------------------------------------------------*/
148
149template<class Type> void MpiTypeDispatcher<Type>::
150gatherVariable(Span<const Type> send_buf,Span<Type> recv_buf,Span<const Int32> send_counts,
151 Span<const Int32> displacements,Int32 rank)
152{
153 MPI_Datatype type = m_datatype->datatype();
154 Int32 nb_elem = send_buf.smallView().size();
155 if (rank!=A_NULL_RANK){
156 m_adapter->gatherVariable(send_buf.data(),recv_buf.data(),send_counts.data(),
157 displacements.data(),nb_elem,rank,type);
158 }
159 else{
160 m_adapter->allGatherVariable(send_buf.data(),recv_buf.data(),send_counts.data(),
161 displacements.data(),nb_elem,type);
162 }
163}
164
165/*---------------------------------------------------------------------------*/
166/*---------------------------------------------------------------------------*/
167
168template<class Type> void MpiTypeDispatcher<Type>::
169scatterVariable(Span<const Type> send_buf,Span<Type> recv_buf,Int32 root)
170{
171 MPI_Datatype type = m_datatype->datatype();
172
173 Int32 comm_size = m_adapter->commSize();
174 UniqueArray<int> recv_counts(comm_size);
175 UniqueArray<int> recv_indexes(comm_size);
176
177 Int64 nb_elem = recv_buf.size();
178 int my_buf_count = m_adapter->toMPISize(nb_elem);
179 Span<const int> count_r(&my_buf_count,1);
180
181 // Récupère le nombre d'éléments de chaque processeur
182 mpAllGather(m_parallel_mng,count_r,recv_counts);
183
184 // Remplit le tableau des index
185 int index = 0;
186 for( Integer i=0, is=comm_size; i<is; ++i ){
187 recv_indexes[i] = index;
188 index += recv_counts[i];
189 }
190
191 m_adapter->scatterVariable(send_buf.data(),recv_counts.data(),recv_indexes.data(),
192 recv_buf.data(),nb_elem,root,type);
193}
194
195/*---------------------------------------------------------------------------*/
196/*---------------------------------------------------------------------------*/
197
198template<class Type> void MpiTypeDispatcher<Type>::
199allToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
200{
201 MPI_Datatype type = m_datatype->datatype();
202 m_adapter->allToAll(send_buf.data(),recv_buf.data(),count,type);
203}
204
205/*---------------------------------------------------------------------------*/
206/*---------------------------------------------------------------------------*/
207
208template<class Type> void MpiTypeDispatcher<Type>::
209allToAllVariable(Span<const Type> send_buf,
210 ConstArrayView<Int32> send_count,
211 ConstArrayView<Int32> send_index,
212 Span<Type> recv_buf,
213 ConstArrayView<Int32> recv_count,
214 ConstArrayView<Int32> recv_index
215 )
216{
217 MPI_Datatype type = m_datatype->datatype();
218
219 m_adapter->allToAllVariable(send_buf.data(),send_count.data(),
220 send_index.data(),recv_buf.data(),
221 recv_count.data(),
222 recv_index.data(),type);
223}
224
225/*---------------------------------------------------------------------------*/
226/*---------------------------------------------------------------------------*/
227
228template<class Type> Request MpiTypeDispatcher<Type>::
229send(Span<const Type> send_buffer,Int32 rank,bool is_blocked)
230{
231 MPI_Datatype type = m_datatype->datatype();
232 return m_adapter->directSend(send_buffer.data(),send_buffer.size(),
233 rank,sizeof(Type),type,100,is_blocked);
234}
235
236/*---------------------------------------------------------------------------*/
237/*---------------------------------------------------------------------------*/
238
239template<class Type> Request MpiTypeDispatcher<Type>::
240receive(Span<Type> recv_buffer,Int32 rank,bool is_blocked)
241{
242 MPI_Datatype type = m_datatype->datatype();
243 MpiLock::Section mls(m_adapter->mpiLock());
244 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
245 rank,sizeof(Type),type,100,is_blocked);
246}
247
248/*---------------------------------------------------------------------------*/
249/*---------------------------------------------------------------------------*/
250
251template<class Type> Request MpiTypeDispatcher<Type>::
252send(Span<const Type> send_buffer,const PointToPointMessageInfo& message)
253{
254 MPI_Datatype type = m_datatype->datatype();
255 Int64 sizeof_type = sizeof(Type);
256 MpiLock::Section mls(m_adapter->mpiLock());
257 bool is_blocking = message.isBlocking();
258 if (message.isRankTag()){
259 return m_adapter->directSend(send_buffer.data(),send_buffer.size(),
260 message.destinationRank().value(),
261 sizeof_type,type,message.tag().value(),is_blocking);
262 }
263 if (message.isMessageId()){
264 // Le send avec un MessageId n'existe pas.
265 ARCCORE_THROW(NotSupportedException,"Invalid generic send with MessageId");
266 }
267 ARCCORE_THROW(NotSupportedException,"Invalid message_info");
268}
269
270/*---------------------------------------------------------------------------*/
271/*---------------------------------------------------------------------------*/
272
273template<class Type> Request MpiTypeDispatcher<Type>::
274receive(Span<Type> recv_buffer,const PointToPointMessageInfo& message)
275{
276 MPI_Datatype type = m_datatype->datatype();
277 Int64 sizeof_type = sizeof(Type);
278 MpiLock::Section mls(m_adapter->mpiLock());
279 bool is_blocking = message.isBlocking();
280 if (message.isRankTag()){
281 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
282 message.destinationRank().value(),sizeof_type,type,
283 message.tag().value(),
284 is_blocking);
285 }
286 if (message.isMessageId()){
287 MessageId message_id = message.messageId();
288 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
289 message_id,sizeof_type,type,is_blocking);
290 }
291 ARCCORE_THROW(NotSupportedException,"Invalid message_info");
292}
293
294/*---------------------------------------------------------------------------*/
295/*---------------------------------------------------------------------------*/
296
297template<class Type> Type MpiTypeDispatcher<Type>::
298allReduce(eReduceType op,Type send_buf)
299{
300 MPI_Datatype type = m_datatype->datatype();
301 Type recv_buf = send_buf;
302 MPI_Op operation = m_datatype->reduceOperator(op);
303 m_adapter->allReduce(&send_buf,&recv_buf,1,type,operation);
304 return recv_buf;
305}
306
307/*---------------------------------------------------------------------------*/
308/*---------------------------------------------------------------------------*/
309
310template<class Type> void MpiTypeDispatcher<Type>::
311allReduce(eReduceType op,Span<Type> send_buf)
312{
313 MPI_Datatype type = m_datatype->datatype();
314 Int64 s = send_buf.size();
315 UniqueArray<Type> recv_buf(s);
316 MPI_Op operation = m_datatype->reduceOperator(op);
317 {
318 MpiLock::Section mls(m_adapter->mpiLock());
319 m_adapter->allReduce(send_buf.data(),recv_buf.data(),s,type,operation);
320 }
321 send_buf.copy(recv_buf);
322}
323
324/*---------------------------------------------------------------------------*/
325/*---------------------------------------------------------------------------*/
326
327template<class Type> Request MpiTypeDispatcher<Type>::
328nonBlockingAllReduce(eReduceType op,Span<const Type> send_buf,Span<Type> recv_buf)
329{
330 MPI_Datatype type = m_datatype->datatype();
331 Int64 s = send_buf.size();
332 MPI_Op operation = m_datatype->reduceOperator(op);
333 Request request;
334 {
335 MpiLock::Section mls(m_adapter->mpiLock());
336 request = m_adapter->nonBlockingAllReduce(send_buf.data(),recv_buf.data(),s,type,operation);
337 }
338 return request;
339}
340
341/*---------------------------------------------------------------------------*/
342/*---------------------------------------------------------------------------*/
343
344template<class Type> Request MpiTypeDispatcher<Type>::
345nonBlockingAllToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
346{
347 MPI_Datatype type = m_datatype->datatype();
348 return m_adapter->nonBlockingAllToAll(send_buf.data(),recv_buf.data(),count,type);
349}
350
351/*---------------------------------------------------------------------------*/
352/*---------------------------------------------------------------------------*/
353
354template<class Type> Request MpiTypeDispatcher<Type>::
355nonBlockingAllToAllVariable(Span<const Type> send_buf,
356 ConstArrayView<Int32> send_count,
357 ConstArrayView<Int32> send_index,
358 Span<Type> recv_buf,
359 ConstArrayView<Int32> recv_count,
360 ConstArrayView<Int32> recv_index
361 )
362{
363 MPI_Datatype type = m_datatype->datatype();
364
365 return m_adapter->nonBlockingAllToAllVariable(send_buf.data(),send_count.data(),
366 send_index.data(),recv_buf.data(),
367 recv_count.data(),
368 recv_index.data(),type);
369}
370
371/*---------------------------------------------------------------------------*/
372/*---------------------------------------------------------------------------*/
373
374template<class Type> Request MpiTypeDispatcher<Type>::
375nonBlockingBroadcast(Span<Type> send_buf,Int32 rank)
376{
377 MPI_Datatype type = m_datatype->datatype();
378 return m_adapter->nonBlockingBroadcast(send_buf.data(),send_buf.size(),rank,type);
379}
380
381/*---------------------------------------------------------------------------*/
382/*---------------------------------------------------------------------------*/
383
384template<class Type> Request MpiTypeDispatcher<Type>::
385nonBlockingAllGather(Span<const Type> send_buf,Span<Type> recv_buf)
386{
387 MPI_Datatype type = m_datatype->datatype();
388 return m_adapter->nonBlockingAllGather(send_buf.data(),recv_buf.data(),send_buf.size(),type);
389}
390
391/*---------------------------------------------------------------------------*/
392/*---------------------------------------------------------------------------*/
393
394template<class Type> Request MpiTypeDispatcher<Type>::
395nonBlockingGather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 rank)
396{
397 MPI_Datatype type = m_datatype->datatype();
398 return m_adapter->nonBlockingGather(send_buf.data(),recv_buf.data(),send_buf.size(),rank,type);
399}
400
401/*---------------------------------------------------------------------------*/
402/*---------------------------------------------------------------------------*/
403
404template<class Type> Request MpiTypeDispatcher<Type>::
405gather(GatherMessageInfo<Type>& gather_info)
406{
407 if (!gather_info.isValid())
408 return {};
409
410 bool is_blocking = gather_info.isBlocking();
411 MessageRank dest_rank = gather_info.destinationRank();
412 bool is_all_variant = dest_rank.isNull();
413 MessageRank my_rank(m_parallel_mng->commRank());
414
415 auto send_buf = gather_info.sendBuffer();
416
417 // GatherVariable avec envoi gather préliminaire pour connaitre la taille
418 // que doit envoyer chaque rang.
419 if (gather_info.mode()==GatherMessageInfoBase::Mode::GatherVariableNeedComputeInfo) {
420 if (!is_blocking)
421 ARCCORE_THROW(NotSupportedException,"non blocking version of AllGatherVariable or GatherVariable with compute info");
422 Array<Type>* receive_array = gather_info.localReceptionBuffer();
423 if (is_all_variant){
424 if (!receive_array)
425 ARCCORE_FATAL("local reception buffer is null");
426 this->allGatherVariable(send_buf, *receive_array);
427 }
428 else{
429 UniqueArray<Type> unused_array;
430 if (dest_rank==my_rank)
431 this->gatherVariable(send_buf, *receive_array, dest_rank.value());
432 else
433 this->gatherVariable(send_buf, unused_array, dest_rank.value());
434 }
435 return {};
436 }
437
438 // GatherVariable classique avec connaissance du déplacement et des tailles
439 if (gather_info.mode() == GatherMessageInfoBase::Mode::GatherVariable) {
440 if (!is_blocking)
441 ARCCORE_THROW(NotImplementedException, "non blocking version of AllGatherVariable or GatherVariable");
442 auto receive_buf = gather_info.receiveBuffer();
443 auto displacements = gather_info.receiveDisplacement();
444 auto receive_counts = gather_info.receiveCounts();
445 gatherVariable(send_buf, receive_buf, receive_counts, displacements, dest_rank.value());
446 return {};
447 }
448
449 // Gather classique
450 if (gather_info.mode() == GatherMessageInfoBase::Mode::Gather) {
451 auto receive_buf = gather_info.receiveBuffer();
452 if (is_blocking) {
453 if (is_all_variant)
454 this->allGather(send_buf, receive_buf);
455 else
456 this->gather(send_buf, receive_buf, dest_rank.value());
457 return {};
458 }
459 else{
460 if (is_all_variant)
461 return this->nonBlockingAllGather(send_buf, receive_buf);
462 else
463 return this->nonBlockingGather(send_buf, receive_buf, dest_rank.value());
464 }
465 }
466
467 ARCCORE_THROW(NotImplementedException,"Unknown type() for GatherMessageInfo");
468}
469
470/*---------------------------------------------------------------------------*/
471/*---------------------------------------------------------------------------*/
472
473} // End namespace Arccore::MessagePassing::Mpi
474
475/*---------------------------------------------------------------------------*/
476/*---------------------------------------------------------------------------*/
477
478#endif
Liste des fonctions d'échange de message.
Interface du gestionnaire des échanges de messages.
C void mpGather(IMessagePassingMng *pm, Span< const char > send_buf, Span< char > recv_buf, Int32 rank)
void mpAllGather(IMessagePassingMng *pm, const ISerializer *send_serializer, ISerializer *receive_serialize)
Message allGather() pour une sérialisation.
Definition Messages.cc:307
Type
Type of JSON value.
Definition rapidjson.h:665