14#include "arcane/utils/ArcanePrecomp.h"
16#include "arcane/utils/Array.h"
17#include "arcane/utils/FatalErrorException.h"
18#include "arcane/utils/NotImplementedException.h"
19#include "arcane/utils/NotSupportedException.h"
20#include "arcane/utils/ArgumentException.h"
21#include "arcane/utils/TraceInfo.h"
22#include "arcane/utils/ITraceMng.h"
23#include "arcane/utils/ValueConvert.h"
25#include "arcane/parallel/mpithread/HybridMessageQueue.h"
26#include "arcane/parallel/mpi/MpiParallelMng.h"
28#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
34#define TRACE_DEBUG(needed_debug_level,format_str,...) \
35 if (m_debug_level>=needed_debug_level){ \
36 info() << String::format("Hybrid " format_str,__VA_ARGS__);\
52: TraceAccessor(mpi_pm->traceMng())
53, m_thread_queue(thread_queue)
54, m_mpi_parallel_mng(mpi_pm)
55, m_mpi_adapter(mpi_pm->adapter())
56, m_local_nb_rank(local_nb_rank)
57, m_rank_tag_builder(local_nb_rank)
60 if (
auto v = Convert::Type<Int32>::tryParseFromEnvironment(
"ARCCORE_ALLOW_NULL_RANK_FOR_MPI_ANY_SOURCE",
true))
61 m_is_allow_null_rank_for_any_source = v.value() != 0;
67void HybridMessageQueue::
68_checkValidRank(MessageRank rank)
77void HybridMessageQueue::
78_checkValidSource(
const PointToPointMessageInfo& message)
80 MessageRank source = message.emiterRank();
88PointToPointMessageInfo HybridMessageQueue::
89_buildSharedMemoryMessage(
const PointToPointMessageInfo& message,
90 const SourceDestinationFullRankInfo& fri)
92 PointToPointMessageInfo p2p_message(message);
93 p2p_message.setEmiterRank(fri.source().localRank());
94 p2p_message.setDestinationRank(fri.destination().localRank());
101PointToPointMessageInfo HybridMessageQueue::
102_buildMPIMessage(
const PointToPointMessageInfo& message,
103 const SourceDestinationFullRankInfo& fri)
105 PointToPointMessageInfo p2p_message(message);
106 p2p_message.setEmiterRank(fri.source().mpiRank());
107 p2p_message.setDestinationRank(fri.destination().mpiRank());
114void HybridMessageQueue::
115waitAll(ArrayView<Request> requests)
118 Integer nb_request = requests.size();
119 UniqueArray<Request> mpi_requests;
120 UniqueArray<Request> thread_requests;
121 for( Integer i=0; i<nb_request; ++i ){
122 Request r = requests[i];
125 IRequestCreator* creator = r.creator();
126 if (creator==m_mpi_adapter) {
129 else if (creator==m_thread_queue)
130 thread_requests.add(r);
135 if (mpi_requests.size()!=0)
136 m_mpi_adapter->waitAllRequests(mpi_requests);
137 if (thread_requests.size()!=0)
138 m_thread_queue->waitAll(thread_requests);
141 for( Request r : requests )
148void HybridMessageQueue::
149waitSome(Int32 rank,ArrayView<Request> requests,ArrayView<bool> requests_done,
150 bool is_non_blocking)
154 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} requests n={1} nb_done={2} is_non_blocking={3}",
155 rank,requests.size(),nb_done,is_non_blocking);
156 nb_done = _testOrWaitSome(rank,requests,requests_done);
157 if (is_non_blocking || nb_done==(-1))
159 }
while (nb_done==0);
166_testOrWaitSome(Int32 rank,ArrayView<Request> requests,ArrayView<bool> requests_done)
168 Integer nb_request = requests.size();
169 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} requests n={1}",rank,nb_request);
174 UniqueArray<Request> mpi_requests;
175 UniqueArray<Request> shm_requests;
177 UniqueArray<Integer> mpi_requests_index;
178 UniqueArray<Integer> shm_requests_index;
181 for( Integer i=0; i<nb_request; ++i ){
182 Request r = requests[i];
185 IRequestCreator* creator = r.creator();
186 if (creator==m_mpi_adapter){
188 mpi_requests_index.add(i);
190 else if (creator==m_thread_queue){
192 shm_requests_index.add(i);
198 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} nb_mpi={1} nb_shm={2}",
199 rank,mpi_requests.size(),shm_requests.size());
205 if (mpi_requests.size()==0 && shm_requests.size()==0)
216 UniqueArray<bool> mpi_done_indexes;
217 Integer nb_mpi_request = mpi_requests.size();
219 if (nb_mpi_request!=0){
220 mpi_done_indexes.resize(nb_mpi_request);
221 mpi_done_indexes.fill(
false);
222 m_mpi_adapter->waitSomeRequests(mpi_requests,mpi_done_indexes,
true);
223 TRACE_DEBUG(2,
"Hybrid: MPI wait some requests n={0} after=",nb_mpi_request,mpi_done_indexes);
224 for( Integer i=0; i<nb_mpi_request; ++i ){
225 Integer index_in_global = mpi_requests_index[i];
226 if (mpi_done_indexes[i]){
227 requests_done[index_in_global] =
true;
228 requests[index_in_global].reset();
230 TRACE_DEBUG(1,
"MPI rank={0} set done i={1} in_global={2}",
231 rank,i,index_in_global);
234 requests[index_in_global] = mpi_requests[i];
238 UniqueArray<bool> shm_done_indexes;
239 Integer nb_shm_request = shm_requests.size();
240 TRACE_DEBUG(2,
"SHM wait some requests n={0}",nb_shm_request);
241 if (shm_requests.size()!=0){
242 shm_done_indexes.resize(nb_shm_request);
243 shm_done_indexes.fill(
false);
244 m_thread_queue->waitSome(rank,shm_requests,shm_done_indexes,
true);
245 for( Integer i=0; i<nb_shm_request; ++i ){
246 Integer index_in_global = shm_requests_index[i];
247 if (shm_done_indexes[i]){
248 requests_done[index_in_global] =
true;
249 requests[index_in_global].reset();
251 TRACE_DEBUG(1,
"SHM rank={0} set done i={1} in_global={2}",
252 rank,i,index_in_global);
255 requests[index_in_global] = shm_requests[i];
264Request HybridMessageQueue::
265_addReceiveRankTag(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf_info)
272 if (message.destinationRank().isNull())
273 ARCANE_THROW(NotSupportedException,
"Receive with any rank. Use probe() and MessageId instead");
275 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
276 bool is_same_mpi_rank = fri.isSameMpiRank();
278 if (is_same_mpi_rank){
279 TRACE_DEBUG(1,
"** MPITMQ SHM ADD RECV S queue={0} message={1}",
this,message);
280 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
281 return m_thread_queue->addReceive(p2p_message,buf_info);
284 ISerializer* serializer = buf_info.serializer();
286 TRACE_DEBUG(1,
"** MPITMQ MPI ADD RECV S queue={0} message={1}",
this,message);
287 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
288 p2p_message.setTag(m_rank_tag_builder.tagForReceive(MessageTag(message.tag()),fri));
289 return m_mpi_parallel_mng->receiveSerializer(serializer,p2p_message);
292 ByteSpan buf = buf_info.memoryBuffer();
293 Int64 size = buf.size();
295 TRACE_DEBUG(1,
"** MPITMQ THREAD ADD RECV B queue={0} message={1} size={2} same_mpi?={3}",
296 this,message,size,fri.isSameMpiRank());
299 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
300 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(message.tag(),fri);
301 Request r = m_mpi_adapter->directRecv(buf.data(),size,fri.destination().mpiRankValue(),
sizeof(
char),
302 char_data_type,mpi_tag.value(),
false);
310Request HybridMessageQueue::
311_addReceiveMessageId(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf_info)
313 MessageId message_id = message.messageId();
314 MessageId::SourceInfo si(message_id.sourceInfo());
316 if (si.rank()!=message.destinationRank())
317 ARCANE_FATAL(
"Incohence between messsage_id rank and destination rank x1={0} x2={1}",
318 si.rank(),message.destinationRank());
320 TRACE_DEBUG(1,
"** MPITMQ ADD_RECV (message_id) queue={0} message={1}",
323 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
324 if (fri.isSameMpiRank()){
325 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
326 return m_thread_queue->addReceive(p2p_message,buf_info);
329 TRACE_DEBUG(1,
"** MPITMQ MPI ADD RECV (message_id) queue={0} message={1}",
this,message);
331 ISerializer* serializer = buf_info.serializer();
333 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
335 TRACE_DEBUG(1,
"** MPI ADD RECV Serializer (message_id) message={0} p2p_message={1}",
336 message,p2p_message);
337 return m_mpi_parallel_mng->receiveSerializer(serializer,p2p_message);
340 ByteSpan buf = buf_info.memoryBuffer();
341 Int64 size = buf.size();
344 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
345 MessageId mpi_message(message_id);
346 MessageId::SourceInfo mpi_si(si);
347 mpi_si.setRank(fri.destination().mpiRank());
348 mpi_message.setSourceInfo(mpi_si);
349 return m_mpi_adapter->directRecv(buf.data(),size,mpi_message,
sizeof(
char),
350 char_data_type,
false);
357Request HybridMessageQueue::
358addReceive(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf)
360 _checkValidSource(message);
362 if (!message.isValid())
365 if (message.isRankTag())
366 return _addReceiveRankTag(message,buf);
368 if (message.isMessageId())
369 return _addReceiveMessageId(message,buf);
371 ARCANE_THROW(NotSupportedException,
"Invalid message_info");
377Request HybridMessageQueue::
378addSend(
const PointToPointMessageInfo& message,SendBufferInfo buf_info)
380 if (!message.isValid())
382 if (message.destinationRank().isNull())
384 if (!message.isRankTag())
385 ARCCORE_FATAL(
"Invalid message_info for sending: message.isRankTag() is false");
387 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
390 if (fri.isSameMpiRank()){
391 TRACE_DEBUG(1,
"** MPITMQ SHM ADD SEND S queue={0} message={1}",
this,message);
392 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
393 return m_thread_queue->addSend(p2p_message,buf_info);
397 MessageTag mpi_tag = m_rank_tag_builder.tagForSend(message.tag(),fri);
398 const ISerializer* serializer = buf_info.serializer();
400 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
401 p2p_message.setTag(mpi_tag);
402 TRACE_DEBUG(1,
"** MPITMQ MPI ADD SEND Serializer queue={0} message={1} p2p_message={2}",
403 this,message,p2p_message);
404 return m_mpi_parallel_mng->sendSerializer(serializer,p2p_message);
407 ByteConstSpan buf = buf_info.memoryBuffer();
408 Int64 size = buf.size();
413 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
415 TRACE_DEBUG(1,
"** MPITMQ MPI ADD SEND B queue={0} message={1} size={2} mpi_tag={3} mpi_rank={4}",
416 this,message,size,mpi_tag,fri.destination().mpiRank());
418 return m_mpi_adapter->directSend(buf.data(),size,fri.destination().mpiRankValue(),
419 sizeof(
char),char_data_type,mpi_tag.value(),
false);
426MP::MessageId HybridMessageQueue::
427probe(
const MP::PointToPointMessageInfo& message)
429 TRACE_DEBUG(1,
"Probe msg='{0}' queue={1} is_valid={2}",
430 message,
this,message.isValid());
432 MessageRank orig = message.emiterRank();
436 if (!message.isValid())
440 if (!message.isRankTag())
441 ARCCORE_FATAL(
"Invalid message_info: message.isRankTag() is false");
443 MessageRank dest = message.destinationRank();
444 MessageTag user_tag = message.tag();
445 bool is_blocking = message.isBlocking();
448 if (user_tag.isNull())
449 ARCANE_THROW(NotImplementedException,
"probe with ANY_TAG");
450 FullRankInfo orig_fri = m_rank_tag_builder.rank(orig);
451 FullRankInfo dest_fri = m_rank_tag_builder.rank(dest);
452 MessageId message_id;
453 Int32 found_dest = dest.value();
454 const bool is_any_source = dest.isNull() || dest.isAnySource();
455 if (dest.isNull() && !m_is_allow_null_rank_for_any_source)
456 ARCANE_FATAL(
"Can not use probe() with null rank. Use MessageRank::anySourceRank() instead");
460 MP::PointToPointMessageInfo p2p_message(message);
461 p2p_message.setEmiterRank(orig_fri.localRank());
462 message_id = m_thread_queue->probe(p2p_message);
463 if (message_id.isValid()){
467 found_dest = orig_fri.mpiRankValue()*m_local_nb_rank + message_id.sourceInfo().rank().value();
468 TRACE_DEBUG(2,
"Probe with null_rank (thread) orig={0} found_dest={1} tag={2}",
469 orig,found_dest,user_tag);
477 for( Integer z=0, zn=m_local_nb_rank; z<zn; ++z ){
478 MP::PointToPointMessageInfo mpi_message(message);
479 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri.localRank(),MessageRank(z));
480 mpi_message.setTag(mpi_tag);
481 TRACE_DEBUG(2,
"Probe with null_rank orig={0} dest={1} tag={2}",orig,dest,mpi_tag);
482 message_id = m_mpi_adapter->probeMessage(mpi_message);
483 if (message_id.isValid()){
486 MessageRank mpi_rank = message_id.sourceInfo().rank();
487 MessageTag ret_tag = message_id.sourceInfo().tag();
488 Int32 local_rank = m_rank_tag_builder.getReceiveRankFromTag(ret_tag);
489 found_dest = mpi_rank.value()*m_local_nb_rank + local_rank;
490 TRACE_DEBUG(2,
"Probe null rank found mpi_rank={0} local_rank={1} tag={2}",
491 ret_tag,mpi_rank,local_rank,ret_tag);
500 if (orig_fri.mpiRank()==dest_fri.mpiRank()){
501 MP::PointToPointMessageInfo p2p_message(message);
502 p2p_message.setDestinationRank(MP::MessageRank(dest_fri.localRank()));
503 p2p_message.setEmiterRank(MessageRank(orig_fri.localRank()));
504 message_id = m_thread_queue->probe(p2p_message);
507 MP::PointToPointMessageInfo mpi_message(message);
508 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri,dest_fri);
509 mpi_message.setTag(mpi_tag);
510 mpi_message.setDestinationRank(MP::MessageRank(dest_fri.mpiRank()));
511 TRACE_DEBUG(2,
"Probe orig={0} dest={1} mpi_tag={2} user_tag={3}",orig,dest,mpi_tag,user_tag);
512 message_id = m_mpi_adapter->probeMessage(mpi_message);
515 if (message_id.isValid()){
518 MessageId::SourceInfo si = message_id.sourceInfo();
519 si.setRank(MessageRank(found_dest));
520 message_id.setSourceInfo(si);
528MP::MessageSourceInfo HybridMessageQueue::
529legacyProbe(
const MP::PointToPointMessageInfo& message)
531 TRACE_DEBUG(1,
"LegacyProbe msg='{0}' queue={1} is_valid={2}",
532 message,
this,message.isValid());
534 MessageRank orig = message.emiterRank();
538 if (!message.isValid())
542 if (!message.isRankTag())
543 ARCCORE_FATAL(
"Invalid message_info: message.isRankTag() is false");
545 const MessageRank dest = message.destinationRank();
546 const MessageTag user_tag = message.tag();
547 const bool is_blocking = message.isBlocking();
550 if (user_tag.isNull())
551 ARCANE_THROW(NotImplementedException,
"legacyProbe with ANY_TAG");
552 FullRankInfo orig_fri = m_rank_tag_builder.rank(orig);
553 FullRankInfo dest_fri = m_rank_tag_builder.rank(dest);
554 MP::MessageSourceInfo message_source_info;
555 Int32 found_dest = dest.value();
556 const bool is_any_source = dest.isNull() || dest.isAnySource();
557 if (dest.isNull() && !m_is_allow_null_rank_for_any_source)
558 ARCANE_FATAL(
"Can not use legacyProbe() with null rank. Use MessageRank::anySourceRank() instead");
562 MP::PointToPointMessageInfo p2p_message(message);
563 p2p_message.setEmiterRank(orig_fri.localRank());
564 message_source_info = m_thread_queue->legacyProbe(p2p_message);
565 if (message_source_info.isValid()){
569 found_dest = orig_fri.mpiRankValue()*m_local_nb_rank + message_source_info.rank().value();
570 TRACE_DEBUG(2,
"LegacyProbe with null_rank (thread) orig={0} found_dest={1} tag={2}",
571 orig,found_dest,user_tag);
579 for( Integer z=0, zn=m_local_nb_rank; z<zn; ++z ){
580 MP::PointToPointMessageInfo mpi_message(message);
581 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri.localRank(),MessageRank(z));
582 mpi_message.setTag(mpi_tag);
583 TRACE_DEBUG(2,
"LegacyProbe with null_rank orig={0} dest={1} tag={2}",orig,dest,mpi_tag);
584 message_source_info = m_mpi_adapter->legacyProbeMessage(mpi_message);
585 if (message_source_info.isValid()){
588 MessageRank mpi_rank = message_source_info.rank();
589 MessageTag ret_tag = message_source_info.tag();
590 Int32 local_rank = m_rank_tag_builder.getReceiveRankFromTag(ret_tag);
591 found_dest = mpi_rank.value()*m_local_nb_rank + local_rank;
592 TRACE_DEBUG(2,
"LegacyProbe null rank found mpi_rank={0} local_rank={1} tag={2}",
593 ret_tag,mpi_rank,local_rank,ret_tag);
595 message_source_info.setTag(user_tag);
604 if (orig_fri.mpiRank()==dest_fri.mpiRank()){
605 MP::PointToPointMessageInfo p2p_message(message);
606 p2p_message.setDestinationRank(MP::MessageRank(dest_fri.localRank()));
607 p2p_message.setEmiterRank(MessageRank(orig_fri.localRank()));
608 TRACE_DEBUG(2,
"LegacyProbe SHM orig={0} dest={1} tag={2}",orig,dest,user_tag);
609 message_source_info = m_thread_queue->legacyProbe(p2p_message);
612 MP::PointToPointMessageInfo mpi_message(message);
613 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri,dest_fri);
614 mpi_message.setTag(mpi_tag);
615 mpi_message.setDestinationRank(MP::MessageRank(dest_fri.mpiRank()));
616 TRACE_DEBUG(2,
"LegacyProbe MPI orig={0} dest={1} mpi_tag={2} user_tag={3}",orig,dest,mpi_tag,user_tag);
617 message_source_info = m_mpi_adapter->legacyProbeMessage(mpi_message);
618 if (message_source_info.isValid()){
620 message_source_info.setTag(user_tag);
624 if (message_source_info.isValid()){
627 message_source_info.setRank(MessageRank(found_dest));
629 TRACE_DEBUG(2,
"LegacyProbe has matched message? = {0}",message_source_info.isValid());
630 return message_source_info;
636std::ostream& operator<<(std::ostream& o,
const FullRankInfo& fri)
638 return o <<
"(local=" << fri.m_local_rank <<
",global="
639 << fri.m_global_rank <<
",mpi=" << fri.m_mpi_rank <<
")";
#define ARCANE_THROW(exception_class,...)
Macro pour envoyer une exception avec formattage.
#define ARCANE_FATAL(...)
Macro envoyant une exception FatalErrorException.
#define ARCCORE_FATAL(...)
Macro envoyant une exception FatalErrorException.
Interface d'une file de messages avec les threads.
Déclarations des types et méthodes utilisés par les mécanismes d'échange de messages.
Int32 Integer
Type représentant un entier.
std::int32_t Int32
Type entier signé sur 32 bits.