12#ifndef ARCCORE_MESSAGEPASSINGMPI_MPITYPEDISPATCHERIMPL_H
13#define ARCCORE_MESSAGEPASSINGMPI_MPITYPEDISPATCHERIMPL_H
17#include "arccore/message_passing_mpi/MpiTypeDispatcher.h"
18#include "arccore/message_passing_mpi/MpiDatatype.h"
19#include "arccore/message_passing_mpi/MpiAdapter.h"
20#include "arccore/message_passing_mpi/MpiLock.h"
23#include "arccore/message_passing/Request.h"
24#include "arccore/message_passing/GatherMessageInfo.h"
26#include "arccore/base/NotSupportedException.h"
27#include "arccore/base/NotImplementedException.h"
29#include "arccore/collections/Array.h"
34namespace Arccore::MessagePassing::Mpi
40template<
class Type> MpiTypeDispatcher<Type>::
41MpiTypeDispatcher(IMessagePassingMng* parallel_mng,MpiAdapter* adapter,MpiDatatype* datatype)
42: m_parallel_mng(parallel_mng)
51template<
class Type> MpiTypeDispatcher<Type>::
54 if (m_is_destroy_datatype)
61template<
class Type>
void MpiTypeDispatcher<Type>::
62broadcast(Span<Type> send_buf,Int32 rank)
64 MPI_Datatype type = m_datatype->datatype();
65 m_adapter->broadcast(send_buf.data(),send_buf.size(),rank,type);
71template<
class Type>
void MpiTypeDispatcher<Type>::
72allGather(Span<const Type> send_buf,Span<Type> recv_buf)
74 MPI_Datatype type = m_datatype->datatype();
75 m_adapter->allGather(send_buf.data(),recv_buf.data(),send_buf.size(),type);
81template<
class Type>
void MpiTypeDispatcher<Type>::
82gather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 rank)
84 MPI_Datatype type = m_datatype->datatype();
85 m_adapter->gather(send_buf.data(),recv_buf.data(),send_buf.size(),rank,type);
91template<
class Type>
void MpiTypeDispatcher<Type>::
92allGatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf)
94 _gatherVariable2(send_buf,recv_buf,-1);
100template<
class Type>
void MpiTypeDispatcher<Type>::
101gatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 rank)
103 _gatherVariable2(send_buf,recv_buf,rank);
109template<
class Type>
void MpiTypeDispatcher<Type>::
110_gatherVariable2(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 rank)
112 Int32 comm_size = m_parallel_mng->commSize();
113 UniqueArray<int> send_counts(comm_size);
114 UniqueArray<int> send_indexes(comm_size);
116 Int64 nb_elem = send_buf.size();
117 int my_buf_count = (int)nb_elem;
118 Span<const int> count_r(&my_buf_count,1);
121 if (rank!=A_NULL_RANK)
122 mpGather(m_parallel_mng,count_r,send_counts,rank);
127 if (rank==A_NULL_RANK || rank==m_adapter->commRank()){
129 for( Integer i=0, is=comm_size; i<is; ++i ){
130 send_indexes[i] = (int)index;
131 index += send_counts[i];
134 Int64 i64_total_elem = index;
135 Int64 max_size = ARCCORE_INT64_MAX;
136 if (i64_total_elem>max_size){
137 ARCCORE_FATAL(
"Invalid size '{0}'",i64_total_elem);
139 Int64 total_elem = i64_total_elem;
140 recv_buf.resize(total_elem);
142 gatherVariable(send_buf,recv_buf,send_counts,send_indexes,rank);
148template<
class Type>
void MpiTypeDispatcher<Type>::
149gatherVariable(Span<const Type> send_buf,Span<Type> recv_buf,Span<const Int32> send_counts,
150 Span<const Int32> displacements,Int32 rank)
152 MPI_Datatype type = m_datatype->datatype();
153 Int32 nb_elem = send_buf.smallView().size();
154 if (rank!=A_NULL_RANK){
155 m_adapter->gatherVariable(send_buf.data(),recv_buf.data(),send_counts.data(),
156 displacements.data(),nb_elem,rank,type);
159 m_adapter->allGatherVariable(send_buf.data(),recv_buf.data(),send_counts.data(),
160 displacements.data(),nb_elem,type);
167template<
class Type>
void MpiTypeDispatcher<Type>::
168scatterVariable(Span<const Type> send_buf,Span<Type> recv_buf,Int32 root)
170 MPI_Datatype type = m_datatype->datatype();
172 Int32 comm_size = m_adapter->commSize();
173 UniqueArray<int> recv_counts(comm_size);
174 UniqueArray<int> recv_indexes(comm_size);
176 Int64 nb_elem = recv_buf.size();
177 int my_buf_count = m_adapter->toMPISize(nb_elem);
178 Span<const int> count_r(&my_buf_count,1);
185 for( Integer i=0, is=comm_size; i<is; ++i ){
186 recv_indexes[i] = index;
187 index += recv_counts[i];
190 m_adapter->scatterVariable(send_buf.data(),recv_counts.data(),recv_indexes.data(),
191 recv_buf.data(),nb_elem,root,type);
197template<
class Type>
void MpiTypeDispatcher<Type>::
198allToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
200 MPI_Datatype type = m_datatype->datatype();
201 m_adapter->allToAll(send_buf.data(),recv_buf.data(),count,type);
207template<
class Type>
void MpiTypeDispatcher<Type>::
208allToAllVariable(Span<const Type> send_buf,
209 Int32ConstArrayView send_count,
210 Int32ConstArrayView send_index,
212 Int32ConstArrayView recv_count,
213 Int32ConstArrayView recv_index
216 MPI_Datatype type = m_datatype->datatype();
218 m_adapter->allToAllVariable(send_buf.data(),send_count.data(),
219 send_index.data(),recv_buf.data(),
221 recv_index.data(),type);
227template<
class Type> Request MpiTypeDispatcher<Type>::
228send(Span<const Type> send_buffer,Int32 rank,
bool is_blocked)
230 MPI_Datatype type = m_datatype->datatype();
231 return m_adapter->directSend(send_buffer.data(),send_buffer.size(),
232 rank,
sizeof(Type),type,100,is_blocked);
238template<
class Type> Request MpiTypeDispatcher<Type>::
239receive(Span<Type> recv_buffer,Int32 rank,
bool is_blocked)
241 MPI_Datatype type = m_datatype->datatype();
242 MpiLock::Section mls(m_adapter->mpiLock());
243 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
244 rank,
sizeof(Type),type,100,is_blocked);
250template<
class Type> Request MpiTypeDispatcher<Type>::
251send(Span<const Type> send_buffer,
const PointToPointMessageInfo& message)
253 MPI_Datatype type = m_datatype->datatype();
254 Int64 sizeof_type =
sizeof(Type);
255 MpiLock::Section mls(m_adapter->mpiLock());
256 bool is_blocking = message.isBlocking();
257 if (message.isRankTag()){
258 return m_adapter->directSend(send_buffer.data(),send_buffer.size(),
259 message.destinationRank().value(),
260 sizeof_type,type,message.tag().value(),is_blocking);
262 if (message.isMessageId()){
264 ARCCORE_THROW(NotSupportedException,
"Invalid generic send with MessageId");
266 ARCCORE_THROW(NotSupportedException,
"Invalid message_info");
272template<
class Type> Request MpiTypeDispatcher<Type>::
273receive(Span<Type> recv_buffer,
const PointToPointMessageInfo& message)
275 MPI_Datatype type = m_datatype->datatype();
276 Int64 sizeof_type =
sizeof(Type);
277 MpiLock::Section mls(m_adapter->mpiLock());
278 bool is_blocking = message.isBlocking();
279 if (message.isRankTag()){
280 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
281 message.destinationRank().value(),sizeof_type,type,
282 message.tag().value(),
285 if (message.isMessageId()){
286 MessageId message_id = message.messageId();
287 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
288 message_id,sizeof_type,type,is_blocking);
290 ARCCORE_THROW(NotSupportedException,
"Invalid message_info");
296template<
class Type> Type MpiTypeDispatcher<Type>::
297allReduce(eReduceType op,Type send_buf)
299 MPI_Datatype type = m_datatype->datatype();
300 Type recv_buf = send_buf;
301 MPI_Op operation = m_datatype->reduceOperator(op);
302 m_adapter->allReduce(&send_buf,&recv_buf,1,type,operation);
309template<
class Type>
void MpiTypeDispatcher<Type>::
310allReduce(eReduceType op,Span<Type> send_buf)
312 MPI_Datatype type = m_datatype->datatype();
313 Int64 s = send_buf.size();
314 UniqueArray<Type> recv_buf(s);
315 MPI_Op operation = m_datatype->reduceOperator(op);
317 MpiLock::Section mls(m_adapter->mpiLock());
318 m_adapter->allReduce(send_buf.data(),recv_buf.data(),s,type,operation);
320 send_buf.copy(recv_buf);
326template<
class Type> Request MpiTypeDispatcher<Type>::
327nonBlockingAllReduce(eReduceType op,Span<const Type> send_buf,Span<Type> recv_buf)
329 MPI_Datatype type = m_datatype->datatype();
330 Int64 s = send_buf.size();
331 MPI_Op operation = m_datatype->reduceOperator(op);
334 MpiLock::Section mls(m_adapter->mpiLock());
335 request = m_adapter->nonBlockingAllReduce(send_buf.data(),recv_buf.data(),s,type,operation);
343template<
class Type> Request MpiTypeDispatcher<Type>::
344nonBlockingAllToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
346 MPI_Datatype type = m_datatype->datatype();
347 return m_adapter->nonBlockingAllToAll(send_buf.data(),recv_buf.data(),count,type);
353template<
class Type> Request MpiTypeDispatcher<Type>::
354nonBlockingAllToAllVariable(Span<const Type> send_buf,
355 Int32ConstArrayView send_count,
356 Int32ConstArrayView send_index,
358 Int32ConstArrayView recv_count,
359 Int32ConstArrayView recv_index
362 MPI_Datatype type = m_datatype->datatype();
364 return m_adapter->nonBlockingAllToAllVariable(send_buf.data(),send_count.data(),
365 send_index.data(),recv_buf.data(),
367 recv_index.data(),type);
373template<
class Type> Request MpiTypeDispatcher<Type>::
374nonBlockingBroadcast(Span<Type> send_buf,Int32 rank)
376 MPI_Datatype type = m_datatype->datatype();
377 return m_adapter->nonBlockingBroadcast(send_buf.data(),send_buf.size(),rank,type);
383template<
class Type> Request MpiTypeDispatcher<Type>::
384nonBlockingAllGather(Span<const Type> send_buf,Span<Type> recv_buf)
386 MPI_Datatype type = m_datatype->datatype();
387 return m_adapter->nonBlockingAllGather(send_buf.data(),recv_buf.data(),send_buf.size(),type);
393template<
class Type> Request MpiTypeDispatcher<Type>::
394nonBlockingGather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 rank)
396 MPI_Datatype type = m_datatype->datatype();
397 return m_adapter->nonBlockingGather(send_buf.data(),recv_buf.data(),send_buf.size(),rank,type);
403template<
class Type> Request MpiTypeDispatcher<Type>::
404gather(GatherMessageInfo<Type>& gather_info)
406 if (!gather_info.isValid())
409 bool is_blocking = gather_info.isBlocking();
410 MessageRank dest_rank = gather_info.destinationRank();
411 bool is_all_variant = dest_rank.isNull();
412 MessageRank my_rank(m_parallel_mng->commRank());
414 auto send_buf = gather_info.sendBuffer();
418 if (gather_info.mode()==GatherMessageInfoBase::Mode::GatherVariableNeedComputeInfo) {
420 ARCCORE_THROW(NotSupportedException,
"non blocking version of AllGatherVariable or GatherVariable with compute info");
421 Array<Type>* receive_array = gather_info.localReceptionBuffer();
424 ARCCORE_FATAL(
"local reception buffer is null");
425 this->allGatherVariable(send_buf, *receive_array);
428 UniqueArray<Type> unused_array;
429 if (dest_rank==my_rank)
430 this->gatherVariable(send_buf, *receive_array, dest_rank.value());
432 this->gatherVariable(send_buf, unused_array, dest_rank.value());
438 if (gather_info.mode() == GatherMessageInfoBase::Mode::GatherVariable) {
440 ARCCORE_THROW(NotImplementedException,
"non blocking version of AllGatherVariable or GatherVariable");
441 auto receive_buf = gather_info.receiveBuffer();
442 auto displacements = gather_info.receiveDisplacement();
443 auto receive_counts = gather_info.receiveCounts();
444 gatherVariable(send_buf, receive_buf, receive_counts, displacements, dest_rank.value());
449 if (gather_info.mode() == GatherMessageInfoBase::Mode::Gather) {
450 auto receive_buf = gather_info.receiveBuffer();
453 this->allGather(send_buf, receive_buf);
455 this->gather(send_buf, receive_buf, dest_rank.value());
460 return this->nonBlockingAllGather(send_buf, receive_buf);
462 return this->nonBlockingGather(send_buf, receive_buf, dest_rank.value());
466 ARCCORE_THROW(NotImplementedException,
"Unknown type() for GatherMessageInfo");
Liste des fonctions d'échange de message.
void mpAllGather(IMessagePassingMng *pm, const ISerializer *send_serializer, ISerializer *receive_serialize)
Message allGather() pour une sérialisation.
C void mpGather(IMessagePassingMng *pm, Span< const char > send_buf, Span< char > recv_buf, Int32 rank)