12#ifndef ARCCORE_MESSAGEPASSINGMPI_INTERNAL_MPITYPEDISPATCHERIMPL_H
13#define ARCCORE_MESSAGEPASSINGMPI_INTERNAL_MPITYPEDISPATCHERIMPL_H
17#include "arccore/message_passing_mpi/internal/MpiTypeDispatcher.h"
19#include "arccore/message_passing_mpi/MpiDatatype.h"
20#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
21#include "arccore/message_passing_mpi/internal/MpiLock.h"
24#include "arccore/message_passing/Request.h"
25#include "arccore/message_passing/GatherMessageInfo.h"
27#include "arccore/base/NotSupportedException.h"
28#include "arccore/base/NotImplementedException.h"
30#include "arccore/collections/Array.h"
35namespace Arcane::MessagePassing::Mpi
43: m_parallel_mng(parallel_mng)
52template<
class Type> MpiTypeDispatcher<Type>::
55 if (m_is_destroy_datatype)
62template<
class Type>
void MpiTypeDispatcher<Type>::
63broadcast(Span<Type> send_buf,Int32 rank)
65 MPI_Datatype type = m_datatype->datatype();
66 m_adapter->broadcast(send_buf.data(),send_buf.size(),rank,type);
72template<
class Type>
void MpiTypeDispatcher<Type>::
73allGather(Span<const Type> send_buf,Span<Type> recv_buf)
75 MPI_Datatype type = m_datatype->datatype();
76 m_adapter->allGather(send_buf.data(),recv_buf.data(),send_buf.size(),type);
82template<
class Type>
void MpiTypeDispatcher<Type>::
83gather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 rank)
85 MPI_Datatype type = m_datatype->datatype();
86 m_adapter->gather(send_buf.data(),recv_buf.data(),send_buf.size(),rank,type);
92template<
class Type>
void MpiTypeDispatcher<Type>::
93allGatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf)
95 _gatherVariable2(send_buf,recv_buf,-1);
101template<
class Type>
void MpiTypeDispatcher<Type>::
102gatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 rank)
104 _gatherVariable2(send_buf,recv_buf,rank);
110template<
class Type>
void MpiTypeDispatcher<Type>::
111_gatherVariable2(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 rank)
113 Int32 comm_size = m_parallel_mng->commSize();
114 UniqueArray<int> send_counts(comm_size);
115 UniqueArray<int> send_indexes(comm_size);
117 Int64 nb_elem = send_buf.size();
118 int my_buf_count = (int)nb_elem;
119 Span<const int> count_r(&my_buf_count,1);
122 if (rank!=A_NULL_RANK)
123 mpGather(m_parallel_mng,count_r,send_counts,rank);
128 if (rank==A_NULL_RANK || rank==m_adapter->commRank()){
130 for( Integer i=0, is=comm_size; i<is; ++i ){
131 send_indexes[i] = (int)index;
132 index += send_counts[i];
135 Int64 i64_total_elem = index;
136 Int64 max_size = ARCCORE_INT64_MAX;
137 if (i64_total_elem>max_size){
138 ARCCORE_FATAL(
"Invalid size '{0}'",i64_total_elem);
140 Int64 total_elem = i64_total_elem;
141 recv_buf.resize(total_elem);
143 gatherVariable(send_buf,recv_buf,send_counts,send_indexes,rank);
149template<
class Type>
void MpiTypeDispatcher<Type>::
150gatherVariable(Span<const Type> send_buf,Span<Type> recv_buf,Span<const Int32> send_counts,
151 Span<const Int32> displacements,Int32 rank)
153 MPI_Datatype type = m_datatype->datatype();
154 Int32 nb_elem = send_buf.smallView().size();
155 if (rank!=A_NULL_RANK){
156 m_adapter->gatherVariable(send_buf.data(),recv_buf.data(),send_counts.data(),
157 displacements.data(),nb_elem,rank,type);
160 m_adapter->allGatherVariable(send_buf.data(),recv_buf.data(),send_counts.data(),
161 displacements.data(),nb_elem,type);
168template<
class Type>
void MpiTypeDispatcher<Type>::
169scatterVariable(Span<const Type> send_buf,Span<Type> recv_buf,Int32 root)
171 MPI_Datatype type = m_datatype->datatype();
173 Int32 comm_size = m_adapter->commSize();
174 UniqueArray<int> recv_counts(comm_size);
175 UniqueArray<int> recv_indexes(comm_size);
177 Int64 nb_elem = recv_buf.size();
178 int my_buf_count = m_adapter->toMPISize(nb_elem);
179 Span<const int> count_r(&my_buf_count,1);
186 for( Integer i=0, is=comm_size; i<is; ++i ){
187 recv_indexes[i] = index;
188 index += recv_counts[i];
191 m_adapter->scatterVariable(send_buf.data(),recv_counts.data(),recv_indexes.data(),
192 recv_buf.data(),nb_elem,root,type);
198template<
class Type>
void MpiTypeDispatcher<Type>::
199allToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
201 MPI_Datatype type = m_datatype->datatype();
202 m_adapter->allToAll(send_buf.data(),recv_buf.data(),count,type);
208template<
class Type>
void MpiTypeDispatcher<Type>::
209allToAllVariable(Span<const Type> send_buf,
210 ConstArrayView<Int32> send_count,
211 ConstArrayView<Int32> send_index,
213 ConstArrayView<Int32> recv_count,
214 ConstArrayView<Int32> recv_index
217 MPI_Datatype type = m_datatype->datatype();
219 m_adapter->allToAllVariable(send_buf.data(),send_count.data(),
220 send_index.data(),recv_buf.data(),
222 recv_index.data(),type);
228template<
class Type> Request MpiTypeDispatcher<Type>::
229send(Span<const Type> send_buffer,Int32 rank,
bool is_blocked)
231 MPI_Datatype type = m_datatype->datatype();
232 return m_adapter->directSend(send_buffer.data(),send_buffer.size(),
233 rank,
sizeof(
Type),type,100,is_blocked);
239template<
class Type> Request MpiTypeDispatcher<Type>::
240receive(Span<Type> recv_buffer,Int32 rank,
bool is_blocked)
242 MPI_Datatype type = m_datatype->datatype();
243 MpiLock::Section mls(m_adapter->mpiLock());
244 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
245 rank,
sizeof(
Type),type,100,is_blocked);
251template<
class Type> Request MpiTypeDispatcher<Type>::
252send(Span<const Type> send_buffer,
const PointToPointMessageInfo& message)
254 MPI_Datatype type = m_datatype->datatype();
255 Int64 sizeof_type =
sizeof(
Type);
256 MpiLock::Section mls(m_adapter->mpiLock());
257 bool is_blocking = message.isBlocking();
258 if (message.isRankTag()){
259 return m_adapter->directSend(send_buffer.data(),send_buffer.size(),
260 message.destinationRank().value(),
261 sizeof_type,type,message.tag().value(),is_blocking);
263 if (message.isMessageId()){
265 ARCCORE_THROW(NotSupportedException,
"Invalid generic send with MessageId");
267 ARCCORE_THROW(NotSupportedException,
"Invalid message_info");
273template<
class Type> Request MpiTypeDispatcher<Type>::
274receive(Span<Type> recv_buffer,
const PointToPointMessageInfo& message)
276 MPI_Datatype type = m_datatype->datatype();
277 Int64 sizeof_type =
sizeof(
Type);
278 MpiLock::Section mls(m_adapter->mpiLock());
279 bool is_blocking = message.isBlocking();
280 if (message.isRankTag()){
281 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
282 message.destinationRank().value(),sizeof_type,type,
283 message.tag().value(),
286 if (message.isMessageId()){
287 MessageId message_id = message.messageId();
288 return m_adapter->directRecv(recv_buffer.data(),recv_buffer.size(),
289 message_id,sizeof_type,type,is_blocking);
291 ARCCORE_THROW(NotSupportedException,
"Invalid message_info");
297template<
class Type>
Type MpiTypeDispatcher<Type>::
298allReduce(eReduceType op,
Type send_buf)
300 MPI_Datatype type = m_datatype->datatype();
301 Type recv_buf = send_buf;
302 MPI_Op operation = m_datatype->reduceOperator(op);
303 m_adapter->allReduce(&send_buf,&recv_buf,1,type,operation);
310template<
class Type>
void MpiTypeDispatcher<Type>::
311allReduce(eReduceType op,Span<Type> send_buf)
313 MPI_Datatype type = m_datatype->datatype();
314 Int64 s = send_buf.size();
315 UniqueArray<Type> recv_buf(s);
316 MPI_Op operation = m_datatype->reduceOperator(op);
318 MpiLock::Section mls(m_adapter->mpiLock());
319 m_adapter->allReduce(send_buf.data(),recv_buf.data(),s,type,operation);
321 send_buf.copy(recv_buf);
327template<
class Type> Request MpiTypeDispatcher<Type>::
328nonBlockingAllReduce(eReduceType op,Span<const Type> send_buf,Span<Type> recv_buf)
330 MPI_Datatype type = m_datatype->datatype();
331 Int64 s = send_buf.size();
332 MPI_Op operation = m_datatype->reduceOperator(op);
335 MpiLock::Section mls(m_adapter->mpiLock());
336 request = m_adapter->nonBlockingAllReduce(send_buf.data(),recv_buf.data(),s,type,operation);
344template<
class Type> Request MpiTypeDispatcher<Type>::
345nonBlockingAllToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
347 MPI_Datatype type = m_datatype->datatype();
348 return m_adapter->nonBlockingAllToAll(send_buf.data(),recv_buf.data(),count,type);
354template<
class Type> Request MpiTypeDispatcher<Type>::
355nonBlockingAllToAllVariable(Span<const Type> send_buf,
356 ConstArrayView<Int32> send_count,
357 ConstArrayView<Int32> send_index,
359 ConstArrayView<Int32> recv_count,
360 ConstArrayView<Int32> recv_index
363 MPI_Datatype type = m_datatype->datatype();
365 return m_adapter->nonBlockingAllToAllVariable(send_buf.data(),send_count.data(),
366 send_index.data(),recv_buf.data(),
368 recv_index.data(),type);
374template<
class Type> Request MpiTypeDispatcher<Type>::
375nonBlockingBroadcast(Span<Type> send_buf,Int32 rank)
377 MPI_Datatype type = m_datatype->datatype();
378 return m_adapter->nonBlockingBroadcast(send_buf.data(),send_buf.size(),rank,type);
384template<
class Type> Request MpiTypeDispatcher<Type>::
385nonBlockingAllGather(Span<const Type> send_buf,Span<Type> recv_buf)
387 MPI_Datatype type = m_datatype->datatype();
388 return m_adapter->nonBlockingAllGather(send_buf.data(),recv_buf.data(),send_buf.size(),type);
394template<
class Type> Request MpiTypeDispatcher<Type>::
395nonBlockingGather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 rank)
397 MPI_Datatype type = m_datatype->datatype();
398 return m_adapter->nonBlockingGather(send_buf.data(),recv_buf.data(),send_buf.size(),rank,type);
404template<
class Type> Request MpiTypeDispatcher<Type>::
405gather(GatherMessageInfo<Type>& gather_info)
407 if (!gather_info.isValid())
410 bool is_blocking = gather_info.isBlocking();
411 MessageRank dest_rank = gather_info.destinationRank();
412 bool is_all_variant = dest_rank.isNull();
413 MessageRank my_rank(m_parallel_mng->commRank());
415 auto send_buf = gather_info.sendBuffer();
419 if (gather_info.mode()==GatherMessageInfoBase::Mode::GatherVariableNeedComputeInfo) {
421 ARCCORE_THROW(NotSupportedException,
"non blocking version of AllGatherVariable or GatherVariable with compute info");
422 Array<Type>* receive_array = gather_info.localReceptionBuffer();
425 ARCCORE_FATAL(
"local reception buffer is null");
426 this->allGatherVariable(send_buf, *receive_array);
429 UniqueArray<Type> unused_array;
430 if (dest_rank==my_rank)
431 this->gatherVariable(send_buf, *receive_array, dest_rank.value());
433 this->gatherVariable(send_buf, unused_array, dest_rank.value());
439 if (gather_info.mode() == GatherMessageInfoBase::Mode::GatherVariable) {
441 ARCCORE_THROW(NotImplementedException,
"non blocking version of AllGatherVariable or GatherVariable");
442 auto receive_buf = gather_info.receiveBuffer();
443 auto displacements = gather_info.receiveDisplacement();
444 auto receive_counts = gather_info.receiveCounts();
445 gatherVariable(send_buf, receive_buf, receive_counts, displacements, dest_rank.value());
450 if (gather_info.mode() == GatherMessageInfoBase::Mode::Gather) {
451 auto receive_buf = gather_info.receiveBuffer();
454 this->allGather(send_buf, receive_buf);
456 this->gather(send_buf, receive_buf, dest_rank.value());
461 return this->nonBlockingAllGather(send_buf, receive_buf);
463 return this->nonBlockingGather(send_buf, receive_buf, dest_rank.value());
467 ARCCORE_THROW(NotImplementedException,
"Unknown type() for GatherMessageInfo");
Liste des fonctions d'échange de message.
Interface du gestionnaire des échanges de messages.
Encapsulation d'un MPI_Datatype.
C void mpGather(IMessagePassingMng *pm, Span< const char > send_buf, Span< char > recv_buf, Int32 rank)
void mpAllGather(IMessagePassingMng *pm, const ISerializer *send_serializer, ISerializer *receive_serialize)
Message allGather() pour une sérialisation.