Arcane  4.1.12.0
Developer documentation
Loading...
Searching...
No Matches
HybridParallelDispatch.cc
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* MpiParallelDispatch.cc (C) 2000-2024 */
9/* */
10/* Parallelism manager using threads and MPI. */
11/*---------------------------------------------------------------------------*/
12/*---------------------------------------------------------------------------*/
13
14#include "arcane/utils/Array.h"
15#include "arcane/utils/PlatformUtils.h"
16#include "arcane/utils/String.h"
17#include "arcane/utils/ITraceMng.h"
18#include "arcane/utils/Real2.h"
19#include "arcane/utils/Real3.h"
20#include "arcane/utils/Real2x2.h"
21#include "arcane/utils/Real3x3.h"
22#include "arcane/utils/APReal.h"
23#include "arcane/utils/FatalErrorException.h"
24#include "arcane/utils/NotImplementedException.h"
25#include "arcane/utils/NotSupportedException.h"
26#include "arcane/utils/IThreadBarrier.h"
27#include "arcane/utils/CheckedConvert.h"
28
29#include "arcane/core/MeshVariableRef.h"
30#include "arcane/core/IParallelMng.h"
31#include "arcane/core/ItemGroup.h"
32#include "arcane/core/IMesh.h"
33#include "arcane/core/IBase.h"
34
35#include "arcane/parallel/mpithread/HybridParallelDispatch.h"
36#include "arcane/parallel/mpithread/HybridParallelMng.h"
37#include "arcane/parallel/mpithread/HybridMessageQueue.h"
38#include "arcane/parallel/mpi/MpiParallelMng.h"
39#include "arcane/parallel/mpi/MpiParallelDispatch.h"
40
41/*---------------------------------------------------------------------------*/
42/*---------------------------------------------------------------------------*/
43
45{
46
47/*---------------------------------------------------------------------------*/
48/*---------------------------------------------------------------------------*/
49
50//TODO: Merge with what is possible in SharedMemoryParallelDispatch
51
52/*---------------------------------------------------------------------------*/
53/*---------------------------------------------------------------------------*/
54
55template <class Type> HybridParallelDispatch<Type>::
56HybridParallelDispatch(ITraceMng* tm, HybridParallelMng* pm, HybridMessageQueue* message_queue,
57 ArrayView<HybridParallelDispatch<Type>*> all_dispatchs)
58: TraceAccessor(tm)
59, m_parallel_mng(pm)
60, m_local_rank(pm->localRank())
61, m_local_nb_rank(pm->localNbRank())
62, m_global_rank(pm->commRank())
63, m_global_nb_rank(pm->commSize())
64, m_mpi_rank(pm->mpiParallelMng()->commRank())
65, m_mpi_nb_rank(pm->mpiParallelMng()->commSize())
66, m_all_dispatchs(all_dispatchs)
67, m_message_queue(message_queue)
68, m_mpi_dispatcher(0)
69{
70 m_reduce_infos.m_index = 0;
71
72 // This array was sized by the creator of this instance.
73 // We just need to update the value corresponding to its rank
74 m_all_dispatchs[m_local_rank] = this;
75
76 // Retrieves the MPI dispatcher for this type.
77 MpiParallelMng* mpi_pm = pm->mpiParallelMng();
78 IParallelDispatchT<Type>* pd = mpi_pm->dispatcher((Type*)nullptr);
79 if (!pd)
80 ARCANE_FATAL("null dispatcher");
81
82 m_mpi_dispatcher = dynamic_cast<MpiParallelDispatchT<Type>*>(pd);
83 if (!m_mpi_dispatcher)
84 ARCANE_FATAL("null mpi dispatcher");
85}
86
87/*---------------------------------------------------------------------------*/
88/*---------------------------------------------------------------------------*/
89
90template <class Type> HybridParallelDispatch<Type>::
91~HybridParallelDispatch()
92{
93 finalize();
94}
95
96/*---------------------------------------------------------------------------*/
97/*---------------------------------------------------------------------------*/
98
99template <class Type> void HybridParallelDispatch<Type>::
100finalize()
101{
102}
103
104/*---------------------------------------------------------------------------*/
105/*---------------------------------------------------------------------------*/
106
107template <typename T>
109{
110 public:
111
112 typedef FalseType IsIntegral;
113};
114
115#define ARCANE_DEFINE_INTEGRAL_TYPE(datatype) \
116 template <> \
117 class _ThreadIntegralType<datatype> \
118 { \
119 public: \
120\
121 typedef TrueType IsIntegral; \
122 }
123
124ARCANE_DEFINE_INTEGRAL_TYPE(long long);
125ARCANE_DEFINE_INTEGRAL_TYPE(long);
126ARCANE_DEFINE_INTEGRAL_TYPE(int);
127ARCANE_DEFINE_INTEGRAL_TYPE(short);
128ARCANE_DEFINE_INTEGRAL_TYPE(unsigned long long);
129ARCANE_DEFINE_INTEGRAL_TYPE(unsigned long);
130ARCANE_DEFINE_INTEGRAL_TYPE(unsigned int);
131ARCANE_DEFINE_INTEGRAL_TYPE(unsigned short);
132ARCANE_DEFINE_INTEGRAL_TYPE(double);
133ARCANE_DEFINE_INTEGRAL_TYPE(float);
134ARCANE_DEFINE_INTEGRAL_TYPE(HPReal);
135
136/*---------------------------------------------------------------------------*/
137/*---------------------------------------------------------------------------*/
138
139namespace
140{
141
142 template <class Type> void
143 _computeMinMaxSum2(ArrayView<HybridParallelDispatch<Type>*> all_dispatchs,
144 Int32 my_rank, Type& min_val, Type& max_val, Type& sum_val,
145 Int32& min_rank, Int32& max_rank, Int32 nb_rank, FalseType)
146 {
147 ARCANE_UNUSED(all_dispatchs);
148 ARCANE_UNUSED(my_rank);
149 ARCANE_UNUSED(min_val);
150 ARCANE_UNUSED(max_val);
151 ARCANE_UNUSED(sum_val);
152 ARCANE_UNUSED(min_rank);
153 ARCANE_UNUSED(max_rank);
154 ARCANE_UNUSED(nb_rank);
155
156 throw NotImplementedException(A_FUNCINFO);
157 }
158
159 /*---------------------------------------------------------------------------*/
160 /*---------------------------------------------------------------------------*/
161
162 template <class Type> void
163 _computeMinMaxSum2(ArrayView<HybridParallelDispatch<Type>*> all_dispatchs,
164 Int32 my_rank, Type& min_val, Type& max_val, Type& sum_val,
165 Int32& min_rank, Int32& max_rank, Int32 nb_rank, TrueType)
166 {
167 ARCANE_UNUSED(my_rank);
168
169 HybridParallelDispatch<Type>* mtpd0 = all_dispatchs[0];
170 Type cval0 = mtpd0->m_reduce_infos.reduce_value;
171 Type _min_val = cval0;
172 Type _max_val = cval0;
173 Type _sum_val = cval0;
174 Integer _min_rank = 0;
175 Integer _max_rank = 0;
176 for (Integer i = 1; i < nb_rank; ++i) {
177 HybridParallelDispatch<Type>* mtpd = all_dispatchs[i];
178 Type cval = mtpd->m_reduce_infos.reduce_value;
179 Int32 grank = mtpd->globalRank();
180 if (cval < _min_val) {
181 _min_val = cval;
182 _min_rank = grank;
183 }
184 if (_max_val < cval) {
185 _max_val = cval;
186 _max_rank = grank;
187 }
188 _sum_val = (Type)(_sum_val + cval);
189 }
190 min_val = _min_val;
191 max_val = _max_val;
192 sum_val = _sum_val;
193 min_rank = _min_rank;
194 max_rank = _max_rank;
195 }
196
197} // namespace
198
199/*---------------------------------------------------------------------------*/
200/*---------------------------------------------------------------------------*/
201
202template <class Type> void HybridParallelDispatch<Type>::
203computeMinMaxSum(Type val, Type& min_val, Type& max_val, Type& sum_val,
204 Int32& min_rank, Int32& max_rank)
205{
206 typedef typename _ThreadIntegralType<Type>::IsIntegral IntegralType;
207 m_reduce_infos.reduce_value = val;
208 _collectiveBarrier();
209 _computeMinMaxSum2(m_all_dispatchs, m_global_rank, min_val, max_val, sum_val, min_rank, max_rank, m_local_nb_rank, IntegralType());
210 if (m_local_rank == 0) {
211 /*pinfo() << "COMPUTE_MIN_MAX_SUM_B rank=" << m_global_rank
212 << " min_rank=" << min_rank
213 << " max_rank=" << max_rank
214 << " min_val=" << min_val
215 << " max_val=" << max_val
216 << " sum_val=" << sum_val;*/
217 m_mpi_dispatcher->computeMinMaxSumNoInit(min_val, max_val, sum_val, min_rank, max_rank);
218 /*pinfo() << "COMPUTE_MIN_MAX_SUM_A rank=" << m_global_rank
219 << " min_rank=" << min_rank
220 << " max_rank=" << max_rank;*/
221
222 m_min_max_sum_infos.m_min_value = min_val;
223 m_min_max_sum_infos.m_max_value = max_val;
224 m_min_max_sum_infos.m_sum_value = sum_val;
225 m_min_max_sum_infos.m_min_rank = min_rank;
226 m_min_max_sum_infos.m_max_rank = max_rank;
227 }
228 _collectiveBarrier();
229 m_min_max_sum_infos = m_all_dispatchs[0]->m_min_max_sum_infos;
230 min_val = m_min_max_sum_infos.m_min_value;
231 max_val = m_min_max_sum_infos.m_max_value;
232 sum_val = m_min_max_sum_infos.m_sum_value;
233 min_rank = m_min_max_sum_infos.m_min_rank;
234 max_rank = m_min_max_sum_infos.m_max_rank;
235 _collectiveBarrier();
236}
237
238/*---------------------------------------------------------------------------*/
239/*---------------------------------------------------------------------------*/
240
241template <class Type> void HybridParallelDispatch<Type>::
242computeMinMaxSum(ConstArrayView<Type> values,
243 ArrayView<Type> min_values,
244 ArrayView<Type> max_values,
245 ArrayView<Type> sum_values,
246 ArrayView<Int32> min_ranks,
247 ArrayView<Int32> max_ranks)
248{
249 // Sub-optimal implementation that does not vectorize the calculation
250 // (it is currently a copy-paste of the above put into a loop)
251 typedef typename _ThreadIntegralType<Type>::IsIntegral IntegralType;
252 Integer n = values.size();
253 for (Integer i = 0; i < n; ++i) {
254 m_reduce_infos.reduce_value = values[i];
255 _collectiveBarrier();
256 _computeMinMaxSum2(m_all_dispatchs, m_global_rank, min_values[i], max_values[i], sum_values[i], min_ranks[i], max_ranks[i], m_local_nb_rank, IntegralType());
257 if (m_local_rank == 0) {
258 /*pinfo() << "COMPUTE_MIN_MAX_SUM_B rank=" << m_global_rank
259 << " min_rank=" << min_rank
260 << " max_rank=" << max_rank
261 << " min_val=" << min_val
262 << " max_val=" << max_val
263 << " sum_val=" << sum_val;*/
264 m_mpi_dispatcher->computeMinMaxSumNoInit(min_values[i], max_values[i], sum_values[i], min_ranks[i], max_ranks[i]);
265 /*pinfo() << "COMPUTE_MIN_MAX_SUM_A rank=" << m_global_rank
266 << " min_rank=" << min_rank
267 << " max_rank=" << max_rank;*/
268
269 m_min_max_sum_infos.m_min_value = min_values[i];
270 m_min_max_sum_infos.m_max_value = max_values[i];
271 m_min_max_sum_infos.m_sum_value = sum_values[i];
272 m_min_max_sum_infos.m_min_rank = min_ranks[i];
273 m_min_max_sum_infos.m_max_rank = max_ranks[i];
274 }
275 _collectiveBarrier();
276 m_min_max_sum_infos = m_all_dispatchs[0]->m_min_max_sum_infos;
277 min_values[i] = m_min_max_sum_infos.m_min_value;
278 max_values[i] = m_min_max_sum_infos.m_max_value;
279 sum_values[i] = m_min_max_sum_infos.m_sum_value;
280 min_ranks[i] = m_min_max_sum_infos.m_min_rank;
281 max_ranks[i] = m_min_max_sum_infos.m_max_rank;
282 _collectiveBarrier();
283 }
284}
285
286/*---------------------------------------------------------------------------*/
287/*---------------------------------------------------------------------------*/
288
289template <class Type> void HybridParallelDispatch<Type>::
290broadcast(Span<Type> send_buf, Int32 rank)
291{
292 m_broadcast_view = send_buf;
293 _collectiveBarrier();
294 FullRankInfo fri = FullRankInfo::compute(MP::MessageRank(rank), m_local_nb_rank);
295 int mpi_rank = fri.mpiRankValue();
296 if (m_mpi_rank == mpi_rank) {
297 // I have the same MPI rank as the one doing the broadcast
298 if (m_global_rank == rank) {
299 //TODO: 64bit passage.
300 m_parallel_mng->mpiParallelMng()->broadcast(send_buf.smallView(), mpi_rank);
301 }
302 else {
303 m_all_dispatchs[m_local_rank]->m_broadcast_view.copy(m_all_dispatchs[fri.localRankValue()]->m_broadcast_view);
304 }
305 }
306 else {
307 if (m_local_rank == 0) {
308 //TODO: 64bit passage.
309 m_parallel_mng->mpiParallelMng()->broadcast(send_buf.smallView(), mpi_rank);
310 }
311 }
312 _collectiveBarrier();
313 if (m_mpi_rank != mpi_rank) {
314 if (m_local_rank != 0)
315 m_all_dispatchs[m_local_rank]->m_broadcast_view.copy(m_all_dispatchs[0]->m_broadcast_view);
316 }
317 _collectiveBarrier();
318}
319
320/*---------------------------------------------------------------------------*/
321/*---------------------------------------------------------------------------*/
322
323template <class Type> void HybridParallelDispatch<Type>::
324allGather(Span<const Type> send_buf, Span<Type> recv_buf)
325{
326 //TODO: merge with allGatherVariable()
327 m_const_view = send_buf;
328 _collectiveBarrier();
329 Int64 total_size = 0;
330 for (Int32 i = 0; i < m_local_nb_rank; ++i) {
331 total_size += m_all_dispatchs[i]->m_const_view.size();
332 }
333 if (m_local_rank == 0) {
334 Int64 index = 0;
335 UniqueArray<Type> local_buf(total_size);
336 for (Integer i = 0; i < m_local_nb_rank; ++i) {
337 Span<const Type> view = m_all_dispatchs[i]->m_const_view;
338 Int64 size = view.size();
339 for (Int64 j = 0; j < size; ++j)
340 local_buf[j + index] = view[j];
341 index += size;
342 }
343 IParallelMng* pm = m_parallel_mng->mpiParallelMng();
344 //TODO: 64bit
345 pm->allGather(local_buf, recv_buf.smallView());
346 m_const_view = recv_buf;
347 }
348 _collectiveBarrier();
349 if (m_local_rank != 0) {
350 Span<const Type> view = m_all_dispatchs[0]->m_const_view;
351 recv_buf.copy(view);
352 }
353 _collectiveBarrier();
354}
355
356/*---------------------------------------------------------------------------*/
357/*---------------------------------------------------------------------------*/
358
359template <class Type> void HybridParallelDispatch<Type>::
360gather(Span<const Type> send_buf, Span<Type> recv_buf, Int32 root_rank)
361{
362 UniqueArray<Type> tmp_buf;
363 if (m_global_rank == root_rank)
364 allGather(send_buf, recv_buf);
365 else {
366 tmp_buf.resize(send_buf.size() * m_global_nb_rank);
367 allGather(send_buf, tmp_buf);
368 }
369}
370
371/*---------------------------------------------------------------------------*/
372/*---------------------------------------------------------------------------*/
373
374template <class Type> void HybridParallelDispatch<Type>::
375allGatherVariable(Span<const Type> send_buf, Array<Type>& recv_buf)
376{
377 m_const_view = send_buf;
378 _collectiveBarrier();
379 Int64 total_size = 0;
380 for (Integer i = 0; i < m_local_nb_rank; ++i) {
381 total_size += m_all_dispatchs[i]->m_const_view.size();
382 }
383 if (m_local_rank == 0) {
384 Int64 index = 0;
385 UniqueArray<Type> local_buf(total_size);
386 for (Integer i = 0; i < m_local_nb_rank; ++i) {
387 Span<const Type> view = m_all_dispatchs[i]->m_const_view;
388 Int64 size = view.size();
389 for (Int64 j = 0; j < size; ++j)
390 local_buf[j + index] = view[j];
391 index += size;
392 }
393 m_parallel_mng->mpiParallelMng()->allGatherVariable(local_buf, recv_buf);
394 m_const_view = recv_buf.constView();
395 }
396 _collectiveBarrier();
397 if (m_local_rank != 0) {
398 Span<const Type> view = m_all_dispatchs[0]->m_const_view;
399 recv_buf.resize(view.size());
400 recv_buf.copy(view);
401 }
402 _collectiveBarrier();
403}
404
405/*---------------------------------------------------------------------------*/
406/*---------------------------------------------------------------------------*/
407
408template <class Type> void HybridParallelDispatch<Type>::
409gatherVariable(Span<const Type> send_buf, Array<Type>& recv_buf, Int32 root_rank)
410{
411 UniqueArray<Type> tmp_buf;
412 if (m_global_rank == root_rank)
413 allGatherVariable(send_buf, recv_buf);
414 else
415 allGatherVariable(send_buf, tmp_buf);
416}
417
418/*---------------------------------------------------------------------------*/
419/*---------------------------------------------------------------------------*/
420
421template <class Type>
422void HybridParallelDispatch<Type>::
423scatterVariable(Span<const Type> send_buf, Span<Type> recv_buf, Int32 root)
424{
425 m_const_view = send_buf;
426 m_recv_view = recv_buf;
427
428 _collectiveBarrier();
429
430 // We calculate the number of elements that all threads in our process want.
431 Int64 total_size = 0;
432 for (Integer i = 0; i < m_local_nb_rank; ++i) {
433 total_size += m_all_dispatchs[i]->m_recv_view.size();
434 }
435
436 _collectiveBarrier();
437
438 // MPI exchanges are performed only by the leader threads of the processes.
439 if (m_local_rank == 0) {
440 FullRankInfo fri(FullRankInfo::compute(MessageRank(root), m_local_nb_rank));
441
442 UniqueArray<Type> local_recv_buf(total_size);
443
444 // If the "root" thread is in our process.
445 if (m_mpi_rank == fri.mpiRankValue()) {
446 // The leader thread handles the exchange.
447 m_parallel_mng->mpiParallelMng()->scatterVariable(m_all_dispatchs[fri.localRankValue()]->m_const_view.smallView(),
448 local_recv_buf, fri.mpiRankValue());
449 }
450 // The other leader threads provide their send buffers (it doesn't matter what
451 // they contain, it's a scatter).
452 else {
453 m_parallel_mng->mpiParallelMng()->scatterVariable(m_const_view.smallView(), local_recv_buf, fri.mpiRankValue());
454 }
455
456 // We just need to distribute the received data among the threads.
457 Integer compt = 0;
458 for (Integer i = 0; i < m_local_nb_rank; ++i) {
459 Int64 size = m_all_dispatchs[i]->m_recv_view.size();
460 for (Integer j = 0; j < size; ++j) {
461 m_all_dispatchs[i]->m_recv_view[j] = local_recv_buf[compt++];
462 }
463 }
464 }
465 _collectiveBarrier();
466 recv_buf.copy(m_recv_view);
467 _collectiveBarrier();
468}
469
470/*---------------------------------------------------------------------------*/
471/*---------------------------------------------------------------------------*/
472
473template <class Type> void HybridParallelDispatch<Type>::
474allToAll(Span<const Type> send_buf, Span<Type> recv_buf, Int32 count)
475{
476 Int32 global_nb_rank = m_global_nb_rank;
477 //TODO: Faire une version sans allocation
478 Int32UniqueArray send_count(global_nb_rank, count);
479 Int32UniqueArray recv_count(global_nb_rank, count);
480
481 Int32UniqueArray send_indexes(global_nb_rank);
482 Int32UniqueArray recv_indexes(global_nb_rank);
483 for (Integer i = 0; i < global_nb_rank; ++i) {
484 send_indexes[i] = count * i;
485 recv_indexes[i] = count * i;
486 }
487 this->allToAllVariable(send_buf, send_count, send_indexes, recv_buf, recv_count, recv_indexes);
488}
489
490/*---------------------------------------------------------------------------*/
491/*---------------------------------------------------------------------------*/
492
493template <class Type> void HybridParallelDispatch<Type>::
494allToAllVariable(Span<const Type> g_send_buf,
495 Int32ConstArrayView g_send_count,
496 Int32ConstArrayView g_send_index,
497 Span<Type> g_recv_buf,
498 Int32ConstArrayView g_recv_count,
499 Int32ConstArrayView g_recv_index)
500{
501 m_alltoallv_infos.send_buf = g_send_buf;
502 m_alltoallv_infos.send_count = g_send_count;
503 m_alltoallv_infos.send_index = g_send_index;
504 m_alltoallv_infos.recv_buf = g_recv_buf;
505 m_alltoallv_infos.recv_count = g_recv_count;
506 m_alltoallv_infos.recv_index = g_recv_index;
507
508 _collectiveBarrier();
509
510 UniqueArray<Type> tmp_recv_buf;
511
512 // FIRST IMPLEMENTATION
513 // Local rank 0 process does all the work.
514
515 if (m_local_rank == 0) {
516
517 Int32UniqueArray tmp_send_count(m_mpi_nb_rank);
518 tmp_send_count.fill(0);
519 Int32UniqueArray tmp_recv_count(m_mpi_nb_rank);
520 tmp_recv_count.fill(0);
521
522 Int64 total_send_size = 0;
523 Int64 total_recv_size = 0;
524
525 for (Integer i = 0; i < m_local_nb_rank; ++i) {
526 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
527 total_send_size += vinfo.send_buf.size();
528 total_recv_size += vinfo.recv_buf.size();
529 }
530
531 UniqueArray<Type> tmp_send_buf(total_send_size);
532 tmp_recv_buf.resize(total_recv_size);
533
534 // We calculate the number of elements to send and receive for each proc.
535 for (Integer i = 0; i < m_local_nb_rank; ++i) {
536 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
537
538 for (Integer z = 0; z < m_global_nb_rank; ++z) {
539
540 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z), m_local_nb_rank));
541 Int32 fri_mpi_rank = fri.mpiRankValue();
542
543 Int32 nb_send = vinfo.send_count[z];
544
545 tmp_send_count[fri_mpi_rank] += nb_send;
546 tmp_recv_count[fri_mpi_rank] += vinfo.recv_count[z];
547
548#if 0
549 info() << "my_local=" << i << " dest=" << z
550 << " send_count=" << vinfo.send_count[z] << " send_index=" << vinfo.send_index[z]
551 << " recv_count=" << vinfo.recv_count[z] << " recv_index=" << vinfo.recv_index[z];
552 {
553 Integer vindex = vinfo.send_index[z];
554 for( Integer w=0, wn=vinfo.send_count[z]; w<wn; ++w ){
555 info() << "V=" << vinfo.send_buf[ vindex + w ];
556 }
557 }
558#endif
559 }
560 }
561
562 Int32UniqueArray tmp_send_index(m_mpi_nb_rank);
563 Int32UniqueArray tmp_recv_index(m_mpi_nb_rank);
564 tmp_send_index[0] = 0;
565 tmp_recv_index[0] = 0;
566 for (Integer k = 1, nmpi = m_mpi_nb_rank; k < nmpi; ++k) {
567 tmp_send_index[k] = tmp_send_index[k - 1] + tmp_send_count[k - 1];
568 tmp_recv_index[k] = tmp_recv_index[k - 1] + tmp_recv_count[k - 1];
569 }
570
571 for (Integer i = 0; i < m_local_nb_rank; ++i) {
572 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
573
574 for (Integer z = 0; z < m_global_nb_rank; ++z) {
575
576 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z), m_local_nb_rank));
577 Int32 fri_mpi_rank = fri.mpiRankValue();
578
579 Integer nb_send = vinfo.send_count[z];
580 {
581
582 Integer tmp_current_index = tmp_send_index[fri_mpi_rank];
583 Integer local_current_index = vinfo.send_index[z];
584 for (Integer j = 0; j < nb_send; ++j)
585 tmp_send_buf[j + tmp_current_index] = vinfo.send_buf[j + local_current_index];
586 tmp_send_index[fri_mpi_rank] += nb_send;
587 }
588 }
589 }
590
591 tmp_send_index[0] = 0;
592 tmp_recv_index[0] = 0;
593 for (Integer k = 1, nmpi = m_mpi_nb_rank; k < nmpi; ++k) {
594 tmp_send_index[k] = tmp_send_index[k - 1] + tmp_send_count[k - 1];
595 tmp_recv_index[k] = tmp_recv_index[k - 1] + tmp_recv_count[k - 1];
596 }
597
598 /* Integer send_index = 0;
599 for( Integer i=0; i<m_local_nb_rank; ++i ){
600 ConstArrayView<Type> send_view = m_all_dispatchs[i]->m_alltoallv_infos.send_buf;
601 Integer send_size = send_view.size();
602 info() << "ADD_TMP_SEND_BUF send_index=" << send_index << " size=" << send_size;
603 for( Integer j=0; j<send_size; ++j )
604 tmp_send_buf[j+send_index] = send_view[j];
605 send_index += send_size;
606 }
607 */
608
609#if 0
610 info() << "AllToAllV nb_send=" << total_send_size << " nb_recv=" << total_recv_size;
611 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
612 info() << "INFOS Rank=" << k << " send_count=" << tmp_send_count[k] << " recv_count=" << tmp_recv_count[k]
613 << " send_index=" << tmp_send_index[k] << " recv_index=" << tmp_recv_index[k];
614 }
615
616 for( Integer i=0; i<tmp_send_buf.size(); ++i )
617 info() << "SEND_BUF[" << i << "] = " << tmp_send_buf[i];
618
619 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
620 info() << "SEND Rank=" << k << " send_count=" << tmp_send_count[k] << " recv_count=" << tmp_recv_count[k]
621 << " send_index=" << tmp_send_index[k] << " recv_index=" << tmp_recv_index[k];
622 Integer vindex = tmp_send_index[k];
623 for( Integer w=0, wn=tmp_send_count[k]; w<wn; ++w ){
624 info() << "V=" << tmp_send_buf[ vindex + w ];
625 }
626 }
627#endif
628
629 m_parallel_mng->mpiParallelMng()->allToAllVariable(tmp_send_buf, tmp_send_count,
630 tmp_send_index, tmp_recv_buf,
631 tmp_recv_count, tmp_recv_index);
632
633#if 0
634 for( Integer i=0; i<tmp_recv_buf.size(); ++i )
635 info() << "RECV_BUF[" << i << "] = " << tmp_recv_buf[i];
636
637 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
638 info() << "RECV Rank=" << k << " send_count=" << tmp_send_count[k] << " recv_count=" << tmp_recv_count[k]
639 << " send_index=" << tmp_send_index[k] << " recv_index=" << tmp_recv_index[k];
640 Integer vindex = tmp_recv_index[k];
641 for( Integer w=0, wn=tmp_recv_count[k]; w<wn; ++w ){
642 info() << "V=" << tmp_recv_buf[ vindex + w ];
643 }
644 }
645#endif
646
647 m_const_view = tmp_recv_buf.constView();
648
649 for (Integer z = 0; z < m_global_nb_rank; ++z) {
650 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z), m_local_nb_rank));
651 Int32 fri_mpi_rank = fri.mpiRankValue();
652
653 for (Integer i = 0; i < m_local_nb_rank; ++i) {
654 AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
655 Span<Type> my_buf = vinfo.recv_buf;
656 Int64 recv_size = vinfo.recv_count[z];
657 Int64 recv_index = tmp_recv_index[fri_mpi_rank];
658
659 Span<const Type> recv_view = tmp_recv_buf.span().subSpan(recv_index, recv_size);
660
661 Int64 my_recv_index = vinfo.recv_index[z];
662
663 //info() << "GET i=" << i << " z=" << z << " size=" << recv_size << " index=" << recv_index
664 // << " mpi_rank=" << fri_mpi_rank << " my_index=" << my_recv_index;
665
666 tmp_recv_index[fri_mpi_rank] = CheckedConvert::toInt32(tmp_recv_index[fri_mpi_rank] + recv_size);
667
668 for (Int64 j = 0; j < recv_size; ++j)
669 my_buf[j + my_recv_index] = recv_view[j];
670
671 //for( Integer j=0; j<recv_size; ++j )
672 //info() << "V=" << recv_view[j];
673
674 my_recv_index += recv_size;
675 }
676 }
677 }
678 _collectiveBarrier();
679
680 //info() << "END_PHASE_1_ALL_TO_ALL_V my_rank=" << m_global_rank << " (local=" << m_local_rank << ")";
681
682 //_collectiveBarrier();
683}
684
685/*---------------------------------------------------------------------------*/
686/*---------------------------------------------------------------------------*/
687
688template <class Type> auto HybridParallelDispatch<Type>::
689send(Span<const Type> send_buffer, Int32 rank, bool is_blocked) -> Request
690{
691 eBlockingType block_mode = (is_blocked) ? MP::Blocking : MP::NonBlocking;
692 PointToPointMessageInfo p2p_message(MessageRank(rank), block_mode);
693 return send(send_buffer, p2p_message);
694}
695
696/*---------------------------------------------------------------------------*/
697/*---------------------------------------------------------------------------*/
698
699template <class Type> void HybridParallelDispatch<Type>::
700send(ConstArrayView<Type> send_buf, Int32 rank)
701{
702 send(send_buf, rank, true);
703}
704
705/*---------------------------------------------------------------------------*/
706/*---------------------------------------------------------------------------*/
707
708template <class Type> Parallel::Request HybridParallelDispatch<Type>::
709receive(Span<Type> recv_buffer, Int32 rank, bool is_blocked)
710{
711 eBlockingType block_mode = (is_blocked) ? MP::Blocking : MP::NonBlocking;
712 PointToPointMessageInfo p2p_message(MessageRank(rank), block_mode);
713 return receive(recv_buffer, p2p_message);
714}
715
716/*---------------------------------------------------------------------------*/
717/*---------------------------------------------------------------------------*/
718
719template <class Type> Request HybridParallelDispatch<Type>::
720send(Span<const Type> send_buffer, const PointToPointMessageInfo& message2)
721{
722 PointToPointMessageInfo message(message2);
723 bool is_blocking = message.isBlocking();
724 message.setEmiterRank(MessageRank(m_global_rank));
725 Request r = m_message_queue->addSend(message, ConstMemoryView(send_buffer));
726 if (is_blocking) {
727 m_message_queue->waitAll(ArrayView<MP::Request>(1, &r));
728 return Request();
729 }
730 return r;
731}
732
733/*---------------------------------------------------------------------------*/
734/*---------------------------------------------------------------------------*/
735
736template <class Type> Request HybridParallelDispatch<Type>::
737receive(Span<Type> recv_buffer, const PointToPointMessageInfo& message2)
738{
739 PointToPointMessageInfo message(message2);
740 message.setEmiterRank(MessageRank(m_global_rank));
741 bool is_blocking = message.isBlocking();
742 Request r = m_message_queue->addReceive(message, ReceiveBufferInfo(MutableMemoryView(recv_buffer)));
743 if (is_blocking) {
744 m_message_queue->waitAll(ArrayView<Request>(1, &r));
745 return Request();
746 }
747 return r;
748}
749
750/*---------------------------------------------------------------------------*/
751/*---------------------------------------------------------------------------*/
752
753template <class Type> void HybridParallelDispatch<Type>::
754recv(ArrayView<Type> recv_buffer, Integer rank)
755{
756 recv(recv_buffer, rank, true);
757}
758
759/*---------------------------------------------------------------------------*/
760/*---------------------------------------------------------------------------*/
761
762template <class Type> void HybridParallelDispatch<Type>::
763sendRecv(ConstArrayView<Type> send_buffer, ArrayView<Type> recv_buffer, Integer proc)
764{
765 ARCANE_UNUSED(send_buffer);
766 ARCANE_UNUSED(recv_buffer);
767 ARCANE_UNUSED(proc);
768 throw NotImplementedException(A_FUNCINFO);
769}
770
771/*---------------------------------------------------------------------------*/
772/*---------------------------------------------------------------------------*/
773
774template <class Type> Type HybridParallelDispatch<Type>::
775allReduce(eReduceType op, Type send_buf)
776{
777 m_reduce_infos.reduce_value = send_buf;
778 //pinfo() << "ALL REDUCE BEGIN RANK=" << m_global_rank << " TYPE=" << (int)op << " MY=" << send_buf;
779 cout.flush();
780 _collectiveBarrier();
781 if (m_local_rank == 0) {
782 Type ret = m_all_dispatchs[0]->m_reduce_infos.reduce_value;
783 switch (op) {
784 case Parallel::ReduceMin:
785 for (Integer i = 1; i < m_local_nb_rank; ++i)
786 ret = math::min(ret, m_all_dispatchs[i]->m_reduce_infos.reduce_value);
787 break;
788 case Parallel::ReduceMax:
789 for (Integer i = 1; i < m_local_nb_rank; ++i)
790 ret = math::max(ret, m_all_dispatchs[i]->m_reduce_infos.reduce_value);
791 break;
792 case Parallel::ReduceSum:
793 for (Integer i = 1; i < m_local_nb_rank; ++i)
794 ret = (Type)(ret + m_all_dispatchs[i]->m_reduce_infos.reduce_value);
795 break;
796 default:
797 ARCANE_FATAL("Bad reduce type");
798 }
799 ret = m_parallel_mng->mpiParallelMng()->reduce(op, ret);
800 m_all_dispatchs[0]->m_reduce_infos.reduce_value = ret;
801 //pinfo() << "ALL REDUCE RANK=" << m_local_rank << " TYPE=" << (int)op << " MY=" << send_buf << " GLOBAL=" << ret << '\n';
802 }
803 _collectiveBarrier();
804 Type ret = m_all_dispatchs[0]->m_reduce_infos.reduce_value;
805 _collectiveBarrier();
806 return ret;
807}
808
809/*---------------------------------------------------------------------------*/
810/*---------------------------------------------------------------------------*/
811
812template <class Type> void HybridParallelDispatch<Type>::
813_applyReduceOperator(eReduceType op, Span<Type> result, AllDispatchView dispatch_view,
814 Int32 first_rank, Int32 last_rank)
815{
816 Int64 buf_size = result.size();
817 switch (op) {
818 case Parallel::ReduceMin:
819 for (Integer i = first_rank; i <= last_rank; ++i)
820 for (Int64 j = 0; j < buf_size; ++j)
821 result[j] = math::min(result[j], dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
822 break;
823 case Parallel::ReduceMax:
824 for (Integer i = first_rank; i <= last_rank; ++i)
825 for (Int64 j = 0; j < buf_size; ++j)
826 result[j] = math::max(result[j], dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
827 break;
828 case Parallel::ReduceSum:
829 for (Integer i = first_rank; i <= last_rank; ++i)
830 for (Integer j = 0; j < buf_size; ++j) {
831 result[j] = static_cast<Type>(result[j] + dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
832 }
833 break;
834 default:
835 ARCANE_FATAL("Bad reduce type");
836 }
837}
838
839/*---------------------------------------------------------------------------*/
840/*---------------------------------------------------------------------------*/
841
842template <class Type> void HybridParallelDispatch<Type>::
843_allReduceOrScan(eReduceType op, Span<Type> send_buf, bool is_scan)
844{
845 m_reduce_infos.reduce_buf_span = send_buf;
846 ++m_reduce_infos.m_index;
847 Int64 buf_size = send_buf.size();
848 UniqueArray<Type> ret(buf_size);
849 // Values from the previous MPI rank (used only in Scan mode)
850 UniqueArray<Type> previous_rank_ret;
851 MpiParallelMng* mpi_pm = m_parallel_mng->mpiParallelMng();
852 Int32 my_mpi_rank = mpi_pm->commRank();
853 Int32 mpi_nb_rank = mpi_pm->commSize();
854
855 //cout << "ALL REDUCE BEGIN RANk=" << m_local_rank << " TYPE=" << (int)op << " MY=" << send_buf << '\n';
856 //cout.flush();
857 _collectiveBarrier();
858 {
859 Integer index0 = m_all_dispatchs[0]->m_reduce_infos.m_index;
860 for (Integer i = 0; i < m_local_nb_rank; ++i) {
861 Integer indexi = m_all_dispatchs[i]->m_reduce_infos.m_index;
862 if (index0 != m_all_dispatchs[i]->m_reduce_infos.m_index) {
863 ARCANE_FATAL("INTERNAL: incoherent all reduce i0={0} in={1} n={2}",
864 index0, indexi, i);
865 }
866 }
867 }
868
869 if (m_local_rank == 0) {
870 const Int32 nb_local_rank = m_local_nb_rank;
871 for (Integer j = 0; j < buf_size; ++j)
872 ret[j] = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span[j];
873 _applyReduceOperator(op, ret, m_all_dispatchs, 1, nb_local_rank - 1);
874 if (is_scan) {
875 // For scan, we need to know the scan value of the preceding rank.
876 // We then use this value and apply our operator.
877 mpi_pm->scan(op, ret);
878 previous_rank_ret.resize(buf_size);
879 UniqueArray<Request> requests;
880 if (my_mpi_rank != 0)
881 requests.add(mpi_pm->recv(previous_rank_ret, my_mpi_rank - 1, false));
882 if (my_mpi_rank != (mpi_nb_rank - 1))
883 requests.add(mpi_pm->send(ret, my_mpi_rank + 1, false));
884 mpi_pm->waitAllRequests(requests);
885 if (my_mpi_rank != 0) {
886 // Apply the scan to my values.
887 _applyReduceOperator(op, previous_rank_ret, m_all_dispatchs, 0, 0);
888 send_buf.copy(previous_rank_ret);
889 }
890 else {
891 // I am the first local and MPI rank. I already have the correct values
892 // in \a send_buf.
893 }
894 }
895 else {
896 mpi_pm->reduce(op, ret);
897 send_buf.copy(ret);
898 }
899 }
900
901 _collectiveBarrier();
902
903 if (is_scan) {
904 if (m_local_rank != 0) {
905 Span<const Type> global_buf = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span;
906 ret.copy(global_buf);
907 // The scan for local rank 0 has already been applied
908 _applyReduceOperator(op, ret, m_all_dispatchs, 1, m_local_rank);
909 }
910 // TODO: We could avoid this barrier if we copied the values of 'send_buf'
911 // before modifying them.
912 _collectiveBarrier();
913
914 if (m_local_rank != 0) {
915 send_buf.copy(ret);
916 }
917 }
918 else {
919 if (m_local_rank != 0) {
920 Span<const Type> global_buf = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span;
921 send_buf.copy(global_buf);
922 }
923 }
924
925 _collectiveBarrier();
926}
927
928/*---------------------------------------------------------------------------*/
929/*---------------------------------------------------------------------------*/
930
931template <class Type> void HybridParallelDispatch<Type>::
932allReduce(eReduceType op, Span<Type> send_buf)
933{
934 _allReduceOrScan(op, send_buf, false);
935}
936
937/*---------------------------------------------------------------------------*/
938/*---------------------------------------------------------------------------*/
939
940template <class Type> Request HybridParallelDispatch<Type>::
941nonBlockingAllReduce(eReduceType op, Span<const Type> send_buf, Span<Type> recv_buf)
942{
943 ARCANE_UNUSED(op);
944 ARCANE_UNUSED(send_buf);
945 ARCANE_UNUSED(recv_buf);
946 throw NotImplementedException(A_FUNCINFO);
947}
948
949/*---------------------------------------------------------------------------*/
950/*---------------------------------------------------------------------------*/
951template <class Type> Request HybridParallelDispatch<Type>::
952nonBlockingAllGather(Span<const Type> send_buf, Span<Type> recv_buf)
953{
954 ARCANE_UNUSED(send_buf);
955 ARCANE_UNUSED(recv_buf);
956 throw NotImplementedException(A_FUNCINFO);
957}
958
959/*---------------------------------------------------------------------------*/
960/*---------------------------------------------------------------------------*/
961
962template <class Type> Request HybridParallelDispatch<Type>::
963nonBlockingBroadcast(Span<Type> send_buf, Int32 rank)
964{
965 ARCANE_UNUSED(send_buf);
966 ARCANE_UNUSED(rank);
967 throw NotImplementedException(A_FUNCINFO);
968}
969
970/*---------------------------------------------------------------------------*/
971/*---------------------------------------------------------------------------*/
972
973template <class Type> Request HybridParallelDispatch<Type>::
974nonBlockingGather(Span<const Type> send_buf, Span<Type> recv_buf, Int32 rank)
975{
976 ARCANE_UNUSED(send_buf);
977 ARCANE_UNUSED(recv_buf);
978 ARCANE_UNUSED(rank);
979 throw NotImplementedException(A_FUNCINFO);
980}
981
982/*---------------------------------------------------------------------------*/
983/*---------------------------------------------------------------------------*/
984
985template <class Type> Request HybridParallelDispatch<Type>::
986nonBlockingAllToAll(Span<const Type> send_buf, Span<Type> recv_buf, Int32 count)
987{
988 ARCANE_UNUSED(send_buf);
989 ARCANE_UNUSED(recv_buf);
990 ARCANE_UNUSED(count);
991 throw NotImplementedException(A_FUNCINFO);
992}
993
994/*---------------------------------------------------------------------------*/
995/*---------------------------------------------------------------------------*/
996
997template <class Type> Request HybridParallelDispatch<Type>::
998nonBlockingAllToAllVariable(Span<const Type> send_buf, ConstArrayView<Int32> send_count,
999 ConstArrayView<Int32> send_index, Span<Type> recv_buf,
1000 ConstArrayView<Int32> recv_count, ConstArrayView<Int32> recv_index)
1001{
1002 ARCANE_UNUSED(send_buf);
1003 ARCANE_UNUSED(recv_buf);
1004 ARCANE_UNUSED(send_count);
1005 ARCANE_UNUSED(recv_count);
1006 ARCANE_UNUSED(send_index);
1007 ARCANE_UNUSED(recv_index);
1008 throw NotImplementedException(A_FUNCINFO);
1009}
1010
1011/*---------------------------------------------------------------------------*/
1012/*---------------------------------------------------------------------------*/
1013
1014template <class Type> Type HybridParallelDispatch<Type>::
1015scan(eReduceType op, Type send_buf)
1016{
1017 ARCANE_UNUSED(op);
1018 ARCANE_UNUSED(send_buf);
1019 throw NotImplementedException(A_FUNCINFO);
1020}
1021
1022/*---------------------------------------------------------------------------*/
1023/*---------------------------------------------------------------------------*/
1024
1025template <class Type> void HybridParallelDispatch<Type>::
1026scan(eReduceType op, ArrayView<Type> send_buf)
1027{
1028 _allReduceOrScan(op, send_buf, true);
1029}
1030
1031/*---------------------------------------------------------------------------*/
1032/*---------------------------------------------------------------------------*/
1033
1034template <class Type> Request HybridParallelDispatch<Type>::
1036{
1037 throw NotImplementedException(A_FUNCINFO);
1038}
1039
1040/*---------------------------------------------------------------------------*/
1041/*---------------------------------------------------------------------------*/
1042
1043template <class Type> void HybridParallelDispatch<Type>::
1044_collectiveBarrier()
1045{
1046 m_parallel_mng->getThreadBarrier()->wait();
1047}
1048
1049/*---------------------------------------------------------------------------*/
1050/*---------------------------------------------------------------------------*/
1051
1052template class HybridParallelDispatch<char>;
1053template class HybridParallelDispatch<signed char>;
1054template class HybridParallelDispatch<unsigned char>;
1055template class HybridParallelDispatch<short>;
1056template class HybridParallelDispatch<unsigned short>;
1057template class HybridParallelDispatch<int>;
1058template class HybridParallelDispatch<unsigned int>;
1059template class HybridParallelDispatch<long>;
1060template class HybridParallelDispatch<unsigned long>;
1061template class HybridParallelDispatch<long long>;
1062template class HybridParallelDispatch<unsigned long long>;
1063template class HybridParallelDispatch<float>;
1064template class HybridParallelDispatch<double>;
1065template class HybridParallelDispatch<long double>;
1066template class HybridParallelDispatch<Real2>;
1067template class HybridParallelDispatch<Real3>;
1068template class HybridParallelDispatch<Real2x2>;
1069template class HybridParallelDispatch<Real3x3>;
1070template class HybridParallelDispatch<HPReal>;
1071template class HybridParallelDispatch<APReal>;
1072
1073/*---------------------------------------------------------------------------*/
1074/*---------------------------------------------------------------------------*/
1075
1076} // End namespace Arcane::MessagePassing
1077
1078/*---------------------------------------------------------------------------*/
1079/*---------------------------------------------------------------------------*/
#define ARCANE_FATAL(...)
Macro throwing a FatalErrorException.
Modifiable view of an array of type T.
Class implementing a High-Precision real number.
Definition HPReal.h:159
Brief information for a 'gather' message for data type DataType.
Interface for a message queue with threads.
Thread-based parallelism manager.
Declarations of types and methods used by message exchange mechanisms.
eBlockingType
Type indicating whether a message is blocking or not.
Int32 Integer
Type representing an integer.
UniqueArray< Int32 > Int32UniqueArray
Dynamic 1D array of 32-bit integers.
Definition UtilsTypes.h:341
std::int32_t Int32
Signed integer type of 32 bits.
Type
Type of JSON value.
Definition rapidjson.h:730
Structure equivalent to the boolean value true.
Structure equivalent to the boolean value true.