14#include "arcane/utils/Array.h"
15#include "arcane/utils/PlatformUtils.h"
16#include "arcane/utils/String.h"
17#include "arcane/utils/ITraceMng.h"
18#include "arcane/utils/Real2.h"
19#include "arcane/utils/Real3.h"
20#include "arcane/utils/Real2x2.h"
21#include "arcane/utils/Real3x3.h"
22#include "arcane/utils/APReal.h"
23#include "arcane/utils/FatalErrorException.h"
24#include "arcane/utils/NotImplementedException.h"
25#include "arcane/utils/NotSupportedException.h"
26#include "arcane/utils/IThreadBarrier.h"
27#include "arcane/utils/CheckedConvert.h"
29#include "arcane/MeshVariableRef.h"
30#include "arcane/IParallelMng.h"
31#include "arcane/ItemGroup.h"
32#include "arcane/IMesh.h"
33#include "arcane/IBase.h"
35#include "arcane/parallel/mpithread/HybridParallelDispatch.h"
36#include "arcane/parallel/mpithread/HybridParallelMng.h"
37#include "arcane/parallel/mpithread/HybridMessageQueue.h"
38#include "arcane/parallel/mpi/MpiParallelMng.h"
39#include "arcane/parallel/mpi/MpiParallelDispatch.h"
55template<
class Type> HybridParallelDispatch<Type>::
56HybridParallelDispatch(ITraceMng* tm,HybridParallelMng* pm,HybridMessageQueue* message_queue,
57 ArrayView<HybridParallelDispatch<Type>*> all_dispatchs)
60, m_local_rank(pm->localRank())
61, m_local_nb_rank(pm->localNbRank())
62, m_global_rank(pm->commRank())
63, m_global_nb_rank(pm->commSize())
64, m_mpi_rank(pm->mpiParallelMng()->commRank())
65, m_mpi_nb_rank(pm->mpiParallelMng()->commSize())
66, m_all_dispatchs(all_dispatchs)
67, m_message_queue(message_queue)
70 m_reduce_infos.m_index = 0;
74 m_all_dispatchs[m_local_rank] =
this;
77 MpiParallelMng* mpi_pm = pm->mpiParallelMng();
78 IParallelDispatchT<Type>* pd = mpi_pm->dispatcher((
Type*)
nullptr);
82 m_mpi_dispatcher =
dynamic_cast<MpiParallelDispatchT<Type>*
>(pd);
83 if (!m_mpi_dispatcher)
90template<
class Type> HybridParallelDispatch<Type>::
91~HybridParallelDispatch()
99template<
class Type>
void HybridParallelDispatch<Type>::
114#define ARCANE_DEFINE_INTEGRAL_TYPE(datatype)\
116class _ThreadIntegralType<datatype>\
119 typedef TrueType IsIntegral;\
122ARCANE_DEFINE_INTEGRAL_TYPE(
long long);
123ARCANE_DEFINE_INTEGRAL_TYPE(
long);
124ARCANE_DEFINE_INTEGRAL_TYPE(
int);
125ARCANE_DEFINE_INTEGRAL_TYPE(
short);
126ARCANE_DEFINE_INTEGRAL_TYPE(
unsigned long long);
127ARCANE_DEFINE_INTEGRAL_TYPE(
unsigned long);
128ARCANE_DEFINE_INTEGRAL_TYPE(
unsigned int);
129ARCANE_DEFINE_INTEGRAL_TYPE(
unsigned short);
130ARCANE_DEFINE_INTEGRAL_TYPE(
double);
131ARCANE_DEFINE_INTEGRAL_TYPE(
float);
132ARCANE_DEFINE_INTEGRAL_TYPE(
HPReal);
139template<
class Type>
void
144 ARCANE_UNUSED(all_dispatchs);
151 ARCANE_UNUSED(nb_rank);
159template<
class Type>
void
173 for( Integer i=1; i<nb_rank; ++i ){
185 _sum_val = (
Type)(_sum_val + cval);
190 min_rank = _min_rank;
191 max_rank = _max_rank;
199template<
class Type>
void HybridParallelDispatch<Type>::
203 typedef typename _ThreadIntegralType<Type>::IsIntegral IntegralType;
204 m_reduce_infos.reduce_value = val;
205 _collectiveBarrier();
206 _computeMinMaxSum2(m_all_dispatchs,m_global_rank,min_val,max_val,sum_val,min_rank,max_rank,m_local_nb_rank,IntegralType());
207 if (m_local_rank==0){
214 m_mpi_dispatcher->computeMinMaxSumNoInit(min_val,max_val,sum_val,min_rank,max_rank);
219 m_min_max_sum_infos.m_min_value = min_val;
220 m_min_max_sum_infos.m_max_value = max_val;
221 m_min_max_sum_infos.m_sum_value = sum_val;
222 m_min_max_sum_infos.m_min_rank = min_rank;
223 m_min_max_sum_infos.m_max_rank = max_rank;
225 _collectiveBarrier();
226 m_min_max_sum_infos = m_all_dispatchs[0]->m_min_max_sum_infos;
227 min_val = m_min_max_sum_infos.m_min_value;
228 max_val = m_min_max_sum_infos.m_max_value;
229 sum_val = m_min_max_sum_infos.m_sum_value;
230 min_rank = m_min_max_sum_infos.m_min_rank;
231 max_rank = m_min_max_sum_infos.m_max_rank;
232 _collectiveBarrier();
238template<
class Type>
void HybridParallelDispatch<Type>::
239computeMinMaxSum(ConstArrayView<Type> values,
240 ArrayView<Type> min_values,
241 ArrayView<Type> max_values,
242 ArrayView<Type> sum_values,
243 ArrayView<Int32> min_ranks,
244 ArrayView<Int32> max_ranks)
248 typedef typename _ThreadIntegralType<Type>::IsIntegral IntegralType;
250 for(Integer i=0;i<n;++i) {
251 m_reduce_infos.reduce_value = values[i];
252 _collectiveBarrier();
253 _computeMinMaxSum2(m_all_dispatchs,m_global_rank,min_values[i],max_values[i],sum_values[i],min_ranks[i],max_ranks[i],m_local_nb_rank,IntegralType());
254 if (m_local_rank==0){
261 m_mpi_dispatcher->computeMinMaxSumNoInit(min_values[i],max_values[i],sum_values[i],min_ranks[i],max_ranks[i]);
266 m_min_max_sum_infos.m_min_value = min_values[i];
267 m_min_max_sum_infos.m_max_value = max_values[i];
268 m_min_max_sum_infos.m_sum_value = sum_values[i];
269 m_min_max_sum_infos.m_min_rank = min_ranks[i];
270 m_min_max_sum_infos.m_max_rank = max_ranks[i];
272 _collectiveBarrier();
273 m_min_max_sum_infos = m_all_dispatchs[0]->m_min_max_sum_infos;
274 min_values[i] = m_min_max_sum_infos.m_min_value;
275 max_values[i] = m_min_max_sum_infos.m_max_value;
276 sum_values[i] = m_min_max_sum_infos.m_sum_value;
277 min_ranks[i] = m_min_max_sum_infos.m_min_rank;
278 max_ranks[i] = m_min_max_sum_infos.m_max_rank;
279 _collectiveBarrier();
286template<
class Type>
void HybridParallelDispatch<Type>::
287broadcast(Span<Type> send_buf,
Int32 rank)
289 m_broadcast_view = send_buf;
290 _collectiveBarrier();
291 FullRankInfo fri = FullRankInfo::compute(
MP::MessageRank(rank),m_local_nb_rank);
292 int mpi_rank = fri.mpiRankValue();
293 if (m_mpi_rank==mpi_rank){
295 if (m_global_rank==rank){
297 m_parallel_mng->mpiParallelMng()->broadcast(send_buf.smallView(),mpi_rank);
300 m_all_dispatchs[m_local_rank]->m_broadcast_view.copy(m_all_dispatchs[fri.localRankValue()]->m_broadcast_view);
304 if (m_local_rank==0){
306 m_parallel_mng->mpiParallelMng()->broadcast(send_buf.smallView(),mpi_rank);
309 _collectiveBarrier();
310 if (m_mpi_rank!=mpi_rank){
312 m_all_dispatchs[m_local_rank]->m_broadcast_view.copy(m_all_dispatchs[0]->m_broadcast_view);
314 _collectiveBarrier();
320template<
class Type>
void HybridParallelDispatch<Type>::
321allGather(Span<const Type> send_buf,Span<Type> recv_buf)
324 m_const_view = send_buf;
325 _collectiveBarrier();
326 Int64 total_size = 0;
327 for(
Int32 i=0; i<m_local_nb_rank; ++i ){
328 total_size += m_all_dispatchs[i]->m_const_view.size();
330 if (m_local_rank==0){
332 UniqueArray<Type> local_buf(total_size);
333 for( Integer i=0; i<m_local_nb_rank; ++i ){
334 Span<const Type> view = m_all_dispatchs[i]->m_const_view;
335 Int64 size = view.size();
336 for(
Int64 j=0; j<size; ++j )
337 local_buf[j+index] = view[j];
340 IParallelMng* pm = m_parallel_mng->mpiParallelMng();
342 pm->
allGather(local_buf,recv_buf.smallView());
343 m_const_view = recv_buf;
345 _collectiveBarrier();
346 if (m_local_rank!=0){
347 Span<const Type> view = m_all_dispatchs[0]->m_const_view;
350 _collectiveBarrier();
356template<
class Type>
void HybridParallelDispatch<Type>::
357gather(Span<const Type> send_buf,Span<Type> recv_buf,
Int32 root_rank)
359 UniqueArray<Type> tmp_buf;
360 if (m_global_rank==root_rank)
361 allGather(send_buf,recv_buf);
363 tmp_buf.resize(send_buf.size() * m_global_nb_rank);
364 allGather(send_buf,tmp_buf);
371template<
class Type>
void HybridParallelDispatch<Type>::
372allGatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf)
374 m_const_view = send_buf;
375 _collectiveBarrier();
376 Int64 total_size = 0;
377 for( Integer i=0; i<m_local_nb_rank; ++i ){
378 total_size += m_all_dispatchs[i]->m_const_view.size();
380 if (m_local_rank==0){
382 UniqueArray<Type> local_buf(total_size);
383 for( Integer i=0; i<m_local_nb_rank; ++i ){
384 Span<const Type> view = m_all_dispatchs[i]->m_const_view;
385 Int64 size = view.size();
386 for(
Int64 j=0; j<size; ++j )
387 local_buf[j+index] = view[j];
390 m_parallel_mng->mpiParallelMng()->allGatherVariable(local_buf,recv_buf);
391 m_const_view = recv_buf.constView();
393 _collectiveBarrier();
394 if (m_local_rank!=0){
395 Span<const Type> view = m_all_dispatchs[0]->m_const_view;
396 recv_buf.resize(view.size());
399 _collectiveBarrier();
405template<
class Type>
void HybridParallelDispatch<Type>::
406gatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf,
Int32 root_rank)
408 UniqueArray<Type> tmp_buf;
409 if (m_global_rank==root_rank)
410 allGatherVariable(send_buf,recv_buf);
412 allGatherVariable(send_buf,tmp_buf);
419void HybridParallelDispatch<Type>::
420scatterVariable(Span<const Type> send_buf, Span<Type> recv_buf,
Int32 root)
422 m_const_view = send_buf;
423 m_recv_view = recv_buf;
425 _collectiveBarrier();
428 Int64 total_size = 0;
429 for (Integer i = 0; i < m_local_nb_rank; ++i) {
430 total_size += m_all_dispatchs[i]->m_recv_view.size();
433 _collectiveBarrier();
436 if (m_local_rank == 0) {
437 FullRankInfo fri(FullRankInfo::compute(MessageRank(root), m_local_nb_rank));
439 UniqueArray<Type> local_recv_buf(total_size);
442 if (m_mpi_rank == fri.mpiRankValue()) {
444 m_parallel_mng->mpiParallelMng()->scatterVariable(m_all_dispatchs[fri.localRankValue()]->m_const_view.smallView(),
445 local_recv_buf, fri.mpiRankValue());
450 m_parallel_mng->mpiParallelMng()->scatterVariable(m_const_view.smallView(), local_recv_buf, fri.mpiRankValue());
455 for (Integer i = 0; i < m_local_nb_rank; ++i) {
456 Int64 size = m_all_dispatchs[i]->m_recv_view.size();
457 for (Integer j = 0; j < size; ++j) {
458 m_all_dispatchs[i]->m_recv_view[j] = local_recv_buf[compt++];
462 _collectiveBarrier();
463 recv_buf.copy(m_recv_view);
464 _collectiveBarrier();
470template<
class Type>
void HybridParallelDispatch<Type>::
471allToAll(Span<const Type> send_buf,Span<Type> recv_buf,
Int32 count)
473 Int32 global_nb_rank = m_global_nb_rank;
480 for( Integer i=0; i<global_nb_rank; ++i ){
481 send_indexes[i] = count * i;
482 recv_indexes[i] = count * i;
484 this->allToAllVariable(send_buf,send_count,send_indexes,recv_buf,recv_count,recv_indexes);
490template<
class Type>
void HybridParallelDispatch<Type>::
491allToAllVariable(Span<const Type> g_send_buf,
492 Int32ConstArrayView g_send_count,
493 Int32ConstArrayView g_send_index,
494 Span<Type> g_recv_buf,
495 Int32ConstArrayView g_recv_count,
496 Int32ConstArrayView g_recv_index
499 m_alltoallv_infos.send_buf = g_send_buf;
500 m_alltoallv_infos.send_count = g_send_count;
501 m_alltoallv_infos.send_index = g_send_index;
502 m_alltoallv_infos.recv_buf = g_recv_buf;
503 m_alltoallv_infos.recv_count = g_recv_count;
504 m_alltoallv_infos.recv_index = g_recv_index;
506 _collectiveBarrier();
508 UniqueArray<Type> tmp_recv_buf;
513 if (m_local_rank==0){
516 tmp_send_count.fill(0);
518 tmp_recv_count.fill(0);
520 Int64 total_send_size = 0;
521 Int64 total_recv_size = 0;
523 for( Integer i=0; i<m_local_nb_rank; ++i ){
524 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
525 total_send_size += vinfo.send_buf.size();
526 total_recv_size += vinfo.recv_buf.size();
529 UniqueArray<Type> tmp_send_buf(total_send_size);
530 tmp_recv_buf.resize(total_recv_size);
533 for( Integer i=0; i<m_local_nb_rank; ++i ){
534 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
536 for( Integer z=0; z<m_global_nb_rank; ++z ){
538 FullRankInfo fri(FullRankInfo::compute(
MP::MessageRank(z),m_local_nb_rank));
539 Int32 fri_mpi_rank = fri.mpiRankValue();
541 Int32 nb_send = vinfo.send_count[z];
543 tmp_send_count[fri_mpi_rank] += nb_send;
544 tmp_recv_count[fri_mpi_rank] += vinfo.recv_count[z];
547 info() <<
"my_local=" << i <<
" dest=" << z
548 <<
" send_count=" << vinfo.send_count[z] <<
" send_index=" << vinfo.send_index[z]
549 <<
" recv_count=" << vinfo.recv_count[z] <<
" recv_index=" << vinfo.recv_index[z];
551 Integer vindex = vinfo.send_index[z];
552 for( Integer w=0, wn=vinfo.send_count[z]; w<wn; ++w ){
553 info() <<
"V=" << vinfo.send_buf[ vindex + w ];
562 tmp_send_index[0] = 0;
563 tmp_recv_index[0] = 0;
564 for( Integer k=1, nmpi=m_mpi_nb_rank; k<nmpi; ++k ){
565 tmp_send_index[k] = tmp_send_index[k-1] + tmp_send_count[k-1];
566 tmp_recv_index[k] = tmp_recv_index[k-1] + tmp_recv_count[k-1];
569 for( Integer i=0; i<m_local_nb_rank; ++i ){
570 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
572 for( Integer z=0; z<m_global_nb_rank; ++ z){
574 FullRankInfo fri(FullRankInfo::compute(
MP::MessageRank(z),m_local_nb_rank));
575 Int32 fri_mpi_rank = fri.mpiRankValue();
577 Integer nb_send = vinfo.send_count[z];
580 Integer tmp_current_index = tmp_send_index[fri_mpi_rank];
581 Integer local_current_index = vinfo.send_index[z];
582 for( Integer j=0; j<nb_send; ++j )
583 tmp_send_buf[j+tmp_current_index] = vinfo.send_buf[j+local_current_index];
584 tmp_send_index[fri_mpi_rank] += nb_send;
591 tmp_send_index[0] = 0;
592 tmp_recv_index[0] = 0;
593 for( Integer k=1, nmpi=m_mpi_nb_rank; k<nmpi; ++k ){
594 tmp_send_index[k] = tmp_send_index[k-1] + tmp_send_count[k-1];
595 tmp_recv_index[k] = tmp_recv_index[k-1] + tmp_recv_count[k-1];
612 info() <<
"AllToAllV nb_send=" << total_send_size <<
" nb_recv=" << total_recv_size;
613 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
614 info() <<
"INFOS Rank=" << k <<
" send_count=" << tmp_send_count[k] <<
" recv_count=" << tmp_recv_count[k]
615 <<
" send_index=" << tmp_send_index[k] <<
" recv_index=" << tmp_recv_index[k];
618 for( Integer i=0; i<tmp_send_buf.size(); ++i )
619 info() <<
"SEND_BUF[" << i <<
"] = " << tmp_send_buf[i];
621 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
622 info() <<
"SEND Rank=" << k <<
" send_count=" << tmp_send_count[k] <<
" recv_count=" << tmp_recv_count[k]
623 <<
" send_index=" << tmp_send_index[k] <<
" recv_index=" << tmp_recv_index[k];
624 Integer vindex = tmp_send_index[k];
625 for( Integer w=0, wn=tmp_send_count[k]; w<wn; ++w ){
626 info() <<
"V=" << tmp_send_buf[ vindex + w ];
631 m_parallel_mng->mpiParallelMng()->allToAllVariable(tmp_send_buf,tmp_send_count,
632 tmp_send_index,tmp_recv_buf,
633 tmp_recv_count,tmp_recv_index);
636 for( Integer i=0; i<tmp_recv_buf.size(); ++i )
637 info() <<
"RECV_BUF[" << i <<
"] = " << tmp_recv_buf[i];
639 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
640 info() <<
"RECV Rank=" << k <<
" send_count=" << tmp_send_count[k] <<
" recv_count=" << tmp_recv_count[k]
641 <<
" send_index=" << tmp_send_index[k] <<
" recv_index=" << tmp_recv_index[k];
642 Integer vindex = tmp_recv_index[k];
643 for( Integer w=0, wn=tmp_recv_count[k]; w<wn; ++w ){
644 info() <<
"V=" << tmp_recv_buf[ vindex + w ];
649 m_const_view = tmp_recv_buf.constView();
652 for( Integer z=0; z<m_global_nb_rank; ++ z){
653 FullRankInfo fri(FullRankInfo::compute(
MP::MessageRank(z),m_local_nb_rank));
654 Int32 fri_mpi_rank = fri.mpiRankValue();
656 for( Integer i=0; i<m_local_nb_rank; ++i ){
657 AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
658 Span<Type> my_buf = vinfo.recv_buf;
659 Int64 recv_size = vinfo.recv_count[z];
660 Int64 recv_index = tmp_recv_index[fri_mpi_rank];
662 Span<const Type> recv_view = tmp_recv_buf.span().subSpan(recv_index,recv_size);
664 Int64 my_recv_index = vinfo.recv_index[z];
669 tmp_recv_index[fri_mpi_rank] = CheckedConvert::toInt32(tmp_recv_index[fri_mpi_rank] + recv_size);
671 for(
Int64 j=0; j<recv_size; ++j )
672 my_buf[j+my_recv_index] = recv_view[j];
677 my_recv_index += recv_size;
682 _collectiveBarrier();
692template<
class Type>
auto HybridParallelDispatch<Type>::
693send(Span<const Type> send_buffer,
Int32 rank,
bool is_blocked) -> Request
695 eBlockingType block_mode = (is_blocked) ? MP::Blocking :
MP::NonBlocking;
696 PointToPointMessageInfo p2p_message(MessageRank(rank),block_mode);
697 return send(send_buffer,p2p_message);
703template<
class Type>
void HybridParallelDispatch<Type>::
704send(ConstArrayView<Type> send_buf,
Int32 rank)
706 send(send_buf,rank,
true);
712template<
class Type> Parallel::Request HybridParallelDispatch<Type>::
713receive(Span<Type> recv_buffer,
Int32 rank,
bool is_blocked)
715 eBlockingType block_mode = (is_blocked) ? MP::Blocking :
MP::NonBlocking;
716 PointToPointMessageInfo p2p_message(MessageRank(rank),block_mode);
717 return receive(recv_buffer,p2p_message);
723template<
class Type> Request HybridParallelDispatch<Type>::
724send(Span<const Type> send_buffer,
const PointToPointMessageInfo& message2)
726 PointToPointMessageInfo message(message2);
727 bool is_blocking = message.isBlocking();
728 message.setEmiterRank(MessageRank(m_global_rank));
729 Request r = m_message_queue->addSend(message, ConstMemoryView(send_buffer));
731 m_message_queue->waitAll(ArrayView<MP::Request>(1,&r));
740template<
class Type> Request HybridParallelDispatch<Type>::
741receive(Span<Type> recv_buffer,
const PointToPointMessageInfo& message2)
743 PointToPointMessageInfo message(message2);
744 message.setEmiterRank(MessageRank(m_global_rank));
745 bool is_blocking = message.isBlocking();
746 Request r = m_message_queue->addReceive(message,ReceiveBufferInfo(MutableMemoryView(recv_buffer)));
748 m_message_queue->waitAll(ArrayView<Request>(1,&r));
757template<
class Type>
void HybridParallelDispatch<Type>::
758recv(ArrayView<Type> recv_buffer,Integer rank)
760 recv(recv_buffer,rank,
true);
766template<
class Type>
void HybridParallelDispatch<Type>::
767sendRecv(ConstArrayView<Type> send_buffer,ArrayView<Type> recv_buffer,Integer proc)
769 ARCANE_UNUSED(send_buffer);
770 ARCANE_UNUSED(recv_buffer);
772 throw NotImplementedException(A_FUNCINFO);
778template<
class Type>
Type HybridParallelDispatch<Type>::
779allReduce(eReduceType op,
Type send_buf)
781 m_reduce_infos.reduce_value = send_buf;
784 _collectiveBarrier();
785 if (m_local_rank==0){
786 Type ret = m_all_dispatchs[0]->m_reduce_infos.reduce_value;
788 case Parallel::ReduceMin:
789 for( Integer i=1; i<m_local_nb_rank; ++i )
790 ret = math::min(ret,m_all_dispatchs[i]->m_reduce_infos.reduce_value);
792 case Parallel::ReduceMax:
793 for( Integer i=1; i<m_local_nb_rank; ++i )
794 ret = math::max(ret,m_all_dispatchs[i]->m_reduce_infos.reduce_value);
796 case Parallel::ReduceSum:
797 for( Integer i=1; i<m_local_nb_rank; ++i )
798 ret = (
Type)(ret + m_all_dispatchs[i]->m_reduce_infos.reduce_value);
803 ret = m_parallel_mng->mpiParallelMng()->reduce(op,ret);
804 m_all_dispatchs[0]->m_reduce_infos.reduce_value = ret;
807 _collectiveBarrier();
808 Type ret = m_all_dispatchs[0]->m_reduce_infos.reduce_value;
809 _collectiveBarrier();
816template <
class Type>
void HybridParallelDispatch<Type>::
817_applyReduceOperator(eReduceType op, Span<Type> result, AllDispatchView dispatch_view,
820 Int64 buf_size = result.size();
822 case Parallel::ReduceMin:
823 for (Integer i = first_rank; i <= last_rank; ++i)
824 for (
Int64 j = 0; j < buf_size; ++j)
825 result[j] = math::min(result[j], dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
827 case Parallel::ReduceMax:
828 for (Integer i = first_rank; i <= last_rank; ++i)
829 for (
Int64 j = 0; j < buf_size; ++j)
830 result[j] = math::max(result[j], dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
832 case Parallel::ReduceSum:
833 for (Integer i = first_rank; i <= last_rank; ++i)
834 for (Integer j = 0; j < buf_size; ++j) {
835 result[j] =
static_cast<Type>(result[j] + dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
846template<
class Type>
void HybridParallelDispatch<Type>::
847_allReduceOrScan(eReduceType op, Span<Type> send_buf,
bool is_scan)
849 m_reduce_infos.reduce_buf_span = send_buf;
850 ++m_reduce_infos.m_index;
851 Int64 buf_size = send_buf.size();
852 UniqueArray<Type> ret(buf_size);
854 UniqueArray<Type> previous_rank_ret;
855 MpiParallelMng* mpi_pm = m_parallel_mng->mpiParallelMng();
856 Int32 my_mpi_rank = mpi_pm->commRank();
857 Int32 mpi_nb_rank = mpi_pm->commSize();
861 _collectiveBarrier();
863 Integer index0 = m_all_dispatchs[0]->m_reduce_infos.m_index;
864 for( Integer i=0; i<m_local_nb_rank; ++i ){
865 Integer indexi = m_all_dispatchs[i]->m_reduce_infos.m_index;
866 if (index0!=m_all_dispatchs[i]->m_reduce_infos.m_index){
867 ARCANE_FATAL(
"INTERNAL: incoherent all reduce i0={0} in={1} n={2}",
873 if (m_local_rank==0){
874 const Int32 nb_local_rank = m_local_nb_rank;
875 for( Integer j=0; j<buf_size; ++j )
876 ret[j] = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span[j];
877 _applyReduceOperator(op, ret, m_all_dispatchs, 1, nb_local_rank - 1);
881 mpi_pm->scan(op, ret);
882 previous_rank_ret.resize(buf_size);
883 UniqueArray<Request> requests;
884 if (my_mpi_rank != 0)
885 requests.add(mpi_pm->recv(previous_rank_ret, my_mpi_rank - 1,
false));
886 if (my_mpi_rank != (mpi_nb_rank - 1))
887 requests.add(mpi_pm->send(ret, my_mpi_rank + 1,
false));
888 mpi_pm->waitAllRequests(requests);
889 if (my_mpi_rank != 0) {
891 _applyReduceOperator(op, previous_rank_ret, m_all_dispatchs, 0, 0);
892 send_buf.copy(previous_rank_ret);
900 mpi_pm->reduce(op, ret);
905 _collectiveBarrier();
908 if (m_local_rank != 0) {
909 Span<const Type> global_buf = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span;
910 ret.copy(global_buf);
912 _applyReduceOperator(op, ret, m_all_dispatchs, 1, m_local_rank);
916 _collectiveBarrier();
918 if (m_local_rank != 0) {
923 if (m_local_rank != 0) {
924 Span<const Type> global_buf = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span;
925 send_buf.copy(global_buf);
929 _collectiveBarrier();
935template <
class Type>
void HybridParallelDispatch<Type>::
936allReduce(eReduceType op, Span<Type> send_buf)
938 _allReduceOrScan(op, send_buf,
false);
944template<
class Type> Request HybridParallelDispatch<Type>::
945nonBlockingAllReduce(eReduceType op,Span<const Type> send_buf,Span<Type> recv_buf)
948 ARCANE_UNUSED(send_buf);
949 ARCANE_UNUSED(recv_buf);
950 throw NotImplementedException(A_FUNCINFO);
955template<
class Type> Request HybridParallelDispatch<Type>::
956nonBlockingAllGather(Span<const Type> send_buf, Span<Type> recv_buf)
958 ARCANE_UNUSED(send_buf);
959 ARCANE_UNUSED(recv_buf);
960 throw NotImplementedException(A_FUNCINFO);
966template<
class Type> Request HybridParallelDispatch<Type>::
967nonBlockingBroadcast(Span<Type> send_buf,
Int32 rank)
969 ARCANE_UNUSED(send_buf);
971 throw NotImplementedException(A_FUNCINFO);
977template<
class Type> Request HybridParallelDispatch<Type>::
978nonBlockingGather(Span<const Type> send_buf, Span<Type> recv_buf,
Int32 rank)
980 ARCANE_UNUSED(send_buf);
981 ARCANE_UNUSED(recv_buf);
983 throw NotImplementedException(A_FUNCINFO);
989template<
class Type> Request HybridParallelDispatch<Type>::
990nonBlockingAllToAll(Span<const Type> send_buf, Span<Type> recv_buf,
Int32 count)
992 ARCANE_UNUSED(send_buf);
993 ARCANE_UNUSED(recv_buf);
994 ARCANE_UNUSED(count);
995 throw NotImplementedException(A_FUNCINFO);
1001template<
class Type> Request HybridParallelDispatch<Type>::
1002nonBlockingAllToAllVariable(Span<const Type> send_buf, ConstArrayView<Int32> send_count,
1003 ConstArrayView<Int32> send_index, Span<Type> recv_buf,
1004 ConstArrayView<Int32> recv_count, ConstArrayView<Int32> recv_index)
1006 ARCANE_UNUSED(send_buf);
1007 ARCANE_UNUSED(recv_buf);
1008 ARCANE_UNUSED(send_count);
1009 ARCANE_UNUSED(recv_count);
1010 ARCANE_UNUSED(send_index);
1011 ARCANE_UNUSED(recv_index);
1012 throw NotImplementedException(A_FUNCINFO);
1018template<
class Type>
Type HybridParallelDispatch<Type>::
1019scan(eReduceType op,
Type send_buf)
1022 ARCANE_UNUSED(send_buf);
1023 throw NotImplementedException(A_FUNCINFO);
1029template<
class Type>
void HybridParallelDispatch<Type>::
1030scan(eReduceType op,ArrayView<Type> send_buf)
1032 _allReduceOrScan(op, send_buf,
true);
1038template<
class Type> Request HybridParallelDispatch<Type>::
1041 throw NotImplementedException(A_FUNCINFO);
1047template<
class Type>
void HybridParallelDispatch<Type>::
1050 m_parallel_mng->getThreadBarrier()->wait();
1056template class HybridParallelDispatch<char>;
1057template class HybridParallelDispatch<signed char>;
1058template class HybridParallelDispatch<unsigned char>;
1059template class HybridParallelDispatch<short>;
1060template class HybridParallelDispatch<unsigned short>;
1061template class HybridParallelDispatch<int>;
1062template class HybridParallelDispatch<unsigned int>;
1063template class HybridParallelDispatch<long>;
1064template class HybridParallelDispatch<unsigned long>;
1065template class HybridParallelDispatch<long long>;
1066template class HybridParallelDispatch<unsigned long long>;
1067template class HybridParallelDispatch<float>;
1068template class HybridParallelDispatch<double>;
1069template class HybridParallelDispatch<long double>;
1070template class HybridParallelDispatch<Real2>;
1071template class HybridParallelDispatch<Real3>;
1072template class HybridParallelDispatch<Real2x2>;
1073template class HybridParallelDispatch<Real3x3>;
1074template class HybridParallelDispatch<HPReal>;
1075template class HybridParallelDispatch<APReal>;
#define ARCANE_FATAL(...)
Macro envoyant une exception FatalErrorException.
Classe implémentant un réel Haute Précision.
virtual void allGather(ConstArrayView< char > send_buf, ArrayView< char > recv_buf)=0
Effectue un regroupement sur tous les processeurs. Il s'agit d'une opération collective....
Lecteur des fichiers de maillage via la bibliothèque LIMA.
Vue modifiable d'un tableau d'un type T.
Informations pour un message 'gather' pour le type de données DataType.
Exception lorsqu'une fonction n'est pas implémentée.
Déclarations des types et méthodes utilisés par les mécanismes d'échange de messages.
UniqueArray< Int32 > Int32UniqueArray
Tableau dynamique à une dimension d'entiers 32 bits.
Espace de nommage contenant les types et déclarations qui gèrent le mécanisme de parallélisme par éch...
eBlockingType
Type indiquant si un message est bloquant ou non.
Int32 Integer
Type représentant un entier.
Structure équivalente à la valeur booléenne vrai.
Structure équivalente à la valeur booléenne vrai.