Arcane  v3.14.10.0
Documentation utilisateur
Chargement...
Recherche...
Aucune correspondance
MpiTypeDispatcherImpl.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2024 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* MpiTypeDispatcherImpl.h (C) 2000-2024 */
9/* */
10/* Implémentation de 'MpiTypeDispatcher'. */
11/*---------------------------------------------------------------------------*/
12#ifndef ARCCORE_MESSAGEPASSINGMPI_MPITYPEDISPATCHERIMPL_H
13#define ARCCORE_MESSAGEPASSINGMPI_MPITYPEDISPATCHERIMPL_H
14/*---------------------------------------------------------------------------*/
15/*---------------------------------------------------------------------------*/
16
17#include "arccore/message_passing_mpi/MpiTypeDispatcher.h"
18#include "arccore/message_passing_mpi/MpiDatatype.h"
19#include "arccore/message_passing_mpi/MpiAdapter.h"
20#include "arccore/message_passing_mpi/MpiLock.h"
21
23#include "arccore/message_passing/Request.h"
24#include "arccore/message_passing/GatherMessageInfo.h"
25
26#include "arccore/base/NotSupportedException.h"
27#include "arccore/base/NotImplementedException.h"
28
29#include "arccore/collections/Array.h"
30
31/*---------------------------------------------------------------------------*/
32/*---------------------------------------------------------------------------*/
33
34namespace Arccore::MessagePassing::Mpi
35{
36
37/*---------------------------------------------------------------------------*/
38/*---------------------------------------------------------------------------*/
39
40template<class Type> MpiTypeDispatcher<Type>::
41MpiTypeDispatcher(IMessagePassingMng* parallel_mng,MpiAdapter* adapter,MpiDatatype* datatype)
42: m_parallel_mng(parallel_mng)
43, m_adapter(adapter)
44, m_datatype(datatype)
45{
46}
47
48/*---------------------------------------------------------------------------*/
49/*---------------------------------------------------------------------------*/
50
51template<class Type> MpiTypeDispatcher<Type>::
52~MpiTypeDispatcher()
53{
54 if (m_is_destroy_datatype)
55 delete m_datatype;
56}
57
58/*---------------------------------------------------------------------------*/
59/*---------------------------------------------------------------------------*/
60
61template<class Type> void MpiTypeDispatcher<Type>::
62broadcast(Span<Type> send_buf,Int32 rank)
63{
64 MPI_Datatype type = m_datatype->datatype();
65 m_adapter->broadcast(send_buf.data(),send_buf.size(),rank,type);
66}
67
68/*---------------------------------------------------------------------------*/
69/*---------------------------------------------------------------------------*/
70
71template<class Type> void MpiTypeDispatcher<Type>::
72allGather(Span<const Type> send_buf,Span<Type> recv_buf)
73{
74 MPI_Datatype type = m_datatype->datatype();
75 m_adapter->allGather(send_buf.data(),recv_buf.data(),send_buf.size(),type);
76}
77
78/*---------------------------------------------------------------------------*/
79/*---------------------------------------------------------------------------*/
80
81template<class Type> void MpiTypeDispatcher<Type>::
82gather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 rank)
83{
84 MPI_Datatype type = m_datatype->datatype();
85 m_adapter->gather(send_buf.data(),recv_buf.data(),send_buf.size(),rank,type);
86}
87
88/*---------------------------------------------------------------------------*/
89/*---------------------------------------------------------------------------*/
90
91template<class Type> void MpiTypeDispatcher<Type>::
92allGatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf)
93{
94 _gatherVariable2(send_buf,recv_buf,-1);
95}
96
97/*---------------------------------------------------------------------------*/
98/*---------------------------------------------------------------------------*/
99
100template<class Type> void MpiTypeDispatcher<Type>::
101gatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 rank)
102{
103 _gatherVariable2(send_buf,recv_buf,rank);
104}
105
106/*---------------------------------------------------------------------------*/
107/*---------------------------------------------------------------------------*/
108
109template<class Type> void MpiTypeDispatcher<Type>::
110_gatherVariable2(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 rank)
111{
112 Int32 comm_size = m_parallel_mng->commSize();
113 UniqueArray<int> send_counts(comm_size);
114 UniqueArray<int> send_indexes(comm_size);
115
116 Int64 nb_elem = send_buf.size();
117 int my_buf_count = (int)nb_elem;
118 Span<const int> count_r(&my_buf_count,1);
119
120 // Récupère le nombre d'éléments de chaque processeur
121 if (rank!=A_NULL_RANK)
122 mpGather(m_parallel_mng,count_r,send_counts,rank);
123 else
124 mpAllGather(m_parallel_mng,count_r,send_counts);
125
126 // Remplit le tableau des index
127 if (rank==A_NULL_RANK || rank==m_adapter->commRank()){
128 Int64 index = 0;
129 for( Integer i=0, is=comm_size; i<is; ++i ){
130 send_indexes[i] = (int)index;
131 index += send_counts[i];
132 //info() << " SEND " << i << " index=" << send_indexes[i] << " count=" << send_counts[i];
133 }
134 Int64 i64_total_elem = index;
135 Int64 max_size = ARCCORE_INT64_MAX;
136 if (i64_total_elem>max_size){
137 ARCCORE_FATAL("Invalid size '{0}'",i64_total_elem);
138 }
139 Int64 total_elem = i64_total_elem;
140 recv_buf.resize(total_elem);
141 }
142 gatherVariable(send_buf,recv_buf,send_counts,send_indexes,rank);
143}
144
145/*---------------------------------------------------------------------------*/
146/*---------------------------------------------------------------------------*/
147
148template<class Type> void MpiTypeDispatcher<Type>::
149gatherVariable(Span<const Type> send_buf,Span<Type> recv_buf,Span<const Int32> send_counts,
150 Span<const Int32> displacements,Int32 rank)
151{
152 MPI_Datatype type = m_datatype->datatype();
153 Int32 nb_elem = send_buf.smallView().size();
154 if (rank!=A_NULL_RANK){
155 m_adapter->gatherVariable(send_buf.data(),recv_buf.data(),send_counts.data(),
156 displacements.data(),nb_elem,rank,type);
157 }
158 else{
159 m_adapter->allGatherVariable(send_buf.data(),recv_buf.data(),send_counts.data(),
160 displacements.data(),nb_elem,type);
161 }
162}
163
164/*---------------------------------------------------------------------------*/
165/*---------------------------------------------------------------------------*/
166
167template<class Type> void MpiTypeDispatcher<Type>::
168scatterVariable(Span<const Type> send_buf,Span<Type> recv_buf,Int32 root)
169{
170 MPI_Datatype type = m_datatype->datatype();
171
172 Int32 comm_size = m_adapter->commSize();
173 UniqueArray<int> recv_counts(comm_size);
174 UniqueArray<int> recv_indexes(comm_size);
175
176 Int64 nb_elem = recv_buf.size();
177 int my_buf_count = m_adapter->toMPISize(nb_elem);
178 Span<const int> count_r(&my_buf_count,1);
179
180 // Récupère le nombre d'éléments de chaque processeur
181 mpAllGather(m_parallel_mng,count_r,recv_counts);
182
183 // Remplit le tableau des index
184 int index = 0;
185 for( Integer i=0, is=comm_size; i<is; ++i ){
186 recv_indexes[i] = index;
187 index += recv_counts[i];
188 }
189
190 m_adapter->scatterVariable(send_buf.data(),recv_counts.data(),recv_indexes.data(),
191 recv_buf.data(),nb_elem,root,type);
192}
193
194/*---------------------------------------------------------------------------*/
195/*---------------------------------------------------------------------------*/
196
197template<class Type> void MpiTypeDispatcher<Type>::
198allToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
199{
200 MPI_Datatype type = m_datatype->datatype();
201 m_adapter->allToAll(send_buf.data(),recv_buf.data(),count,type);
202}
203
204/*---------------------------------------------------------------------------*/
205/*---------------------------------------------------------------------------*/
206
207template<class Type> void MpiTypeDispatcher<Type>::
208allToAllVariable(Span<const Type> send_buf,
209 Int32ConstArrayView send_count,
210 Int32ConstArrayView send_index,
211 Span<Type> recv_buf,
212 Int32ConstArrayView recv_count,
213 Int32ConstArrayView recv_index
214 )
215{
216 MPI_Datatype type = m_datatype->datatype();
217
218 m_adapter->allToAllVariable(send_buf.data(),send_count.data(),
219 send_index.data(),recv_buf.data(),
220 recv_count.data(),
221 recv_index.data(),type);
222}
223
224/*---------------------------------------------------------------------------*/
225/*---------------------------------------------------------------------------*/
226
227template<class Type> Request MpiTypeDispatcher<Type>::
228send(Span<const Type> send_buffer,Int32 rank,bool is_blocked)
229{
230 MPI_Datatype type = m_datatype->datatype();
231 return m_adapter->directSend(send_buffer.data(),send_buffer.size(),
232 rank,sizeof(Type),type,100,is_blocked);
233}
234
235/*---------------------------------------------------------------------------*/
236/*---------------------------------------------------------------------------*/
237
238template<class Type> Request MpiTypeDispatcher<Type>::
239receive(Span<Type> recv_buffer,Int32 rank,bool is_blocked)
240{
241 MPI_Datatype type = m_datatype->datatype();
242 MpiLock::Section mls(m_adapter->mpiLock());
243 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
244 rank,sizeof(Type),type,100,is_blocked);
245}
246
247/*---------------------------------------------------------------------------*/
248/*---------------------------------------------------------------------------*/
249
250template<class Type> Request MpiTypeDispatcher<Type>::
251send(Span<const Type> send_buffer,const PointToPointMessageInfo& message)
252{
253 MPI_Datatype type = m_datatype->datatype();
254 Int64 sizeof_type = sizeof(Type);
255 MpiLock::Section mls(m_adapter->mpiLock());
256 bool is_blocking = message.isBlocking();
257 if (message.isRankTag()){
258 return m_adapter->directSend(send_buffer.data(),send_buffer.size(),
259 message.destinationRank().value(),
260 sizeof_type,type,message.tag().value(),is_blocking);
261 }
262 if (message.isMessageId()){
263 // Le send avec un MessageId n'existe pas.
264 ARCCORE_THROW(NotSupportedException,"Invalid generic send with MessageId");
265 }
266 ARCCORE_THROW(NotSupportedException,"Invalid message_info");
267}
268
269/*---------------------------------------------------------------------------*/
270/*---------------------------------------------------------------------------*/
271
272template<class Type> Request MpiTypeDispatcher<Type>::
273receive(Span<Type> recv_buffer,const PointToPointMessageInfo& message)
274{
275 MPI_Datatype type = m_datatype->datatype();
276 Int64 sizeof_type = sizeof(Type);
277 MpiLock::Section mls(m_adapter->mpiLock());
278 bool is_blocking = message.isBlocking();
279 if (message.isRankTag()){
280 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
281 message.destinationRank().value(),sizeof_type,type,
282 message.tag().value(),
283 is_blocking);
284 }
285 if (message.isMessageId()){
286 MessageId message_id = message.messageId();
287 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
288 message_id,sizeof_type,type,is_blocking);
289 }
290 ARCCORE_THROW(NotSupportedException,"Invalid message_info");
291}
292
293/*---------------------------------------------------------------------------*/
294/*---------------------------------------------------------------------------*/
295
296template<class Type> Type MpiTypeDispatcher<Type>::
297allReduce(eReduceType op,Type send_buf)
298{
299 MPI_Datatype type = m_datatype->datatype();
300 Type recv_buf = send_buf;
301 MPI_Op operation = m_datatype->reduceOperator(op);
302 m_adapter->allReduce(&send_buf,&recv_buf,1,type,operation);
303 return recv_buf;
304}
305
306/*---------------------------------------------------------------------------*/
307/*---------------------------------------------------------------------------*/
308
309template<class Type> void MpiTypeDispatcher<Type>::
310allReduce(eReduceType op,Span<Type> send_buf)
311{
312 MPI_Datatype type = m_datatype->datatype();
313 Int64 s = send_buf.size();
314 UniqueArray<Type> recv_buf(s);
315 MPI_Op operation = m_datatype->reduceOperator(op);
316 {
317 MpiLock::Section mls(m_adapter->mpiLock());
318 m_adapter->allReduce(send_buf.data(),recv_buf.data(),s,type,operation);
319 }
320 send_buf.copy(recv_buf);
321}
322
323/*---------------------------------------------------------------------------*/
324/*---------------------------------------------------------------------------*/
325
326template<class Type> Request MpiTypeDispatcher<Type>::
327nonBlockingAllReduce(eReduceType op,Span<const Type> send_buf,Span<Type> recv_buf)
328{
329 MPI_Datatype type = m_datatype->datatype();
330 Int64 s = send_buf.size();
331 MPI_Op operation = m_datatype->reduceOperator(op);
332 Request request;
333 {
334 MpiLock::Section mls(m_adapter->mpiLock());
335 request = m_adapter->nonBlockingAllReduce(send_buf.data(),recv_buf.data(),s,type,operation);
336 }
337 return request;
338}
339
340/*---------------------------------------------------------------------------*/
341/*---------------------------------------------------------------------------*/
342
343template<class Type> Request MpiTypeDispatcher<Type>::
344nonBlockingAllToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
345{
346 MPI_Datatype type = m_datatype->datatype();
347 return m_adapter->nonBlockingAllToAll(send_buf.data(),recv_buf.data(),count,type);
348}
349
350/*---------------------------------------------------------------------------*/
351/*---------------------------------------------------------------------------*/
352
353template<class Type> Request MpiTypeDispatcher<Type>::
354nonBlockingAllToAllVariable(Span<const Type> send_buf,
355 Int32ConstArrayView send_count,
356 Int32ConstArrayView send_index,
357 Span<Type> recv_buf,
358 Int32ConstArrayView recv_count,
359 Int32ConstArrayView recv_index
360 )
361{
362 MPI_Datatype type = m_datatype->datatype();
363
364 return m_adapter->nonBlockingAllToAllVariable(send_buf.data(),send_count.data(),
365 send_index.data(),recv_buf.data(),
366 recv_count.data(),
367 recv_index.data(),type);
368}
369
370/*---------------------------------------------------------------------------*/
371/*---------------------------------------------------------------------------*/
372
373template<class Type> Request MpiTypeDispatcher<Type>::
374nonBlockingBroadcast(Span<Type> send_buf,Int32 rank)
375{
376 MPI_Datatype type = m_datatype->datatype();
377 return m_adapter->nonBlockingBroadcast(send_buf.data(),send_buf.size(),rank,type);
378}
379
380/*---------------------------------------------------------------------------*/
381/*---------------------------------------------------------------------------*/
382
383template<class Type> Request MpiTypeDispatcher<Type>::
384nonBlockingAllGather(Span<const Type> send_buf,Span<Type> recv_buf)
385{
386 MPI_Datatype type = m_datatype->datatype();
387 return m_adapter->nonBlockingAllGather(send_buf.data(),recv_buf.data(),send_buf.size(),type);
388}
389
390/*---------------------------------------------------------------------------*/
391/*---------------------------------------------------------------------------*/
392
393template<class Type> Request MpiTypeDispatcher<Type>::
394nonBlockingGather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 rank)
395{
396 MPI_Datatype type = m_datatype->datatype();
397 return m_adapter->nonBlockingGather(send_buf.data(),recv_buf.data(),send_buf.size(),rank,type);
398}
399
400/*---------------------------------------------------------------------------*/
401/*---------------------------------------------------------------------------*/
402
403template<class Type> Request MpiTypeDispatcher<Type>::
404gather(GatherMessageInfo<Type>& gather_info)
405{
406 if (!gather_info.isValid())
407 return {};
408
409 bool is_blocking = gather_info.isBlocking();
410 MessageRank dest_rank = gather_info.destinationRank();
411 bool is_all_variant = dest_rank.isNull();
412 MessageRank my_rank(m_parallel_mng->commRank());
413
414 auto send_buf = gather_info.sendBuffer();
415
416 // GatherVariable avec envoi gather préliminaire pour connaitre la taille
417 // que doit envoyer chaque rang.
418 if (gather_info.mode()==GatherMessageInfoBase::Mode::GatherVariableNeedComputeInfo) {
419 if (!is_blocking)
420 ARCCORE_THROW(NotSupportedException,"non blocking version of AllGatherVariable or GatherVariable with compute info");
421 Array<Type>* receive_array = gather_info.localReceptionBuffer();
422 if (is_all_variant){
423 if (!receive_array)
424 ARCCORE_FATAL("local reception buffer is null");
425 this->allGatherVariable(send_buf, *receive_array);
426 }
427 else{
428 UniqueArray<Type> unused_array;
429 if (dest_rank==my_rank)
430 this->gatherVariable(send_buf, *receive_array, dest_rank.value());
431 else
432 this->gatherVariable(send_buf, unused_array, dest_rank.value());
433 }
434 return {};
435 }
436
437 // GatherVariable classique avec connaissance du déplacement et des tailles
438 if (gather_info.mode() == GatherMessageInfoBase::Mode::GatherVariable) {
439 if (!is_blocking)
440 ARCCORE_THROW(NotImplementedException, "non blocking version of AllGatherVariable or GatherVariable");
441 auto receive_buf = gather_info.receiveBuffer();
442 auto displacements = gather_info.receiveDisplacement();
443 auto receive_counts = gather_info.receiveCounts();
444 gatherVariable(send_buf, receive_buf, receive_counts, displacements, dest_rank.value());
445 return {};
446 }
447
448 // Gather classique
449 if (gather_info.mode() == GatherMessageInfoBase::Mode::Gather) {
450 auto receive_buf = gather_info.receiveBuffer();
451 if (is_blocking) {
452 if (is_all_variant)
453 this->allGather(send_buf, receive_buf);
454 else
455 this->gather(send_buf, receive_buf, dest_rank.value());
456 return {};
457 }
458 else{
459 if (is_all_variant)
460 return this->nonBlockingAllGather(send_buf, receive_buf);
461 else
462 return this->nonBlockingGather(send_buf, receive_buf, dest_rank.value());
463 }
464 }
465
466 ARCCORE_THROW(NotImplementedException,"Unknown type() for GatherMessageInfo");
467}
468
469/*---------------------------------------------------------------------------*/
470/*---------------------------------------------------------------------------*/
471
472} // End namespace Arccore::MessagePassing::Mpi
473
474/*---------------------------------------------------------------------------*/
475/*---------------------------------------------------------------------------*/
476
477#endif
Liste des fonctions d'échange de message.
void mpAllGather(IMessagePassingMng *pm, const ISerializer *send_serializer, ISerializer *receive_serialize)
Message allGather() pour une sérialisation.
Definition Messages.cc:296
C ARCCORE_MESSAGEPASSING_EXPORT void mpGather(IMessagePassingMng *pm, Span< const char > send_buf, Span< char > recv_buf, Int32 rank)