41class DistributedDirectSolverBase
45 typedef typename math::scalar_of<value_type>::type scalar_type;
46 typedef typename math::rhs_of<value_type>::type rhs_type;
49 DistributedDirectSolverBase() {}
56 std::vector<int> domain = comm.exclusive_sum(n);
57 std::vector<int> active;
58 active.reserve(comm.size);
62 for (
int i = 0; i < comm.size; ++i) {
63 if (domain[i + 1] - domain[i] > 0) {
65 active_rank = active.size();
71 int nmasters = std::min<int>(active.size(), solver().comm_size(domain.back()));
72 int slaves_per_master = (active.size() + nmasters - 1) / nmasters;
73 int group_beg = (active_rank / slaves_per_master) * slaves_per_master;
75 group_master = active[group_beg];
79 comm.rank == group_master ? 0 : MPI_UNDEFINED,
80 comm.rank, &masters_comm);
86 std::vector<ptrdiff_t> widths(n);
87 for (ptrdiff_t i = 0; i < n; ++i)
88 widths[i] = Astrip.ptr[i + 1] - Astrip.ptr[i];
90 if (comm.rank == group_master) {
91 int group_end = std::min<int>(group_beg + slaves_per_master, active.size());
93 int group_size = group_end - group_beg;
99 solve_req.resize(group_size);
100 slaves.reserve(group_size);
101 counts.reserve(group_size);
106 for (
int j = group_beg; j < group_end; ++j) {
109 int m = domain[i + 1] - domain[i];
117 A.set_size(nloc, domain.back(),
false);
121 cons_x.resize(A.nbRow());
124 std::copy(widths.begin(), widths.end(), &A.ptr[1]);
126 for (
int j = 0; j < group_size; ++j) {
129 cnt_req[j] = comm.doIReceive(&A.ptr[shift], counts[j], i, cnt_tag);
134 comm.waitAll(cnt_req);
136 A.set_nonzeros(A.scan_row_sizes());
138 std::copy(Astrip.col.data(), Astrip.col.data() + Astrip.nbNonZero(), A.col.data());
139 std::copy(Astrip.val.data(), Astrip.val.data() + Astrip.nbNonZero(), A.val.data());
141 shift = Astrip.nbNonZero();
142 for (
int j = 0, d0 = domain[comm.rank]; j < group_size; ++j) {
145 int nnz = A.ptr[domain[i + 1] - d0] - A.ptr[domain[i] - d0];
147 col_req[j] = comm.doIReceive(A.col + shift, nnz, i, col_tag);
148 val_req[j] = comm.doIReceive(A.val + shift, nnz, i, val_tag);
153 comm.waitAll(col_req);
154 comm.waitAll(val_req);
159 comm.doSend(widths.data(), n, group_master, cnt_tag);
160 comm.doSend(Astrip.col.data(), Astrip.nbNonZero(), group_master, col_tag);
161 comm.doSend(Astrip.val.data(), Astrip.nbNonZero(), group_master, val_tag);
170 const build_matrix& A_loc = *A.local();
171 const build_matrix& A_rem = *A.remote();
175 a.set_size(A.loc_rows(), A.glob_cols(),
false);
176 a.set_nonzeros(A_loc.nbNonZero() + A_rem.nbNonZero());
179 for (
size_t i = 0, head = 0; i < A_loc.nbRow(); ++i) {
180 ptrdiff_t shift = A.loc_col_shift();
182 for (ptrdiff_t j = A_loc.ptr[i], e = A_loc.ptr[i + 1]; j < e; ++j) {
183 a.col[head] = A_loc.col[j] + shift;
184 a.val[head] = A_loc.val[j];
188 for (ptrdiff_t j = A_rem.ptr[i], e = A_rem.ptr[i + 1]; j < e; ++j) {
189 a.col[head] = A_rem.col[j];
190 a.val[head] = A_rem.val[j];
200 virtual ~DistributedDirectSolverBase()
202 if (masters_comm != MPI_COMM_NULL)
203 MPI_Comm_free(&masters_comm);
208 return *
static_cast<Solver*
>(
this);
211 const Solver& solver()
const
213 return *
static_cast<const Solver*
>(
this);
216 template <
class VecF,
class VecX>
217 void operator()(
const VecF& f, VecX& x)
const
222 backend::copy(f, host_v);
224 if (comm.rank == group_master) {
225 std::copy(host_v.begin(), host_v.end(), cons_f.begin());
227 int shift = n, j = 0;
228 for (
int i : slaves) {
229 solve_req[j] = comm.doIReceive(&cons_f[shift], counts[j], i, rhs_tag);
230 shift += counts[j++];
233 comm.waitAll(solve_req);
235 solver().solve(cons_f, cons_x);
237 std::copy(cons_x.begin(), cons_x.begin() + n, host_v.begin());
241 for (
int i : slaves) {
242 solve_req[j] = comm.doISend(&cons_x[shift], counts[j], i, sol_tag);
243 shift += counts[j++];
246 comm.waitAll(solve_req);
249 comm.doSend(host_v.data(), n, group_master, rhs_tag);
250 comm.doReceive(host_v.data(), n, group_master, sol_tag);
253 backend::copy(host_v, x);
258 static const int cnt_tag = 5001;
259 static const int col_tag = 5002;
260 static const int val_tag = 5003;
261 static const int rhs_tag = 5004;
262 static const int sol_tag = 5005;
267 MPI_Comm masters_comm;
268 std::vector<int> slaves;
269 std::vector<int> counts;
270 mutable std::vector<rhs_type> cons_f, cons_x, host_v;