14#include "arcane/utils/ArcanePrecomp.h"
16#include "arcane/utils/Array.h"
17#include "arcane/utils/FatalErrorException.h"
18#include "arcane/utils/NotImplementedException.h"
19#include "arcane/utils/NotSupportedException.h"
20#include "arcane/utils/ArgumentException.h"
21#include "arcane/utils/TraceInfo.h"
22#include "arcane/utils/ITraceMng.h"
23#include "arcane/utils/ValueConvert.h"
25#include "arcane/parallel/mpithread/HybridMessageQueue.h"
26#include "arcane/parallel/mpi/MpiAdapter.h"
27#include "arcane/parallel/mpi/MpiParallelMng.h"
33#define TRACE_DEBUG(needed_debug_level,format_str,...) \
34 if (m_debug_level>=needed_debug_level){ \
35 info() << String::format("Hybrid " format_str,__VA_ARGS__);\
49HybridMessageQueue(ISharedMemoryMessageQueue* thread_queue,MpiParallelMng* mpi_pm,
51: TraceAccessor(mpi_pm->traceMng())
52, m_thread_queue(thread_queue)
53, m_mpi_parallel_mng(mpi_pm)
54, m_mpi_adapter(mpi_pm->adapter())
55, m_local_nb_rank(local_nb_rank)
56, m_rank_tag_builder(local_nb_rank)
59 if (
auto v = Convert::Type<Int32>::tryParseFromEnvironment(
"ARCCORE_ALLOW_NULL_RANK_FOR_MPI_ANY_SOURCE",
true))
60 m_is_allow_null_rank_for_any_source = v.value() != 0;
66void HybridMessageQueue::
67_checkValidRank(MessageRank rank)
76void HybridMessageQueue::
77_checkValidSource(
const PointToPointMessageInfo& message)
79 MessageRank source = message.emiterRank();
87PointToPointMessageInfo HybridMessageQueue::
88_buildSharedMemoryMessage(
const PointToPointMessageInfo& message,
89 const SourceDestinationFullRankInfo& fri)
91 PointToPointMessageInfo p2p_message(message);
92 p2p_message.setEmiterRank(fri.source().localRank());
93 p2p_message.setDestinationRank(fri.destination().localRank());
100PointToPointMessageInfo HybridMessageQueue::
101_buildMPIMessage(
const PointToPointMessageInfo& message,
102 const SourceDestinationFullRankInfo& fri)
104 PointToPointMessageInfo p2p_message(message);
105 p2p_message.setEmiterRank(fri.source().mpiRank());
106 p2p_message.setDestinationRank(fri.destination().mpiRank());
113void HybridMessageQueue::
114waitAll(ArrayView<Request> requests)
117 Integer nb_request = requests.size();
118 UniqueArray<Request> mpi_requests;
119 UniqueArray<Request> thread_requests;
120 for( Integer i=0; i<nb_request; ++i ){
121 Request r = requests[i];
124 IRequestCreator* creator = r.creator();
125 if (creator==m_mpi_adapter) {
128 else if (creator==m_thread_queue)
129 thread_requests.add(r);
134 if (mpi_requests.size()!=0)
135 m_mpi_adapter->waitAllRequests(mpi_requests);
136 if (thread_requests.size()!=0)
137 m_thread_queue->waitAll(thread_requests);
140 for( Request r : requests )
147void HybridMessageQueue::
148waitSome(
Int32 rank,ArrayView<Request> requests,ArrayView<bool> requests_done,
149 bool is_non_blocking)
153 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} requests n={1} nb_done={2} is_non_blocking={3}",
154 rank,requests.size(),nb_done,is_non_blocking);
155 nb_done = _testOrWaitSome(rank,requests,requests_done);
156 if (is_non_blocking || nb_done==(-1))
158 }
while (nb_done==0);
165_testOrWaitSome(
Int32 rank,ArrayView<Request> requests,ArrayView<bool> requests_done)
167 Integer nb_request = requests.size();
168 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} requests n={1}",rank,nb_request);
173 UniqueArray<Request> mpi_requests;
174 UniqueArray<Request> shm_requests;
176 UniqueArray<Integer> mpi_requests_index;
177 UniqueArray<Integer> shm_requests_index;
180 for( Integer i=0; i<nb_request; ++i ){
181 Request r = requests[i];
184 IRequestCreator* creator = r.creator();
185 if (creator==m_mpi_adapter){
187 mpi_requests_index.add(i);
189 else if (creator==m_thread_queue){
191 shm_requests_index.add(i);
197 TRACE_DEBUG(2,
"Hybrid: wait some rank={0} nb_mpi={1} nb_shm={2}",
198 rank,mpi_requests.size(),shm_requests.size());
204 if (mpi_requests.size()==0 && shm_requests.size()==0)
215 UniqueArray<bool> mpi_done_indexes;
216 Integer nb_mpi_request = mpi_requests.size();
218 if (nb_mpi_request!=0){
219 mpi_done_indexes.resize(nb_mpi_request);
220 mpi_done_indexes.fill(
false);
221 m_mpi_adapter->waitSomeRequests(mpi_requests,mpi_done_indexes,
true);
222 TRACE_DEBUG(2,
"Hybrid: MPI wait some requests n={0} after=",nb_mpi_request,mpi_done_indexes);
223 for( Integer i=0; i<nb_mpi_request; ++i ){
224 Integer index_in_global = mpi_requests_index[i];
225 if (mpi_done_indexes[i]){
226 requests_done[index_in_global] =
true;
227 requests[index_in_global].reset();
229 TRACE_DEBUG(1,
"MPI rank={0} set done i={1} in_global={2}",
230 rank,i,index_in_global);
233 requests[index_in_global] = mpi_requests[i];
237 UniqueArray<bool> shm_done_indexes;
238 Integer nb_shm_request = shm_requests.size();
239 TRACE_DEBUG(2,
"SHM wait some requests n={0}",nb_shm_request);
240 if (shm_requests.size()!=0){
241 shm_done_indexes.resize(nb_shm_request);
242 shm_done_indexes.fill(
false);
243 m_thread_queue->waitSome(rank,shm_requests,shm_done_indexes,
true);
244 for( Integer i=0; i<nb_shm_request; ++i ){
245 Integer index_in_global = shm_requests_index[i];
246 if (shm_done_indexes[i]){
247 requests_done[index_in_global] =
true;
248 requests[index_in_global].reset();
250 TRACE_DEBUG(1,
"SHM rank={0} set done i={1} in_global={2}",
251 rank,i,index_in_global);
254 requests[index_in_global] = shm_requests[i];
263Request HybridMessageQueue::
264_addReceiveRankTag(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf_info)
271 if (message.destinationRank().isNull())
272 ARCANE_THROW(NotSupportedException,
"Receive with any rank. Use probe() and MessageId instead");
274 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
275 bool is_same_mpi_rank = fri.isSameMpiRank();
277 if (is_same_mpi_rank){
278 TRACE_DEBUG(1,
"** MPITMQ SHM ADD RECV S queue={0} message={1}",
this,message);
279 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
280 return m_thread_queue->addReceive(p2p_message,buf_info);
283 ISerializer* serializer = buf_info.serializer();
285 TRACE_DEBUG(1,
"** MPITMQ MPI ADD RECV S queue={0} message={1}",
this,message);
286 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
287 p2p_message.setTag(m_rank_tag_builder.tagForReceive(MessageTag(message.tag()),fri));
288 return m_mpi_parallel_mng->receiveSerializer(serializer,p2p_message);
291 ByteSpan buf = buf_info.memoryBuffer();
292 Int64 size = buf.size();
294 TRACE_DEBUG(1,
"** MPITMQ THREAD ADD RECV B queue={0} message={1} size={2} same_mpi?={3}",
295 this,message,size,fri.isSameMpiRank());
298 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
299 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(message.tag(),fri);
300 Request r = m_mpi_adapter->directRecv(buf.data(),size,fri.destination().mpiRankValue(),
sizeof(
char),
301 char_data_type,mpi_tag.value(),
false);
309Request HybridMessageQueue::
310_addReceiveMessageId(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf_info)
312 MessageId message_id = message.messageId();
313 MessageId::SourceInfo si(message_id.sourceInfo());
315 if (si.rank()!=message.destinationRank())
316 ARCANE_FATAL(
"Incohence between messsage_id rank and destination rank x1={0} x2={1}",
317 si.rank(),message.destinationRank());
319 TRACE_DEBUG(1,
"** MPITMQ ADD_RECV (message_id) queue={0} message={1}",
322 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
323 if (fri.isSameMpiRank()){
324 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
325 return m_thread_queue->addReceive(p2p_message,buf_info);
328 TRACE_DEBUG(1,
"** MPITMQ MPI ADD RECV (message_id) queue={0} message={1}",
this,message);
330 ISerializer* serializer = buf_info.serializer();
332 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
334 TRACE_DEBUG(1,
"** MPI ADD RECV Serializer (message_id) message={0} p2p_message={1}",
335 message,p2p_message);
336 return m_mpi_parallel_mng->receiveSerializer(serializer,p2p_message);
339 ByteSpan buf = buf_info.memoryBuffer();
340 Int64 size = buf.size();
343 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
344 MessageId mpi_message(message_id);
345 MessageId::SourceInfo mpi_si(si);
346 mpi_si.setRank(fri.destination().mpiRank());
347 mpi_message.setSourceInfo(mpi_si);
348 return m_mpi_adapter->directRecv(buf.data(),size,mpi_message,
sizeof(
char),
349 char_data_type,
false);
356Request HybridMessageQueue::
357addReceive(
const PointToPointMessageInfo& message,ReceiveBufferInfo buf)
359 _checkValidSource(message);
361 if (!message.isValid())
364 if (message.isRankTag())
365 return _addReceiveRankTag(message,buf);
367 if (message.isMessageId())
368 return _addReceiveMessageId(message,buf);
370 ARCANE_THROW(NotSupportedException,
"Invalid message_info");
376Request HybridMessageQueue::
377addSend(
const PointToPointMessageInfo& message,SendBufferInfo buf_info)
379 if (!message.isValid())
381 if (message.destinationRank().isNull())
382 ARCCORE_FATAL(
"Null destination");
383 if (!message.isRankTag())
384 ARCCORE_FATAL(
"Invalid message_info for sending: message.isRankTag() is false");
386 SourceDestinationFullRankInfo fri = _getFullRankInfo(message);
389 if (fri.isSameMpiRank()){
390 TRACE_DEBUG(1,
"** MPITMQ SHM ADD SEND S queue={0} message={1}",
this,message);
391 PointToPointMessageInfo p2p_message(_buildSharedMemoryMessage(message,fri));
392 return m_thread_queue->addSend(p2p_message,buf_info);
396 MessageTag mpi_tag = m_rank_tag_builder.tagForSend(message.tag(),fri);
397 const ISerializer* serializer = buf_info.serializer();
399 PointToPointMessageInfo p2p_message(_buildMPIMessage(message,fri));
400 p2p_message.setTag(mpi_tag);
401 TRACE_DEBUG(1,
"** MPITMQ MPI ADD SEND Serializer queue={0} message={1} p2p_message={2}",
402 this,message,p2p_message);
403 return m_mpi_parallel_mng->sendSerializer(serializer,p2p_message);
406 ByteConstSpan buf = buf_info.memoryBuffer();
407 Int64 size = buf.size();
412 MPI_Datatype char_data_type = MpiBuiltIn::datatype(
char());
414 TRACE_DEBUG(1,
"** MPITMQ MPI ADD SEND B queue={0} message={1} size={2} mpi_tag={3} mpi_rank={4}",
415 this,message,size,mpi_tag,fri.destination().mpiRank());
417 return m_mpi_adapter->directSend(buf.data(),size,fri.destination().mpiRankValue(),
418 sizeof(
char),char_data_type,mpi_tag.value(),
false);
428 TRACE_DEBUG(1,
"Probe msg='{0}' queue={1} is_valid={2}",
429 message,
this,message.isValid());
435 if (!message.isValid())
440 ARCCORE_FATAL(
"Invalid message_info: message.isRankTag() is false");
443 MessageTag user_tag = message.
tag();
447 if (user_tag.isNull())
448 ARCANE_THROW(NotImplementedException,
"probe with ANY_TAG");
449 FullRankInfo orig_fri = m_rank_tag_builder.rank(orig);
450 FullRankInfo dest_fri = m_rank_tag_builder.rank(dest);
451 MessageId message_id;
452 Int32 found_dest = dest.value();
453 const bool is_any_source = dest.isNull() || dest.isAnySource();
454 if (dest.isNull() && !m_is_allow_null_rank_for_any_source)
455 ARCANE_FATAL(
"Can not use probe() with null rank. Use MessageRank::anySourceRank() instead");
460 p2p_message.setEmiterRank(orig_fri.localRank());
461 message_id = m_thread_queue->probe(p2p_message);
462 if (message_id.isValid()){
466 found_dest = orig_fri.mpiRankValue()*m_local_nb_rank + message_id.sourceInfo().rank().value();
467 TRACE_DEBUG(2,
"Probe with null_rank (thread) orig={0} found_dest={1} tag={2}",
468 orig,found_dest,user_tag);
476 for( Integer z=0, zn=m_local_nb_rank; z<zn; ++z ){
478 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri.localRank(),MessageRank(z));
479 mpi_message.setTag(mpi_tag);
480 TRACE_DEBUG(2,
"Probe with null_rank orig={0} dest={1} tag={2}",orig,dest,mpi_tag);
481 message_id = m_mpi_adapter->probeMessage(mpi_message);
482 if (message_id.isValid()){
485 MessageRank mpi_rank = message_id.sourceInfo().rank();
486 MessageTag ret_tag = message_id.sourceInfo().tag();
487 Int32 local_rank = m_rank_tag_builder.getReceiveRankFromTag(ret_tag);
488 found_dest = mpi_rank.value()*m_local_nb_rank + local_rank;
489 TRACE_DEBUG(2,
"Probe null rank found mpi_rank={0} local_rank={1} tag={2}",
490 ret_tag,mpi_rank,local_rank,ret_tag);
499 if (orig_fri.mpiRank()==dest_fri.mpiRank()){
502 p2p_message.setEmiterRank(MessageRank(orig_fri.localRank()));
503 message_id = m_thread_queue->probe(p2p_message);
507 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri,dest_fri);
508 mpi_message.setTag(mpi_tag);
510 TRACE_DEBUG(2,
"Probe orig={0} dest={1} mpi_tag={2} user_tag={3}",orig,dest,mpi_tag,user_tag);
511 message_id = m_mpi_adapter->probeMessage(mpi_message);
514 if (message_id.isValid()){
517 MessageId::SourceInfo si = message_id.sourceInfo();
518 si.setRank(MessageRank(found_dest));
519 message_id.setSourceInfo(si);
530 TRACE_DEBUG(1,
"LegacyProbe msg='{0}' queue={1} is_valid={2}",
531 message,
this,message.isValid());
537 if (!message.isValid())
542 ARCCORE_FATAL(
"Invalid message_info: message.isRankTag() is false");
545 const MessageTag user_tag = message.
tag();
546 const bool is_blocking = message.
isBlocking();
549 if (user_tag.isNull())
550 ARCANE_THROW(NotImplementedException,
"legacyProbe with ANY_TAG");
551 FullRankInfo orig_fri = m_rank_tag_builder.rank(orig);
552 FullRankInfo dest_fri = m_rank_tag_builder.rank(dest);
554 Int32 found_dest = dest.value();
555 const bool is_any_source = dest.isNull() || dest.isAnySource();
556 if (dest.isNull() && !m_is_allow_null_rank_for_any_source)
557 ARCANE_FATAL(
"Can not use legacyProbe() with null rank. Use MessageRank::anySourceRank() instead");
562 p2p_message.setEmiterRank(orig_fri.localRank());
563 message_source_info = m_thread_queue->legacyProbe(p2p_message);
564 if (message_source_info.
isValid()){
568 found_dest = orig_fri.mpiRankValue()*m_local_nb_rank + message_source_info.
rank().
value();
569 TRACE_DEBUG(2,
"LegacyProbe with null_rank (thread) orig={0} found_dest={1} tag={2}",
570 orig,found_dest,user_tag);
578 for( Integer z=0, zn=m_local_nb_rank; z<zn; ++z ){
580 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri.localRank(),MessageRank(z));
581 mpi_message.setTag(mpi_tag);
582 TRACE_DEBUG(2,
"LegacyProbe with null_rank orig={0} dest={1} tag={2}",orig,dest,mpi_tag);
583 message_source_info = m_mpi_adapter->legacyProbeMessage(mpi_message);
584 if (message_source_info.
isValid()){
587 MessageRank mpi_rank = message_source_info.
rank();
588 MessageTag ret_tag = message_source_info.
tag();
589 Int32 local_rank = m_rank_tag_builder.getReceiveRankFromTag(ret_tag);
590 found_dest = mpi_rank.value()*m_local_nb_rank + local_rank;
591 TRACE_DEBUG(2,
"LegacyProbe null rank found mpi_rank={0} local_rank={1} tag={2}",
592 ret_tag,mpi_rank,local_rank,ret_tag);
594 message_source_info.
setTag(user_tag);
603 if (orig_fri.mpiRank()==dest_fri.mpiRank()){
606 p2p_message.setEmiterRank(MessageRank(orig_fri.localRank()));
607 TRACE_DEBUG(2,
"LegacyProbe SHM orig={0} dest={1} tag={2}",orig,dest,user_tag);
608 message_source_info = m_thread_queue->legacyProbe(p2p_message);
612 MessageTag mpi_tag = m_rank_tag_builder.tagForReceive(user_tag,orig_fri,dest_fri);
613 mpi_message.setTag(mpi_tag);
615 TRACE_DEBUG(2,
"LegacyProbe MPI orig={0} dest={1} mpi_tag={2} user_tag={3}",orig,dest,mpi_tag,user_tag);
616 message_source_info = m_mpi_adapter->legacyProbeMessage(mpi_message);
617 if (message_source_info.
isValid()){
619 message_source_info.
setTag(user_tag);
623 if (message_source_info.
isValid()){
626 message_source_info.
setRank(MessageRank(found_dest));
628 TRACE_DEBUG(2,
"LegacyProbe has matched message? = {0}",message_source_info.
isValid());
629 return message_source_info;
635std::ostream&
operator<<(std::ostream& o,
const FullRankInfo& fri)
637 return o <<
"(local=" << fri.m_local_rank <<
",global="
638 << fri.m_global_rank <<
",mpi=" << fri.m_mpi_rank <<
")";
#define ARCANE_THROW(exception_class,...)
Macro pour envoyer une exception avec formattage.
#define ARCANE_FATAL(...)
Macro envoyant une exception FatalErrorException.
Lecteur des fichiers de maillage via la bibliothèque LIMA.
Int32 value() const
Valeur du rang.
Informations sur la source d'un message.
MessageTag tag() const
Tag du message.
MessageRank rank() const
Rang de la source.
bool isValid() const
Indique si la source est valide.
void setTag(MessageTag tag)
Positionne le tag du message.
void setRank(MessageRank rank)
Positionne le rang de la source.
Informations pour envoyer/recevoir un message point à point.
MessageRank emiterRank() const
Rang de l'émetteur du message.
MessageRank destinationRank() const
Rang de la destination du message.
bool isBlocking() const
Indique si le message est bloquant.
MessageTag tag() const
Tag du message.
bool isRankTag() const
Vrai si l'instance a été créée avec un couple (rank,tag). Dans ce cas rank() et tag() sont valides.
std::ostream & operator<<(std::ostream &o, eExecutionPolicy exec_policy)
Affiche le nom de la politique d'exécution.
Déclarations des types et méthodes utilisés par les mécanismes d'échange de messages.
Int32 Integer
Type représentant un entier.