Arcane  v4.1.10.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
MessagePassingUtils.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* MessagePassingUtils.h (C) 2000-2026 */
9/* */
10/* Various utilities to handle message passing. */
11/*---------------------------------------------------------------------------*/
12#ifndef ARCCORE_ALINA_MESSAGEPASSINGUTILS_H
13#define ARCCORE_ALINA_MESSAGEPASSINGUTILS_H
14/*---------------------------------------------------------------------------*/
15/*---------------------------------------------------------------------------*/
16/*
17 * This file is based on the work on AMGCL library (version march 2026)
18 * which can be found at https://github.com/ddemidov/amgcl.
19 *
20 * Copyright (c) 2012-2022 Denis Demidov <dennis.demidov@gmail.com>
21 * SPDX-License-Identifier: MIT
22 */
23/*---------------------------------------------------------------------------*/
24/*---------------------------------------------------------------------------*/
25
26#include "arccore/base/FixedArray.h"
27#include "arccore/common/Array.h"
28
29#include "arccore/message_passing_mpi/StandaloneMpiMessagePassingMng.h"
31#include "arccore/message_passing/PointToPointMessageInfo.h"
32
33#include "arccore/alina/ValueTypeInterface.h"
34#include "arccore/alina/AlinaUtils.h"
35
36#include <vector>
37#include <numeric>
38#include <complex>
39#include <type_traits>
40
41/*---------------------------------------------------------------------------*/
42/*---------------------------------------------------------------------------*/
43
44namespace Arcane::Alina
45{
46
48template <class T, class Enable = void>
50{
51 static MPI_Datatype get()
52 {
53 static const MPI_Datatype t = create();
54 return t;
55 }
56
57 static MPI_Datatype create()
58 {
59 typedef typename math::scalar_of<T>::type S;
60 MPI_Datatype t;
61 int n = sizeof(T) / sizeof(S);
62 MPI_Type_contiguous(n, mpi_datatype_impl<S>::get(), &t);
63 MPI_Type_commit(&t);
64 return t;
65 }
66};
67
68template <>
69struct mpi_datatype_impl<float>
70{
71 static MPI_Datatype get() { return MPI_FLOAT; }
72};
73
74template <>
75struct mpi_datatype_impl<double>
76{
77 static MPI_Datatype get() { return MPI_DOUBLE; }
78};
79
80template <>
81struct mpi_datatype_impl<long double>
82{
83 static MPI_Datatype get() { return MPI_LONG_DOUBLE; }
84};
85
86template <>
88{
89 static MPI_Datatype get() { return MPI_INT; }
90};
91
92template <>
93struct mpi_datatype_impl<unsigned>
94{
95 static MPI_Datatype get() { return MPI_UNSIGNED; }
96};
97
98template <>
99struct mpi_datatype_impl<long long>
100{
101 static MPI_Datatype get() { return MPI_LONG_LONG_INT; }
102};
103
104template <>
105struct mpi_datatype_impl<unsigned long long>
106{
107 static MPI_Datatype get() { return MPI_UNSIGNED_LONG_LONG; }
108};
109
110#if (MPI_VERSION > 2) || (MPI_VERSION == 2 && MPI_SUBVERSION >= 2)
111template <>
112struct mpi_datatype_impl<std::complex<double>>
113{
114 static MPI_Datatype get() { return MPI_CXX_DOUBLE_COMPLEX; }
115};
116
117template <>
118struct mpi_datatype_impl<std::complex<float>>
119{
120 static MPI_Datatype get() { return MPI_CXX_FLOAT_COMPLEX; }
121};
122#endif
123
124template <typename T>
126 typename std::enable_if<
127 std::is_same<T, ptrdiff_t>::value &&
128 !std::is_same<ptrdiff_t, long long>::value &&
129 !std::is_same<ptrdiff_t, int>::value>::type> : std::conditional<sizeof(ptrdiff_t) == sizeof(int), mpi_datatype_impl<int>, mpi_datatype_impl<long long>>::type
130{};
131
132template <typename T>
134 typename std::enable_if<
135 std::is_same<T, size_t>::value &&
136 !std::is_same<size_t, unsigned long long>::value &&
137 !std::is_same<ptrdiff_t, unsigned int>::value>::type>
138: std::conditional<
139 sizeof(size_t) == sizeof(unsigned), mpi_datatype_impl<unsigned>, mpi_datatype_impl<unsigned long long>>::type
140{};
141
142template <>
144{
145 static MPI_Datatype get() { return MPI_CHAR; }
146};
147
148/*---------------------------------------------------------------------------*/
149/*---------------------------------------------------------------------------*/
153template <typename T>
154MPI_Datatype mpi_datatype()
155{
156 return mpi_datatype_impl<T>::get();
157}
158
159/*---------------------------------------------------------------------------*/
160/*---------------------------------------------------------------------------*/
161
163struct mpi_init
164{
165 mpi_init(int* argc, char*** argv)
166 {
167 MPI_Init(argc, argv);
168 }
169
170 ~mpi_init()
171 {
172 MPI_Finalize();
173 }
174};
175
176/*---------------------------------------------------------------------------*/
177/*---------------------------------------------------------------------------*/
178
180struct mpi_init_thread
181{
182 mpi_init_thread(int* argc, char*** argv)
183 {
184 int _;
185 MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, &_);
186 }
187
188 ~mpi_init_thread()
189 {
190 MPI_Finalize();
191 }
192};
193
194/*---------------------------------------------------------------------------*/
195/*---------------------------------------------------------------------------*/
199struct mpi_communicator
200{
201 MPI_Comm comm = MPI_COMM_NULL;
202 int rank = 0;
203 int size = 0;
204 Ref<IMessagePassingMng> m_message_passing_mng;
205
206 mpi_communicator() = default;
207
208 explicit mpi_communicator(MPI_Comm comm)
209 : comm(comm)
210 {
211 MPI_Comm_rank(comm, &rank);
212 MPI_Comm_size(comm, &size);
214 };
215
216 operator MPI_Comm() const
217 {
218 return comm;
219 }
220
222 template <typename T>
223 std::vector<T> exclusive_sum(T n) const
224 {
225 // TODO: Utiliser scan.
226 std::vector<T> v(size + 1);
227 v[0] = 0;
228 MPI_Allgather(&n, 1, mpi_datatype<T>(), &v[1], 1, mpi_datatype<T>(), comm);
229 std::partial_sum(v.begin(), v.end(), v.begin());
230 return v;
231 }
232
233 std::complex<long double> reduceSum(const std::complex<long double>& lval) const
234 {
235 return _reduceSumForComplex(lval);
236 }
237 std::complex<double> reduceSum(const std::complex<double>& lval) const
238 {
239 return _reduceSumForComplex(lval);
240 }
241 std::complex<float> reduceSum(const std::complex<float>& lval) const
242 {
243 return _reduceSumForComplex(lval);
244 }
245
246 template <typename T> T reduceSum(const T& lval) const
247 {
248 return mpAllReduce(m_message_passing_mng.get(), MessagePassing::eReduceType::ReduceSum, lval);
249 }
250
251 void waitAll(ArrayView<MessagePassing::Request> requests) const
252 {
253 mpWaitAll(m_message_passing_mng.get(), requests);
254 }
255 void wait(MessagePassing::Request request) const
256 {
257 ArrayView<MessagePassing::Request> requests(1, &request);
258 mpWaitAll(m_message_passing_mng.get(), requests);
259 }
260
270 template <class Condition, class Message>
271 void check(const Condition& cond, const Message& message)
272 {
273 int lc = static_cast<int>(cond);
274 int gc = _reduce(MPI_PROD, lc);
275
276 if (!gc) {
277 std::vector<int> c(size);
278 MPI_Gather(&lc, 1, MPI_INT, &c[0], size, MPI_INT, 0, comm);
279 if (rank == 0) {
280 std::cerr << "Failed assumption: " << message << std::endl;
281 std::cerr << "Offending processes:";
282 for (int i = 0; i < size; ++i)
283 if (!c[i])
284 std::cerr << " " << i;
285 std::cerr << std::endl;
286 }
287 MPI_Barrier(comm);
288 ARCCORE_FATAL("CheckError in MessagePassingUtils: {0}", message);
289 }
290 }
291
292 template <typename T> MessagePassing::Request
293 doIReceive(T* buf, int count, int source, int tag) const
294 {
295 using namespace Arcane::MessagePassing;
296 Span<T> s(buf, count);
297 Span<unsigned char> schar(reinterpret_cast<unsigned char*>(s.data()), s.sizeBytes());
298 PointToPointMessageInfo msg_info(MessageRank{ source }, MessageTag{ tag }, eBlockingType::NonBlocking);
299 return mpReceive(m_message_passing_mng.get(), schar, msg_info);
300 }
301
302 template <typename T> void
303 doReceive(T* buf, int count, int source, int tag) const
304 {
305 using namespace Arcane::MessagePassing;
306 Span<T> s(buf, count);
307 Span<unsigned char> schar(reinterpret_cast<unsigned char*>(s.data()), s.sizeBytes());
308 PointToPointMessageInfo msg_info(MessageRank{ source }, MessageTag{ tag }, eBlockingType::Blocking);
309 mpReceive(m_message_passing_mng.get(), schar, msg_info);
310 }
311
312 template <typename T> MessagePassing::Request
313 doISend(const T* buf, int count, int dest, int tag) const
314 {
315 using namespace Arcane::MessagePassing;
316 Span<const T> s(buf, count);
317 Span<const unsigned char> schar(reinterpret_cast<const unsigned char*>(s.data()), s.sizeBytes());
318 PointToPointMessageInfo msg_info(MessageRank{ dest }, MessageTag{ tag }, eBlockingType::NonBlocking);
319 return mpSend(m_message_passing_mng.get(), schar, msg_info);
320 }
321
322 template <typename T> void
323 doSend(const T* buf, int count, int dest, int tag) const
324 {
325 using namespace Arcane::MessagePassing;
326 Span<const T> s(buf, count);
327 Span<const unsigned char> schar(reinterpret_cast<const unsigned char*>(s.data()), s.sizeBytes());
328 PointToPointMessageInfo msg_info(MessageRank{ dest }, MessageTag{ tag }, eBlockingType::Blocking);
329 mpSend(m_message_passing_mng.get(), schar, msg_info);
330 }
331
332 private:
333
334 template <typename T> T _reduce(MPI_Op op, const T& lval) const
335 {
336 const int elems = math::static_rows<T>::value * math::static_cols<T>::value;
337 T gval;
338
339 MPI_Allreduce((void*)&lval, &gval, elems, mpi_datatype<T>(), op, comm);
340 return gval;
341 }
342
343 template <typename T> std::complex<T>
344 _reduceSumForComplex(const std::complex<T>& lval) const
345 {
346 // Specialisation for 'std::complex<float>' as 2 float.
347 FixedArray<T, 2> values = { { lval.real(), lval.imag() } };
348 mpAllReduce(m_message_passing_mng.get(), MessagePassing::eReduceType::ReduceSum, values.view());
349 return std::complex<T>(values[0], values[1]);
350 }
351};
352
353/*---------------------------------------------------------------------------*/
354/*---------------------------------------------------------------------------*/
355
356} // namespace Arcane::Alina
357
358/*---------------------------------------------------------------------------*/
359/*---------------------------------------------------------------------------*/
360
361#endif
#define ARCCORE_FATAL(...)
Macro envoyant une exception FatalErrorException.
Liste des fonctions d'échange de message.
constexpr __host__ __device__ ArrayView< T > view()
Vue modifiable sur le tableau.
static Ref< IMessagePassingMng > createRef(MPI_Comm comm, bool clean_comm=false)
Créé un gestionnaire associé au communicateur comm.
Informations pour envoyer/recevoir un message point à point.
Requête d'un message.
Definition Request.h:77
Référence à une instance.
InstanceType * get() const
Instance associée ou nullptr si aucune.
Vue d'un tableau d'éléments de type T.
Definition Span.h:633
Déclarations des types et méthodes utilisés par les mécanismes d'échange de messages.
C char mpAllReduce(IMessagePassingMng *pm, eReduceType rt, char v)
void mpWaitAll(IMessagePassingMng *pm, ArrayView< Request > requests)
Bloque tant que les requêtes de requests ne sont pas terminées.
Definition Messages.cc:163
Request mpReceive(IMessagePassingMng *pm, ISerializer *values, const PointToPointMessageInfo &message)
Message de réception utilisant un ISerializer.
Definition Messages.cc:289
Request mpSend(IMessagePassingMng *pm, const ISerializer *values, const PointToPointMessageInfo &message)
Message d'envoi utilisant un ISerializer.
Definition Messages.cc:278
void check(const Condition &cond, const Message &message)
Communicator-wise condition checking.
std::vector< T > exclusive_sum(T n) const
Exclusive sum over mpi communicator.
Converts C type to MPI datatype.