14#include "arccore/message_passing_mpi/internal/MpiSerializeDispatcher.h"
16#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
17#include "arccore/message_passing_mpi/MpiMessagePassingMng.h"
18#include "arccore/message_passing_mpi/internal/MpiLock.h"
20#include "arccore/message_passing/Request.h"
21#include "arccore/message_passing/internal/SerializeMessageList.h"
22#include "arccore/message_passing/internal/SubRequestCompletionInfo.h"
24#include "arccore/serialize/BasicSerializer.h"
26#include "arccore/base/NotImplementedException.h"
27#include "arccore/base/FatalErrorException.h"
28#include "arccore/base/NotSupportedException.h"
29#include "arccore/base/ArgumentException.h"
30#include "arccore/base/PlatformUtils.h"
31#include "arccore/trace/ITraceMng.h"
36namespace Arcane::MessagePassing::Mpi
60template <
typename SpanType>
61class SerializeByteConverter
65 SerializeByteConverter(
Span<SpanType> buffer, MPI_Datatype byte_serializer_datatype)
67 , m_datatype(byte_serializer_datatype)
72 if ((size % align_size) != 0)
73 ARCCORE_FATAL(
"Buffer size '{0}' is not a multiple of '{1}' Invalid size", size, align_size);
74 m_final_size = size / align_size;
76 SpanType* data() {
return m_buffer.data(); }
77 Int64 size()
const {
return m_final_size; }
78 Int64 messageSize()
const {
return m_buffer.size() *
sizeof(
Byte); }
80 MPI_Datatype datatype()
const {
return m_datatype; }
85 MPI_Datatype m_datatype;
105 SendSerializerSubRequest(MpiSerializeDispatcher* pm,
BasicSerializer* buf,
108 , m_serialize_buffer(buf)
117 if (!m_is_message_sent)
119 return m_send_request;
126 if (m_is_message_sent)
128 bool do_print = m_dispatcher->m_is_trace_serializer;
130 ITraceMng* tm = m_dispatcher->traceMng();
131 tm->
info() <<
" SendSerializerSubRequest::sendMessage()"
132 <<
" rank=" << m_rank <<
" tag=" << m_mpi_tag;
134 Span<Byte> bytes = m_serialize_buffer->globalBuffer();
135 m_send_request = m_dispatcher->_sendSerializerBytes(bytes, m_rank, m_mpi_tag,
false);
136 m_is_message_sent =
true;
141 MpiSerializeDispatcher* m_dispatcher;
146 bool m_is_message_sent =
false;
157 ReceiveSerializerSubRequest(MpiSerializeDispatcher* d,
BasicSerializer* buf,
160 , m_serialize_buffer(buf)
170 bool is_trace = m_dispatcher->m_is_trace_serializer;
171 ITraceMng* tm = m_dispatcher->traceMng();
173 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion()"
174 <<
" rank=" << rank <<
" wanted_tag=" << m_mpi_tag <<
" action=" << m_action;
178 Int64 total_recv_size = sbuf->totalSize();
181 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion() total_size=" << total_recv_size
185 if (total_recv_size <= m_dispatcher->m_serialize_buffer_size) {
186 sbuf->setFromSizes();
190 sbuf->preallocate(total_recv_size);
191 auto bytes = sbuf->globalBuffer();
195 Request r2 = m_dispatcher->_recvSerializerBytes(bytes, rank, m_mpi_tag,
false);
196 ISubRequest* sr =
new ReceiveSerializerSubRequest(m_dispatcher, m_serialize_buffer, m_mpi_tag, 2);
201 m_serialize_buffer->setFromSizes();
208 MpiSerializeDispatcher* m_dispatcher =
nullptr;
220MpiSerializeDispatcher::
223, m_message_passing_mng(message_passing_mng)
224, m_trace(adapter->traceMng())
225, m_serialize_buffer_size(50000)
227, m_max_serialize_buffer_size(m_serialize_buffer_size)
228, m_byte_serializer_datatype(MPI_DATATYPE_NULL)
236MpiSerializeDispatcher::
237~MpiSerializeDispatcher()
239 if (m_byte_serializer_datatype != MPI_DATATYPE_NULL)
240 MPI_Type_free(&m_byte_serializer_datatype);
255void MpiSerializeDispatcher::
259 MPI_Datatype mpi_datatype;
261 MPI_Type_commit(&mpi_datatype);
262 m_byte_serializer_datatype = mpi_datatype;
265 m_is_trace_serializer =
true;
271Request MpiSerializeDispatcher::
272legacySendSerializer(ISerializer* values,
const PointToPointMessageInfo& message)
274 if (!message.isRankTag())
275 ARCCORE_FATAL(
"Only message.isRangTag()==true are allowed for legacy mode");
277 MessageRank rank = message.destinationRank();
278 MessageTag mpi_tag = message.tag();
279 bool is_blocking = message.isBlocking();
281 BasicSerializer* sbuf = _castSerializer(values);
282 ITraceMng* tm = m_trace;
284 Span<Byte> bytes = sbuf->globalBuffer();
286 Int64 total_size = sbuf->totalSize();
287 _checkBigMessage(total_size);
289 if (m_is_trace_serializer)
290 tm->info() <<
"legacySendSerializer(): sending to "
291 <<
" rank=" << rank <<
" bytes " << bytes.size()
292 << BasicSerializer::SizesPrinter(*sbuf)
293 <<
" tag=" << mpi_tag <<
" is_blocking=" << is_blocking;
297 if (total_size <= m_serialize_buffer_size) {
298 if (m_is_trace_serializer)
299 tm->info() <<
"Small message size=" << bytes.size();
300 return _sendSerializerBytes(bytes, rank, mpi_tag, is_blocking);
306 auto x = sbuf->copyAndGetSizesBuffer();
307 if (m_is_trace_serializer)
308 tm->info() <<
"Big message first size=" << x.size();
309 Request r = _sendSerializerBytes(x, rank, mpi_tag, is_blocking);
311 SerializeSubRequest* sub_request =
new SerializeSubRequest();
312 sub_request->m_request = r;
315 MpiLock::Section ls(m_adapter->mpiLock());
316 m_sub_requests.add(sub_request);
321 if (m_is_trace_serializer)
322 tm->info() <<
"Big message second size=" << bytes.size();
323 return _sendSerializerBytes(bytes, rank, nextSerializeTag(mpi_tag), is_blocking);
329Request MpiSerializeDispatcher::
330_recvSerializerBytes(Span<Byte> bytes, MessageId message_id,
bool is_blocking)
332 SerializeByteConverter<Byte> sbc(bytes, m_byte_serializer_datatype);
333 MPI_Datatype dt = sbc.datatype();
334 if (m_is_trace_serializer)
335 m_trace->info() <<
"_recvSerializerBytes: size=" << bytes.size()
336 <<
" message_id=" << message_id <<
" is_blocking=" << is_blocking;
337 return m_adapter->directRecv(sbc.data(), sbc.size(), message_id, sbc.elementSize(), dt, is_blocking);
343Request MpiSerializeDispatcher::
344_recvSerializerBytes(Span<Byte> bytes, MessageRank rank, MessageTag tag,
bool is_blocking)
346 SerializeByteConverter<Byte> sbc(bytes, m_byte_serializer_datatype);
347 MPI_Datatype dt = sbc.datatype();
348 if (m_is_trace_serializer)
349 m_trace->info() <<
"_recvSerializerBytes: size=" << bytes.size()
350 <<
" rank=" << rank <<
" tag=" << tag <<
" is_blocking=" << is_blocking;
351 Request r = m_adapter->directRecv(sbc.data(), sbc.size(), rank.value(),
352 sbc.elementSize(), dt, tag.value(), is_blocking);
353 if (m_is_trace_serializer)
354 m_trace->info() <<
"_recvSerializerBytes: request=" << r;
361Request MpiSerializeDispatcher::
362_sendSerializerBytes(Span<const Byte> bytes, MessageRank rank, MessageTag tag,
365 SerializeByteConverter<const Byte> sbc(bytes, m_byte_serializer_datatype);
366 MPI_Datatype dt = sbc.datatype();
367 if (m_is_trace_serializer)
368 m_trace->info() <<
"_sendSerializerBytes: orig_size=" << bytes.size()
369 <<
" rank=" << rank <<
" tag=" << tag
370 <<
" second_size=" << sbc.size()
371 <<
" message_size=" << sbc.messageSize();
372 Request r = m_adapter->directSend(sbc.data(), sbc.size(), rank.value(),
373 sbc.elementSize(), dt, tag.value(), is_blocking);
374 if (m_is_trace_serializer)
375 m_trace->info() <<
"_sendSerializerBytes: request=" << r;
382void MpiSerializeDispatcher::
383legacyReceiveSerializer(ISerializer* values, MessageRank rank, MessageTag mpi_tag)
385 BasicSerializer* sbuf = _castSerializer(values);
386 ITraceMng* tm = m_trace;
388 if (m_is_trace_serializer)
389 tm->info() <<
"legacyReceiveSerializer() begin receive"
390 <<
" rank=" << rank <<
" tag=" << mpi_tag;
391 sbuf->preallocate(m_serialize_buffer_size);
392 Span<Byte> bytes = sbuf->globalBuffer();
394 _recvSerializerBytes(bytes, rank, mpi_tag,
true);
395 Int64 total_recv_size = sbuf->totalSize();
397 if (m_is_trace_serializer)
398 tm->info() <<
"legacyReceiveSerializer total_size=" << total_recv_size
400 << BasicSerializer::SizesPrinter(*sbuf);
403 if (total_recv_size <= m_serialize_buffer_size) {
404 sbuf->setFromSizes();
408 if (m_is_trace_serializer)
409 tm->info() <<
"Receive overflow buffer: " << total_recv_size;
410 sbuf->preallocate(total_recv_size);
411 bytes = sbuf->globalBuffer();
412 _recvSerializerBytes(bytes, rank, nextSerializeTag(mpi_tag),
true);
413 sbuf->setFromSizes();
414 if (m_is_trace_serializer)
415 tm->info() <<
"End receive overflow buffer: " << total_recv_size;
421void MpiSerializeDispatcher::
422checkFinishedSubRequests()
427 UniqueArray<SerializeSubRequest*> new_sub_requests;
428 for (
Integer i = 0, n = m_sub_requests.size(); i < n; ++i) {
429 SerializeSubRequest* ssr = m_sub_requests[i];
430 bool is_finished = m_adapter->testRequest(ssr->m_request);
432 new_sub_requests.add(ssr);
438 m_sub_requests = new_sub_requests;
444void MpiSerializeDispatcher::
445_checkBigMessage(
Int64 message_size)
447 if (message_size > m_max_serialize_buffer_size) {
448 m_max_serialize_buffer_size = message_size;
449 m_trace->info() <<
"big buffer: " << message_size;
456Request MpiSerializeDispatcher::
457sendSerializer(
const ISerializer* s,
const PointToPointMessageInfo& message)
459 return sendSerializer(s, message,
false);
465Request MpiSerializeDispatcher::
466sendSerializer(
const ISerializer* s,
const PointToPointMessageInfo& message,
467 bool force_one_message)
469 BasicSerializer* sbuf = _castSerializer(
const_cast<ISerializer*
>(s));
471 MessageRank rank = message.destinationRank();
472 MessageTag mpi_tag = message.tag();
473 bool is_blocking = message.isBlocking();
475 ITraceMng* tm = m_trace;
477 Span<const Byte> bytes = sbuf->globalBuffer();
478 Int64 total_size = sbuf->totalSize();
479 _checkBigMessage(total_size);
481 if (m_is_trace_serializer)
482 tm->info() <<
"sendSerializer(): sending to "
483 <<
" p2p_message=" << message
484 <<
" rank=" << rank <<
" bytes " << bytes.size()
485 << BasicSerializer::SizesPrinter(*sbuf)
486 <<
" tag=" << mpi_tag
487 <<
" total_size=" << total_size;
491 if (total_size <= m_serialize_buffer_size || force_one_message) {
492 if (m_is_trace_serializer)
493 tm->info() <<
"Small message size=" << bytes.size();
494 return _sendSerializerBytes(bytes, rank, mpi_tag, is_blocking);
499 auto x = sbuf->copyAndGetSizesBuffer();
500 Request r1 = _sendSerializerBytes(x, rank, mpi_tag, is_blocking);
511Request MpiSerializeDispatcher::
512receiveSerializer(ISerializer* s,
const PointToPointMessageInfo& message)
514 BasicSerializer* sbuf = _castSerializer(s);
515 MessageRank rank = message.destinationRank();
516 MessageTag tag = message.tag();
517 bool is_blocking = message.isBlocking();
519 sbuf->preallocate(m_serialize_buffer_size);
520 Span<Byte> bytes = sbuf->globalBuffer();
523 if (message.isRankTag())
524 r = _recvSerializerBytes(bytes, rank, tag, is_blocking);
525 else if (message.isMessageId())
526 r = _recvSerializerBytes(bytes, message.messageId(), is_blocking);
528 ARCCORE_THROW(NotSupportedException,
"Only message.isRankTag() or message.isMessageId() is supported");
537void MpiSerializeDispatcher::
538broadcastSerializer(ISerializer* values, MessageRank rank)
540 BasicSerializer* sbuf = _castSerializer(values);
541 ITraceMng* tm = m_trace;
542 MessageRank my_rank(m_adapter->commRank());
543 bool is_broadcaster = (rank == my_rank);
545 MPI_Datatype int64_datatype = MpiBuiltIn::datatype(
Int64());
550 if (is_broadcaster) {
551 Int64 total_size = sbuf->totalSize();
552 Span<Byte> bytes = sbuf->globalBuffer();
553 _checkBigMessage(total_size);
554 ArrayView<Int64> total_size_buf(1, &total_size);
555 m_adapter->broadcast(total_size_buf.data(), total_size_buf.size(), rank.value(), int64_datatype);
556 if (m_is_trace_serializer)
557 tm->info() <<
"MpiSerializeDispatcher::broadcastSerializer(): sending "
558 << BasicSerializer::SizesPrinter(*sbuf);
559 SerializeByteConverter<Byte> sbc(bytes, m_byte_serializer_datatype);
560 m_adapter->broadcast(sbc.data(), sbc.size(), rank.value(), sbc.datatype());
563 Int64 total_size = 0;
564 ArrayView<Int64> total_size_buf(1, &total_size);
565 m_adapter->broadcast(total_size_buf.data(), total_size_buf.size(), rank.value(), int64_datatype);
566 sbuf->preallocate(total_size);
567 Span<Byte> bytes = sbuf->globalBuffer();
568 SerializeByteConverter<Byte> sbc(bytes, m_byte_serializer_datatype);
569 m_adapter->broadcast(sbc.data(), sbc.size(), rank.value(), sbc.datatype());
570 sbuf->setFromSizes();
571 if (m_is_trace_serializer)
572 tm->info() <<
"MpiSerializeDispatcher::broadcastSerializer(): receiving from "
573 <<
" rank=" << rank <<
" bytes " << bytes.size()
574 << BasicSerializer::SizesPrinter(*sbuf);
581BasicSerializer* MpiSerializeDispatcher::
582_castSerializer(ISerializer* serializer)
584 BasicSerializer* sbuf =
dynamic_cast<BasicSerializer*
>(serializer);
586 ARCCORE_THROW(ArgumentException,
"Can not cast 'ISerializer' to 'BasicSerializer'");
593Ref<ISerializeMessageList> MpiSerializeDispatcher::
594createSerializeMessageListRef()
596 ISerializeMessageList* x =
new internal::SerializeMessageList(m_message_passing_mng);
#define ARCCORE_FATAL(...)
Macro throwing a FatalErrorException.
#define ARCCORE_THROW(exception_class,...)
Macro to throw an exception with formatting.
static ARCCORE_CONSTEXPR Integer paddingSize()
Padding and alignment size.
virtual TraceMessage info()=0
Stream for an information message.
Interface of the message passing manager.
Request executeOnCompletion(const SubRequestCompletionInfo &completion_info) override
Callback called when the associated request is finished.
Request executeOnCompletion(const SubRequestCompletionInfo &) override
Callback called when the associated request is finished.
constexpr __host__ __device__ SizeType size() const noexcept
Returns the size of the array.
View of an array of elements of type T.
std::int64_t Int64
Signed integer type of 64 bits.
Int32 Integer
Type representing an integer.
unsigned char Byte
Type of a byte.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Creates a reference on a pointer.
std::int32_t Int32
Signed integer type of 32 bits.