14#include "arcane/utils/Array.h"
15#include "arcane/utils/PlatformUtils.h"
16#include "arcane/utils/String.h"
17#include "arcane/utils/ITraceMng.h"
18#include "arcane/utils/Real2.h"
19#include "arcane/utils/Real3.h"
20#include "arcane/utils/Real2x2.h"
21#include "arcane/utils/Real3x3.h"
22#include "arcane/utils/APReal.h"
23#include "arcane/utils/FatalErrorException.h"
24#include "arcane/utils/NotImplementedException.h"
25#include "arcane/utils/NotSupportedException.h"
26#include "arcane/utils/IThreadBarrier.h"
27#include "arcane/utils/CheckedConvert.h"
29#include "arcane/core/MeshVariableRef.h"
30#include "arcane/core/IParallelMng.h"
31#include "arcane/core/ItemGroup.h"
32#include "arcane/core/IMesh.h"
33#include "arcane/core/IBase.h"
35#include "arcane/parallel/mpithread/HybridParallelDispatch.h"
36#include "arcane/parallel/mpithread/HybridParallelMng.h"
37#include "arcane/parallel/mpithread/HybridMessageQueue.h"
38#include "arcane/parallel/mpi/MpiParallelMng.h"
39#include "arcane/parallel/mpi/MpiParallelDispatch.h"
60, m_local_rank(pm->localRank())
61, m_local_nb_rank(pm->localNbRank())
62, m_global_rank(pm->commRank())
63, m_global_nb_rank(pm->commSize())
64, m_mpi_rank(pm->mpiParallelMng()->commRank())
65, m_mpi_nb_rank(pm->mpiParallelMng()->commSize())
66, m_all_dispatchs(all_dispatchs)
67, m_message_queue(message_queue)
70 m_reduce_infos.m_index = 0;
74 m_all_dispatchs[m_local_rank] =
this;
77 MpiParallelMng* mpi_pm = pm->mpiParallelMng();
78 IParallelDispatchT<Type>* pd = mpi_pm->dispatcher((
Type*)
nullptr);
82 m_mpi_dispatcher =
dynamic_cast<MpiParallelDispatchT<Type>*
>(pd);
83 if (!m_mpi_dispatcher)
90template <
class Type> HybridParallelDispatch<Type>::
91~HybridParallelDispatch()
99template <
class Type>
void HybridParallelDispatch<Type>::
115#define ARCANE_DEFINE_INTEGRAL_TYPE(datatype) \
117 class _ThreadIntegralType<datatype> \
121 typedef TrueType IsIntegral; \
124ARCANE_DEFINE_INTEGRAL_TYPE(
long long);
125ARCANE_DEFINE_INTEGRAL_TYPE(
long);
126ARCANE_DEFINE_INTEGRAL_TYPE(
int);
127ARCANE_DEFINE_INTEGRAL_TYPE(
short);
128ARCANE_DEFINE_INTEGRAL_TYPE(
unsigned long long);
129ARCANE_DEFINE_INTEGRAL_TYPE(
unsigned long);
130ARCANE_DEFINE_INTEGRAL_TYPE(
unsigned int);
131ARCANE_DEFINE_INTEGRAL_TYPE(
unsigned short);
132ARCANE_DEFINE_INTEGRAL_TYPE(
double);
133ARCANE_DEFINE_INTEGRAL_TYPE(
float);
134ARCANE_DEFINE_INTEGRAL_TYPE(
HPReal);
142 template <
class Type>
void
147 ARCANE_UNUSED(all_dispatchs);
148 ARCANE_UNUSED(my_rank);
149 ARCANE_UNUSED(min_val);
150 ARCANE_UNUSED(max_val);
151 ARCANE_UNUSED(sum_val);
152 ARCANE_UNUSED(min_rank);
153 ARCANE_UNUSED(max_rank);
154 ARCANE_UNUSED(nb_rank);
162 template <
class Type>
void
167 ARCANE_UNUSED(my_rank);
170 Type cval0 = mtpd0->m_reduce_infos.reduce_value;
171 Type _min_val = cval0;
172 Type _max_val = cval0;
173 Type _sum_val = cval0;
176 for (
Integer i = 1; i < nb_rank; ++i) {
178 Type cval = mtpd->m_reduce_infos.reduce_value;
179 Int32 grank = mtpd->globalRank();
180 if (cval < _min_val) {
184 if (_max_val < cval) {
188 _sum_val = (
Type)(_sum_val + cval);
193 min_rank = _min_rank;
194 max_rank = _max_rank;
202template <
class Type>
void HybridParallelDispatch<Type>::
204 Int32& min_rank, Int32& max_rank)
206 typedef typename _ThreadIntegralType<Type>::IsIntegral IntegralType;
207 m_reduce_infos.reduce_value = val;
208 _collectiveBarrier();
209 _computeMinMaxSum2(m_all_dispatchs, m_global_rank, min_val, max_val, sum_val, min_rank, max_rank, m_local_nb_rank, IntegralType());
210 if (m_local_rank == 0) {
217 m_mpi_dispatcher->computeMinMaxSumNoInit(min_val, max_val, sum_val, min_rank, max_rank);
222 m_min_max_sum_infos.m_min_value = min_val;
223 m_min_max_sum_infos.m_max_value = max_val;
224 m_min_max_sum_infos.m_sum_value = sum_val;
225 m_min_max_sum_infos.m_min_rank = min_rank;
226 m_min_max_sum_infos.m_max_rank = max_rank;
228 _collectiveBarrier();
229 m_min_max_sum_infos = m_all_dispatchs[0]->m_min_max_sum_infos;
230 min_val = m_min_max_sum_infos.m_min_value;
231 max_val = m_min_max_sum_infos.m_max_value;
232 sum_val = m_min_max_sum_infos.m_sum_value;
233 min_rank = m_min_max_sum_infos.m_min_rank;
234 max_rank = m_min_max_sum_infos.m_max_rank;
235 _collectiveBarrier();
241template <
class Type>
void HybridParallelDispatch<Type>::
242computeMinMaxSum(ConstArrayView<Type> values,
243 ArrayView<Type> min_values,
244 ArrayView<Type> max_values,
245 ArrayView<Type> sum_values,
246 ArrayView<Int32> min_ranks,
247 ArrayView<Int32> max_ranks)
251 typedef typename _ThreadIntegralType<Type>::IsIntegral IntegralType;
253 for (Integer i = 0; i < n; ++i) {
254 m_reduce_infos.reduce_value = values[i];
255 _collectiveBarrier();
256 _computeMinMaxSum2(m_all_dispatchs, m_global_rank, min_values[i], max_values[i], sum_values[i], min_ranks[i], max_ranks[i], m_local_nb_rank, IntegralType());
257 if (m_local_rank == 0) {
264 m_mpi_dispatcher->computeMinMaxSumNoInit(min_values[i], max_values[i], sum_values[i], min_ranks[i], max_ranks[i]);
269 m_min_max_sum_infos.m_min_value = min_values[i];
270 m_min_max_sum_infos.m_max_value = max_values[i];
271 m_min_max_sum_infos.m_sum_value = sum_values[i];
272 m_min_max_sum_infos.m_min_rank = min_ranks[i];
273 m_min_max_sum_infos.m_max_rank = max_ranks[i];
275 _collectiveBarrier();
276 m_min_max_sum_infos = m_all_dispatchs[0]->m_min_max_sum_infos;
277 min_values[i] = m_min_max_sum_infos.m_min_value;
278 max_values[i] = m_min_max_sum_infos.m_max_value;
279 sum_values[i] = m_min_max_sum_infos.m_sum_value;
280 min_ranks[i] = m_min_max_sum_infos.m_min_rank;
281 max_ranks[i] = m_min_max_sum_infos.m_max_rank;
282 _collectiveBarrier();
289template <
class Type>
void HybridParallelDispatch<Type>::
290broadcast(Span<Type> send_buf, Int32 rank)
292 m_broadcast_view = send_buf;
293 _collectiveBarrier();
294 FullRankInfo fri = FullRankInfo::compute(MP::MessageRank(rank), m_local_nb_rank);
295 int mpi_rank = fri.mpiRankValue();
296 if (m_mpi_rank == mpi_rank) {
298 if (m_global_rank == rank) {
300 m_parallel_mng->mpiParallelMng()->broadcast(send_buf.smallView(), mpi_rank);
303 m_all_dispatchs[m_local_rank]->m_broadcast_view.copy(m_all_dispatchs[fri.localRankValue()]->m_broadcast_view);
307 if (m_local_rank == 0) {
309 m_parallel_mng->mpiParallelMng()->broadcast(send_buf.smallView(), mpi_rank);
312 _collectiveBarrier();
313 if (m_mpi_rank != mpi_rank) {
314 if (m_local_rank != 0)
315 m_all_dispatchs[m_local_rank]->m_broadcast_view.copy(m_all_dispatchs[0]->m_broadcast_view);
317 _collectiveBarrier();
323template <
class Type>
void HybridParallelDispatch<Type>::
324allGather(Span<const Type> send_buf, Span<Type> recv_buf)
327 m_const_view = send_buf;
328 _collectiveBarrier();
329 Int64 total_size = 0;
330 for (Int32 i = 0; i < m_local_nb_rank; ++i) {
331 total_size += m_all_dispatchs[i]->m_const_view.size();
333 if (m_local_rank == 0) {
335 UniqueArray<Type> local_buf(total_size);
336 for (Integer i = 0; i < m_local_nb_rank; ++i) {
337 Span<const Type> view = m_all_dispatchs[i]->m_const_view;
338 Int64 size = view.size();
339 for (Int64 j = 0; j < size; ++j)
340 local_buf[j + index] = view[j];
343 IParallelMng* pm = m_parallel_mng->mpiParallelMng();
345 pm->allGather(local_buf, recv_buf.smallView());
346 m_const_view = recv_buf;
348 _collectiveBarrier();
349 if (m_local_rank != 0) {
350 Span<const Type> view = m_all_dispatchs[0]->m_const_view;
353 _collectiveBarrier();
359template <
class Type>
void HybridParallelDispatch<Type>::
360gather(Span<const Type> send_buf, Span<Type> recv_buf, Int32 root_rank)
362 UniqueArray<Type> tmp_buf;
363 if (m_global_rank == root_rank)
364 allGather(send_buf, recv_buf);
366 tmp_buf.resize(send_buf.size() * m_global_nb_rank);
367 allGather(send_buf, tmp_buf);
374template <
class Type>
void HybridParallelDispatch<Type>::
375allGatherVariable(Span<const Type> send_buf, Array<Type>& recv_buf)
377 m_const_view = send_buf;
378 _collectiveBarrier();
379 Int64 total_size = 0;
380 for (Integer i = 0; i < m_local_nb_rank; ++i) {
381 total_size += m_all_dispatchs[i]->m_const_view.size();
383 if (m_local_rank == 0) {
385 UniqueArray<Type> local_buf(total_size);
386 for (Integer i = 0; i < m_local_nb_rank; ++i) {
387 Span<const Type> view = m_all_dispatchs[i]->m_const_view;
388 Int64 size = view.size();
389 for (Int64 j = 0; j < size; ++j)
390 local_buf[j + index] = view[j];
393 m_parallel_mng->mpiParallelMng()->allGatherVariable(local_buf, recv_buf);
394 m_const_view = recv_buf.constView();
396 _collectiveBarrier();
397 if (m_local_rank != 0) {
398 Span<const Type> view = m_all_dispatchs[0]->m_const_view;
399 recv_buf.resize(view.size());
402 _collectiveBarrier();
408template <
class Type>
void HybridParallelDispatch<Type>::
409gatherVariable(Span<const Type> send_buf, Array<Type>& recv_buf, Int32 root_rank)
411 UniqueArray<Type> tmp_buf;
412 if (m_global_rank == root_rank)
413 allGatherVariable(send_buf, recv_buf);
415 allGatherVariable(send_buf, tmp_buf);
422void HybridParallelDispatch<Type>::
423scatterVariable(Span<const Type> send_buf, Span<Type> recv_buf, Int32 root)
425 m_const_view = send_buf;
426 m_recv_view = recv_buf;
428 _collectiveBarrier();
431 Int64 total_size = 0;
432 for (Integer i = 0; i < m_local_nb_rank; ++i) {
433 total_size += m_all_dispatchs[i]->m_recv_view.size();
436 _collectiveBarrier();
439 if (m_local_rank == 0) {
440 FullRankInfo fri(FullRankInfo::compute(MessageRank(root), m_local_nb_rank));
442 UniqueArray<Type> local_recv_buf(total_size);
445 if (m_mpi_rank == fri.mpiRankValue()) {
447 m_parallel_mng->mpiParallelMng()->scatterVariable(m_all_dispatchs[fri.localRankValue()]->m_const_view.smallView(),
448 local_recv_buf, fri.mpiRankValue());
453 m_parallel_mng->mpiParallelMng()->scatterVariable(m_const_view.smallView(), local_recv_buf, fri.mpiRankValue());
458 for (Integer i = 0; i < m_local_nb_rank; ++i) {
459 Int64 size = m_all_dispatchs[i]->m_recv_view.size();
460 for (Integer j = 0; j < size; ++j) {
461 m_all_dispatchs[i]->m_recv_view[j] = local_recv_buf[compt++];
465 _collectiveBarrier();
466 recv_buf.copy(m_recv_view);
467 _collectiveBarrier();
473template <
class Type>
void HybridParallelDispatch<Type>::
474allToAll(Span<const Type> send_buf, Span<Type> recv_buf, Int32 count)
476 Int32 global_nb_rank = m_global_nb_rank;
483 for (Integer i = 0; i < global_nb_rank; ++i) {
484 send_indexes[i] = count * i;
485 recv_indexes[i] = count * i;
487 this->allToAllVariable(send_buf, send_count, send_indexes, recv_buf, recv_count, recv_indexes);
493template <
class Type>
void HybridParallelDispatch<Type>::
494allToAllVariable(Span<const Type> g_send_buf,
495 Int32ConstArrayView g_send_count,
496 Int32ConstArrayView g_send_index,
497 Span<Type> g_recv_buf,
498 Int32ConstArrayView g_recv_count,
499 Int32ConstArrayView g_recv_index)
501 m_alltoallv_infos.send_buf = g_send_buf;
502 m_alltoallv_infos.send_count = g_send_count;
503 m_alltoallv_infos.send_index = g_send_index;
504 m_alltoallv_infos.recv_buf = g_recv_buf;
505 m_alltoallv_infos.recv_count = g_recv_count;
506 m_alltoallv_infos.recv_index = g_recv_index;
508 _collectiveBarrier();
510 UniqueArray<Type> tmp_recv_buf;
515 if (m_local_rank == 0) {
518 tmp_send_count.fill(0);
520 tmp_recv_count.fill(0);
522 Int64 total_send_size = 0;
523 Int64 total_recv_size = 0;
525 for (Integer i = 0; i < m_local_nb_rank; ++i) {
526 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
527 total_send_size += vinfo.send_buf.size();
528 total_recv_size += vinfo.recv_buf.size();
531 UniqueArray<Type> tmp_send_buf(total_send_size);
532 tmp_recv_buf.resize(total_recv_size);
535 for (Integer i = 0; i < m_local_nb_rank; ++i) {
536 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
538 for (Integer z = 0; z < m_global_nb_rank; ++z) {
540 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z), m_local_nb_rank));
541 Int32 fri_mpi_rank = fri.mpiRankValue();
543 Int32 nb_send = vinfo.send_count[z];
545 tmp_send_count[fri_mpi_rank] += nb_send;
546 tmp_recv_count[fri_mpi_rank] += vinfo.recv_count[z];
549 info() <<
"my_local=" << i <<
" dest=" << z
550 <<
" send_count=" << vinfo.send_count[z] <<
" send_index=" << vinfo.send_index[z]
551 <<
" recv_count=" << vinfo.recv_count[z] <<
" recv_index=" << vinfo.recv_index[z];
553 Integer vindex = vinfo.send_index[z];
554 for( Integer w=0, wn=vinfo.send_count[z]; w<wn; ++w ){
555 info() <<
"V=" << vinfo.send_buf[ vindex + w ];
564 tmp_send_index[0] = 0;
565 tmp_recv_index[0] = 0;
566 for (Integer k = 1, nmpi = m_mpi_nb_rank; k < nmpi; ++k) {
567 tmp_send_index[k] = tmp_send_index[k - 1] + tmp_send_count[k - 1];
568 tmp_recv_index[k] = tmp_recv_index[k - 1] + tmp_recv_count[k - 1];
571 for (Integer i = 0; i < m_local_nb_rank; ++i) {
572 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
574 for (Integer z = 0; z < m_global_nb_rank; ++z) {
576 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z), m_local_nb_rank));
577 Int32 fri_mpi_rank = fri.mpiRankValue();
579 Integer nb_send = vinfo.send_count[z];
582 Integer tmp_current_index = tmp_send_index[fri_mpi_rank];
583 Integer local_current_index = vinfo.send_index[z];
584 for (Integer j = 0; j < nb_send; ++j)
585 tmp_send_buf[j + tmp_current_index] = vinfo.send_buf[j + local_current_index];
586 tmp_send_index[fri_mpi_rank] += nb_send;
591 tmp_send_index[0] = 0;
592 tmp_recv_index[0] = 0;
593 for (Integer k = 1, nmpi = m_mpi_nb_rank; k < nmpi; ++k) {
594 tmp_send_index[k] = tmp_send_index[k - 1] + tmp_send_count[k - 1];
595 tmp_recv_index[k] = tmp_recv_index[k - 1] + tmp_recv_count[k - 1];
610 info() <<
"AllToAllV nb_send=" << total_send_size <<
" nb_recv=" << total_recv_size;
611 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
612 info() <<
"INFOS Rank=" << k <<
" send_count=" << tmp_send_count[k] <<
" recv_count=" << tmp_recv_count[k]
613 <<
" send_index=" << tmp_send_index[k] <<
" recv_index=" << tmp_recv_index[k];
616 for( Integer i=0; i<tmp_send_buf.size(); ++i )
617 info() <<
"SEND_BUF[" << i <<
"] = " << tmp_send_buf[i];
619 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
620 info() <<
"SEND Rank=" << k <<
" send_count=" << tmp_send_count[k] <<
" recv_count=" << tmp_recv_count[k]
621 <<
" send_index=" << tmp_send_index[k] <<
" recv_index=" << tmp_recv_index[k];
622 Integer vindex = tmp_send_index[k];
623 for( Integer w=0, wn=tmp_send_count[k]; w<wn; ++w ){
624 info() <<
"V=" << tmp_send_buf[ vindex + w ];
629 m_parallel_mng->mpiParallelMng()->allToAllVariable(tmp_send_buf, tmp_send_count,
630 tmp_send_index, tmp_recv_buf,
631 tmp_recv_count, tmp_recv_index);
634 for( Integer i=0; i<tmp_recv_buf.size(); ++i )
635 info() <<
"RECV_BUF[" << i <<
"] = " << tmp_recv_buf[i];
637 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
638 info() <<
"RECV Rank=" << k <<
" send_count=" << tmp_send_count[k] <<
" recv_count=" << tmp_recv_count[k]
639 <<
" send_index=" << tmp_send_index[k] <<
" recv_index=" << tmp_recv_index[k];
640 Integer vindex = tmp_recv_index[k];
641 for( Integer w=0, wn=tmp_recv_count[k]; w<wn; ++w ){
642 info() <<
"V=" << tmp_recv_buf[ vindex + w ];
647 m_const_view = tmp_recv_buf.constView();
649 for (Integer z = 0; z < m_global_nb_rank; ++z) {
650 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z), m_local_nb_rank));
651 Int32 fri_mpi_rank = fri.mpiRankValue();
653 for (Integer i = 0; i < m_local_nb_rank; ++i) {
654 AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
655 Span<Type> my_buf = vinfo.recv_buf;
656 Int64 recv_size = vinfo.recv_count[z];
657 Int64 recv_index = tmp_recv_index[fri_mpi_rank];
659 Span<const Type> recv_view = tmp_recv_buf.span().subSpan(recv_index, recv_size);
661 Int64 my_recv_index = vinfo.recv_index[z];
666 tmp_recv_index[fri_mpi_rank] = CheckedConvert::toInt32(tmp_recv_index[fri_mpi_rank] + recv_size);
668 for (Int64 j = 0; j < recv_size; ++j)
669 my_buf[j + my_recv_index] = recv_view[j];
674 my_recv_index += recv_size;
678 _collectiveBarrier();
688template <
class Type>
auto HybridParallelDispatch<Type>::
689send(Span<const Type> send_buffer, Int32 rank,
bool is_blocked) -> Request
691 eBlockingType block_mode = (is_blocked) ? MP::Blocking : MP::NonBlocking;
692 PointToPointMessageInfo p2p_message(MessageRank(rank), block_mode);
693 return send(send_buffer, p2p_message);
699template <
class Type>
void HybridParallelDispatch<Type>::
700send(ConstArrayView<Type> send_buf, Int32 rank)
702 send(send_buf, rank,
true);
708template <
class Type> Parallel::Request HybridParallelDispatch<Type>::
709receive(Span<Type> recv_buffer, Int32 rank,
bool is_blocked)
711 eBlockingType block_mode = (is_blocked) ? MP::Blocking : MP::NonBlocking;
712 PointToPointMessageInfo p2p_message(MessageRank(rank), block_mode);
713 return receive(recv_buffer, p2p_message);
719template <
class Type> Request HybridParallelDispatch<Type>::
720send(Span<const Type> send_buffer,
const PointToPointMessageInfo& message2)
722 PointToPointMessageInfo message(message2);
723 bool is_blocking = message.isBlocking();
724 message.setEmiterRank(MessageRank(m_global_rank));
725 Request r = m_message_queue->addSend(message, ConstMemoryView(send_buffer));
727 m_message_queue->waitAll(ArrayView<MP::Request>(1, &r));
736template <
class Type> Request HybridParallelDispatch<Type>::
737receive(Span<Type> recv_buffer,
const PointToPointMessageInfo& message2)
739 PointToPointMessageInfo message(message2);
740 message.setEmiterRank(MessageRank(m_global_rank));
741 bool is_blocking = message.isBlocking();
742 Request r = m_message_queue->addReceive(message, ReceiveBufferInfo(MutableMemoryView(recv_buffer)));
744 m_message_queue->waitAll(ArrayView<Request>(1, &r));
753template <
class Type>
void HybridParallelDispatch<Type>::
754recv(ArrayView<Type> recv_buffer, Integer rank)
756 recv(recv_buffer, rank,
true);
762template <
class Type>
void HybridParallelDispatch<Type>::
763sendRecv(ConstArrayView<Type> send_buffer, ArrayView<Type> recv_buffer, Integer proc)
765 ARCANE_UNUSED(send_buffer);
766 ARCANE_UNUSED(recv_buffer);
768 throw NotImplementedException(A_FUNCINFO);
774template <
class Type>
Type HybridParallelDispatch<Type>::
775allReduce(eReduceType op,
Type send_buf)
777 m_reduce_infos.reduce_value = send_buf;
780 _collectiveBarrier();
781 if (m_local_rank == 0) {
782 Type ret = m_all_dispatchs[0]->m_reduce_infos.reduce_value;
784 case Parallel::ReduceMin:
785 for (Integer i = 1; i < m_local_nb_rank; ++i)
786 ret = math::min(ret, m_all_dispatchs[i]->m_reduce_infos.reduce_value);
788 case Parallel::ReduceMax:
789 for (Integer i = 1; i < m_local_nb_rank; ++i)
790 ret = math::max(ret, m_all_dispatchs[i]->m_reduce_infos.reduce_value);
792 case Parallel::ReduceSum:
793 for (Integer i = 1; i < m_local_nb_rank; ++i)
794 ret = (
Type)(ret + m_all_dispatchs[i]->m_reduce_infos.reduce_value);
799 ret = m_parallel_mng->mpiParallelMng()->reduce(op, ret);
800 m_all_dispatchs[0]->m_reduce_infos.reduce_value = ret;
803 _collectiveBarrier();
804 Type ret = m_all_dispatchs[0]->m_reduce_infos.reduce_value;
805 _collectiveBarrier();
812template <
class Type>
void HybridParallelDispatch<Type>::
813_applyReduceOperator(eReduceType op, Span<Type> result, AllDispatchView dispatch_view,
814 Int32 first_rank, Int32 last_rank)
816 Int64 buf_size = result.size();
818 case Parallel::ReduceMin:
819 for (Integer i = first_rank; i <= last_rank; ++i)
820 for (Int64 j = 0; j < buf_size; ++j)
821 result[j] = math::min(result[j], dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
823 case Parallel::ReduceMax:
824 for (Integer i = first_rank; i <= last_rank; ++i)
825 for (Int64 j = 0; j < buf_size; ++j)
826 result[j] = math::max(result[j], dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
828 case Parallel::ReduceSum:
829 for (Integer i = first_rank; i <= last_rank; ++i)
830 for (Integer j = 0; j < buf_size; ++j) {
831 result[j] =
static_cast<Type>(result[j] + dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
842template <
class Type>
void HybridParallelDispatch<Type>::
843_allReduceOrScan(eReduceType op, Span<Type> send_buf,
bool is_scan)
845 m_reduce_infos.reduce_buf_span = send_buf;
846 ++m_reduce_infos.m_index;
847 Int64 buf_size = send_buf.size();
848 UniqueArray<Type> ret(buf_size);
850 UniqueArray<Type> previous_rank_ret;
851 MpiParallelMng* mpi_pm = m_parallel_mng->mpiParallelMng();
852 Int32 my_mpi_rank = mpi_pm->commRank();
853 Int32 mpi_nb_rank = mpi_pm->commSize();
857 _collectiveBarrier();
859 Integer index0 = m_all_dispatchs[0]->m_reduce_infos.m_index;
860 for (Integer i = 0; i < m_local_nb_rank; ++i) {
861 Integer indexi = m_all_dispatchs[i]->m_reduce_infos.m_index;
862 if (index0 != m_all_dispatchs[i]->m_reduce_infos.m_index) {
863 ARCANE_FATAL(
"INTERNAL: incoherent all reduce i0={0} in={1} n={2}",
869 if (m_local_rank == 0) {
870 const Int32 nb_local_rank = m_local_nb_rank;
871 for (Integer j = 0; j < buf_size; ++j)
872 ret[j] = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span[j];
873 _applyReduceOperator(op, ret, m_all_dispatchs, 1, nb_local_rank - 1);
877 mpi_pm->scan(op, ret);
878 previous_rank_ret.resize(buf_size);
879 UniqueArray<Request> requests;
880 if (my_mpi_rank != 0)
881 requests.add(mpi_pm->recv(previous_rank_ret, my_mpi_rank - 1,
false));
882 if (my_mpi_rank != (mpi_nb_rank - 1))
883 requests.add(mpi_pm->send(ret, my_mpi_rank + 1,
false));
884 mpi_pm->waitAllRequests(requests);
885 if (my_mpi_rank != 0) {
887 _applyReduceOperator(op, previous_rank_ret, m_all_dispatchs, 0, 0);
888 send_buf.copy(previous_rank_ret);
896 mpi_pm->reduce(op, ret);
901 _collectiveBarrier();
904 if (m_local_rank != 0) {
905 Span<const Type> global_buf = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span;
906 ret.copy(global_buf);
908 _applyReduceOperator(op, ret, m_all_dispatchs, 1, m_local_rank);
912 _collectiveBarrier();
914 if (m_local_rank != 0) {
919 if (m_local_rank != 0) {
920 Span<const Type> global_buf = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span;
921 send_buf.copy(global_buf);
925 _collectiveBarrier();
931template <
class Type>
void HybridParallelDispatch<Type>::
932allReduce(eReduceType op, Span<Type> send_buf)
934 _allReduceOrScan(op, send_buf,
false);
940template <
class Type> Request HybridParallelDispatch<Type>::
941nonBlockingAllReduce(eReduceType op, Span<const Type> send_buf, Span<Type> recv_buf)
944 ARCANE_UNUSED(send_buf);
945 ARCANE_UNUSED(recv_buf);
946 throw NotImplementedException(A_FUNCINFO);
951template <
class Type> Request HybridParallelDispatch<Type>::
952nonBlockingAllGather(Span<const Type> send_buf, Span<Type> recv_buf)
954 ARCANE_UNUSED(send_buf);
955 ARCANE_UNUSED(recv_buf);
956 throw NotImplementedException(A_FUNCINFO);
962template <
class Type> Request HybridParallelDispatch<Type>::
963nonBlockingBroadcast(Span<Type> send_buf, Int32 rank)
965 ARCANE_UNUSED(send_buf);
967 throw NotImplementedException(A_FUNCINFO);
973template <
class Type> Request HybridParallelDispatch<Type>::
974nonBlockingGather(Span<const Type> send_buf, Span<Type> recv_buf, Int32 rank)
976 ARCANE_UNUSED(send_buf);
977 ARCANE_UNUSED(recv_buf);
979 throw NotImplementedException(A_FUNCINFO);
985template <
class Type> Request HybridParallelDispatch<Type>::
986nonBlockingAllToAll(Span<const Type> send_buf, Span<Type> recv_buf, Int32 count)
988 ARCANE_UNUSED(send_buf);
989 ARCANE_UNUSED(recv_buf);
990 ARCANE_UNUSED(count);
991 throw NotImplementedException(A_FUNCINFO);
997template <
class Type> Request HybridParallelDispatch<Type>::
998nonBlockingAllToAllVariable(Span<const Type> send_buf, ConstArrayView<Int32> send_count,
999 ConstArrayView<Int32> send_index, Span<Type> recv_buf,
1000 ConstArrayView<Int32> recv_count, ConstArrayView<Int32> recv_index)
1002 ARCANE_UNUSED(send_buf);
1003 ARCANE_UNUSED(recv_buf);
1004 ARCANE_UNUSED(send_count);
1005 ARCANE_UNUSED(recv_count);
1006 ARCANE_UNUSED(send_index);
1007 ARCANE_UNUSED(recv_index);
1008 throw NotImplementedException(A_FUNCINFO);
1014template <
class Type>
Type HybridParallelDispatch<Type>::
1015scan(eReduceType op,
Type send_buf)
1018 ARCANE_UNUSED(send_buf);
1019 throw NotImplementedException(A_FUNCINFO);
1025template <
class Type>
void HybridParallelDispatch<Type>::
1026scan(eReduceType op, ArrayView<Type> send_buf)
1028 _allReduceOrScan(op, send_buf,
true);
1034template <
class Type> Request HybridParallelDispatch<Type>::
1037 throw NotImplementedException(A_FUNCINFO);
1043template <
class Type>
void HybridParallelDispatch<Type>::
1046 m_parallel_mng->getThreadBarrier()->wait();
1052template class HybridParallelDispatch<char>;
1053template class HybridParallelDispatch<signed char>;
1054template class HybridParallelDispatch<unsigned char>;
1055template class HybridParallelDispatch<short>;
1056template class HybridParallelDispatch<unsigned short>;
1057template class HybridParallelDispatch<int>;
1058template class HybridParallelDispatch<unsigned int>;
1059template class HybridParallelDispatch<long>;
1060template class HybridParallelDispatch<unsigned long>;
1061template class HybridParallelDispatch<long long>;
1062template class HybridParallelDispatch<unsigned long long>;
1063template class HybridParallelDispatch<float>;
1064template class HybridParallelDispatch<double>;
1065template class HybridParallelDispatch<long double>;
1066template class HybridParallelDispatch<Real2>;
1067template class HybridParallelDispatch<Real3>;
1068template class HybridParallelDispatch<Real2x2>;
1069template class HybridParallelDispatch<Real3x3>;
1070template class HybridParallelDispatch<HPReal>;
1071template class HybridParallelDispatch<APReal>;
#define ARCANE_FATAL(...)
Macro throwing a FatalErrorException.
Modifiable view of an array of type T.
Class implementing a High-Precision real number.
Brief information for a 'gather' message for data type DataType.
Interface for a message queue with threads.
Message interface for type Type.
Thread-based parallelism manager.
Exception when a function is not implemented.
Declarations of types and methods used by message exchange mechanisms.
eBlockingType
Type indicating whether a message is blocking or not.
Int32 Integer
Type representing an integer.
UniqueArray< Int32 > Int32UniqueArray
Dynamic 1D array of 32-bit integers.
std::int32_t Int32
Signed integer type of 32 bits.
Structure equivalent to the boolean value true.
Structure equivalent to the boolean value true.