14#include "arccore/message_passing_mpi/internal/MpiSerializeDispatcher.h"
16#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
17#include "arccore/message_passing_mpi/MpiMessagePassingMng.h"
18#include "arccore/message_passing_mpi/internal/MpiLock.h"
20#include "arccore/message_passing/Request.h"
21#include "arccore/message_passing/SerializeMessageList.h"
22#include "arccore/message_passing/internal/SubRequestCompletionInfo.h"
24#include "arccore/serialize/BasicSerializer.h"
26#include "arccore/base/NotImplementedException.h"
27#include "arccore/base/FatalErrorException.h"
28#include "arccore/base/NotSupportedException.h"
29#include "arccore/base/ArgumentException.h"
30#include "arccore/base/PlatformUtils.h"
31#include "arccore/trace/ITraceMng.h"
36namespace Arcane::MessagePassing::Mpi
60template<
typename SpanType>
61class SerializeByteConverter
64 SerializeByteConverter(
Span<SpanType> buffer,MPI_Datatype byte_serializer_datatype)
65 : m_buffer(buffer), m_datatype(byte_serializer_datatype), m_final_size(-1)
69 if ((size%align_size)!=0)
70 ARCCORE_FATAL(
"Buffer size '{0}' is not a multiple of '{1}' Invalid size",size,align_size);
71 m_final_size = size / align_size;
73 SpanType* data() {
return m_buffer.data(); }
74 Int64 size()
const {
return m_final_size; }
75 Int64 messageSize()
const {
return m_buffer.size() *
sizeof(
Byte); }
77 MPI_Datatype datatype()
const {
return m_datatype; }
80 MPI_Datatype m_datatype;
100 SendSerializerSubRequest(MpiSerializeDispatcher* pm,
BasicSerializer* buf,
102 : m_dispatcher(pm), m_serialize_buffer(buf), m_rank(rank), m_mpi_tag(mpi_tag) {}
108 if (!m_is_message_sent)
110 return m_send_request;
115 if (m_is_message_sent)
116 ARCCORE_FATAL(
"Message already sent");
117 bool do_print = m_dispatcher->m_is_trace_serializer;
119 ITraceMng* tm = m_dispatcher->traceMng();
120 tm->
info() <<
" SendSerializerSubRequest::sendMessage()"
121 <<
" rank=" << m_rank <<
" tag=" << m_mpi_tag;
123 Span<Byte> bytes = m_serialize_buffer->globalBuffer();
124 m_send_request = m_dispatcher->_sendSerializerBytes(bytes,m_rank,m_mpi_tag,
false);
125 m_is_message_sent =
true;
128 MpiSerializeDispatcher* m_dispatcher;
133 bool m_is_message_sent =
false;
144 ReceiveSerializerSubRequest(MpiSerializeDispatcher* d,
BasicSerializer* buf,
147 , m_serialize_buffer(buf)
157 bool is_trace = m_dispatcher->m_is_trace_serializer;
158 ITraceMng* tm = m_dispatcher->traceMng();
160 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion()"
161 <<
" rank=" << rank <<
" wanted_tag=" << m_mpi_tag <<
" action=" << m_action;
165 Int64 total_recv_size = sbuf->totalSize();
168 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion() total_size=" << total_recv_size
172 if (total_recv_size<=m_dispatcher->m_serialize_buffer_size){
173 sbuf->setFromSizes();
177 sbuf->preallocate(total_recv_size);
178 auto bytes = sbuf->globalBuffer();
182 Request r2 = m_dispatcher->_recvSerializerBytes(bytes, rank, m_mpi_tag,
false);
183 ISubRequest* sr =
new ReceiveSerializerSubRequest(m_dispatcher, m_serialize_buffer, m_mpi_tag, 2);
188 m_serialize_buffer->setFromSizes();
195 MpiSerializeDispatcher* m_dispatcher =
nullptr;
207MpiSerializeDispatcher::
210, m_message_passing_mng(message_passing_mng)
211, m_trace(adapter->traceMng())
212, m_serialize_buffer_size(50000)
214, m_max_serialize_buffer_size(m_serialize_buffer_size)
215, m_byte_serializer_datatype(MPI_DATATYPE_NULL)
223MpiSerializeDispatcher::
224~MpiSerializeDispatcher()
226 if (m_byte_serializer_datatype!=MPI_DATATYPE_NULL)
227 MPI_Type_free(&m_byte_serializer_datatype);
242void MpiSerializeDispatcher::
246 MPI_Datatype mpi_datatype;
248 MPI_Type_commit(&mpi_datatype);
249 m_byte_serializer_datatype = mpi_datatype;
252 m_is_trace_serializer =
true;
258Request MpiSerializeDispatcher::
259legacySendSerializer(ISerializer* values,
const PointToPointMessageInfo& message)
261 if (!message.isRankTag())
262 ARCCORE_FATAL(
"Only message.isRangTag()==true are allowed for legacy mode");
264 MessageRank rank = message.destinationRank();
265 MessageTag mpi_tag = message.tag();
266 bool is_blocking = message.isBlocking();
268 BasicSerializer* sbuf = _castSerializer(values);
269 ITraceMng* tm = m_trace;
271 Span<Byte> bytes = sbuf->globalBuffer();
273 Int64 total_size = sbuf->totalSize();
274 _checkBigMessage(total_size);
276 if (m_is_trace_serializer)
277 tm->info() <<
"legacySendSerializer(): sending to "
278 <<
" rank=" << rank <<
" bytes " << bytes.size()
279 << BasicSerializer::SizesPrinter(*sbuf)
280 <<
" tag=" << mpi_tag <<
" is_blocking=" << is_blocking;
284 if (total_size<=m_serialize_buffer_size){
285 if (m_is_trace_serializer)
286 tm->info() <<
"Small message size=" << bytes.size();
287 return _sendSerializerBytes(bytes,rank,mpi_tag,is_blocking);
293 auto x = sbuf->copyAndGetSizesBuffer();
294 if (m_is_trace_serializer)
295 tm->info() <<
"Big message first size=" << x.size();
296 Request r = _sendSerializerBytes(x,rank,mpi_tag,is_blocking);
298 SerializeSubRequest* sub_request =
new SerializeSubRequest();
299 sub_request->m_request = r;
302 MpiLock::Section ls(m_adapter->mpiLock());
303 m_sub_requests.add(sub_request);
308 if (m_is_trace_serializer)
309 tm->info() <<
"Big message second size=" << bytes.size();
310 return _sendSerializerBytes(bytes,rank,nextSerializeTag(mpi_tag),is_blocking);
316Request MpiSerializeDispatcher::
317_recvSerializerBytes(Span<Byte> bytes,MessageId message_id,
bool is_blocking)
319 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
320 MPI_Datatype dt = sbc.datatype();
321 if (m_is_trace_serializer)
322 m_trace->info() <<
"_recvSerializerBytes: size=" << bytes.size()
323 <<
" message_id=" << message_id <<
" is_blocking=" << is_blocking;
324 return m_adapter->directRecv(sbc.data(),sbc.size(),message_id,sbc.elementSize(),dt,is_blocking);
330Request MpiSerializeDispatcher::
331_recvSerializerBytes(Span<Byte> bytes,MessageRank rank,MessageTag tag,
bool is_blocking)
333 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
334 MPI_Datatype dt = sbc.datatype();
335 if (m_is_trace_serializer)
336 m_trace->info() <<
"_recvSerializerBytes: size=" << bytes.size()
337 <<
" rank=" << rank <<
" tag=" << tag <<
" is_blocking=" << is_blocking;
338 Request r = m_adapter->directRecv(sbc.data(),sbc.size(),rank.value(),
339 sbc.elementSize(),dt,tag.value(),is_blocking);
340 if (m_is_trace_serializer)
341 m_trace->info() <<
"_recvSerializerBytes: request=" << r;
348Request MpiSerializeDispatcher::
349_sendSerializerBytes(Span<const Byte> bytes,MessageRank rank,MessageTag tag,
352 SerializeByteConverter<const Byte> sbc(bytes,m_byte_serializer_datatype);
353 MPI_Datatype dt = sbc.datatype();
354 if (m_is_trace_serializer)
355 m_trace->info() <<
"_sendSerializerBytes: orig_size=" << bytes.size()
356 <<
" rank=" << rank <<
" tag=" << tag
357 <<
" second_size=" << sbc.size()
358 <<
" message_size=" << sbc.messageSize();
359 Request r = m_adapter->directSend(sbc.data(),sbc.size(),rank.value(),
360 sbc.elementSize(),dt,tag.value(),is_blocking);
361 if (m_is_trace_serializer)
362 m_trace->info() <<
"_sendSerializerBytes: request=" << r;
369void MpiSerializeDispatcher::
370legacyReceiveSerializer(ISerializer* values,MessageRank rank,MessageTag mpi_tag)
372 BasicSerializer* sbuf = _castSerializer(values);
373 ITraceMng* tm = m_trace;
375 if (m_is_trace_serializer)
376 tm->info() <<
"legacyReceiveSerializer() begin receive"
377 <<
" rank=" << rank <<
" tag=" << mpi_tag;
378 sbuf->preallocate(m_serialize_buffer_size);
379 Span<Byte> bytes = sbuf->globalBuffer();
381 _recvSerializerBytes(bytes,rank,mpi_tag,
true);
382 Int64 total_recv_size = sbuf->totalSize();
384 if (m_is_trace_serializer)
385 tm->info() <<
"legacyReceiveSerializer total_size=" << total_recv_size
387 << BasicSerializer::SizesPrinter(*sbuf);
391 if (total_recv_size<=m_serialize_buffer_size){
392 sbuf->setFromSizes();
396 if (m_is_trace_serializer)
397 tm->info() <<
"Receive overflow buffer: " << total_recv_size;
398 sbuf->preallocate(total_recv_size);
399 bytes = sbuf->globalBuffer();
400 _recvSerializerBytes(bytes,rank,nextSerializeTag(mpi_tag),
true);
401 sbuf->setFromSizes();
402 if (m_is_trace_serializer)
403 tm->info() <<
"End receive overflow buffer: " << total_recv_size;
409void MpiSerializeDispatcher::
410checkFinishedSubRequests()
415 UniqueArray<SerializeSubRequest*> new_sub_requests;
416 for(
Integer i=0, n=m_sub_requests.size(); i<n; ++i ){
417 SerializeSubRequest* ssr = m_sub_requests[i];
418 bool is_finished = m_adapter->testRequest(ssr->m_request);
420 new_sub_requests.add(ssr);
426 m_sub_requests = new_sub_requests;
432void MpiSerializeDispatcher::
433_checkBigMessage(
Int64 message_size)
435 if (message_size>m_max_serialize_buffer_size){
436 m_max_serialize_buffer_size = message_size;
437 m_trace->info() <<
"big buffer: " << message_size;
444Request MpiSerializeDispatcher::
445sendSerializer(
const ISerializer* s,
const PointToPointMessageInfo& message)
447 return sendSerializer(s,message,
false);
453Request MpiSerializeDispatcher::
454sendSerializer(
const ISerializer* s,
const PointToPointMessageInfo& message,
455 bool force_one_message)
457 BasicSerializer* sbuf = _castSerializer(
const_cast<ISerializer*
>(s));
459 MessageRank rank = message.destinationRank();
460 MessageTag mpi_tag = message.tag();
461 bool is_blocking = message.isBlocking();
463 ITraceMng* tm = m_trace;
465 Span<const Byte> bytes = sbuf->globalBuffer();
466 Int64 total_size = sbuf->totalSize();
467 _checkBigMessage(total_size);
469 if (m_is_trace_serializer)
470 tm->info() <<
"sendSerializer(): sending to "
471 <<
" p2p_message=" << message
472 <<
" rank=" << rank <<
" bytes " << bytes.size()
473 << BasicSerializer::SizesPrinter(*sbuf)
474 <<
" tag=" << mpi_tag
475 <<
" total_size=" << total_size;
480 if (total_size<=m_serialize_buffer_size || force_one_message){
481 if (m_is_trace_serializer)
482 tm->info() <<
"Small message size=" << bytes.size();
483 return _sendSerializerBytes(bytes,rank,mpi_tag,is_blocking);
488 auto x = sbuf->copyAndGetSizesBuffer();
489 Request r1 = _sendSerializerBytes(x,rank,mpi_tag,is_blocking);
500Request MpiSerializeDispatcher::
501receiveSerializer(ISerializer* s,
const PointToPointMessageInfo& message)
503 BasicSerializer* sbuf = _castSerializer(s);
504 MessageRank rank = message.destinationRank();
505 MessageTag tag = message.tag();
506 bool is_blocking = message.isBlocking();
508 sbuf->preallocate(m_serialize_buffer_size);
509 Span<Byte> bytes = sbuf->globalBuffer();
512 if (message.isRankTag())
513 r = _recvSerializerBytes(bytes,rank,tag,is_blocking);
514 else if (message.isMessageId())
515 r = _recvSerializerBytes(bytes,message.messageId(),is_blocking);
517 ARCCORE_THROW(NotSupportedException,
"Only message.isRankTag() or message.isMessageId() is supported");
526void MpiSerializeDispatcher::
527broadcastSerializer(ISerializer* values,MessageRank rank)
529 BasicSerializer* sbuf = _castSerializer(values);
530 ITraceMng* tm = m_trace;
531 MessageRank my_rank(m_adapter->commRank());
532 bool is_broadcaster = (rank==my_rank);
534 MPI_Datatype int64_datatype = MpiBuiltIn::datatype(
Int64());
540 Int64 total_size = sbuf->totalSize();
541 Span<Byte> bytes = sbuf->globalBuffer();
542 _checkBigMessage(total_size);
543 ArrayView<Int64> total_size_buf(1,&total_size);
544 m_adapter->broadcast(total_size_buf.data(),total_size_buf.size(),rank.value(),int64_datatype);
545 if (m_is_trace_serializer)
546 tm->info() <<
"MpiSerializeDispatcher::broadcastSerializer(): sending "
547 << BasicSerializer::SizesPrinter(*sbuf);
548 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
549 m_adapter->broadcast(sbc.data(),sbc.size(),rank.value(),sbc.datatype());
552 Int64 total_size = 0;
553 ArrayView<Int64> total_size_buf(1,&total_size);
554 m_adapter->broadcast(total_size_buf.data(),total_size_buf.size(),rank.value(),int64_datatype);
555 sbuf->preallocate(total_size);
556 Span<Byte> bytes = sbuf->globalBuffer();
557 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
558 m_adapter->broadcast(sbc.data(),sbc.size(),rank.value(),sbc.datatype());
559 sbuf->setFromSizes();
560 if (m_is_trace_serializer)
561 tm->info() <<
"MpiSerializeDispatcher::broadcastSerializer(): receiving from "
562 <<
" rank=" << rank <<
" bytes " << bytes.size()
563 << BasicSerializer::SizesPrinter(*sbuf);
570BasicSerializer* MpiSerializeDispatcher::
571_castSerializer(ISerializer* serializer)
573 BasicSerializer* sbuf =
dynamic_cast<BasicSerializer*
>(serializer);
575 ARCCORE_THROW(ArgumentException,
"Can not cast 'ISerializer' to 'BasicSerializer'");
582Ref<ISerializeMessageList> MpiSerializeDispatcher::
583createSerializeMessageListRef()
585 ISerializeMessageList* x =
new internal::SerializeMessageList(m_message_passing_mng);
static ARCCORE_CONSTEXPR Integer paddingSize()
Taille du padding et de l'alignement.
Interface du gestionnaire de traces.
virtual TraceMessage info()=0
Flot pour un message d'information.
Interface du gestionnaire des échanges de messages.
Request executeOnCompletion(const SubRequestCompletionInfo &completion_info) override
Callback appelé lorsque la requête associée est terminée.
Request executeOnCompletion(const SubRequestCompletionInfo &) override
Callback appelé lorsque la requête associée est terminée.
constexpr __host__ __device__ SizeType size() const noexcept
Retourne la taille du tableau.
Vue d'un tableau d'éléments de type T.
std::int64_t Int64
Type entier signé sur 64 bits.
Int32 Integer
Type représentant un entier.
unsigned char Byte
Type d'un octet.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Créé une référence sur un pointeur.
std::int32_t Int32
Type entier signé sur 32 bits.