Arcane  v3.15.0.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
HybridParallelDispatch.cc
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2024 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* MpiParallelDispatch.cc (C) 2000-2024 */
9/* */
10/* Gestionnaire de parallélisme utilisant les threads et MPI. */
11/*---------------------------------------------------------------------------*/
12/*---------------------------------------------------------------------------*/
13
14#include "arcane/utils/Array.h"
15#include "arcane/utils/PlatformUtils.h"
16#include "arcane/utils/String.h"
17#include "arcane/utils/ITraceMng.h"
18#include "arcane/utils/Real2.h"
19#include "arcane/utils/Real3.h"
20#include "arcane/utils/Real2x2.h"
21#include "arcane/utils/Real3x3.h"
22#include "arcane/utils/APReal.h"
23#include "arcane/utils/FatalErrorException.h"
24#include "arcane/utils/NotImplementedException.h"
25#include "arcane/utils/NotSupportedException.h"
26#include "arcane/utils/IThreadBarrier.h"
27#include "arcane/utils/CheckedConvert.h"
28
29#include "arcane/MeshVariableRef.h"
30#include "arcane/IParallelMng.h"
31#include "arcane/ItemGroup.h"
32#include "arcane/IMesh.h"
33#include "arcane/IBase.h"
34
35#include "arcane/parallel/mpithread/HybridParallelDispatch.h"
36#include "arcane/parallel/mpithread/HybridParallelMng.h"
37#include "arcane/parallel/mpithread/HybridMessageQueue.h"
38#include "arcane/parallel/mpi/MpiParallelMng.h"
39#include "arcane/parallel/mpi/MpiParallelDispatch.h"
40
41/*---------------------------------------------------------------------------*/
42/*---------------------------------------------------------------------------*/
43
45{
46
47/*---------------------------------------------------------------------------*/
48/*---------------------------------------------------------------------------*/
49
50//TODO: Fusionner avec ce qui est possible dans SharedMemoryParallelDispatch
51
52/*---------------------------------------------------------------------------*/
53/*---------------------------------------------------------------------------*/
54
55template<class Type> HybridParallelDispatch<Type>::
56HybridParallelDispatch(ITraceMng* tm,HybridParallelMng* pm,HybridMessageQueue* message_queue,
57 ArrayView<HybridParallelDispatch<Type>*> all_dispatchs)
58: TraceAccessor(tm)
59, m_parallel_mng(pm)
60, m_local_rank(pm->localRank())
61, m_local_nb_rank(pm->localNbRank())
62, m_global_rank(pm->commRank())
63, m_global_nb_rank(pm->commSize())
64, m_mpi_rank(pm->mpiParallelMng()->commRank())
65, m_mpi_nb_rank(pm->mpiParallelMng()->commSize())
66, m_all_dispatchs(all_dispatchs)
67, m_message_queue(message_queue)
68, m_mpi_dispatcher(0)
69{
70 m_reduce_infos.m_index = 0;
71
72 // Ce tableau a été dimensionné par le créateur de cette instance.
73 // Il faut juste mettre à jour la valeur correspondant à son rang
74 m_all_dispatchs[m_local_rank] = this;
75
76 // Récupère le dispatcher MPI pour ce type.
77 MpiParallelMng* mpi_pm = pm->mpiParallelMng();
78 IParallelDispatchT<Type>* pd = mpi_pm->dispatcher((Type*)nullptr);
79 if (!pd)
80 ARCANE_FATAL("null dispatcher");
81
82 m_mpi_dispatcher = dynamic_cast<MpiParallelDispatchT<Type>*>(pd);
83 if (!m_mpi_dispatcher)
84 ARCANE_FATAL("null mpi dispatcher");
85}
86
87/*---------------------------------------------------------------------------*/
88/*---------------------------------------------------------------------------*/
89
90template<class Type> HybridParallelDispatch<Type>::
91~HybridParallelDispatch()
92{
93 finalize();
94}
95
96/*---------------------------------------------------------------------------*/
97/*---------------------------------------------------------------------------*/
98
99template<class Type> void HybridParallelDispatch<Type>::
100finalize()
101{
102}
103
104/*---------------------------------------------------------------------------*/
105/*---------------------------------------------------------------------------*/
106
107template<typename T>
109{
110 public:
111 typedef FalseType IsIntegral;
112};
113
114#define ARCANE_DEFINE_INTEGRAL_TYPE(datatype)\
115template<>\
116class _ThreadIntegralType<datatype>\
117{\
118 public:\
119 typedef TrueType IsIntegral;\
120}
121
122ARCANE_DEFINE_INTEGRAL_TYPE(long long);
123ARCANE_DEFINE_INTEGRAL_TYPE(long);
124ARCANE_DEFINE_INTEGRAL_TYPE(int);
125ARCANE_DEFINE_INTEGRAL_TYPE(short);
126ARCANE_DEFINE_INTEGRAL_TYPE(unsigned long long);
127ARCANE_DEFINE_INTEGRAL_TYPE(unsigned long);
128ARCANE_DEFINE_INTEGRAL_TYPE(unsigned int);
129ARCANE_DEFINE_INTEGRAL_TYPE(unsigned short);
130ARCANE_DEFINE_INTEGRAL_TYPE(double);
131ARCANE_DEFINE_INTEGRAL_TYPE(float);
132ARCANE_DEFINE_INTEGRAL_TYPE(HPReal);
133
134/*---------------------------------------------------------------------------*/
135/*---------------------------------------------------------------------------*/
136
137namespace{
138
139template<class Type> void
142 Int32& min_rank,Int32& max_rank,Int32 nb_rank,FalseType)
143{
144 ARCANE_UNUSED(all_dispatchs);
145 ARCANE_UNUSED(my_rank);
146 ARCANE_UNUSED(min_val);
147 ARCANE_UNUSED(max_val);
148 ARCANE_UNUSED(sum_val);
149 ARCANE_UNUSED(min_rank);
150 ARCANE_UNUSED(max_rank);
151 ARCANE_UNUSED(nb_rank);
152
153 throw NotImplementedException(A_FUNCINFO);
154}
155
156/*---------------------------------------------------------------------------*/
157/*---------------------------------------------------------------------------*/
158
159template<class Type> void
162 Int32& min_rank,Int32& max_rank,Int32 nb_rank,TrueType)
163{
164 ARCANE_UNUSED(my_rank);
165
166 HybridParallelDispatch<Type>* mtpd0 = all_dispatchs[0];
167 Type cval0 = mtpd0->m_reduce_infos.reduce_value;
171 Integer _min_rank = 0;
172 Integer _max_rank = 0;
173 for( Integer i=1; i<nb_rank; ++i ){
174 HybridParallelDispatch<Type>* mtpd = all_dispatchs[i];
175 Type cval = mtpd->m_reduce_infos.reduce_value;
176 Int32 grank = mtpd->globalRank();
177 if (cval<_min_val){
178 _min_val = cval;
180 }
181 if (_max_val<cval){
182 _max_val = cval;
183 _max_rank = grank;
184 }
185 _sum_val = (Type)(_sum_val + cval);
186 }
187 min_val = _min_val;
188 max_val = _max_val;
189 sum_val = _sum_val;
190 min_rank = _min_rank;
191 max_rank = _max_rank;
192}
193
194}
195
196/*---------------------------------------------------------------------------*/
197/*---------------------------------------------------------------------------*/
198
199template<class Type> void HybridParallelDispatch<Type>::
200computeMinMaxSum(Type val,Type& min_val,Type& max_val,Type& sum_val,
201 Int32& min_rank,Int32& max_rank)
202{
203 typedef typename _ThreadIntegralType<Type>::IsIntegral IntegralType;
204 m_reduce_infos.reduce_value = val;
205 _collectiveBarrier();
206 _computeMinMaxSum2(m_all_dispatchs,m_global_rank,min_val,max_val,sum_val,min_rank,max_rank,m_local_nb_rank,IntegralType());
207 if (m_local_rank==0){
208 /*pinfo() << "COMPUTE_MIN_MAX_SUM_B rank=" << m_global_rank
209 << " min_rank=" << min_rank
210 << " max_rank=" << max_rank
211 << " min_val=" << min_val
212 << " max_val=" << max_val
213 << " sum_val=" << sum_val;*/
214 m_mpi_dispatcher->computeMinMaxSumNoInit(min_val,max_val,sum_val,min_rank,max_rank);
215 /*pinfo() << "COMPUTE_MIN_MAX_SUM_A rank=" << m_global_rank
216 << " min_rank=" << min_rank
217 << " max_rank=" << max_rank;*/
218
219 m_min_max_sum_infos.m_min_value = min_val;
220 m_min_max_sum_infos.m_max_value = max_val;
221 m_min_max_sum_infos.m_sum_value = sum_val;
222 m_min_max_sum_infos.m_min_rank = min_rank;
223 m_min_max_sum_infos.m_max_rank = max_rank;
224 }
225 _collectiveBarrier();
226 m_min_max_sum_infos = m_all_dispatchs[0]->m_min_max_sum_infos;
227 min_val = m_min_max_sum_infos.m_min_value;
228 max_val = m_min_max_sum_infos.m_max_value;
229 sum_val = m_min_max_sum_infos.m_sum_value;
230 min_rank = m_min_max_sum_infos.m_min_rank;
231 max_rank = m_min_max_sum_infos.m_max_rank;
232 _collectiveBarrier();
233}
234
235/*---------------------------------------------------------------------------*/
236/*---------------------------------------------------------------------------*/
237
238template<class Type> void HybridParallelDispatch<Type>::
239computeMinMaxSum(ConstArrayView<Type> values,
240 ArrayView<Type> min_values,
241 ArrayView<Type> max_values,
242 ArrayView<Type> sum_values,
243 ArrayView<Int32> min_ranks,
244 ArrayView<Int32> max_ranks)
245{
246 // Implémentation sous-optimale qui ne vectorise pas le calcul
247 // (c'est actuellement un copier-coller d'au-dessus mis dans une boucle)
248 typedef typename _ThreadIntegralType<Type>::IsIntegral IntegralType;
249 Integer n = values.size();
250 for(Integer i=0;i<n;++i) {
251 m_reduce_infos.reduce_value = values[i];
252 _collectiveBarrier();
253 _computeMinMaxSum2(m_all_dispatchs,m_global_rank,min_values[i],max_values[i],sum_values[i],min_ranks[i],max_ranks[i],m_local_nb_rank,IntegralType());
254 if (m_local_rank==0){
255 /*pinfo() << "COMPUTE_MIN_MAX_SUM_B rank=" << m_global_rank
256 << " min_rank=" << min_rank
257 << " max_rank=" << max_rank
258 << " min_val=" << min_val
259 << " max_val=" << max_val
260 << " sum_val=" << sum_val;*/
261 m_mpi_dispatcher->computeMinMaxSumNoInit(min_values[i],max_values[i],sum_values[i],min_ranks[i],max_ranks[i]);
262 /*pinfo() << "COMPUTE_MIN_MAX_SUM_A rank=" << m_global_rank
263 << " min_rank=" << min_rank
264 << " max_rank=" << max_rank;*/
265
266 m_min_max_sum_infos.m_min_value = min_values[i];
267 m_min_max_sum_infos.m_max_value = max_values[i];
268 m_min_max_sum_infos.m_sum_value = sum_values[i];
269 m_min_max_sum_infos.m_min_rank = min_ranks[i];
270 m_min_max_sum_infos.m_max_rank = max_ranks[i];
271 }
272 _collectiveBarrier();
273 m_min_max_sum_infos = m_all_dispatchs[0]->m_min_max_sum_infos;
274 min_values[i] = m_min_max_sum_infos.m_min_value;
275 max_values[i] = m_min_max_sum_infos.m_max_value;
276 sum_values[i] = m_min_max_sum_infos.m_sum_value;
277 min_ranks[i] = m_min_max_sum_infos.m_min_rank;
278 max_ranks[i] = m_min_max_sum_infos.m_max_rank;
279 _collectiveBarrier();
280 }
281}
282
283/*---------------------------------------------------------------------------*/
284/*---------------------------------------------------------------------------*/
285
286template<class Type> void HybridParallelDispatch<Type>::
287broadcast(Span<Type> send_buf,Int32 rank)
288{
289 m_broadcast_view = send_buf;
290 _collectiveBarrier();
291 FullRankInfo fri = FullRankInfo::compute(MP::MessageRank(rank),m_local_nb_rank);
292 int mpi_rank = fri.mpiRankValue();
293 if (m_mpi_rank==mpi_rank){
294 // J'ai le meme rang MPI que celui qui fait le broadcast
295 if (m_global_rank==rank){
296 //TODO: passage 64bit.
297 m_parallel_mng->mpiParallelMng()->broadcast(send_buf.smallView(),mpi_rank);
298 }
299 else{
300 m_all_dispatchs[m_local_rank]->m_broadcast_view.copy(m_all_dispatchs[fri.localRankValue()]->m_broadcast_view);
301 }
302 }
303 else{
304 if (m_local_rank==0){
305 //TODO: passage 64bit.
306 m_parallel_mng->mpiParallelMng()->broadcast(send_buf.smallView(),mpi_rank);
307 }
308 }
309 _collectiveBarrier();
310 if (m_mpi_rank!=mpi_rank){
311 if (m_local_rank!=0)
312 m_all_dispatchs[m_local_rank]->m_broadcast_view.copy(m_all_dispatchs[0]->m_broadcast_view);
313 }
314 _collectiveBarrier();
315}
316
317/*---------------------------------------------------------------------------*/
318/*---------------------------------------------------------------------------*/
319
320template<class Type> void HybridParallelDispatch<Type>::
321allGather(Span<const Type> send_buf,Span<Type> recv_buf)
322{
323 //TODO: fusionner avec allGatherVariable()
324 m_const_view = send_buf;
325 _collectiveBarrier();
326 Int64 total_size = 0;
327 for( Int32 i=0; i<m_local_nb_rank; ++i ){
328 total_size += m_all_dispatchs[i]->m_const_view.size();
329 }
330 if (m_local_rank==0){
331 Int64 index = 0;
332 UniqueArray<Type> local_buf(total_size);
333 for( Integer i=0; i<m_local_nb_rank; ++i ){
334 Span<const Type> view = m_all_dispatchs[i]->m_const_view;
335 Int64 size = view.size();
336 for( Int64 j=0; j<size; ++j )
337 local_buf[j+index] = view[j];
338 index += size;
339 }
340 IParallelMng* pm = m_parallel_mng->mpiParallelMng();
341 //TODO: 64bit
342 pm->allGather(local_buf,recv_buf.smallView());
343 m_const_view = recv_buf;
344 }
345 _collectiveBarrier();
346 if (m_local_rank!=0){
347 Span<const Type> view = m_all_dispatchs[0]->m_const_view;
348 recv_buf.copy(view);
349 }
350 _collectiveBarrier();
351}
352
353/*---------------------------------------------------------------------------*/
354/*---------------------------------------------------------------------------*/
355
356template<class Type> void HybridParallelDispatch<Type>::
357gather(Span<const Type> send_buf,Span<Type> recv_buf,Int32 root_rank)
358{
359 UniqueArray<Type> tmp_buf;
360 if (m_global_rank==root_rank)
361 allGather(send_buf,recv_buf);
362 else{
363 tmp_buf.resize(send_buf.size() * m_global_nb_rank);
364 allGather(send_buf,tmp_buf);
365 }
366}
367
368/*---------------------------------------------------------------------------*/
369/*---------------------------------------------------------------------------*/
370
371template<class Type> void HybridParallelDispatch<Type>::
372allGatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf)
373{
374 m_const_view = send_buf;
375 _collectiveBarrier();
376 Int64 total_size = 0;
377 for( Integer i=0; i<m_local_nb_rank; ++i ){
378 total_size += m_all_dispatchs[i]->m_const_view.size();
379 }
380 if (m_local_rank==0){
381 Int64 index = 0;
382 UniqueArray<Type> local_buf(total_size);
383 for( Integer i=0; i<m_local_nb_rank; ++i ){
384 Span<const Type> view = m_all_dispatchs[i]->m_const_view;
385 Int64 size = view.size();
386 for( Int64 j=0; j<size; ++j )
387 local_buf[j+index] = view[j];
388 index += size;
389 }
390 m_parallel_mng->mpiParallelMng()->allGatherVariable(local_buf,recv_buf);
391 m_const_view = recv_buf.constView();
392 }
393 _collectiveBarrier();
394 if (m_local_rank!=0){
395 Span<const Type> view = m_all_dispatchs[0]->m_const_view;
396 recv_buf.resize(view.size());
397 recv_buf.copy(view);
398 }
399 _collectiveBarrier();
400}
401
402/*---------------------------------------------------------------------------*/
403/*---------------------------------------------------------------------------*/
404
405template<class Type> void HybridParallelDispatch<Type>::
406gatherVariable(Span<const Type> send_buf,Array<Type>& recv_buf,Int32 root_rank)
407{
408 UniqueArray<Type> tmp_buf;
409 if (m_global_rank==root_rank)
410 allGatherVariable(send_buf,recv_buf);
411 else
412 allGatherVariable(send_buf,tmp_buf);
413}
414
415/*---------------------------------------------------------------------------*/
416/*---------------------------------------------------------------------------*/
417
418template <class Type>
419void HybridParallelDispatch<Type>::
420scatterVariable(Span<const Type> send_buf, Span<Type> recv_buf, Int32 root)
421{
422 m_const_view = send_buf;
423 m_recv_view = recv_buf;
424
425 _collectiveBarrier();
426
427 // On calcule le nombre d'élément que veut tous les threads de notre processus.
428 Int64 total_size = 0;
429 for (Integer i = 0; i < m_local_nb_rank; ++i) {
430 total_size += m_all_dispatchs[i]->m_recv_view.size();
431 }
432
433 _collectiveBarrier();
434
435 // Les échanges MPI s'effectuent uniquement par les threads leaders des processus.
436 if (m_local_rank == 0) {
437 FullRankInfo fri(FullRankInfo::compute(MessageRank(root), m_local_nb_rank));
438
439 UniqueArray<Type> local_recv_buf(total_size);
440
441 // Si le thread "root" est dans notre processus.
442 if (m_mpi_rank == fri.mpiRankValue()) {
443 // Le thread leader s'occupe de l'échange.
444 m_parallel_mng->mpiParallelMng()->scatterVariable(m_all_dispatchs[fri.localRankValue()]->m_const_view.smallView(),
445 local_recv_buf, fri.mpiRankValue());
446 }
447 // Les autres threads leaders mettent leurs buffers d'envoi (qu'importe ce
448 // qu'ils contiennent, c'est un scatter).
449 else {
450 m_parallel_mng->mpiParallelMng()->scatterVariable(m_const_view.smallView(), local_recv_buf, fri.mpiRankValue());
451 }
452
453 // On a plus qu'à répartir les données reçues entre les threads.
454 Integer compt = 0;
455 for (Integer i = 0; i < m_local_nb_rank; ++i) {
456 Int64 size = m_all_dispatchs[i]->m_recv_view.size();
457 for (Integer j = 0; j < size; ++j) {
458 m_all_dispatchs[i]->m_recv_view[j] = local_recv_buf[compt++];
459 }
460 }
461 }
462 _collectiveBarrier();
463 recv_buf.copy(m_recv_view);
464 _collectiveBarrier();
465}
466
467/*---------------------------------------------------------------------------*/
468/*---------------------------------------------------------------------------*/
469
470template<class Type> void HybridParallelDispatch<Type>::
471allToAll(Span<const Type> send_buf,Span<Type> recv_buf,Int32 count)
472{
473 Int32 global_nb_rank = m_global_nb_rank;
474 //TODO: Faire une version sans allocation
475 Int32UniqueArray send_count(global_nb_rank,count);
476 Int32UniqueArray recv_count(global_nb_rank,count);
477
478 Int32UniqueArray send_indexes(global_nb_rank);
479 Int32UniqueArray recv_indexes(global_nb_rank);
480 for( Integer i=0; i<global_nb_rank; ++i ){
481 send_indexes[i] = count * i;
482 recv_indexes[i] = count * i;
483 }
484 this->allToAllVariable(send_buf,send_count,send_indexes,recv_buf,recv_count,recv_indexes);
485}
486
487/*---------------------------------------------------------------------------*/
488/*---------------------------------------------------------------------------*/
489
490template<class Type> void HybridParallelDispatch<Type>::
491allToAllVariable(Span<const Type> g_send_buf,
492 Int32ConstArrayView g_send_count,
493 Int32ConstArrayView g_send_index,
494 Span<Type> g_recv_buf,
495 Int32ConstArrayView g_recv_count,
496 Int32ConstArrayView g_recv_index
497 )
498{
499 m_alltoallv_infos.send_buf = g_send_buf;
500 m_alltoallv_infos.send_count = g_send_count;
501 m_alltoallv_infos.send_index = g_send_index;
502 m_alltoallv_infos.recv_buf = g_recv_buf;
503 m_alltoallv_infos.recv_count = g_recv_count;
504 m_alltoallv_infos.recv_index = g_recv_index;
505
506 _collectiveBarrier();
507
508 UniqueArray<Type> tmp_recv_buf;
509
510 // PREMIERE IMPLEMENTATION
511 // Le proc de rang local 0 fait tout le travail.
512
513 if (m_local_rank==0){
514
515 Int32UniqueArray tmp_send_count(m_mpi_nb_rank);
516 tmp_send_count.fill(0);
517 Int32UniqueArray tmp_recv_count(m_mpi_nb_rank);
518 tmp_recv_count.fill(0);
519
520 Int64 total_send_size = 0;
521 Int64 total_recv_size = 0;
522
523 for( Integer i=0; i<m_local_nb_rank; ++i ){
524 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
525 total_send_size += vinfo.send_buf.size();
526 total_recv_size += vinfo.recv_buf.size();
527 }
528
529 UniqueArray<Type> tmp_send_buf(total_send_size);
530 tmp_recv_buf.resize(total_recv_size);
531
532 // Calcule le nombre d'éléments à envoyer et recevoir pour chaque proc.
533 for( Integer i=0; i<m_local_nb_rank; ++i ){
534 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
535
536 for( Integer z=0; z<m_global_nb_rank; ++z ){
537
538 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z),m_local_nb_rank));
539 Int32 fri_mpi_rank = fri.mpiRankValue();
540
541 Int32 nb_send = vinfo.send_count[z];
542
543 tmp_send_count[fri_mpi_rank] += nb_send;
544 tmp_recv_count[fri_mpi_rank] += vinfo.recv_count[z];
545
546#if 0
547 info() << "my_local=" << i << " dest=" << z
548 << " send_count=" << vinfo.send_count[z] << " send_index=" << vinfo.send_index[z]
549 << " recv_count=" << vinfo.recv_count[z] << " recv_index=" << vinfo.recv_index[z];
550 {
551 Integer vindex = vinfo.send_index[z];
552 for( Integer w=0, wn=vinfo.send_count[z]; w<wn; ++w ){
553 info() << "V=" << vinfo.send_buf[ vindex + w ];
554 }
555 }
556#endif
557 }
558 }
559
560 Int32UniqueArray tmp_send_index(m_mpi_nb_rank);
561 Int32UniqueArray tmp_recv_index(m_mpi_nb_rank);
562 tmp_send_index[0] = 0;
563 tmp_recv_index[0] = 0;
564 for( Integer k=1, nmpi=m_mpi_nb_rank; k<nmpi; ++k ){
565 tmp_send_index[k] = tmp_send_index[k-1] + tmp_send_count[k-1];
566 tmp_recv_index[k] = tmp_recv_index[k-1] + tmp_recv_count[k-1];
567 }
568
569 for( Integer i=0; i<m_local_nb_rank; ++i ){
570 const AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
571
572 for( Integer z=0; z<m_global_nb_rank; ++ z){
573
574 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z),m_local_nb_rank));
575 Int32 fri_mpi_rank = fri.mpiRankValue();
576
577 Integer nb_send = vinfo.send_count[z];
578 {
579
580 Integer tmp_current_index = tmp_send_index[fri_mpi_rank];
581 Integer local_current_index = vinfo.send_index[z];
582 for( Integer j=0; j<nb_send; ++j )
583 tmp_send_buf[j+tmp_current_index] = vinfo.send_buf[j+local_current_index];
584 tmp_send_index[fri_mpi_rank] += nb_send;
585 }
586
587
588 }
589 }
590
591 tmp_send_index[0] = 0;
592 tmp_recv_index[0] = 0;
593 for( Integer k=1, nmpi=m_mpi_nb_rank; k<nmpi; ++k ){
594 tmp_send_index[k] = tmp_send_index[k-1] + tmp_send_count[k-1];
595 tmp_recv_index[k] = tmp_recv_index[k-1] + tmp_recv_count[k-1];
596 }
597
598
599
600 /* Integer send_index = 0;
601 for( Integer i=0; i<m_local_nb_rank; ++i ){
602 ConstArrayView<Type> send_view = m_all_dispatchs[i]->m_alltoallv_infos.send_buf;
603 Integer send_size = send_view.size();
604 info() << "ADD_TMP_SEND_BUF send_index=" << send_index << " size=" << send_size;
605 for( Integer j=0; j<send_size; ++j )
606 tmp_send_buf[j+send_index] = send_view[j];
607 send_index += send_size;
608 }
609 */
610
611#if 0
612 info() << "AllToAllV nb_send=" << total_send_size << " nb_recv=" << total_recv_size;
613 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
614 info() << "INFOS Rank=" << k << " send_count=" << tmp_send_count[k] << " recv_count=" << tmp_recv_count[k]
615 << " send_index=" << tmp_send_index[k] << " recv_index=" << tmp_recv_index[k];
616 }
617
618 for( Integer i=0; i<tmp_send_buf.size(); ++i )
619 info() << "SEND_BUF[" << i << "] = " << tmp_send_buf[i];
620
621 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
622 info() << "SEND Rank=" << k << " send_count=" << tmp_send_count[k] << " recv_count=" << tmp_recv_count[k]
623 << " send_index=" << tmp_send_index[k] << " recv_index=" << tmp_recv_index[k];
624 Integer vindex = tmp_send_index[k];
625 for( Integer w=0, wn=tmp_send_count[k]; w<wn; ++w ){
626 info() << "V=" << tmp_send_buf[ vindex + w ];
627 }
628 }
629#endif
630
631 m_parallel_mng->mpiParallelMng()->allToAllVariable(tmp_send_buf,tmp_send_count,
632 tmp_send_index,tmp_recv_buf,
633 tmp_recv_count,tmp_recv_index);
634
635#if 0
636 for( Integer i=0; i<tmp_recv_buf.size(); ++i )
637 info() << "RECV_BUF[" << i << "] = " << tmp_recv_buf[i];
638
639 for( Integer k=0; k<m_mpi_nb_rank; ++k ){
640 info() << "RECV Rank=" << k << " send_count=" << tmp_send_count[k] << " recv_count=" << tmp_recv_count[k]
641 << " send_index=" << tmp_send_index[k] << " recv_index=" << tmp_recv_index[k];
642 Integer vindex = tmp_recv_index[k];
643 for( Integer w=0, wn=tmp_recv_count[k]; w<wn; ++w ){
644 info() << "V=" << tmp_recv_buf[ vindex + w ];
645 }
646 }
647#endif
648
649 m_const_view = tmp_recv_buf.constView();
650
651
652 for( Integer z=0; z<m_global_nb_rank; ++ z){
653 FullRankInfo fri(FullRankInfo::compute(MP::MessageRank(z),m_local_nb_rank));
654 Int32 fri_mpi_rank = fri.mpiRankValue();
655
656 for( Integer i=0; i<m_local_nb_rank; ++i ){
657 AllToAllVariableInfo& vinfo = m_all_dispatchs[i]->m_alltoallv_infos;
658 Span<Type> my_buf = vinfo.recv_buf;
659 Int64 recv_size = vinfo.recv_count[z];
660 Int64 recv_index = tmp_recv_index[fri_mpi_rank];
661
662 Span<const Type> recv_view = tmp_recv_buf.span().subSpan(recv_index,recv_size);
663
664 Int64 my_recv_index = vinfo.recv_index[z];
665
666 //info() << "GET i=" << i << " z=" << z << " size=" << recv_size << " index=" << recv_index
667 // << " mpi_rank=" << fri_mpi_rank << " my_index=" << my_recv_index;
668
669 tmp_recv_index[fri_mpi_rank] = CheckedConvert::toInt32(tmp_recv_index[fri_mpi_rank] + recv_size);
670
671 for( Int64 j=0; j<recv_size; ++j )
672 my_buf[j+my_recv_index] = recv_view[j];
673
674 //for( Integer j=0; j<recv_size; ++j )
675 //info() << "V=" << recv_view[j];
676
677 my_recv_index += recv_size;
678 }
679 }
680
681 }
682 _collectiveBarrier();
683
684 //info() << "END_PHASE_1_ALL_TO_ALL_V my_rank=" << m_global_rank << " (local=" << m_local_rank << ")";
685
686 //_collectiveBarrier();
687}
688
689/*---------------------------------------------------------------------------*/
690/*---------------------------------------------------------------------------*/
691
692template<class Type> auto HybridParallelDispatch<Type>::
693send(Span<const Type> send_buffer,Int32 rank,bool is_blocked) -> Request
694{
695 eBlockingType block_mode = (is_blocked) ? MP::Blocking : MP::NonBlocking;
696 PointToPointMessageInfo p2p_message(MessageRank(rank),block_mode);
697 return send(send_buffer,p2p_message);
698}
699
700/*---------------------------------------------------------------------------*/
701/*---------------------------------------------------------------------------*/
702
703template<class Type> void HybridParallelDispatch<Type>::
704send(ConstArrayView<Type> send_buf,Int32 rank)
705{
706 send(send_buf,rank,true);
707}
708
709/*---------------------------------------------------------------------------*/
710/*---------------------------------------------------------------------------*/
711
712template<class Type> Parallel::Request HybridParallelDispatch<Type>::
713receive(Span<Type> recv_buffer,Int32 rank,bool is_blocked)
714{
715 eBlockingType block_mode = (is_blocked) ? MP::Blocking : MP::NonBlocking;
716 PointToPointMessageInfo p2p_message(MessageRank(rank),block_mode);
717 return receive(recv_buffer,p2p_message);
718}
719
720/*---------------------------------------------------------------------------*/
721/*---------------------------------------------------------------------------*/
722
723template<class Type> Request HybridParallelDispatch<Type>::
724send(Span<const Type> send_buffer,const PointToPointMessageInfo& message2)
725{
726 PointToPointMessageInfo message(message2);
727 bool is_blocking = message.isBlocking();
728 message.setEmiterRank(MessageRank(m_global_rank));
729 Request r = m_message_queue->addSend(message, ConstMemoryView(send_buffer));
730 if (is_blocking){
731 m_message_queue->waitAll(ArrayView<MP::Request>(1,&r));
732 return Request();
733 }
734 return r;
735}
736
737/*---------------------------------------------------------------------------*/
738/*---------------------------------------------------------------------------*/
739
740template<class Type> Request HybridParallelDispatch<Type>::
741receive(Span<Type> recv_buffer,const PointToPointMessageInfo& message2)
742{
743 PointToPointMessageInfo message(message2);
744 message.setEmiterRank(MessageRank(m_global_rank));
745 bool is_blocking = message.isBlocking();
746 Request r = m_message_queue->addReceive(message,ReceiveBufferInfo(MutableMemoryView(recv_buffer)));
747 if (is_blocking){
748 m_message_queue->waitAll(ArrayView<Request>(1,&r));
749 return Request();
750 }
751 return r;
752}
753
754/*---------------------------------------------------------------------------*/
755/*---------------------------------------------------------------------------*/
756
757template<class Type> void HybridParallelDispatch<Type>::
758recv(ArrayView<Type> recv_buffer,Integer rank)
759{
760 recv(recv_buffer,rank,true);
761}
762
763/*---------------------------------------------------------------------------*/
764/*---------------------------------------------------------------------------*/
765
766template<class Type> void HybridParallelDispatch<Type>::
767sendRecv(ConstArrayView<Type> send_buffer,ArrayView<Type> recv_buffer,Integer proc)
768{
769 ARCANE_UNUSED(send_buffer);
770 ARCANE_UNUSED(recv_buffer);
771 ARCANE_UNUSED(proc);
772 throw NotImplementedException(A_FUNCINFO);
773}
774
775/*---------------------------------------------------------------------------*/
776/*---------------------------------------------------------------------------*/
777
778template<class Type> Type HybridParallelDispatch<Type>::
779allReduce(eReduceType op,Type send_buf)
780{
781 m_reduce_infos.reduce_value = send_buf;
782 //pinfo() << "ALL REDUCE BEGIN RANK=" << m_global_rank << " TYPE=" << (int)op << " MY=" << send_buf;
783 cout.flush();
784 _collectiveBarrier();
785 if (m_local_rank==0){
786 Type ret = m_all_dispatchs[0]->m_reduce_infos.reduce_value;
787 switch(op){
788 case Parallel::ReduceMin:
789 for( Integer i=1; i<m_local_nb_rank; ++i )
790 ret = math::min(ret,m_all_dispatchs[i]->m_reduce_infos.reduce_value);
791 break;
792 case Parallel::ReduceMax:
793 for( Integer i=1; i<m_local_nb_rank; ++i )
794 ret = math::max(ret,m_all_dispatchs[i]->m_reduce_infos.reduce_value);
795 break;
796 case Parallel::ReduceSum:
797 for( Integer i=1; i<m_local_nb_rank; ++i )
798 ret = (Type)(ret + m_all_dispatchs[i]->m_reduce_infos.reduce_value);
799 break;
800 default:
801 ARCANE_FATAL("Bad reduce type");
802 }
803 ret = m_parallel_mng->mpiParallelMng()->reduce(op,ret);
804 m_all_dispatchs[0]->m_reduce_infos.reduce_value = ret;
805 //pinfo() << "ALL REDUCE RANK=" << m_local_rank << " TYPE=" << (int)op << " MY=" << send_buf << " GLOBAL=" << ret << '\n';
806 }
807 _collectiveBarrier();
808 Type ret = m_all_dispatchs[0]->m_reduce_infos.reduce_value;
809 _collectiveBarrier();
810 return ret;
811}
812
813/*---------------------------------------------------------------------------*/
814/*---------------------------------------------------------------------------*/
815
816template <class Type> void HybridParallelDispatch<Type>::
817_applyReduceOperator(eReduceType op, Span<Type> result, AllDispatchView dispatch_view,
818 Int32 first_rank, Int32 last_rank)
819{
820 Int64 buf_size = result.size();
821 switch (op) {
822 case Parallel::ReduceMin:
823 for (Integer i = first_rank; i <= last_rank; ++i)
824 for (Int64 j = 0; j < buf_size; ++j)
825 result[j] = math::min(result[j], dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
826 break;
827 case Parallel::ReduceMax:
828 for (Integer i = first_rank; i <= last_rank; ++i)
829 for (Int64 j = 0; j < buf_size; ++j)
830 result[j] = math::max(result[j], dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
831 break;
832 case Parallel::ReduceSum:
833 for (Integer i = first_rank; i <= last_rank; ++i)
834 for (Integer j = 0; j < buf_size; ++j) {
835 result[j] = static_cast<Type>(result[j] + dispatch_view[i]->m_reduce_infos.reduce_buf_span[j]);
836 }
837 break;
838 default:
839 ARCANE_FATAL("Bad reduce type");
840 }
841}
842
843/*---------------------------------------------------------------------------*/
844/*---------------------------------------------------------------------------*/
845
846template<class Type> void HybridParallelDispatch<Type>::
847_allReduceOrScan(eReduceType op, Span<Type> send_buf, bool is_scan)
848{
849 m_reduce_infos.reduce_buf_span = send_buf;
850 ++m_reduce_infos.m_index;
851 Int64 buf_size = send_buf.size();
852 UniqueArray<Type> ret(buf_size);
853 // Valeurs du rang MPI précédent (utilisé uniquement en mode Scan)
854 UniqueArray<Type> previous_rank_ret;
855 MpiParallelMng* mpi_pm = m_parallel_mng->mpiParallelMng();
856 Int32 my_mpi_rank = mpi_pm->commRank();
857 Int32 mpi_nb_rank = mpi_pm->commSize();
858
859 //cout << "ALL REDUCE BEGIN RANk=" << m_local_rank << " TYPE=" << (int)op << " MY=" << send_buf << '\n';
860 //cout.flush();
861 _collectiveBarrier();
862 {
863 Integer index0 = m_all_dispatchs[0]->m_reduce_infos.m_index;
864 for( Integer i=0; i<m_local_nb_rank; ++i ){
865 Integer indexi = m_all_dispatchs[i]->m_reduce_infos.m_index;
866 if (index0!=m_all_dispatchs[i]->m_reduce_infos.m_index){
867 ARCANE_FATAL("INTERNAL: incoherent all reduce i0={0} in={1} n={2}",
868 index0,indexi,i);
869 }
870 }
871 }
872
873 if (m_local_rank==0){
874 const Int32 nb_local_rank = m_local_nb_rank;
875 for( Integer j=0; j<buf_size; ++j )
876 ret[j] = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span[j];
877 _applyReduceOperator(op, ret, m_all_dispatchs, 1, nb_local_rank - 1);
878 if (is_scan) {
879 // Pour le scan, on a besoin de savoir la valeur du scan du rang qui nous précéde.
880 // On utilise ensuite cette valeur et on applique notre opérateur.
881 mpi_pm->scan(op, ret);
882 previous_rank_ret.resize(buf_size);
883 UniqueArray<Request> requests;
884 if (my_mpi_rank != 0)
885 requests.add(mpi_pm->recv(previous_rank_ret, my_mpi_rank - 1, false));
886 if (my_mpi_rank != (mpi_nb_rank - 1))
887 requests.add(mpi_pm->send(ret, my_mpi_rank + 1, false));
888 mpi_pm->waitAllRequests(requests);
889 if (my_mpi_rank != 0) {
890 // Applique le scan à mes valeurs.
891 _applyReduceOperator(op, previous_rank_ret, m_all_dispatchs, 0, 0);
892 send_buf.copy(previous_rank_ret);
893 }
894 else {
895 // Je suis le premier rang local et MPI. J'ai déja les bonnes valeurs
896 // dans \a send_buf.
897 }
898 }
899 else {
900 mpi_pm->reduce(op, ret);
901 send_buf.copy(ret);
902 }
903 }
904
905 _collectiveBarrier();
906
907 if (is_scan) {
908 if (m_local_rank != 0) {
909 Span<const Type> global_buf = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span;
910 ret.copy(global_buf);
911 // Le scan pour le rank local 0 a déjà été appliqué
912 _applyReduceOperator(op, ret, m_all_dispatchs, 1, m_local_rank);
913 }
914 // TODO: On pourrait éviter cette barrière si on copiait les valeurs de 'send_buf'
915 // avant de les modifier.
916 _collectiveBarrier();
917
918 if (m_local_rank != 0) {
919 send_buf.copy(ret);
920 }
921 }
922 else {
923 if (m_local_rank != 0) {
924 Span<const Type> global_buf = m_all_dispatchs[0]->m_reduce_infos.reduce_buf_span;
925 send_buf.copy(global_buf);
926 }
927 }
928
929 _collectiveBarrier();
930}
931
932/*---------------------------------------------------------------------------*/
933/*---------------------------------------------------------------------------*/
934
935template <class Type> void HybridParallelDispatch<Type>::
936allReduce(eReduceType op, Span<Type> send_buf)
937{
938 _allReduceOrScan(op, send_buf, false);
939}
940
941/*---------------------------------------------------------------------------*/
942/*---------------------------------------------------------------------------*/
943
944template<class Type> Request HybridParallelDispatch<Type>::
945nonBlockingAllReduce(eReduceType op,Span<const Type> send_buf,Span<Type> recv_buf)
946{
947 ARCANE_UNUSED(op);
948 ARCANE_UNUSED(send_buf);
949 ARCANE_UNUSED(recv_buf);
950 throw NotImplementedException(A_FUNCINFO);
951}
952
953/*---------------------------------------------------------------------------*/
954/*---------------------------------------------------------------------------*/
955template<class Type> Request HybridParallelDispatch<Type>::
956nonBlockingAllGather(Span<const Type> send_buf, Span<Type> recv_buf)
957{
958 ARCANE_UNUSED(send_buf);
959 ARCANE_UNUSED(recv_buf);
960 throw NotImplementedException(A_FUNCINFO);
961}
962
963/*---------------------------------------------------------------------------*/
964/*---------------------------------------------------------------------------*/
965
966template<class Type> Request HybridParallelDispatch<Type>::
967nonBlockingBroadcast(Span<Type> send_buf, Int32 rank)
968{
969 ARCANE_UNUSED(send_buf);
970 ARCANE_UNUSED(rank);
971 throw NotImplementedException(A_FUNCINFO);
972}
973
974/*---------------------------------------------------------------------------*/
975/*---------------------------------------------------------------------------*/
976
977template<class Type> Request HybridParallelDispatch<Type>::
978nonBlockingGather(Span<const Type> send_buf, Span<Type> recv_buf, Int32 rank)
979{
980 ARCANE_UNUSED(send_buf);
981 ARCANE_UNUSED(recv_buf);
982 ARCANE_UNUSED(rank);
983 throw NotImplementedException(A_FUNCINFO);
984}
985
986/*---------------------------------------------------------------------------*/
987/*---------------------------------------------------------------------------*/
988
989template<class Type> Request HybridParallelDispatch<Type>::
990nonBlockingAllToAll(Span<const Type> send_buf, Span<Type> recv_buf, Int32 count)
991{
992 ARCANE_UNUSED(send_buf);
993 ARCANE_UNUSED(recv_buf);
994 ARCANE_UNUSED(count);
995 throw NotImplementedException(A_FUNCINFO);
996}
997
998/*---------------------------------------------------------------------------*/
999/*---------------------------------------------------------------------------*/
1000
1001template<class Type> Request HybridParallelDispatch<Type>::
1002nonBlockingAllToAllVariable(Span<const Type> send_buf, ConstArrayView<Int32> send_count,
1003 ConstArrayView<Int32> send_index, Span<Type> recv_buf,
1004 ConstArrayView<Int32> recv_count, ConstArrayView<Int32> recv_index)
1005{
1006 ARCANE_UNUSED(send_buf);
1007 ARCANE_UNUSED(recv_buf);
1008 ARCANE_UNUSED(send_count);
1009 ARCANE_UNUSED(recv_count);
1010 ARCANE_UNUSED(send_index);
1011 ARCANE_UNUSED(recv_index);
1012 throw NotImplementedException(A_FUNCINFO);
1013}
1014
1015/*---------------------------------------------------------------------------*/
1016/*---------------------------------------------------------------------------*/
1017
1018template<class Type> Type HybridParallelDispatch<Type>::
1019scan(eReduceType op,Type send_buf)
1020{
1021 ARCANE_UNUSED(op);
1022 ARCANE_UNUSED(send_buf);
1023 throw NotImplementedException(A_FUNCINFO);
1024}
1025
1026/*---------------------------------------------------------------------------*/
1027/*---------------------------------------------------------------------------*/
1028
1029template<class Type> void HybridParallelDispatch<Type>::
1030scan(eReduceType op,ArrayView<Type> send_buf)
1031{
1032 _allReduceOrScan(op, send_buf, true);
1033}
1034
1035/*---------------------------------------------------------------------------*/
1036/*---------------------------------------------------------------------------*/
1037
1038template<class Type> Request HybridParallelDispatch<Type>::
1040{
1041 throw NotImplementedException(A_FUNCINFO);
1042}
1043
1044/*---------------------------------------------------------------------------*/
1045/*---------------------------------------------------------------------------*/
1046
1047template<class Type> void HybridParallelDispatch<Type>::
1048_collectiveBarrier()
1049{
1050 m_parallel_mng->getThreadBarrier()->wait();
1051}
1052
1053/*---------------------------------------------------------------------------*/
1054/*---------------------------------------------------------------------------*/
1055
1056template class HybridParallelDispatch<char>;
1057template class HybridParallelDispatch<signed char>;
1058template class HybridParallelDispatch<unsigned char>;
1059template class HybridParallelDispatch<short>;
1060template class HybridParallelDispatch<unsigned short>;
1061template class HybridParallelDispatch<int>;
1062template class HybridParallelDispatch<unsigned int>;
1063template class HybridParallelDispatch<long>;
1064template class HybridParallelDispatch<unsigned long>;
1065template class HybridParallelDispatch<long long>;
1066template class HybridParallelDispatch<unsigned long long>;
1067template class HybridParallelDispatch<float>;
1068template class HybridParallelDispatch<double>;
1069template class HybridParallelDispatch<long double>;
1070template class HybridParallelDispatch<Real2>;
1071template class HybridParallelDispatch<Real3>;
1072template class HybridParallelDispatch<Real2x2>;
1073template class HybridParallelDispatch<Real3x3>;
1074template class HybridParallelDispatch<HPReal>;
1075template class HybridParallelDispatch<APReal>;
1076
1077/*---------------------------------------------------------------------------*/
1078/*---------------------------------------------------------------------------*/
1079
1080} // End namespace Arcane::MessagePassing
1081
1082/*---------------------------------------------------------------------------*/
1083/*---------------------------------------------------------------------------*/
#define ARCANE_FATAL(...)
Macro envoyant une exception FatalErrorException.
Classe implémentant un réel Haute Précision.
Definition HPReal.h:161
virtual void allGather(ConstArrayView< char > send_buf, ArrayView< char > recv_buf)=0
Effectue un regroupement sur tous les processeurs. Il s'agit d'une opération collective....
Lecteur des fichiers de maillage via la bibliothèque LIMA.
Definition Lima.cc:149
Vue modifiable d'un tableau d'un type T.
Informations pour un message 'gather' pour le type de données DataType.
Exception lorsqu'une fonction n'est pas implémentée.
Déclarations des types et méthodes utilisés par les mécanismes d'échange de messages.
Definition Parallel.h:94
UniqueArray< Int32 > Int32UniqueArray
Tableau dynamique à une dimension d'entiers 32 bits.
Definition UtilsTypes.h:552
Espace de nommage contenant les types et déclarations qui gèrent le mécanisme de parallélisme par éch...
eBlockingType
Type indiquant si un message est bloquant ou non.
Int32 Integer
Type représentant un entier.
Type
Type of JSON value.
Definition rapidjson.h:665
Structure équivalente à la valeur booléenne vrai.
Structure équivalente à la valeur booléenne vrai.