Arcane  4.1.11.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
DistributedDirectSolverBase.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* DistributedDirectSolverBase.h (C) 2000-2026 */
9/* */
10/* Base class for distributed direct solver. */
11/*---------------------------------------------------------------------------*/
12#ifndef ARCCORE_ALINA_MPI_DISTRIBUTEDDIRECTSOLVERBASE_H
13#define ARCCORE_ALINA_MPI_DISTRIBUTEDDIRECTSOLVERBASE_H
14/*---------------------------------------------------------------------------*/
15/*---------------------------------------------------------------------------*/
16/*
17 * This file is based on the work on AMGCL library (version march 2026)
18 * which can be found at https://github.com/ddemidov/amgcl.
19 *
20 * Copyright (c) 2012-2022 Denis Demidov <dennis.demidov@gmail.com>
21 * SPDX-License-Identifier: MIT
22 */
23/*---------------------------------------------------------------------------*/
24/*---------------------------------------------------------------------------*/
25
26#include "arccore/alina/MessagePassingUtils.h"
27#include "arccore/alina/DistributedMatrix.h"
28
29/*---------------------------------------------------------------------------*/
30/*---------------------------------------------------------------------------*/
31
32namespace Arcane::Alina
33{
34
35/*---------------------------------------------------------------------------*/
36/*---------------------------------------------------------------------------*/
40template <typename Backend, class Solver>
41class DistributedDirectSolverBase
42{
43 public:
44
45 typedef typename Backend::value_type value_type;
46 typedef typename math::scalar_of<value_type>::type scalar_type;
47 typedef typename math::rhs_of<value_type>::type rhs_type;
48 //typedef Backend::matrix build_matrix;
49 using col_type = Backend::col_type;
50 using ptr_type = Backend::ptr_type;
52
53 DistributedDirectSolverBase() {}
54
55 void init(mpi_communicator comm, const build_matrix& Astrip)
56 {
57 this->comm = comm;
58 n = Astrip.nbRow();
59
60 std::vector<int> domain = comm.exclusive_sum(n);
61 std::vector<int> active;
62 active.reserve(comm.size);
63
64 // Find out how many ranks are active (own non-zero matrix rows):
65 int active_rank = 0;
66 for (int i = 0; i < comm.size; ++i) {
67 if (domain[i + 1] - domain[i] > 0) {
68 if (comm.rank == i)
69 active_rank = active.size();
70 active.push_back(i);
71 }
72 }
73
74 // Consolidate the matrix on a fewer processes.
75 int nmasters = std::min<int>(active.size(), solver().comm_size(domain.back()));
76 int slaves_per_master = (active.size() + nmasters - 1) / nmasters;
77 int group_beg = (active_rank / slaves_per_master) * slaves_per_master;
78
79 group_master = active[group_beg];
80
81 // Communicator for masters (used to solve the coarse problem):
82 MPI_Comm_split(comm,
83 comm.rank == group_master ? 0 : MPI_UNDEFINED,
84 comm.rank, &masters_comm);
85
86 if (!n)
87 return; // I am not active
88
89 // Shift from row pointers to row widths:
90 std::vector<ptr_type> widths(n);
91 for (ptrdiff_t i = 0; i < n; ++i)
92 widths[i] = Astrip.ptr[i + 1] - Astrip.ptr[i];
93
94 if (comm.rank == group_master) {
95 int group_end = std::min<int>(group_beg + slaves_per_master, active.size());
96 group_beg += 1;
97 int group_size = group_end - group_beg;
98
99 UniqueArray<MessagePassing::Request> cnt_req(group_size);
100 UniqueArray<MessagePassing::Request> col_req(group_size);
101 UniqueArray<MessagePassing::Request> val_req(group_size);
102
103 solve_req.resize(group_size);
104 slaves.reserve(group_size);
105 counts.reserve(group_size);
106
107 // Count rows in local chunk of the consolidated matrix,
108 // see who is reporting to us.
109 int nloc = n;
110 for (int j = group_beg; j < group_end; ++j) {
111 int i = active[j];
112
113 int m = domain[i + 1] - domain[i];
114 nloc += m;
115 counts.push_back(m);
116 slaves.push_back(i);
117 }
118
119 // Get matrix chunks from my slaves.
120 build_matrix A;
121 A.set_size(nloc, domain.back(), false);
122 A.ptr[0] = 0;
123
124 cons_f.resize(A.nbRow());
125 cons_x.resize(A.nbRow());
126
127 int shift = n + 1;
128 std::copy(widths.begin(), widths.end(), &A.ptr[1]);
129
130 for (int j = 0; j < group_size; ++j) {
131 int i = slaves[j];
132
133 cnt_req[j] = comm.doIReceive(&A.ptr[shift], counts[j], i, cnt_tag);
134
135 shift += counts[j];
136 }
137
138 comm.waitAll(cnt_req);
139
140 A.set_nonzeros(A.scan_row_sizes());
141
142 std::copy(Astrip.col.data(), Astrip.col.data() + Astrip.nbNonZero(), A.col.data());
143 std::copy(Astrip.val.data(), Astrip.val.data() + Astrip.nbNonZero(), A.val.data());
144
145 shift = Astrip.nbNonZero();
146 for (int j = 0, d0 = domain[comm.rank]; j < group_size; ++j) {
147 int i = slaves[j];
148
149 int nnz = A.ptr[domain[i + 1] - d0] - A.ptr[domain[i] - d0];
150
151 col_req[j] = comm.doIReceive(A.col + shift, nnz, i, col_tag);
152 val_req[j] = comm.doIReceive(A.val + shift, nnz, i, val_tag);
153
154 shift += nnz;
155 }
156
157 comm.waitAll(col_req);
158 comm.waitAll(val_req);
159
160 solver().init(mpi_communicator(masters_comm), A);
161 }
162 else {
163 comm.doSend(widths.data(), n, group_master, cnt_tag);
164 comm.doSend(Astrip.col.data(), Astrip.nbNonZero(), group_master, col_tag);
165 comm.doSend(Astrip.val.data(), Astrip.nbNonZero(), group_master, val_tag);
166 }
167
168 host_v.resize(n);
169 }
170
171 template <class B>
172 void init(mpi_communicator comm, const DistributedMatrix<B>& A)
173 {
174 const build_matrix& A_loc = *A.local();
175 const build_matrix& A_rem = *A.remote();
176
177 build_matrix a;
178
179 a.set_size(A.loc_rows(), A.glob_cols(), false);
180 a.set_nonzeros(A_loc.nbNonZero() + A_rem.nbNonZero());
181 a.ptr[0] = 0;
182
183 for (size_t i = 0, head = 0; i < A_loc.nbRow(); ++i) {
184 ptrdiff_t shift = A.loc_col_shift();
185
186 for (ptrdiff_t j = A_loc.ptr[i], e = A_loc.ptr[i + 1]; j < e; ++j) {
187 a.col[head] = A_loc.col[j] + shift;
188 a.val[head] = A_loc.val[j];
189 ++head;
190 }
191
192 for (ptrdiff_t j = A_rem.ptr[i], e = A_rem.ptr[i + 1]; j < e; ++j) {
193 a.col[head] = A_rem.col[j];
194 a.val[head] = A_rem.val[j];
195 ++head;
196 }
197
198 a.ptr[i + 1] = head;
199 }
200
201 init(comm, a);
202 }
203
204 virtual ~DistributedDirectSolverBase()
205 {
206 if (masters_comm != MPI_COMM_NULL)
207 MPI_Comm_free(&masters_comm);
208 }
209
210 Solver& solver()
211 {
212 return *static_cast<Solver*>(this);
213 }
214
215 const Solver& solver() const
216 {
217 return *static_cast<const Solver*>(this);
218 }
219
220 template <class VecF, class VecX>
221 void operator()(const VecF& f, VecX& x) const
222 {
223 if (!n)
224 return;
225
226 backend::copy(f, host_v);
227
228 if (comm.rank == group_master) {
229 std::copy(host_v.begin(), host_v.end(), cons_f.begin());
230
231 int shift = n, j = 0;
232 for (int i : slaves) {
233 solve_req[j] = comm.doIReceive(&cons_f[shift], counts[j], i, rhs_tag);
234 shift += counts[j++];
235 }
236
237 comm.waitAll(solve_req);
238
239 solver().solve(cons_f, cons_x);
240
241 std::copy(cons_x.begin(), cons_x.begin() + n, host_v.begin());
242 shift = n;
243 j = 0;
244
245 for (int i : slaves) {
246 solve_req[j] = comm.doISend(&cons_x[shift], counts[j], i, sol_tag);
247 shift += counts[j++];
248 }
249
250 comm.waitAll(solve_req);
251 }
252 else {
253 comm.doSend(host_v.data(), n, group_master, rhs_tag);
254 comm.doReceive(host_v.data(), n, group_master, sol_tag);
255 }
256
257 backend::copy(host_v, x);
258 }
259
260 private:
261
262 static const int cnt_tag = 5001;
263 static const int col_tag = 5002;
264 static const int val_tag = 5003;
265 static const int rhs_tag = 5004;
266 static const int sol_tag = 5005;
267
268 mpi_communicator comm;
269 int n;
270 int group_master;
271 MPI_Comm masters_comm;
272 std::vector<int> slaves;
273 std::vector<int> counts;
274 mutable std::vector<rhs_type> cons_f, cons_x, host_v;
275 mutable UniqueArray<MessagePassing::Request> solve_req;
276};
277
278/*---------------------------------------------------------------------------*/
279/*---------------------------------------------------------------------------*/
280
281} // namespace Arcane::Alina
282
283/*---------------------------------------------------------------------------*/
284/*---------------------------------------------------------------------------*/
285
286#endif
void resize(size_t new_size)
Set the new size. WARNING: this method do not handle the delete of the current value.
Definition CSRMatrix.h:72
Distributed Matrix using message passing.
Vecteur 1D de données avec sémantique par valeur (style STL).
Sparse matrix stored in CSR (Compressed Sparse Row) format.
Definition CSRMatrix.h:98
Convenience wrapper around MPI_Comm.