14#include "arccore/message_passing_mpi/internal/MpiSerializeDispatcher.h"
16#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
17#include "arccore/message_passing_mpi/MpiMessagePassingMng.h"
18#include "arccore/message_passing_mpi/internal/MpiLock.h"
20#include "arccore/message_passing/Request.h"
21#include "arccore/message_passing/SerializeMessageList.h"
22#include "arccore/message_passing/internal/SubRequestCompletionInfo.h"
24#include "arccore/serialize/BasicSerializer.h"
26#include "arccore/base/NotImplementedException.h"
27#include "arccore/base/FatalErrorException.h"
28#include "arccore/base/NotSupportedException.h"
29#include "arccore/base/ArgumentException.h"
30#include "arccore/base/PlatformUtils.h"
31#include "arccore/trace/ITraceMng.h"
36namespace Arcane::MessagePassing::Mpi
60template<
typename SpanType>
61class SerializeByteConverter
64 SerializeByteConverter(
Span<SpanType> buffer,MPI_Datatype byte_serializer_datatype)
65 : m_buffer(buffer), m_datatype(byte_serializer_datatype), m_final_size(-1)
69 if ((size%align_size)!=0)
70 ARCCORE_FATAL(
"Buffer size '{0}' is not a multiple of '{1}' Invalid size",size,align_size);
71 m_final_size = size / align_size;
73 SpanType* data() {
return m_buffer.data(); }
74 Int64 size()
const {
return m_final_size; }
75 Int64 messageSize()
const {
return m_buffer.size() *
sizeof(
Byte); }
77 MPI_Datatype datatype()
const {
return m_datatype; }
80 MPI_Datatype m_datatype;
100 SendSerializerSubRequest(MpiSerializeDispatcher* pm,
BasicSerializer* buf,
102 : m_dispatcher(pm), m_serialize_buffer(buf), m_rank(rank), m_mpi_tag(mpi_tag) {}
108 if (!m_is_message_sent)
110 return m_send_request;
115 if (m_is_message_sent)
116 ARCCORE_FATAL(
"Message already sent");
117 bool do_print = m_dispatcher->m_is_trace_serializer;
119 ITraceMng* tm = m_dispatcher->traceMng();
120 tm->
info() <<
" SendSerializerSubRequest::sendMessage()"
121 <<
" rank=" << m_rank <<
" tag=" << m_mpi_tag;
123 Span<Byte> bytes = m_serialize_buffer->globalBuffer();
124 m_send_request = m_dispatcher->_sendSerializerBytes(bytes,m_rank,m_mpi_tag,
false);
125 m_is_message_sent =
true;
133 bool m_is_message_sent =
false;
144 ReceiveSerializerSubRequest(MpiSerializeDispatcher* d,
BasicSerializer* buf,
147 , m_serialize_buffer(buf)
157 bool is_trace = m_dispatcher->m_is_trace_serializer;
158 ITraceMng* tm = m_dispatcher->traceMng();
160 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion()"
161 <<
" rank=" << rank <<
" wanted_tag=" << m_mpi_tag <<
" action=" << m_action;
165 Int64 total_recv_size = sbuf->totalSize();
168 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion() total_size=" << total_recv_size
172 if (total_recv_size<=m_dispatcher->m_serialize_buffer_size){
173 sbuf->setFromSizes();
177 sbuf->preallocate(total_recv_size);
178 auto bytes = sbuf->globalBuffer();
182 Request r2 = m_dispatcher->_recvSerializerBytes(bytes, rank, m_mpi_tag,
false);
183 ISubRequest* sr =
new ReceiveSerializerSubRequest(m_dispatcher, m_serialize_buffer, m_mpi_tag, 2);
188 m_serialize_buffer->setFromSizes();
195 MpiSerializeDispatcher* m_dispatcher =
nullptr;
207MpiSerializeDispatcher::
210, m_message_passing_mng(message_passing_mng)
211, m_trace(adapter->traceMng())
212, m_serialize_buffer_size(50000)
214, m_max_serialize_buffer_size(m_serialize_buffer_size)
215, m_byte_serializer_datatype(MPI_DATATYPE_NULL)
223MpiSerializeDispatcher::
224~MpiSerializeDispatcher()
226 if (m_byte_serializer_datatype!=MPI_DATATYPE_NULL)
227 MPI_Type_free(&m_byte_serializer_datatype);
242void MpiSerializeDispatcher::
246 MPI_Datatype mpi_datatype;
248 MPI_Type_commit(&mpi_datatype);
249 m_byte_serializer_datatype = mpi_datatype;
252 m_is_trace_serializer =
true;
258Request MpiSerializeDispatcher::
259legacySendSerializer(ISerializer* values,
const PointToPointMessageInfo& message)
261 if (!message.isRankTag())
262 ARCCORE_FATAL(
"Only message.isRangTag()==true are allowed for legacy mode");
264 MessageRank rank = message.destinationRank();
265 MessageTag mpi_tag = message.tag();
266 bool is_blocking = message.isBlocking();
268 BasicSerializer* sbuf = _castSerializer(values);
269 ITraceMng* tm = m_trace;
271 Span<Byte> bytes = sbuf->globalBuffer();
273 Int64 total_size = sbuf->totalSize();
274 _checkBigMessage(total_size);
276 if (m_is_trace_serializer)
277 tm->info() <<
"legacySendSerializer(): sending to "
278 <<
" rank=" << rank <<
" bytes " << bytes.size()
279 << BasicSerializer::SizesPrinter(*sbuf)
280 <<
" tag=" << mpi_tag <<
" is_blocking=" << is_blocking;
284 if (total_size<=m_serialize_buffer_size){
285 if (m_is_trace_serializer)
286 tm->info() <<
"Small message size=" << bytes.size();
287 return _sendSerializerBytes(bytes,rank,mpi_tag,is_blocking);
293 auto x = sbuf->copyAndGetSizesBuffer();
294 if (m_is_trace_serializer)
295 tm->info() <<
"Big message first size=" << x.size();
296 Request r = _sendSerializerBytes(x,rank,mpi_tag,is_blocking);
299 sub_request->m_request = r;
302 MpiLock::Section ls(m_adapter->mpiLock());
303 m_sub_requests.add(sub_request);
308 if (m_is_trace_serializer)
309 tm->info() <<
"Big message second size=" << bytes.size();
310 return _sendSerializerBytes(bytes,rank,nextSerializeTag(mpi_tag),is_blocking);
316Request MpiSerializeDispatcher::
317_recvSerializerBytes(Span<Byte> bytes,MessageId message_id,
bool is_blocking)
319 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
320 MPI_Datatype dt = sbc.datatype();
321 if (m_is_trace_serializer)
322 m_trace->info() <<
"_recvSerializerBytes: size=" << bytes.size()
323 <<
" message_id=" << message_id <<
" is_blocking=" << is_blocking;
324 return m_adapter->directRecv(sbc.data(),sbc.size(),message_id,sbc.elementSize(),dt,is_blocking);
330Request MpiSerializeDispatcher::
331_recvSerializerBytes(Span<Byte> bytes,MessageRank rank,MessageTag tag,
bool is_blocking)
333 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
334 MPI_Datatype dt = sbc.datatype();
335 if (m_is_trace_serializer)
336 m_trace->info() <<
"_recvSerializerBytes: size=" << bytes.size()
337 <<
" rank=" << rank <<
" tag=" << tag <<
" is_blocking=" << is_blocking;
338 Request r = m_adapter->directRecv(sbc.data(),sbc.size(),rank.value(),
339 sbc.elementSize(),dt,tag.value(),is_blocking);
340 if (m_is_trace_serializer)
341 m_trace->info() <<
"_recvSerializerBytes: request=" << r;
348Request MpiSerializeDispatcher::
349_sendSerializerBytes(Span<const Byte> bytes,MessageRank rank,MessageTag tag,
352 SerializeByteConverter<const Byte> sbc(bytes,m_byte_serializer_datatype);
353 MPI_Datatype dt = sbc.datatype();
354 if (m_is_trace_serializer)
355 m_trace->info() <<
"_sendSerializerBytes: orig_size=" << bytes.size()
356 <<
" rank=" << rank <<
" tag=" << tag
357 <<
" second_size=" << sbc.size()
358 <<
" message_size=" << sbc.messageSize();
359 Request r = m_adapter->directSend(sbc.data(),sbc.size(),rank.value(),
360 sbc.elementSize(),dt,tag.value(),is_blocking);
361 if (m_is_trace_serializer)
362 m_trace->info() <<
"_sendSerializerBytes: request=" << r;
369void MpiSerializeDispatcher::
370legacyReceiveSerializer(ISerializer* values,MessageRank rank,MessageTag mpi_tag)
372 BasicSerializer* sbuf = _castSerializer(values);
373 ITraceMng* tm = m_trace;
375 if (m_is_trace_serializer)
376 tm->info() <<
"legacyReceiveSerializer() begin receive"
377 <<
" rank=" << rank <<
" tag=" << mpi_tag;
378 sbuf->preallocate(m_serialize_buffer_size);
379 Span<Byte> bytes = sbuf->globalBuffer();
381 _recvSerializerBytes(bytes,rank,mpi_tag,
true);
382 Int64 total_recv_size = sbuf->totalSize();
384 if (m_is_trace_serializer)
385 tm->info() <<
"legacyReceiveSerializer total_size=" << total_recv_size
387 << BasicSerializer::SizesPrinter(*sbuf);
391 if (total_recv_size<=m_serialize_buffer_size){
392 sbuf->setFromSizes();
396 if (m_is_trace_serializer)
397 tm->info() <<
"Receive overflow buffer: " << total_recv_size;
398 sbuf->preallocate(total_recv_size);
399 bytes = sbuf->globalBuffer();
400 _recvSerializerBytes(bytes,rank,nextSerializeTag(mpi_tag),
true);
401 sbuf->setFromSizes();
402 if (m_is_trace_serializer)
403 tm->info() <<
"End receive overflow buffer: " << total_recv_size;
409void MpiSerializeDispatcher::
410checkFinishedSubRequests()
415 UniqueArray<SerializeSubRequest*> new_sub_requests;
416 for(
Integer i=0, n=m_sub_requests.size(); i<n; ++i ){
418 bool is_finished = m_adapter->testRequest(ssr->m_request);
420 new_sub_requests.add(ssr);
426 m_sub_requests = new_sub_requests;
432void MpiSerializeDispatcher::
433_checkBigMessage(
Int64 message_size)
435 if (message_size>m_max_serialize_buffer_size){
436 m_max_serialize_buffer_size = message_size;
437 m_trace->info() <<
"big buffer: " << message_size;
455 bool force_one_message)
466 Int64 total_size = sbuf->totalSize();
467 _checkBigMessage(total_size);
469 if (m_is_trace_serializer)
470 tm->
info() <<
"sendSerializer(): sending to "
471 <<
" p2p_message=" << message
472 <<
" rank=" << rank <<
" bytes " << bytes.
size()
474 <<
" tag=" << mpi_tag
475 <<
" total_size=" << total_size;
480 if (total_size<=m_serialize_buffer_size || force_one_message){
481 if (m_is_trace_serializer)
482 tm->
info() <<
"Small message size=" << bytes.
size();
483 return _sendSerializerBytes(bytes,rank,mpi_tag,is_blocking);
488 auto x = sbuf->copyAndGetSizesBuffer();
489 Request r1 = _sendSerializerBytes(x,rank,mpi_tag,is_blocking);
490 auto* x2 =
new SendSerializerSubRequest(
this,sbuf,rank,nextSerializeTag(mpi_tag));
508 sbuf->preallocate(m_serialize_buffer_size);
513 r = _recvSerializerBytes(bytes,rank,tag,is_blocking);
515 r = _recvSerializerBytes(bytes,message.
messageId(),is_blocking);
517 ARCCORE_THROW(
NotSupportedException,
"Only message.isRankTag() or message.isMessageId() is supported");
526void MpiSerializeDispatcher::
532 bool is_broadcaster = (rank==my_rank);
534 MPI_Datatype int64_datatype = MpiBuiltIn::datatype(
Int64());
540 Int64 total_size = sbuf->totalSize();
542 _checkBigMessage(total_size);
544 m_adapter->broadcast(total_size_buf.data(),total_size_buf.size(),rank.
value(),int64_datatype);
545 if (m_is_trace_serializer)
546 tm->
info() <<
"MpiSerializeDispatcher::broadcastSerializer(): sending "
549 m_adapter->broadcast(sbc.data(),sbc.size(),rank.
value(),sbc.datatype());
552 Int64 total_size = 0;
554 m_adapter->broadcast(total_size_buf.data(),total_size_buf.size(),rank.
value(),int64_datatype);
555 sbuf->preallocate(total_size);
558 m_adapter->broadcast(sbc.data(),sbc.size(),rank.
value(),sbc.datatype());
559 sbuf->setFromSizes();
560 if (m_is_trace_serializer)
561 tm->
info() <<
"MpiSerializeDispatcher::broadcastSerializer(): receiving from "
562 <<
" rank=" << rank <<
" bytes " << bytes.
size()
570BasicSerializer* MpiSerializeDispatcher::
571_castSerializer(ISerializer* serializer)
573 BasicSerializer* sbuf =
dynamic_cast<BasicSerializer*
>(serializer);
575 ARCCORE_THROW(ArgumentException,
"Can not cast 'ISerializer' to 'BasicSerializer'");
Vue modifiable d'un tableau d'un type T.
Implémentation basique de 'ISerializer'.
static ARCCORE_CONSTEXPR Integer paddingSize()
Taille du padding et de l'alignement.
Interface d'un sérialiseur.
Interface du gestionnaire de traces.
virtual TraceMessage info()=0
Flot pour un message d'information.
Interface du gestionnaire des échanges de messages.
Interface d'une liste de messages de sérialisation.
Sous-requête d'une requête.
Int32 value() const
Valeur du rang.
int commRank() const
Rang de cette instance dans le communicateur.
Request executeOnCompletion(const SubRequestCompletionInfo &completion_info) override
Callback appelé lorsque la requête associée est terminée.
Request executeOnCompletion(const SubRequestCompletionInfo &) override
Callback appelé lorsque la requête associée est terminée.
Request receiveSerializer(ISerializer *s, const PointToPointMessageInfo &message) override
Message de réception.
Ref< ISerializeMessageList > createSerializeMessageListRef() override
Créé une liste de messages de sérialisation.
Request sendSerializer(const ISerializer *s, const PointToPointMessageInfo &message) override
Message d'envoi.
Wrappeur pour envoyer un tableau d'octets d'un sérialiseur.
Informations pour envoyer/recevoir un message point à point.
MessageId messageId() const
Identifiant du message.
MessageTag tag() const
Tag du message.
bool isBlocking() const
Indique si le message est bloquant.
bool isRankTag() const
Vrai si l'instance a été créée avec un couple (rank,tag). Dans ce cas rank() et tag() sont valides.
bool isMessageId() const
Vrai si l'instance a été créée avec un MessageId. Dans ce cas messageId() est valide.
MessageRank destinationRank() const
Rang de la destination du message.
Informations de complètion d'une sous-requête.
MessageRank sourceRank() const
Rang d'origine de la requête.
Liste de messages de sérialisation.
Exception lorsqu'une opération n'est pas supportée.
Référence à une instance.
constexpr __host__ __device__ SizeType size() const noexcept
Retourne la taille du tableau.
Vue d'un tableau d'éléments de type T.
std::int64_t Int64
Type entier signé sur 64 bits.
Int32 Integer
Type représentant un entier.
unsigned char Byte
Type d'un octet.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Créé une référence sur un pointeur.
std::int32_t Int32
Type entier signé sur 32 bits.