14#include "arcane/utils/ArcanePrecomp.h"
16#include "arcane/utils/Array.h"
17#include "arcane/utils/FatalErrorException.h"
18#include "arcane/utils/NotImplementedException.h"
19#include "arcane/utils/NotSupportedException.h"
20#include "arcane/utils/ArgumentException.h"
21#include "arcane/utils/TraceInfo.h"
22#include "arcane/utils/ITraceMng.h"
23#include "arcane/utils/ValueConvert.h"
25#include "arcane/parallel/mpithread/HybridMessageQueue.h"
26#include "arcane/parallel/mpi/MpiParallelMng.h"
28#include "arccore/message_passing_mpi/internal/MpiAdapter.h"
34#define TRACE_DEBUG(needed_debug_level,format_str,...) \
35 if (m_debug_level>=needed_debug_level){ \
36 info() << String::format("Hybrid " format_str,__VA_ARGS__);\
52: TraceAccessor(mpi_pm->traceMng())
53, m_thread_queue(thread_queue)
54, m_mpi_parallel_mng(mpi_pm)
55, m_mpi_adapter(mpi_pm->adapter())
56, m_local_nb_rank(local_nb_rank)
57, m_rank_tag_builder(local_nb_rank)
60 if (
auto v = Convert::Type<Int32>::tryParseFromEnvironment(
"ARCCORE_ALLOW_NULL_RANK_FOR_MPI_ANY_SOURCE",
true))
61 m_is_allow_null_rank_for_any_source = v.value() != 0;
67void HybridMessageQueue::
68_checkValidRank(MessageRank rank)
77void HybridMessageQueue::
78_checkValidSource(
const PointToPointMessageInfo& message)
80 MessageRank source = message.emiterRank();
88PointToPointMessageInfo HybridMessageQueue::
89_buildSharedMemoryMessage(
const PointToPointMessageInfo& message,
90 const SourceDestinationFullRankInfo& fri)
92 PointToPointMessageInfo p2p_message(message);
93 p2p_message.setEmiterRank(fri.source().localRank());
94 p2p_message.setDestinationRank(fri.destination().localRank());
101PointToPointMessageInfo HybridMessageQueue::
102_buildMPIMessage(
const PointToPointMessageInfo& message,
103 const SourceDestinationFullRankInfo& fri)
105 PointToPointMessageInfo p2p_message(message);
106 p2p_message.setEmiterRank(fri.source().mpiRank());
107 p2p_message.setDestinationRank(fri.destination().mpiRank());
114void HybridMessageQueue::
115waitAll(ArrayView<Request> requests)
118 Integer nb_request = requests.size();
119 UniqueArray<Request> mpi_requests;
120 UniqueArray<Request> thread_requests;
121 for( Integer i=0; i<nb_request; ++i ){
122 Request r = requests[i];
125 IRequestCreator* creator = r.creator();
126 if (creator==m_mpi_adapter) {
129 else if (creator==m_thread_queue)
130 thread_requests.add(r);
135 if (mpi_requests.size()!=0)
136 m_mpi_adapter->waitAllRequests(mpi_requests);
137 if (thread_requests.size()!=0)
138 m_thread_queue->waitAll(thread_requests);
141 for( Request r : requests )
148void HybridMessageQueue::
149waitSome(Int32 rank,ArrayView<Request> requests,ArrayView<bool> requests_done,
150 bool is_non_blocking)
154 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} requests n={1} nb_done={2} is_non_blocking={3}",
155 rank,requests.size(),nb_done,is_non_blocking);
156 nb_done = _testOrWaitSome(rank,requests,requests_done);
157 if (is_non_blocking || nb_done==(-1))
159 }
while (nb_done==0);
166_testOrWaitSome(Int32 rank,ArrayView<Request> requests,ArrayView<bool> requests_done)
168 Integer nb_request = requests.size();
169 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} requests n={1}",rank,nb_request);
174 UniqueArray<Request> mpi_requests;
175 UniqueArray<Request> shm_requests;
177 UniqueArray<Integer> mpi_requests_index;
178 UniqueArray<Integer> shm_requests_index;
181 for( Integer i=0; i<nb_request; ++i ){
182 Request r = requests[i];
185 IRequestCreator* creator = r.creator();
186 if (creator==m_mpi_adapter){
188 mpi_requests_index.add(i);
190 else if (creator==m_thread_queue){
192 shm_requests_index.add(i);
198 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} nb_mpi={1} nb_shm={2}",
199 rank,mpi_requests.size(),shm_requests.size());
205 if (mpi_requests.size()==0 && shm_requests.size()==0)
216 UniqueArray<bool> mpi_done_indexes;
217 Integer nb_mpi_request = mpi_requests.size();
219 if (nb_mpi_request!=0){
220 mpi_done_indexes.resize(nb_mpi_request);
221 mpi_done_indexes.fill(
false);
222 m_mpi_adapter->waitSomeRequests(mpi_requests,mpi_done_indexes,
true);
223 TRACE_DEBUG(2,
"Hybrid: MPI wait some requests n={0} after=",nb_mpi_request,mpi_done_indexes);
224 for( Integer i=0; i<nb_mpi_request; ++i ){
225 Integer index_in_global = mpi_requests_index[i];
226 if (mpi_done_indexes[i]){
227 requests_done[index_in_global] =
true;
228 requests[index_in_global].reset();
230 TRACE_DEBUG(1,
"MPI rank={0} set done i={1} in_global={2}",
231 rank,i,index_in_global);
234 requests[index_in_global] = mpi_requests[i];
238 UniqueArray<bool> shm_done_indexes;
239 Integer nb_shm_request = shm_requests.size();
240 TRACE_DEBUG(2,
"SHM wait some requests n={0}",nb_shm_request);
241 if (shm_requests.size()!=0){
242 shm_done_indexes.resize(nb_shm_request);
243 shm_done_indexes.fill(
false);
244 m_thread_queue->waitSome(rank,shm_requests,shm_done_indexes,
true);
245 for( Integer i=0; i<nb_shm_request; ++i ){
246 Integer index_in_global = shm_requests_index[i];
247 if (shm_done_indexes[i]){
248 requests_done[index_in_global] =
true;
249 requests[index_in_global].reset();
251 TRACE_DEBUG(1,
"SHM rank={0} set done i={1} in_global={2}",
252 rank,i,index_in_global);
255 requests[index_in_global] = shm_requests[i];
264Request HybridMessageQueue::
265_addReceiveRankTag(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf_info)
272 if (message.destinationRank().isNull())
273 ARCANE_THROW(NotSupportedException,
"Receive with any rank. Use probe() and MessageId instead");
275 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
276 bool is_same_mpi_rank = fri.isSameMpiRank();
278 if (is_same_mpi_rank){
279 TRACE_DEBUG(1,
"** MPITMQ SHM ADD RECV S queue={0} message={1}",
this,message);
280 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
281 return m_thread_queue->addReceive(p2p_message,buf_info);
284 ISerializer* serializer = buf_info.serializer();
286 TRACE_DEBUG(1,
"** MPITMQ MPI ADD RECV S queue={0} message={1}",
this,message);
287 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
288 p2p_message.setTag(m_rank_tag_builder.tagForReceive(MessageTag(message.tag()),fri));
289 return m_mpi_parallel_mng->receiveSerializer(serializer,p2p_message);
292 ByteSpan buf = buf_info.memoryBuffer();
293 Int64 size = buf.size();
295 TRACE_DEBUG(1,
"** MPITMQ THREAD ADD RECV B queue={0} message={1} size={2} same_mpi?={3}",
296 this,message,size,fri.isSameMpiRank());
299 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
300 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(message.tag(),fri);
301 Request r = m_mpi_adapter->directRecv(buf.data(),size,fri.destination().mpiRankValue(),
sizeof(
char),
302 char_data_type,mpi_tag.value(),
false);
310Request HybridMessageQueue::
311_addReceiveMessageId(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf_info)
313 MessageId message_id = message.messageId();
314 MessageId::SourceInfo si(message_id.sourceInfo());
316 if (si.rank()!=message.destinationRank())
317 ARCANE_FATAL(
"Incohence between messsage_id rank and destination rank x1={0} x2={1}",
318 si.rank(),message.destinationRank());
320 TRACE_DEBUG(1,
"** MPITMQ ADD_RECV (message_id) queue={0} message={1}",
323 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
324 if (fri.isSameMpiRank()){
325 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
326 return m_thread_queue->addReceive(p2p_message,buf_info);
329 TRACE_DEBUG(1,
"** MPITMQ MPI ADD RECV (message_id) queue={0} message={1}",
this,message);
331 ISerializer* serializer = buf_info.serializer();
333 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
335 TRACE_DEBUG(1,
"** MPI ADD RECV Serializer (message_id) message={0} p2p_message={1}",
336 message,p2p_message);
337 return m_mpi_parallel_mng->receiveSerializer(serializer,p2p_message);
340 ByteSpan buf = buf_info.memoryBuffer();
341 Int64 size = buf.size();
344 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
345 MessageId mpi_message(message_id);
346 MessageId::SourceInfo mpi_si(si);
347 mpi_si.setRank(fri.destination().mpiRank());
348 mpi_message.setSourceInfo(mpi_si);
349 return m_mpi_adapter->directRecv(buf.data(),size,mpi_message,
sizeof(
char),
350 char_data_type,
false);
357Request HybridMessageQueue::
358addReceive(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf)
360 _checkValidSource(message);
362 if (!message.isValid())
365 if (message.isRankTag())
366 return _addReceiveRankTag(message,buf);
368 if (message.isMessageId())
369 return _addReceiveMessageId(message,buf);
371 ARCANE_THROW(NotSupportedException,
"Invalid message_info");
377Request HybridMessageQueue::
378addSend(
const PointToPointMessageInfo& message,SendBufferInfo buf_info)
380 if (!message.isValid())
382 if (message.destinationRank().isNull())
383 ARCCORE_FATAL(
"Null destination");
384 if (!message.isRankTag())
385 ARCCORE_FATAL(
"Invalid message_info for sending: message.isRankTag() is false");
387 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
390 if (fri.isSameMpiRank()){
391 TRACE_DEBUG(1,
"** MPITMQ SHM ADD SEND S queue={0} message={1}",
this,message);
392 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
393 return m_thread_queue->addSend(p2p_message,buf_info);
397 MessageTag mpi_tag = m_rank_tag_builder.tagForSend(message.tag(),fri);
398 const ISerializer* serializer = buf_info.serializer();
400 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
401 p2p_message.setTag(mpi_tag);
402 TRACE_DEBUG(1,
"** MPITMQ MPI ADD SEND Serializer queue={0} message={1} p2p_message={2}",
403 this,message,p2p_message);
404 return m_mpi_parallel_mng->sendSerializer(serializer,p2p_message);
407 ByteConstSpan buf = buf_info.memoryBuffer();
408 Int64 size = buf.size();
413 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
415 TRACE_DEBUG(1,
"** MPITMQ MPI ADD SEND B queue={0} message={1} size={2} mpi_tag={3} mpi_rank={4}",
416 this,message,size,mpi_tag,fri.destination().mpiRank());
418 return m_mpi_adapter->directSend(buf.data(),size,fri.destination().mpiRankValue(),
419 sizeof(
char),char_data_type,mpi_tag.value(),
false);
426MP::MessageId HybridMessageQueue::
427probe(
const MP::PointToPointMessageInfo& message)
429 TRACE_DEBUG(1,
"Probe msg='{0}' queue={1} is_valid={2}",
430 message,
this,message.isValid());
432 MessageRank orig = message.emiterRank();
436 if (!message.isValid())
440 if (!message.isRankTag())
441 ARCCORE_FATAL(
"Invalid message_info: message.isRankTag() is false");
443 MessageRank dest = message.destinationRank();
444 MessageTag user_tag = message.tag();
445 bool is_blocking = message.isBlocking();
448 if (user_tag.isNull())
449 ARCANE_THROW(NotImplementedException,
"probe with ANY_TAG");
450 FullRankInfo orig_fri = m_rank_tag_builder.rank(orig);
451 FullRankInfo dest_fri = m_rank_tag_builder.rank(dest);
452 MessageId message_id;
453 Int32 found_dest = dest.value();
454 const bool is_any_source = dest.isNull() || dest.isAnySource();
455 if (dest.isNull() && !m_is_allow_null_rank_for_any_source)
456 ARCANE_FATAL(
"Can not use probe() with null rank. Use MessageRank::anySourceRank() instead");
460 MP::PointToPointMessageInfo p2p_message(message);
461 p2p_message.setEmiterRank(orig_fri.localRank());
462 message_id = m_thread_queue->probe(p2p_message);
463 if (message_id.isValid()){
467 found_dest = orig_fri.mpiRankValue()*m_local_nb_rank + message_id.sourceInfo().rank().value();
468 TRACE_DEBUG(2,
"Probe with null_rank (thread) orig={0} found_dest={1} tag={2}",
469 orig,found_dest,user_tag);
477 for( Integer z=0, zn=m_local_nb_rank; z<zn; ++z ){
478 MP::PointToPointMessageInfo mpi_message(message);
479 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri.localRank(),MessageRank(z));
480 mpi_message.setTag(mpi_tag);
481 TRACE_DEBUG(2,
"Probe with null_rank orig={0} dest={1} tag={2}",orig,dest,mpi_tag);
482 message_id = m_mpi_adapter->probeMessage(mpi_message);
483 if (message_id.isValid()){
486 MessageRank mpi_rank = message_id.sourceInfo().rank();
487 MessageTag ret_tag = message_id.sourceInfo().tag();
488 Int32 local_rank = m_rank_tag_builder.getReceiveRankFromTag(ret_tag);
489 found_dest = mpi_rank.value()*m_local_nb_rank + local_rank;
490 TRACE_DEBUG(2,
"Probe null rank found mpi_rank={0} local_rank={1} tag={2}",
491 ret_tag,mpi_rank,local_rank,ret_tag);
500 if (orig_fri.mpiRank()==dest_fri.mpiRank()){
501 MP::PointToPointMessageInfo p2p_message(message);
502 p2p_message.setDestinationRank(MP::MessageRank(dest_fri.localRank()));
503 p2p_message.setEmiterRank(MessageRank(orig_fri.localRank()));
504 message_id = m_thread_queue->probe(p2p_message);
507 MP::PointToPointMessageInfo mpi_message(message);
508 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri,dest_fri);
509 mpi_message.setTag(mpi_tag);
510 mpi_message.setDestinationRank(MP::MessageRank(dest_fri.mpiRank()));
511 TRACE_DEBUG(2,
"Probe orig={0} dest={1} mpi_tag={2} user_tag={3}",orig,dest,mpi_tag,user_tag);
512 message_id = m_mpi_adapter->probeMessage(mpi_message);
515 if (message_id.isValid()){
518 MessageId::SourceInfo si = message_id.sourceInfo();
519 si.setRank(MessageRank(found_dest));
520 message_id.setSourceInfo(si);
528MP::MessageSourceInfo HybridMessageQueue::
529legacyProbe(
const MP::PointToPointMessageInfo& message)
531 TRACE_DEBUG(1,
"LegacyProbe msg='{0}' queue={1} is_valid={2}",
532 message,
this,message.isValid());
534 MessageRank orig = message.emiterRank();
538 if (!message.isValid())
542 if (!message.isRankTag())
543 ARCCORE_FATAL(
"Invalid message_info: message.isRankTag() is false");
545 const MessageRank dest = message.destinationRank();
546 const MessageTag user_tag = message.tag();
547 const bool is_blocking = message.isBlocking();
550 if (user_tag.isNull())
551 ARCANE_THROW(NotImplementedException,
"legacyProbe with ANY_TAG");
552 FullRankInfo orig_fri = m_rank_tag_builder.rank(orig);
553 FullRankInfo dest_fri = m_rank_tag_builder.rank(dest);
554 MP::MessageSourceInfo message_source_info;
555 Int32 found_dest = dest.value();
556 const bool is_any_source = dest.isNull() || dest.isAnySource();
557 if (dest.isNull() && !m_is_allow_null_rank_for_any_source)
558 ARCANE_FATAL(
"Can not use legacyProbe() with null rank. Use MessageRank::anySourceRank() instead");
562 MP::PointToPointMessageInfo p2p_message(message);
563 p2p_message.setEmiterRank(orig_fri.localRank());
564 message_source_info = m_thread_queue->legacyProbe(p2p_message);
565 if (message_source_info.isValid()){
569 found_dest = orig_fri.mpiRankValue()*m_local_nb_rank + message_source_info.rank().value();
570 TRACE_DEBUG(2,
"LegacyProbe with null_rank (thread) orig={0} found_dest={1} tag={2}",
571 orig,found_dest,user_tag);
579 for( Integer z=0, zn=m_local_nb_rank; z<zn; ++z ){
580 MP::PointToPointMessageInfo mpi_message(message);
581 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri.localRank(),MessageRank(z));
582 mpi_message.setTag(mpi_tag);
583 TRACE_DEBUG(2,
"LegacyProbe with null_rank orig={0} dest={1} tag={2}",orig,dest,mpi_tag);
584 message_source_info = m_mpi_adapter->legacyProbeMessage(mpi_message);
585 if (message_source_info.isValid()){
588 MessageRank mpi_rank = message_source_info.rank();
589 MessageTag ret_tag = message_source_info.tag();
590 Int32 local_rank = m_rank_tag_builder.getReceiveRankFromTag(ret_tag);
591 found_dest = mpi_rank.value()*m_local_nb_rank + local_rank;
592 TRACE_DEBUG(2,
"LegacyProbe null rank found mpi_rank={0} local_rank={1} tag={2}",
593 ret_tag,mpi_rank,local_rank,ret_tag);
595 message_source_info.setTag(user_tag);
604 if (orig_fri.mpiRank()==dest_fri.mpiRank()){
605 MP::PointToPointMessageInfo p2p_message(message);
606 p2p_message.setDestinationRank(MP::MessageRank(dest_fri.localRank()));
607 p2p_message.setEmiterRank(MessageRank(orig_fri.localRank()));
608 TRACE_DEBUG(2,
"LegacyProbe SHM orig={0} dest={1} tag={2}",orig,dest,user_tag);
609 message_source_info = m_thread_queue->legacyProbe(p2p_message);
612 MP::PointToPointMessageInfo mpi_message(message);
613 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri,dest_fri);
614 mpi_message.setTag(mpi_tag);
615 mpi_message.setDestinationRank(MP::MessageRank(dest_fri.mpiRank()));
616 TRACE_DEBUG(2,
"LegacyProbe MPI orig={0} dest={1} mpi_tag={2} user_tag={3}",orig,dest,mpi_tag,user_tag);
617 message_source_info = m_mpi_adapter->legacyProbeMessage(mpi_message);
618 if (message_source_info.isValid()){
620 message_source_info.setTag(user_tag);
624 if (message_source_info.isValid()){
627 message_source_info.setRank(MessageRank(found_dest));
629 TRACE_DEBUG(2,
"LegacyProbe has matched message? = {0}",message_source_info.isValid());
630 return message_source_info;
636std::ostream&
operator<<(std::ostream& o,
const FullRankInfo& fri)
638 return o <<
"(local=" << fri.m_local_rank <<
",global="
639 << fri.m_global_rank <<
",mpi=" << fri.m_mpi_rank <<
")";
#define ARCANE_THROW(exception_class,...)
Macro pour envoyer une exception avec formattage.
#define ARCANE_FATAL(...)
Macro envoyant une exception FatalErrorException.
Interface d'une file de messages avec les threads.
std::ostream & operator<<(std::ostream &o, eExecutionPolicy exec_policy)
Affiche le nom de la politique d'exécution.
Déclarations des types et méthodes utilisés par les mécanismes d'échange de messages.
Int32 Integer
Type représentant un entier.
std::int32_t Int32
Type entier signé sur 32 bits.