14#include "arcane/utils/ArcanePrecomp.h"
16#include "arcane/utils/Array.h"
17#include "arcane/utils/FatalErrorException.h"
18#include "arcane/utils/NotImplementedException.h"
19#include "arcane/utils/NotSupportedException.h"
20#include "arcane/utils/ArgumentException.h"
21#include "arcane/utils/TraceInfo.h"
22#include "arcane/utils/ITraceMng.h"
23#include "arcane/utils/ValueConvert.h"
25#include "arcane/parallel/mpithread/HybridMessageQueue.h"
26#include "arcane/parallel/mpi/MpiParallelMng.h"
28#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
34#define TRACE_DEBUG(needed_debug_level, format_str, ...) \
35 if (m_debug_level >= needed_debug_level) { \
36 info() << String::format("Hybrid " format_str, __VA_ARGS__); \
37 traceMng()->flush(); \
52: TraceAccessor(mpi_pm->traceMng())
53, m_thread_queue(thread_queue)
54, m_mpi_parallel_mng(mpi_pm)
55, m_mpi_adapter(mpi_pm->adapter())
56, m_local_nb_rank(local_nb_rank)
57, m_rank_tag_builder(local_nb_rank)
60 if (
auto v = Convert::Type<Int32>::tryParseFromEnvironment(
"ARCCORE_ALLOW_NULL_RANK_FOR_MPI_ANY_SOURCE",
true))
61 m_is_allow_null_rank_for_any_source = v.value() != 0;
67void HybridMessageQueue::
68_checkValidRank(MessageRank rank)
77void HybridMessageQueue::
78_checkValidSource(
const PointToPointMessageInfo& message)
80 MessageRank source = message.emiterRank();
88PointToPointMessageInfo HybridMessageQueue::
89_buildSharedMemoryMessage(
const PointToPointMessageInfo& message,
90 const SourceDestinationFullRankInfo& fri)
92 PointToPointMessageInfo p2p_message(message);
93 p2p_message.setEmiterRank(fri.source().localRank());
94 p2p_message.setDestinationRank(fri.destination().localRank());
101PointToPointMessageInfo HybridMessageQueue::
102_buildMPIMessage(
const PointToPointMessageInfo& message,
103 const SourceDestinationFullRankInfo& fri)
105 PointToPointMessageInfo p2p_message(message);
106 p2p_message.setEmiterRank(fri.source().mpiRank());
107 p2p_message.setDestinationRank(fri.destination().mpiRank());
114void HybridMessageQueue::
115waitAll(ArrayView<Request> requests)
118 Integer nb_request = requests.size();
119 UniqueArray<Request> mpi_requests;
120 UniqueArray<Request> thread_requests;
121 for (Integer i = 0; i < nb_request; ++i) {
122 Request r = requests[i];
125 IRequestCreator* creator = r.creator();
126 if (creator == m_mpi_adapter) {
129 else if (creator == m_thread_queue)
130 thread_requests.add(r);
135 if (mpi_requests.size() != 0)
136 m_mpi_adapter->waitAllRequests(mpi_requests);
137 if (thread_requests.size() != 0)
138 m_thread_queue->waitAll(thread_requests);
141 for (Request r : requests)
148void HybridMessageQueue::
149waitSome(Int32 rank, ArrayView<Request> requests, ArrayView<bool> requests_done,
150 bool is_non_blocking)
154 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} requests n={1} nb_done={2} is_non_blocking={3}",
155 rank, requests.size(), nb_done, is_non_blocking);
156 nb_done = _testOrWaitSome(rank, requests, requests_done);
157 if (is_non_blocking || nb_done == (-1))
159 }
while (nb_done == 0);
166_testOrWaitSome(Int32 rank, ArrayView<Request> requests, ArrayView<bool> requests_done)
168 Integer nb_request = requests.size();
169 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} requests n={1}", rank, nb_request);
174 UniqueArray<Request> mpi_requests;
175 UniqueArray<Request> shm_requests;
177 UniqueArray<Integer> mpi_requests_index;
178 UniqueArray<Integer> shm_requests_index;
181 for (Integer i = 0; i < nb_request; ++i) {
182 Request r = requests[i];
185 IRequestCreator* creator = r.creator();
186 if (creator == m_mpi_adapter) {
188 mpi_requests_index.add(i);
190 else if (creator == m_thread_queue) {
192 shm_requests_index.add(i);
198 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} nb_mpi={1} nb_shm={2}",
199 rank, mpi_requests.size(), shm_requests.size());
206 if (mpi_requests.size() == 0 && shm_requests.size() == 0)
217 UniqueArray<bool> mpi_done_indexes;
218 Integer nb_mpi_request = mpi_requests.size();
220 if (nb_mpi_request != 0) {
221 mpi_done_indexes.resize(nb_mpi_request);
222 mpi_done_indexes.fill(
false);
223 m_mpi_adapter->waitSomeRequests(mpi_requests, mpi_done_indexes,
true);
224 TRACE_DEBUG(2,
"Hybrid: MPI wait some requests n={0} after=", nb_mpi_request, mpi_done_indexes);
225 for (Integer i = 0; i < nb_mpi_request; ++i) {
226 Integer index_in_global = mpi_requests_index[i];
227 if (mpi_done_indexes[i]) {
228 requests_done[index_in_global] =
true;
229 requests[index_in_global].reset();
231 TRACE_DEBUG(1,
"MPI rank={0} set done i={1} in_global={2}",
232 rank, i, index_in_global);
235 requests[index_in_global] = mpi_requests[i];
239 UniqueArray<bool> shm_done_indexes;
240 Integer nb_shm_request = shm_requests.size();
241 TRACE_DEBUG(2,
"SHM wait some requests n={0}", nb_shm_request);
242 if (shm_requests.size() != 0) {
243 shm_done_indexes.resize(nb_shm_request);
244 shm_done_indexes.fill(
false);
245 m_thread_queue->waitSome(rank, shm_requests, shm_done_indexes,
true);
246 for (Integer i = 0; i < nb_shm_request; ++i) {
247 Integer index_in_global = shm_requests_index[i];
248 if (shm_done_indexes[i]) {
249 requests_done[index_in_global] =
true;
250 requests[index_in_global].reset();
252 TRACE_DEBUG(1,
"SHM rank={0} set done i={1} in_global={2}",
253 rank, i, index_in_global);
256 requests[index_in_global] = shm_requests[i];
265Request HybridMessageQueue::
266_addReceiveRankTag(
const PointToPointMessageInfo& message, ReceiveBufferInfo buf_info)
273 if (message.destinationRank().isNull())
274 ARCANE_THROW(NotSupportedException,
"Receive with any rank. Use probe() and MessageId instead");
276 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
277 bool is_same_mpi_rank = fri.isSameMpiRank();
279 if (is_same_mpi_rank) {
280 TRACE_DEBUG(1,
"** MPITMQ SHM ADD RECV S queue={0} message={1}",
this, message);
281 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message, fri));
282 return m_thread_queue->addReceive(p2p_message, buf_info);
285 ISerializer* serializer = buf_info.serializer();
287 TRACE_DEBUG(1,
"** MPITMQ MPI ADD RECV S queue={0} message={1}",
this, message);
288 PointToPointMessageInfo p2p_message(_buildMPIMessage(message, fri));
289 p2p_message.setTag(m_rank_tag_builder.tagForReceive(MessageTag(message.tag()), fri));
290 return m_mpi_parallel_mng->receiveSerializer(serializer, p2p_message);
293 ByteSpan buf = buf_info.memoryBuffer();
294 Int64 size = buf.size();
296 TRACE_DEBUG(1,
"** MPITMQ THREAD ADD RECV B queue={0} message={1} size={2} same_mpi?={3}",
297 this, message, size, fri.isSameMpiRank());
300 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
301 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(message.tag(), fri);
302 Request r = m_mpi_adapter->directRecv(buf.data(), size, fri.destination().mpiRankValue(),
sizeof(
char),
303 char_data_type, mpi_tag.value(),
false);
311Request HybridMessageQueue::
312_addReceiveMessageId(
const PointToPointMessageInfo& message, ReceiveBufferInfo buf_info)
314 MessageId message_id = message.messageId();
315 MessageId::SourceInfo si(message_id.sourceInfo());
317 if (si.rank() != message.destinationRank())
318 ARCANE_FATAL(
"Incohence between messsage_id rank and destination rank x1={0} x2={1}",
319 si.rank(), message.destinationRank());
321 TRACE_DEBUG(1,
"** MPITMQ ADD_RECV (message_id) queue={0} message={1}",
324 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
325 if (fri.isSameMpiRank()) {
326 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message, fri));
327 return m_thread_queue->addReceive(p2p_message, buf_info);
330 TRACE_DEBUG(1,
"** MPITMQ MPI ADD RECV (message_id) queue={0} message={1}",
this, message);
332 ISerializer* serializer = buf_info.serializer();
334 PointToPointMessageInfo p2p_message(_buildMPIMessage(message, fri));
336 TRACE_DEBUG(1,
"** MPI ADD RECV Serializer (message_id) message={0} p2p_message={1}",
337 message, p2p_message);
338 return m_mpi_parallel_mng->receiveSerializer(serializer, p2p_message);
341 ByteSpan buf = buf_info.memoryBuffer();
342 Int64 size = buf.size();
345 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
346 MessageId mpi_message(message_id);
347 MessageId::SourceInfo mpi_si(si);
348 mpi_si.setRank(fri.destination().mpiRank());
349 mpi_message.setSourceInfo(mpi_si);
350 return m_mpi_adapter->directRecv(buf.data(), size, mpi_message,
sizeof(
char),
351 char_data_type,
false);
358Request HybridMessageQueue::
359addReceive(
const PointToPointMessageInfo& message, ReceiveBufferInfo buf)
361 _checkValidSource(message);
363 if (!message.isValid())
366 if (message.isRankTag())
367 return _addReceiveRankTag(message, buf);
369 if (message.isMessageId())
370 return _addReceiveMessageId(message, buf);
372 ARCANE_THROW(NotSupportedException,
"Invalid message_info");
378Request HybridMessageQueue::
379addSend(
const PointToPointMessageInfo& message, SendBufferInfo buf_info)
381 if (!message.isValid())
383 if (message.destinationRank().isNull())
385 if (!message.isRankTag())
386 ARCCORE_FATAL(
"Invalid message_info for sending: message.isRankTag() is false");
388 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
391 if (fri.isSameMpiRank()) {
392 TRACE_DEBUG(1,
"** MPITMQ SHM ADD SEND S queue={0} message={1}",
this, message);
393 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message, fri));
394 return m_thread_queue->addSend(p2p_message, buf_info);
398 MessageTag mpi_tag = m_rank_tag_builder.tagForSend(message.tag(), fri);
399 const ISerializer* serializer = buf_info.serializer();
401 PointToPointMessageInfo p2p_message(_buildMPIMessage(message, fri));
402 p2p_message.setTag(mpi_tag);
403 TRACE_DEBUG(1,
"** MPITMQ MPI ADD SEND Serializer queue={0} message={1} p2p_message={2}",
404 this, message, p2p_message);
405 return m_mpi_parallel_mng->sendSerializer(serializer, p2p_message);
408 ByteConstSpan buf = buf_info.memoryBuffer();
409 Int64 size = buf.size();
414 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
416 TRACE_DEBUG(1,
"** MPITMQ MPI ADD SEND B queue={0} message={1} size={2} mpi_tag={3} mpi_rank={4}",
417 this, message, size, mpi_tag, fri.destination().mpiRank());
419 return m_mpi_adapter->directSend(buf.data(), size, fri.destination().mpiRankValue(),
420 sizeof(
char), char_data_type, mpi_tag.value(),
false);
427MP::MessageId HybridMessageQueue::
428probe(
const MP::PointToPointMessageInfo& message)
430 TRACE_DEBUG(1,
"Probe msg='{0}' queue={1} is_valid={2}",
431 message,
this, message.isValid());
433 MessageRank orig = message.emiterRank();
437 if (!message.isValid())
441 if (!message.isRankTag())
442 ARCCORE_FATAL(
"Invalid message_info: message.isRankTag() is false");
444 MessageRank dest = message.destinationRank();
445 MessageTag user_tag = message.tag();
446 bool is_blocking = message.isBlocking();
448 ARCANE_THROW(NotImplementedException,
"blocking probe");
449 if (user_tag.isNull())
450 ARCANE_THROW(NotImplementedException,
"probe with ANY_TAG");
451 FullRankInfo orig_fri = m_rank_tag_builder.rank(orig);
452 FullRankInfo dest_fri = m_rank_tag_builder.rank(dest);
453 MessageId message_id;
454 Int32 found_dest = dest.value();
455 const bool is_any_source = dest.isNull() || dest.isAnySource();
456 if (dest.isNull() && !m_is_allow_null_rank_for_any_source)
457 ARCANE_FATAL(
"Can not use probe() with null rank. Use MessageRank::anySourceRank() instead");
461 MP::PointToPointMessageInfo p2p_message(message);
462 p2p_message.setEmiterRank(orig_fri.localRank());
463 message_id = m_thread_queue->probe(p2p_message);
464 if (message_id.isValid()) {
468 found_dest = orig_fri.mpiRankValue() * m_local_nb_rank + message_id.sourceInfo().rank().value();
469 TRACE_DEBUG(2,
"Probe with null_rank (thread) orig={0} found_dest={1} tag={2}",
470 orig, found_dest, user_tag);
478 for (Integer z = 0, zn = m_local_nb_rank; z < zn; ++z) {
479 MP::PointToPointMessageInfo mpi_message(message);
480 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag, orig_fri.localRank(), MessageRank(z));
481 mpi_message.setTag(mpi_tag);
482 TRACE_DEBUG(2,
"Probe with null_rank orig={0} dest={1} tag={2}", orig, dest, mpi_tag);
483 message_id = m_mpi_adapter->probeMessage(mpi_message);
484 if (message_id.isValid()) {
487 MessageRank mpi_rank = message_id.sourceInfo().rank();
488 MessageTag ret_tag = message_id.sourceInfo().tag();
489 Int32 local_rank = m_rank_tag_builder.getReceiveRankFromTag(ret_tag);
490 found_dest = mpi_rank.value() * m_local_nb_rank + local_rank;
491 TRACE_DEBUG(2,
"Probe null rank found mpi_rank={0} local_rank={1} tag={2}",
492 ret_tag, mpi_rank, local_rank, ret_tag);
501 if (orig_fri.mpiRank() == dest_fri.mpiRank()) {
502 MP::PointToPointMessageInfo p2p_message(message);
503 p2p_message.setDestinationRank(MP::MessageRank(dest_fri.localRank()));
504 p2p_message.setEmiterRank(MessageRank(orig_fri.localRank()));
505 message_id = m_thread_queue->probe(p2p_message);
508 MP::PointToPointMessageInfo mpi_message(message);
509 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag, orig_fri, dest_fri);
510 mpi_message.setTag(mpi_tag);
511 mpi_message.setDestinationRank(MP::MessageRank(dest_fri.mpiRank()));
512 TRACE_DEBUG(2,
"Probe orig={0} dest={1} mpi_tag={2} user_tag={3}", orig, dest, mpi_tag, user_tag);
513 message_id = m_mpi_adapter->probeMessage(mpi_message);
516 if (message_id.isValid()) {
519 MessageId::SourceInfo si = message_id.sourceInfo();
520 si.setRank(MessageRank(found_dest));
521 message_id.setSourceInfo(si);
529MP::MessageSourceInfo HybridMessageQueue::
530legacyProbe(
const MP::PointToPointMessageInfo& message)
532 TRACE_DEBUG(1,
"LegacyProbe msg='{0}' queue={1} is_valid={2}",
533 message,
this, message.isValid());
535 MessageRank orig = message.emiterRank();
539 if (!message.isValid())
543 if (!message.isRankTag())
544 ARCCORE_FATAL(
"Invalid message_info: message.isRankTag() is false");
546 const MessageRank dest = message.destinationRank();
547 const MessageTag user_tag = message.tag();
548 const bool is_blocking = message.isBlocking();
550 ARCANE_THROW(NotImplementedException,
"blocking probe");
551 if (user_tag.isNull())
552 ARCANE_THROW(NotImplementedException,
"legacyProbe with ANY_TAG");
553 FullRankInfo orig_fri = m_rank_tag_builder.rank(orig);
554 FullRankInfo dest_fri = m_rank_tag_builder.rank(dest);
555 MP::MessageSourceInfo message_source_info;
556 Int32 found_dest = dest.value();
557 const bool is_any_source = dest.isNull() || dest.isAnySource();
558 if (dest.isNull() && !m_is_allow_null_rank_for_any_source)
559 ARCANE_FATAL(
"Can not use legacyProbe() with null rank. Use MessageRank::anySourceRank() instead");
563 MP::PointToPointMessageInfo p2p_message(message);
564 p2p_message.setEmiterRank(orig_fri.localRank());
565 message_source_info = m_thread_queue->legacyProbe(p2p_message);
566 if (message_source_info.isValid()) {
570 found_dest = orig_fri.mpiRankValue() * m_local_nb_rank + message_source_info.rank().value();
571 TRACE_DEBUG(2,
"LegacyProbe with null_rank (thread) orig={0} found_dest={1} tag={2}",
572 orig, found_dest, user_tag);
580 for (Integer z = 0, zn = m_local_nb_rank; z < zn; ++z) {
581 MP::PointToPointMessageInfo mpi_message(message);
582 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag, orig_fri.localRank(), MessageRank(z));
583 mpi_message.setTag(mpi_tag);
584 TRACE_DEBUG(2,
"LegacyProbe with null_rank orig={0} dest={1} tag={2}", orig, dest, mpi_tag);
585 message_source_info = m_mpi_adapter->legacyProbeMessage(mpi_message);
586 if (message_source_info.isValid()) {
589 MessageRank mpi_rank = message_source_info.rank();
590 MessageTag ret_tag = message_source_info.tag();
591 Int32 local_rank = m_rank_tag_builder.getReceiveRankFromTag(ret_tag);
592 found_dest = mpi_rank.value() * m_local_nb_rank + local_rank;
593 TRACE_DEBUG(2,
"LegacyProbe null rank found mpi_rank={0} local_rank={1} tag={2}",
594 ret_tag, mpi_rank, local_rank, ret_tag);
596 message_source_info.setTag(user_tag);
605 if (orig_fri.mpiRank() == dest_fri.mpiRank()) {
606 MP::PointToPointMessageInfo p2p_message(message);
607 p2p_message.setDestinationRank(MP::MessageRank(dest_fri.localRank()));
608 p2p_message.setEmiterRank(MessageRank(orig_fri.localRank()));
609 TRACE_DEBUG(2,
"LegacyProbe SHM orig={0} dest={1} tag={2}", orig, dest, user_tag);
610 message_source_info = m_thread_queue->legacyProbe(p2p_message);
613 MP::PointToPointMessageInfo mpi_message(message);
614 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag, orig_fri, dest_fri);
615 mpi_message.setTag(mpi_tag);
616 mpi_message.setDestinationRank(MP::MessageRank(dest_fri.mpiRank()));
617 TRACE_DEBUG(2,
"LegacyProbe MPI orig={0} dest={1} mpi_tag={2} user_tag={3}", orig, dest, mpi_tag, user_tag);
618 message_source_info = m_mpi_adapter->legacyProbeMessage(mpi_message);
619 if (message_source_info.isValid()) {
621 message_source_info.setTag(user_tag);
625 if (message_source_info.isValid()) {
627 message_source_info.setRank(MessageRank(found_dest));
629 TRACE_DEBUG(2,
"LegacyProbe has matched message? = {0}", message_source_info.isValid());
630 return message_source_info;
636std::ostream& operator<<(std::ostream& o,
const FullRankInfo& fri)
638 return o <<
"(local=" << fri.m_local_rank <<
",global="
639 << fri.m_global_rank <<
",mpi=" << fri.m_mpi_rank <<
")";
#define ARCANE_THROW(exception_class,...)
Macro for throwing an exception with formatting.
#define ARCANE_FATAL(...)
Macro throwing a FatalErrorException.
#define ARCCORE_FATAL(...)
Macro throwing a FatalErrorException.
Interface of a message queue with threads.
Declarations of types and methods used by message exchange mechanisms.
Int32 Integer
Type representing an integer.
std::int32_t Int32
Signed integer type of 32 bits.