14#include "arccore/message_passing_mpi/internal/MpiSerializeDispatcher.h"
16#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
17#include "arccore/message_passing_mpi/MpiMessagePassingMng.h"
18#include "arccore/message_passing_mpi/internal/MpiLock.h"
20#include "arccore/message_passing/Request.h"
21#include "arccore/message_passing/internal/SerializeMessageList.h"
22#include "arccore/message_passing/internal/SubRequestCompletionInfo.h"
24#include "arccore/serialize/BasicSerializer.h"
26#include "arccore/base/NotImplementedException.h"
27#include "arccore/base/FatalErrorException.h"
28#include "arccore/base/NotSupportedException.h"
29#include "arccore/base/ArgumentException.h"
30#include "arccore/base/PlatformUtils.h"
31#include "arccore/trace/ITraceMng.h"
36namespace Arcane::MessagePassing::Mpi
60template <
typename SpanType>
61class SerializeByteConverter
65 SerializeByteConverter(
Span<SpanType> buffer, MPI_Datatype byte_serializer_datatype)
67 , m_datatype(byte_serializer_datatype)
72 if ((size % align_size) != 0)
73 ARCCORE_FATAL(
"Buffer size '{0}' is not a multiple of '{1}' Invalid size", size, align_size);
74 m_final_size = size / align_size;
76 SpanType* data() {
return m_buffer.data(); }
77 Int64 size()
const {
return m_final_size; }
78 Int64 messageSize()
const {
return m_buffer.size() *
sizeof(
Byte); }
80 MPI_Datatype datatype()
const {
return m_datatype; }
85 MPI_Datatype m_datatype;
105 SendSerializerSubRequest(MpiSerializeDispatcher* pm,
BasicSerializer* buf,
108 , m_serialize_buffer(buf)
117 if (!m_is_message_sent)
119 return m_send_request;
126 if (m_is_message_sent)
128 bool do_print = m_dispatcher->m_is_trace_serializer;
130 ITraceMng* tm = m_dispatcher->traceMng();
131 tm->
info() <<
" SendSerializerSubRequest::sendMessage()"
132 <<
" rank=" << m_rank <<
" tag=" << m_mpi_tag;
134 Span<Byte> bytes = m_serialize_buffer->globalBuffer();
135 m_send_request = m_dispatcher->_sendSerializerBytes(bytes, m_rank, m_mpi_tag,
false);
136 m_is_message_sent =
true;
146 bool m_is_message_sent =
false;
157 ReceiveSerializerSubRequest(MpiSerializeDispatcher* d,
BasicSerializer* buf,
160 , m_serialize_buffer(buf)
170 bool is_trace = m_dispatcher->m_is_trace_serializer;
171 ITraceMng* tm = m_dispatcher->traceMng();
173 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion()"
174 <<
" rank=" << rank <<
" wanted_tag=" << m_mpi_tag <<
" action=" << m_action;
178 Int64 total_recv_size = sbuf->totalSize();
181 tm->
info() <<
" ReceiveSerializerSubRequest::executeOnCompletion() total_size=" << total_recv_size
185 if (total_recv_size <= m_dispatcher->m_serialize_buffer_size) {
186 sbuf->setFromSizes();
190 sbuf->preallocate(total_recv_size);
191 auto bytes = sbuf->globalBuffer();
195 Request r2 = m_dispatcher->_recvSerializerBytes(bytes, rank, m_mpi_tag,
false);
196 ISubRequest* sr =
new ReceiveSerializerSubRequest(m_dispatcher, m_serialize_buffer, m_mpi_tag, 2);
201 m_serialize_buffer->setFromSizes();
208 MpiSerializeDispatcher* m_dispatcher =
nullptr;
220MpiSerializeDispatcher::
223, m_message_passing_mng(message_passing_mng)
224, m_trace(adapter->traceMng())
225, m_serialize_buffer_size(50000)
227, m_max_serialize_buffer_size(m_serialize_buffer_size)
228, m_byte_serializer_datatype(MPI_DATATYPE_NULL)
236MpiSerializeDispatcher::
237~MpiSerializeDispatcher()
239 if (m_byte_serializer_datatype != MPI_DATATYPE_NULL)
240 MPI_Type_free(&m_byte_serializer_datatype);
255void MpiSerializeDispatcher::
259 MPI_Datatype mpi_datatype;
261 MPI_Type_commit(&mpi_datatype);
262 m_byte_serializer_datatype = mpi_datatype;
265 m_is_trace_serializer =
true;
271Request MpiSerializeDispatcher::
272legacySendSerializer(ISerializer* values,
const PointToPointMessageInfo& message)
274 if (!message.isRankTag())
275 ARCCORE_FATAL(
"Only message.isRangTag()==true are allowed for legacy mode");
277 MessageRank rank = message.destinationRank();
278 MessageTag mpi_tag = message.tag();
279 bool is_blocking = message.isBlocking();
281 BasicSerializer* sbuf = _castSerializer(values);
282 ITraceMng* tm = m_trace;
284 Span<Byte> bytes = sbuf->globalBuffer();
286 Int64 total_size = sbuf->totalSize();
287 _checkBigMessage(total_size);
289 if (m_is_trace_serializer)
290 tm->info() <<
"legacySendSerializer(): sending to "
291 <<
" rank=" << rank <<
" bytes " << bytes.size()
292 << BasicSerializer::SizesPrinter(*sbuf)
293 <<
" tag=" << mpi_tag <<
" is_blocking=" << is_blocking;
297 if (total_size <= m_serialize_buffer_size) {
298 if (m_is_trace_serializer)
299 tm->info() <<
"Small message size=" << bytes.size();
300 return _sendSerializerBytes(bytes, rank, mpi_tag, is_blocking);
306 auto x = sbuf->copyAndGetSizesBuffer();
307 if (m_is_trace_serializer)
308 tm->info() <<
"Big message first size=" << x.size();
309 Request r = _sendSerializerBytes(x, rank, mpi_tag, is_blocking);
312 sub_request->m_request = r;
315 MpiLock::Section ls(m_adapter->mpiLock());
316 m_sub_requests.add(sub_request);
321 if (m_is_trace_serializer)
322 tm->info() <<
"Big message second size=" << bytes.size();
323 return _sendSerializerBytes(bytes, rank, nextSerializeTag(mpi_tag), is_blocking);
329Request MpiSerializeDispatcher::
330_recvSerializerBytes(Span<Byte> bytes, MessageId message_id,
bool is_blocking)
332 SerializeByteConverter<Byte> sbc(bytes, m_byte_serializer_datatype);
333 MPI_Datatype dt = sbc.datatype();
334 if (m_is_trace_serializer)
335 m_trace->info() <<
"_recvSerializerBytes: size=" << bytes.size()
336 <<
" message_id=" << message_id <<
" is_blocking=" << is_blocking;
337 return m_adapter->directRecv(sbc.data(), sbc.size(), message_id, sbc.elementSize(), dt, is_blocking);
343Request MpiSerializeDispatcher::
344_recvSerializerBytes(Span<Byte> bytes, MessageRank rank, MessageTag tag,
bool is_blocking)
346 SerializeByteConverter<Byte> sbc(bytes, m_byte_serializer_datatype);
347 MPI_Datatype dt = sbc.datatype();
348 if (m_is_trace_serializer)
349 m_trace->info() <<
"_recvSerializerBytes: size=" << bytes.size()
350 <<
" rank=" << rank <<
" tag=" << tag <<
" is_blocking=" << is_blocking;
351 Request r = m_adapter->directRecv(sbc.data(), sbc.size(), rank.value(),
352 sbc.elementSize(), dt, tag.value(), is_blocking);
353 if (m_is_trace_serializer)
354 m_trace->info() <<
"_recvSerializerBytes: request=" << r;
361Request MpiSerializeDispatcher::
362_sendSerializerBytes(Span<const Byte> bytes, MessageRank rank, MessageTag tag,
365 SerializeByteConverter<const Byte> sbc(bytes, m_byte_serializer_datatype);
366 MPI_Datatype dt = sbc.datatype();
367 if (m_is_trace_serializer)
368 m_trace->info() <<
"_sendSerializerBytes: orig_size=" << bytes.size()
369 <<
" rank=" << rank <<
" tag=" << tag
370 <<
" second_size=" << sbc.size()
371 <<
" message_size=" << sbc.messageSize();
372 Request r = m_adapter->directSend(sbc.data(), sbc.size(), rank.value(),
373 sbc.elementSize(), dt, tag.value(), is_blocking);
374 if (m_is_trace_serializer)
375 m_trace->info() <<
"_sendSerializerBytes: request=" << r;
382void MpiSerializeDispatcher::
383legacyReceiveSerializer(ISerializer* values, MessageRank rank, MessageTag mpi_tag)
385 BasicSerializer* sbuf = _castSerializer(values);
386 ITraceMng* tm = m_trace;
388 if (m_is_trace_serializer)
389 tm->info() <<
"legacyReceiveSerializer() begin receive"
390 <<
" rank=" << rank <<
" tag=" << mpi_tag;
391 sbuf->preallocate(m_serialize_buffer_size);
392 Span<Byte> bytes = sbuf->globalBuffer();
394 _recvSerializerBytes(bytes, rank, mpi_tag,
true);
395 Int64 total_recv_size = sbuf->totalSize();
397 if (m_is_trace_serializer)
398 tm->info() <<
"legacyReceiveSerializer total_size=" << total_recv_size
400 << BasicSerializer::SizesPrinter(*sbuf);
403 if (total_recv_size <= m_serialize_buffer_size) {
404 sbuf->setFromSizes();
408 if (m_is_trace_serializer)
409 tm->info() <<
"Receive overflow buffer: " << total_recv_size;
410 sbuf->preallocate(total_recv_size);
411 bytes = sbuf->globalBuffer();
412 _recvSerializerBytes(bytes, rank, nextSerializeTag(mpi_tag),
true);
413 sbuf->setFromSizes();
414 if (m_is_trace_serializer)
415 tm->info() <<
"End receive overflow buffer: " << total_recv_size;
421void MpiSerializeDispatcher::
422checkFinishedSubRequests()
427 UniqueArray<SerializeSubRequest*> new_sub_requests;
428 for (
Integer i = 0, n = m_sub_requests.size(); i < n; ++i) {
430 bool is_finished = m_adapter->testRequest(ssr->m_request);
432 new_sub_requests.add(ssr);
438 m_sub_requests = new_sub_requests;
444void MpiSerializeDispatcher::
445_checkBigMessage(
Int64 message_size)
447 if (message_size > m_max_serialize_buffer_size) {
448 m_max_serialize_buffer_size = message_size;
449 m_trace->info() <<
"big buffer: " << message_size;
467 bool force_one_message)
478 Int64 total_size = sbuf->totalSize();
479 _checkBigMessage(total_size);
481 if (m_is_trace_serializer)
482 tm->
info() <<
"sendSerializer(): sending to "
483 <<
" p2p_message=" << message
484 <<
" rank=" << rank <<
" bytes " << bytes.size()
486 <<
" tag=" << mpi_tag
487 <<
" total_size=" << total_size;
491 if (total_size <= m_serialize_buffer_size || force_one_message) {
492 if (m_is_trace_serializer)
493 tm->
info() <<
"Small message size=" << bytes.size();
494 return _sendSerializerBytes(bytes, rank, mpi_tag, is_blocking);
499 auto x = sbuf->copyAndGetSizesBuffer();
500 Request r1 = _sendSerializerBytes(x, rank, mpi_tag, is_blocking);
501 auto* x2 =
new SendSerializerSubRequest(
this, sbuf, rank, nextSerializeTag(mpi_tag));
519 sbuf->preallocate(m_serialize_buffer_size);
524 r = _recvSerializerBytes(bytes, rank, tag, is_blocking);
526 r = _recvSerializerBytes(bytes, message.
messageId(), is_blocking);
537void MpiSerializeDispatcher::
543 bool is_broadcaster = (rank == my_rank);
545 MPI_Datatype int64_datatype = MpiBuiltIn::datatype(
Int64());
550 if (is_broadcaster) {
551 Int64 total_size = sbuf->totalSize();
553 _checkBigMessage(total_size);
555 m_adapter->broadcast(total_size_buf.data(), total_size_buf.size(), rank.
value(), int64_datatype);
556 if (m_is_trace_serializer)
557 tm->
info() <<
"MpiSerializeDispatcher::broadcastSerializer(): sending "
560 m_adapter->broadcast(sbc.data(), sbc.size(), rank.
value(), sbc.datatype());
563 Int64 total_size = 0;
565 m_adapter->broadcast(total_size_buf.data(), total_size_buf.size(), rank.
value(), int64_datatype);
566 sbuf->preallocate(total_size);
569 m_adapter->broadcast(sbc.data(), sbc.size(), rank.
value(), sbc.datatype());
570 sbuf->setFromSizes();
571 if (m_is_trace_serializer)
572 tm->
info() <<
"MpiSerializeDispatcher::broadcastSerializer(): receiving from "
573 <<
" rank=" << rank <<
" bytes " << bytes.size()
581BasicSerializer* MpiSerializeDispatcher::
582_castSerializer(ISerializer* serializer)
584 BasicSerializer* sbuf =
dynamic_cast<BasicSerializer*
>(serializer);
586 ARCCORE_THROW(ArgumentException,
"Can not cast 'ISerializer' to 'BasicSerializer'");
#define ARCCORE_FATAL(...)
Macro throwing a FatalErrorException.
#define ARCCORE_THROW(exception_class,...)
Macro to throw an exception with formatting.
Modifiable view of an array of type T.
Basic implementation of 'ISerializer'.
static ARCCORE_CONSTEXPR Integer paddingSize()
Padding and alignment size.
virtual TraceMessage info()=0
Stream for an information message.
Interface of the message passing manager.
Interface for a serialization message list.
Sub-request of a request.
Int32 value() const
Rank value.
int commRank() const
Rank of this instance in the communicator.
Request executeOnCompletion(const SubRequestCompletionInfo &completion_info) override
Callback called when the associated request is finished.
Request executeOnCompletion(const SubRequestCompletionInfo &) override
Callback called when the associated request is finished.
Request receiveSerializer(ISerializer *s, const PointToPointMessageInfo &message) override
Receiving message.
Ref< ISerializeMessageList > createSerializeMessageListRef() override
Create a list of serialization messages.
Request sendSerializer(const ISerializer *s, const PointToPointMessageInfo &message) override
Sending message.
Wrapper for sending a byte array from a serializer.
Information for sending/receiving a point-to-point message.
MessageId messageId() const
Message identifier.
MessageTag tag() const
Message tag.
bool isBlocking() const
Indicates if the message is blocking.
bool isRankTag() const
True if the instance was created with a pair (rank,tag). In this case rank() and tag() are valid.
bool isMessageId() const
True if the instance was created with a MessageId. In this case messageId() is valid.
MessageRank destinationRank() const
Message destination rank.
Completion information for a sub-request.
MessageRank sourceRank() const
Source rank of the request.
Serialization message list.
Exception when an operation is not supported.
Reference to an instance.
constexpr __host__ __device__ SizeType size() const noexcept
Returns the size of the array.
View of an array of elements of type T.
std::int64_t Int64
Signed integer type of 64 bits.
Int32 Integer
Type representing an integer.
unsigned char Byte
Type of a byte.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Creates a reference on a pointer.
std::int32_t Int32
Signed integer type of 32 bits.