14#include "arccore/message_passing_mpi/MpiSerializeDispatcher.h"
16#include "arccore/message_passing_mpi/MpiAdapter.h"
17#include "arccore/message_passing_mpi/MpiMessagePassingMng.h"
18#include "arccore/message_passing_mpi/MpiSerializeMessageList.h"
19#include "arccore/message_passing_mpi/MpiLock.h"
20#include "arccore/message_passing/Request.h"
21#include "arccore/message_passing/internal/SubRequestCompletionInfo.h"
22#include "arccore/serialize/BasicSerializer.h"
23#include "arccore/base/NotImplementedException.h"
24#include "arccore/base/FatalErrorException.h"
25#include "arccore/base/NotSupportedException.h"
26#include "arccore/base/ArgumentException.h"
27#include "arccore/base/PlatformUtils.h"
28#include "arccore/trace/ITraceMng.h"
33namespace Arccore::MessagePassing::Mpi
57template<
typename SpanType>
62 : m_buffer(buffer), m_datatype(byte_serializer_datatype), m_final_size(-1)
66 if ((size%align_size)!=0)
67 ARCCORE_FATAL(
"Buffer size '{0}' is not a multiple of '{1}' Invalid size",size,align_size);
68 m_final_size = size / align_size;
70 SpanType* data() {
return m_buffer.
data(); }
71 Int64 size()
const {
return m_final_size; }
72 Int64 messageSize()
const {
return m_buffer.
size() *
sizeof(
Byte); }
74 MPI_Datatype datatype()
const {
return m_datatype; }
77 MPI_Datatype m_datatype;
99 : m_dispatcher(pm), m_serialize_buffer(buf), m_rank(rank), m_mpi_tag(mpi_tag) {}
105 if (!m_is_message_sent)
107 return m_send_request;
112 if (m_is_message_sent)
113 ARCCORE_FATAL(
"Message already sent");
114 bool do_print = m_dispatcher->m_is_trace_serializer;
116 ITraceMng* tm = m_dispatcher->traceMng();
117 tm->
info() <<
" SendSerializerSubRequest::sendMessage()"
118 <<
" rank=" << m_rank <<
" tag=" << m_mpi_tag;
120 Span<Byte> bytes = m_serialize_buffer->globalBuffer();
121 m_send_request = m_dispatcher->_sendSerializerBytes(bytes,m_rank,m_mpi_tag,
false);
122 m_is_message_sent =
true;
125 MpiSerializeDispatcher* m_dispatcher;
130 bool m_is_message_sent =
false;
144 , m_serialize_buffer(buf)
154 bool is_trace = m_dispatcher->m_is_trace_serializer;
155 ITraceMng* tm = m_dispatcher->traceMng();
157 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion()"
158 <<
" rank=" << rank <<
" wanted_tag=" << m_mpi_tag <<
" action=" << m_action;
162 Int64 total_recv_size = sbuf->totalSize();
165 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion() total_size=" << total_recv_size
169 if (total_recv_size<=m_dispatcher->m_serialize_buffer_size){
170 sbuf->setFromSizes();
174 sbuf->preallocate(total_recv_size);
175 auto bytes = sbuf->globalBuffer();
179 Request r2 = m_dispatcher->_recvSerializerBytes(bytes, rank, m_mpi_tag,
false);
181 r2.setSubRequest(makeRef(sr));
185 m_serialize_buffer->setFromSizes();
204MpiSerializeDispatcher::
207, m_trace(adapter->traceMng())
208, m_serialize_buffer_size(50000)
210, m_max_serialize_buffer_size(m_serialize_buffer_size)
211, m_byte_serializer_datatype(MPI_DATATYPE_NULL)
219MpiSerializeDispatcher::
220~MpiSerializeDispatcher()
222 if (m_byte_serializer_datatype!=MPI_DATATYPE_NULL)
223 MPI_Type_free(&m_byte_serializer_datatype);
229MessageTag MpiSerializeDispatcher::
230nextSerializeTag(MessageTag tag)
232 return MessageTag(tag.value()+1);
238void MpiSerializeDispatcher::
242 MPI_Datatype mpi_datatype;
244 MPI_Type_commit(&mpi_datatype);
245 m_byte_serializer_datatype = mpi_datatype;
248 m_is_trace_serializer =
true;
254Request MpiSerializeDispatcher::
255legacySendSerializer(ISerializer* values,
const PointToPointMessageInfo& message)
257 if (!message.isRankTag())
258 ARCCORE_FATAL(
"Only message.isRangTag()==true are allowed for legacy mode");
260 MessageRank rank = message.destinationRank();
261 MessageTag mpi_tag = message.tag();
262 bool is_blocking = message.isBlocking();
264 BasicSerializer* sbuf = _castSerializer(values);
265 ITraceMng* tm = m_trace;
267 Span<Byte> bytes = sbuf->globalBuffer();
269 Int64 total_size = sbuf->totalSize();
270 _checkBigMessage(total_size);
272 if (m_is_trace_serializer)
273 tm->info() <<
"legacySendSerializer(): sending to "
274 <<
" rank=" << rank <<
" bytes " << bytes.size()
275 << BasicSerializer::SizesPrinter(*sbuf)
276 <<
" tag=" << mpi_tag <<
" is_blocking=" << is_blocking;
280 if (total_size<=m_serialize_buffer_size){
281 if (m_is_trace_serializer)
282 tm->info() <<
"Small message size=" << bytes.size();
283 return _sendSerializerBytes(bytes,rank,mpi_tag,is_blocking);
289 auto x = sbuf->copyAndGetSizesBuffer();
290 if (m_is_trace_serializer)
291 tm->info() <<
"Big message first size=" << x.size();
292 Request r = _sendSerializerBytes(x,rank,mpi_tag,is_blocking);
294 SerializeSubRequest* sub_request =
new SerializeSubRequest();
295 sub_request->m_request = r;
298 MpiLock::Section ls(m_adapter->mpiLock());
299 m_sub_requests.add(sub_request);
304 if (m_is_trace_serializer)
305 tm->info() <<
"Big message second size=" << bytes.size();
306 return _sendSerializerBytes(bytes,rank,nextSerializeTag(mpi_tag),is_blocking);
312Request MpiSerializeDispatcher::
313_recvSerializerBytes(Span<Byte> bytes,MessageId message_id,
bool is_blocking)
315 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
316 MPI_Datatype dt = sbc.datatype();
317 if (m_is_trace_serializer)
318 m_trace->
info() <<
"_recvSerializerBytes: size=" << bytes.size()
319 <<
" message_id=" << message_id <<
" is_blocking=" << is_blocking;
320 return m_adapter->directRecv(sbc.data(),sbc.size(),message_id,sbc.elementSize(),dt,is_blocking);
326Request MpiSerializeDispatcher::
327_recvSerializerBytes(Span<Byte> bytes,MessageRank rank,MessageTag tag,
bool is_blocking)
329 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
330 MPI_Datatype dt = sbc.datatype();
331 if (m_is_trace_serializer)
332 m_trace->
info() <<
"_recvSerializerBytes: size=" << bytes.size()
333 <<
" rank=" << rank <<
" tag=" << tag <<
" is_blocking=" << is_blocking;
334 Request r = m_adapter->directRecv(sbc.data(),sbc.size(),rank.value(),
335 sbc.elementSize(),dt,tag.value(),is_blocking);
336 if (m_is_trace_serializer)
337 m_trace->
info() <<
"_recvSerializerBytes: request=" << r;
344Request MpiSerializeDispatcher::
345_sendSerializerBytes(Span<const Byte> bytes,MessageRank rank,MessageTag tag,
348 SerializeByteConverter<const Byte> sbc(bytes,m_byte_serializer_datatype);
349 MPI_Datatype dt = sbc.datatype();
350 if (m_is_trace_serializer)
351 m_trace->
info() <<
"_sendSerializerBytes: orig_size=" << bytes.size()
352 <<
" rank=" << rank <<
" tag=" << tag
353 <<
" second_size=" << sbc.size()
354 <<
" message_size=" << sbc.messageSize();
355 Request r = m_adapter->directSend(sbc.data(),sbc.size(),rank.value(),
356 sbc.elementSize(),dt,tag.value(),is_blocking);
357 if (m_is_trace_serializer)
358 m_trace->
info() <<
"_sendSerializerBytes: request=" << r;
365void MpiSerializeDispatcher::
366legacyReceiveSerializer(ISerializer* values,MessageRank rank,MessageTag mpi_tag)
368 BasicSerializer* sbuf = _castSerializer(values);
369 ITraceMng* tm = m_trace;
371 if (m_is_trace_serializer)
372 tm->
info() <<
"legacyReceiveSerializer() begin receive"
373 <<
" rank=" << rank <<
" tag=" << mpi_tag;
374 sbuf->preallocate(m_serialize_buffer_size);
375 Span<Byte> bytes = sbuf->globalBuffer();
377 _recvSerializerBytes(bytes,rank,mpi_tag,
true);
378 Int64 total_recv_size = sbuf->totalSize();
380 if (m_is_trace_serializer)
381 tm->info() <<
"legacyReceiveSerializer total_size=" << total_recv_size
383 << BasicSerializer::SizesPrinter(*sbuf);
387 if (total_recv_size<=m_serialize_buffer_size){
388 sbuf->setFromSizes();
392 if (m_is_trace_serializer)
393 tm->info() <<
"Receive overflow buffer: " << total_recv_size;
394 sbuf->preallocate(total_recv_size);
395 bytes = sbuf->globalBuffer();
396 _recvSerializerBytes(bytes,rank,nextSerializeTag(mpi_tag),
true);
397 sbuf->setFromSizes();
398 if (m_is_trace_serializer)
399 tm->info() <<
"End receive overflow buffer: " << total_recv_size;
405void MpiSerializeDispatcher::
406checkFinishedSubRequests()
411 UniqueArray<SerializeSubRequest*> new_sub_requests;
412 for(
Integer i=0, n=m_sub_requests.size(); i<n; ++i ){
413 SerializeSubRequest* ssr = m_sub_requests[i];
414 bool is_finished = m_adapter->testRequest(ssr->m_request);
416 new_sub_requests.add(ssr);
422 m_sub_requests = new_sub_requests;
428void MpiSerializeDispatcher::
429_checkBigMessage(Int64 message_size)
431 if (message_size>m_max_serialize_buffer_size){
432 m_max_serialize_buffer_size = message_size;
433 m_trace->
info() <<
"big buffer: " << message_size;
451 bool force_one_message)
462 Int64 total_size = sbuf->totalSize();
463 _checkBigMessage(total_size);
465 if (m_is_trace_serializer)
466 tm->
info() <<
"sendSerializer(): sending to "
467 <<
" p2p_message=" << message
468 <<
" rank=" << rank <<
" bytes " << bytes.
size()
470 <<
" tag=" << mpi_tag
471 <<
" total_size=" << total_size;
476 if (total_size<=m_serialize_buffer_size || force_one_message){
477 if (m_is_trace_serializer)
478 tm->
info() <<
"Small message size=" << bytes.
size();
479 return _sendSerializerBytes(bytes,rank,mpi_tag,is_blocking);
484 auto x = sbuf->copyAndGetSizesBuffer();
485 Request r1 = _sendSerializerBytes(x,rank,mpi_tag,is_blocking);
486 auto* x2 =
new SendSerializerSubRequest(
this,sbuf,rank,nextSerializeTag(mpi_tag));
489 r1.setSubRequest(makeRef<ISubRequest>(x2));
504 sbuf->preallocate(m_serialize_buffer_size);
509 r = _recvSerializerBytes(bytes,rank,tag,is_blocking);
511 r = _recvSerializerBytes(bytes,message.
messageId(),is_blocking);
513 ARCCORE_THROW(
NotSupportedException,
"Only message.isRankTag() or message.isMessageId() is supported");
515 r.setSubRequest(makeRef<ISubRequest>(sr));
522void MpiSerializeDispatcher::
528 bool is_broadcaster = (rank==my_rank);
530 MPI_Datatype int64_datatype = MpiBuiltIn::datatype(
Int64());
536 Int64 total_size = sbuf->totalSize();
538 _checkBigMessage(total_size);
540 m_adapter->broadcast(total_size_buf.data(),total_size_buf.size(),rank.
value(),int64_datatype);
541 if (m_is_trace_serializer)
542 tm->
info() <<
"MpiSerializeDispatcher::broadcastSerializer(): sending "
545 m_adapter->broadcast(sbc.data(),sbc.size(),rank.
value(),sbc.datatype());
548 Int64 total_size = 0;
550 m_adapter->broadcast(total_size_buf.data(),total_size_buf.size(),rank.
value(),int64_datatype);
551 sbuf->preallocate(total_size);
553 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
554 m_adapter->broadcast(sbc.data(),sbc.size(),rank.
value(),sbc.datatype());
555 sbuf->setFromSizes();
556 if (m_is_trace_serializer)
557 tm->
info() <<
"MpiSerializeDispatcher::broadcastSerializer(): receiving from "
558 <<
" rank=" << rank <<
" bytes " << bytes.
size()
566BasicSerializer* MpiSerializeDispatcher::
567_castSerializer(ISerializer* serializer)
569 BasicSerializer* sbuf =
dynamic_cast<BasicSerializer*
>(serializer);
571 ARCCORE_THROW(ArgumentException,
"Can not cast 'ISerializer' to 'BasicSerializer'");
static ARCCORE_CONSTEXPR Integer paddingSize()
Taille du padding et de l'alignement.
Interface d'un sérialiseur.
Interface du gestionnaire de traces.
virtual TraceMessage info()=0
Flot pour un message d'information.
Interface d'une liste de messages de sérialisation.
Int32 value() const
Valeur du rang.
int commRank() const
Rang de cette instance dans le communicateur.
Request executeOnCompletion(const SubRequestCompletionInfo &completion_info) override
Callback appelé lorsque la requête associée est terminée.
Request executeOnCompletion(const SubRequestCompletionInfo &) override
Callback appelé lorsque la requête associée est terminée.
Request receiveSerializer(ISerializer *s, const PointToPointMessageInfo &message) override
Message de réception.
Request sendSerializer(const ISerializer *s, const PointToPointMessageInfo &message) override
Message d'envoi.
Ref< ISerializeMessageList > createSerializeMessageListRef() override
Créé une liste de messages de sérialisation.
Implémentation MPI de la gestion des 'ISerializeMessage'.
Informations pour envoyer/recevoir un message point à point.
MessageId messageId() const
Identifiant du message.
MessageRank destinationRank() const
Rang de la destination du message.
bool isBlocking() const
Indique si le message est bloquant.
bool isMessageId() const
Vrai si l'instance a été créée avec un MessageId. Dans ce cas messageId() est valide.
MessageTag tag() const
Tag du message.
bool isRankTag() const
Vrai si l'instance a été créée avec un couple (rank,tag). Dans ce cas rank() et tag() sont valides.
constexpr __host__ __device__ SizeType size() const noexcept
Retourne la taille du tableau.
constexpr __host__ __device__ pointer data() const noexcept
Pointeur sur le début de la vue.
Vue d'un tableau d'éléments de type T.
Int32 Integer
Type représentant un entier.
std::int64_t Int64
Type entier signé sur 64 bits.
std::int32_t Int32
Type entier signé sur 32 bits.
unsigned char Byte
Type d'un octet.