Arcane  v4.1.10.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
SparseMatrixMatrixProduct.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* SparseMatrixMatrixProduct.h (C) 2000-2026 */
9/* */
10/* Sparse matrix-matrix product algorithms. */
11/*---------------------------------------------------------------------------*/
12#ifndef ARCCORE_ALINA_SPARSEMATRIXMATRIXPRODUCT_H
13#define ARCCORE_ALINA_SPARSEMATRIXMATRIXPRODUCT_H
14/*---------------------------------------------------------------------------*/
15/*---------------------------------------------------------------------------*/
16/*
17 * This file is based on the work on AMGCL library (version march 2026)
18 * which can be found at https://github.com/ddemidov/amgcl.
19 *
20 * Copyright (c) 2012-2022 Denis Demidov <dennis.demidov@gmail.com>
21 * SPDX-License-Identifier: MIT
22 */
23/*---------------------------------------------------------------------------*/
24/*---------------------------------------------------------------------------*/
25/*
26 * Sparse matrix-matrix product algorithms.
27 *
28 * This implements two algorithms.
29 *
30 * The first is an OpenMP-enabled modification of classic algorithm from Saad
31 * [1]. It is used whenever number of OpenMP cores is 4 or less.
32 *
33 * The second is Row-merge algorithm from Rupp et al. [2]. The algorithm
34 * requires less memory and shows much better scalability than classic one.
35 * It is used when number of OpenMP cores is more than 4.
36 *
37 * [1] Saad, Yousef. Iterative methods for sparse linear systems. Siam, 2003.
38 * [2] Rupp K, Rudolf F, Weinbub J, Morhammer A, Grasser T, Jungel A. Optimized
39 * Sparse Matrix-Matrix Multiplication for Multi-Core CPUs, GPUs, and Xeon
40 * Phi. Submitted
41 */
42
43#include "arccore/alina/BackendInterface.h"
44#include "arccore/accelerator/Atomic.h"
45#include "arccore/common/SmallArray.h"
46
47#include <vector>
48#include <algorithm>
49#include <atomic>
50
51/*---------------------------------------------------------------------------*/
52/*---------------------------------------------------------------------------*/
53
54namespace Arcane::Alina
55{
56
57/*---------------------------------------------------------------------------*/
58/*---------------------------------------------------------------------------*/
59
60template <class AMatrix, class BMatrix, class CMatrix>
61void spgemm_saad(const AMatrix& A, const BMatrix& B, CMatrix& C, bool sort = true)
62{
63 typedef typename backend::value_type<CMatrix>::type Val;
64 typedef typename backend::col_type<CMatrix>::type Col;
65 typedef ptrdiff_t Idx;
66
67 C.set_size(A.nbRow(), B.ncols);
68 C.ptr[0] = 0;
69
70 arccoreParallelFor(0, A.nbRow(), ForLoopRunInfo{}, [&](Int32 begin, Int32 size) {
71 std::vector<ptrdiff_t> marker(B.ncols, -1);
72
73 for (Idx ia = begin; ia < (begin + size); ++ia) {
74 Col C_cols = 0;
75 for (Idx ja = A.ptr[ia], ea = A.ptr[ia + 1]; ja < ea; ++ja) {
76 Col ca = A.col[ja];
77
78 for (Idx jb = B.ptr[ca], eb = B.ptr[ca + 1]; jb < eb; ++jb) {
79 Col cb = B.col[jb];
80 if (marker[cb] != ia) {
81 marker[cb] = ia;
82 ++C_cols;
83 }
84 }
85 }
86 C.ptr[ia + 1] = C_cols;
87 }
88 });
89
90 C.set_nonzeros(C.scan_row_sizes());
91
92 arccoreParallelFor(0, A.nbRow(), ForLoopRunInfo{}, [&](Int32 begin, Int32 size) {
93 std::vector<ptrdiff_t> marker(B.ncols, -1);
94
95 for (Idx ia = begin; ia < (begin + size); ++ia) {
96 Idx row_beg = C.ptr[ia];
97 Idx row_end = row_beg;
98
99 for (Idx ja = A.ptr[ia], ea = A.ptr[ia + 1]; ja < ea; ++ja) {
100 Col ca = A.col[ja];
101 Val va = A.val[ja];
102
103 for (Idx jb = B.ptr[ca], eb = B.ptr[ca + 1]; jb < eb; ++jb) {
104 Col cb = B.col[jb];
105 Val vb = B.val[jb];
106
107 if (marker[cb] < row_beg) {
108 marker[cb] = row_end;
109 C.col[row_end] = cb;
110 C.val[row_end] = va * vb;
111 ++row_end;
112 }
113 else {
114 C.val[marker[cb]] += va * vb;
115 }
116 }
117 }
118
119 if (sort)
120 detail::sort_row(C.col + row_beg, C.val + row_beg, row_end - row_beg);
121 }
122 });
123}
124
125/*---------------------------------------------------------------------------*/
126/*---------------------------------------------------------------------------*/
127
128template <bool need_out, class Idx> Idx*
129merge_rows(const Idx* col1, const Idx* col1_end,
130 const Idx* col2, const Idx* col2_end,
131 Idx* col3)
132{
133 while (col1 != col1_end && col2 != col2_end) {
134 Idx c1 = *col1;
135 Idx c2 = *col2;
136
137 if (c1 < c2) {
138 if (need_out)
139 *col3 = c1;
140 ++col1;
141 }
142 else if (c1 == c2) {
143 if (need_out)
144 *col3 = c1;
145 ++col1;
146 ++col2;
147 }
148 else {
149 if (need_out)
150 *col3 = c2;
151 ++col2;
152 }
153 ++col3;
154 }
155
156 if (need_out) {
157 if (col1 < col1_end) {
158 return std::copy(col1, col1_end, col3);
159 }
160 else if (col2 < col2_end) {
161 return std::copy(col2, col2_end, col3);
162 }
163 else {
164 return col3;
165 }
166 }
167 else {
168 return col3 + (col1_end - col1) + (col2_end - col2);
169 }
170}
171
172/*---------------------------------------------------------------------------*/
173/*---------------------------------------------------------------------------*/
174
175template <class Idx, class Val> Idx*
176merge_rows(const Val& alpha1, const Idx* col1, const Idx* col1_end, const Val* val1,
177 const Val& alpha2, const Idx* col2, const Idx* col2_end, const Val* val2,
178 Idx* col3, Val* val3)
179{
180 while (col1 != col1_end && col2 != col2_end) {
181 Idx c1 = *col1;
182 Idx c2 = *col2;
183
184 if (c1 < c2) {
185 ++col1;
186
187 *col3 = c1;
188 *val3 = alpha1 * (*val1++);
189 }
190 else if (c1 == c2) {
191 ++col1;
192 ++col2;
193
194 *col3 = c1;
195 *val3 = alpha1 * (*val1++) + alpha2 * (*val2++);
196 }
197 else {
198 ++col2;
199
200 *col3 = c2;
201 *val3 = alpha2 * (*val2++);
202 }
203
204 ++col3;
205 ++val3;
206 }
207
208 while (col1 < col1_end) {
209 *col3++ = *col1++;
210 *val3++ = alpha1 * (*val1++);
211 }
212
213 while (col2 < col2_end) {
214 *col3++ = *col2++;
215 *val3++ = alpha2 * (*val2++);
216 }
217
218 return col3;
219}
220
221/*---------------------------------------------------------------------------*/
222/*---------------------------------------------------------------------------*/
223
224template <class Col, class Ptr> Ptr
225prod_row_width(const Col* acol, const Col* acol_end,
226 const Ptr* bptr, const Col* bcol,
227 Col* tmp_col1, Col* tmp_col2, Col* tmp_col3)
228{
229 const Col nrows = acol_end - acol;
230
231 /* No rows to merge, nothing to do */
232 if (nrows == 0)
233 return 0;
234
235 /* Single row, just copy it to output */
236 if (nrows == 1)
237 return bptr[*acol + 1] - bptr[*acol];
238
239 /* Two rows, merge them */
240 if (nrows == 2) {
241 int a1 = acol[0];
242 int a2 = acol[1];
243
244 return merge_rows<false>(bcol + bptr[a1], bcol + bptr[a1 + 1],
245 bcol + bptr[a2], bcol + bptr[a2 + 1],
246 tmp_col1) -
247 tmp_col1;
248 }
249
250 /* Generic case (more than two rows).
251 *
252 * Merge rows by pairs, then merge the results together.
253 * When merging two rows, the result is always wider (or equal).
254 * Merging by pairs allows to work with short rows as often as possible.
255 */
256 // Merge first two.
257 Col a1 = *acol++;
258 Col a2 = *acol++;
259 Col c_col1 = merge_rows<true>(bcol + bptr[a1], bcol + bptr[a1 + 1],
260 bcol + bptr[a2], bcol + bptr[a2 + 1],
261 tmp_col1) -
262 tmp_col1;
263
264 // Go by pairs.
265 while (acol + 1 < acol_end) {
266 a1 = *acol++;
267 a2 = *acol++;
268
269 Col c_col2 = merge_rows<true>(bcol + bptr[a1], bcol + bptr[a1 + 1],
270 bcol + bptr[a2], bcol + bptr[a2 + 1],
271 tmp_col2) -
272 tmp_col2;
273
274 if (acol == acol_end) {
275 return merge_rows<false>(tmp_col1, tmp_col1 + c_col1,
276 tmp_col2, tmp_col2 + c_col2,
277 tmp_col3) -
278 tmp_col3;
279 }
280 else {
281 c_col1 = merge_rows<true>(tmp_col1, tmp_col1 + c_col1,
282 tmp_col2, tmp_col2 + c_col2,
283 tmp_col3) -
284 tmp_col3;
285
286 std::swap(tmp_col1, tmp_col3);
287 }
288 }
289
290 // Merge the tail.
291 a2 = *acol;
292 return merge_rows<false>(tmp_col1, tmp_col1 + c_col1,
293 bcol + bptr[a2], bcol + bptr[a2 + 1],
294 tmp_col2) -
295 tmp_col2;
296}
297
298/*---------------------------------------------------------------------------*/
299/*---------------------------------------------------------------------------*/
300
301template <class Col, class Ptr, class Val>
302void prod_row(const Col* acol, const Col* acol_end, const Val* aval,
303 const Ptr* bptr, const Col* bcol, const Val* bval,
304 Col* out_col, Val* out_val,
305 Col* tm2_col, Val* tm2_val,
306 Col* tm3_col, Val* tm3_val)
307{
308 const Col nrows = acol_end - acol;
309
310 /* No rows to merge, nothing to do */
311 if (nrows == 0)
312 return;
313
314 /* Single row, just copy it to output */
315 if (nrows == 1) {
316 Col ac = *acol;
317 Val av = *aval;
318
319 const Val* bv = bval + bptr[ac];
320 const Col* bc = bcol + bptr[ac];
321 const Col* be = bcol + bptr[ac + 1];
322
323 while (bc != be) {
324 *out_col++ = *bc++;
325 *out_val++ = av * (*bv++);
326 }
327
328 return;
329 }
330
331 /* Two rows, merge them */
332 if (nrows == 2) {
333 Col ac1 = acol[0];
334 Col ac2 = acol[1];
335
336 Val av1 = aval[0];
337 Val av2 = aval[1];
338
339 merge_rows(
340 av1, bcol + bptr[ac1], bcol + bptr[ac1 + 1], bval + bptr[ac1],
341 av2, bcol + bptr[ac2], bcol + bptr[ac2 + 1], bval + bptr[ac2],
342 out_col, out_val);
343
344 return;
345 }
346
347 /* Generic case (more than two rows).
348 *
349 * Merge rows by pairs, then merge the results together.
350 * When merging two rows, the result is always wider (or equal).
351 * Merging by pairs allows to work with short rows as often as possible.
352 */
353 // Merge first two.
354 Col ac1 = *acol++;
355 Col ac2 = *acol++;
356
357 Val av1 = *aval++;
358 Val av2 = *aval++;
359
360 Col* tm1_col = out_col;
361 Val* tm1_val = out_val;
362
363 Col c_col1 = merge_rows(av1, bcol + bptr[ac1], bcol + bptr[ac1 + 1], bval + bptr[ac1],
364 av2, bcol + bptr[ac2], bcol + bptr[ac2 + 1], bval + bptr[ac2],
365 tm1_col, tm1_val) -
366 tm1_col;
367
368 // Go by pairs.
369 while (acol + 1 < acol_end) {
370 ac1 = *acol++;
371 ac2 = *acol++;
372
373 av1 = *aval++;
374 av2 = *aval++;
375
376 Col c_col2 = merge_rows(av1, bcol + bptr[ac1], bcol + bptr[ac1 + 1], bval + bptr[ac1],
377 av2, bcol + bptr[ac2], bcol + bptr[ac2 + 1], bval + bptr[ac2],
378 tm2_col, tm2_val) -
379 tm2_col;
380
381 c_col1 = merge_rows(math::identity<Val>(), tm1_col, tm1_col + c_col1, tm1_val,
382 math::identity<Val>(), tm2_col, tm2_col + c_col2, tm2_val,
383 tm3_col, tm3_val) -
384 tm3_col;
385
386 std::swap(tm3_col, tm1_col);
387 std::swap(tm3_val, tm1_val);
388 }
389
390 // Merge the tail if there is one.
391 if (acol < acol_end) {
392 ac2 = *acol++;
393 av2 = *aval++;
394
395 c_col1 = merge_rows(math::identity<Val>(), tm1_col, tm1_col + c_col1, tm1_val,
396 av2, bcol + bptr[ac2], bcol + bptr[ac2 + 1], bval + bptr[ac2],
397 tm3_col, tm3_val) -
398 tm3_col;
399
400 std::swap(tm3_col, tm1_col);
401 std::swap(tm3_val, tm1_val);
402 }
403
404 // If we are lucky, tm1 now points to out.
405 // Otherwise, copy the results.
406 if (tm1_col != out_col) {
407 std::copy(tm1_col, tm1_col + c_col1, out_col);
408 std::copy(tm1_val, tm1_val + c_col1, out_val);
409 }
410}
411
412/*---------------------------------------------------------------------------*/
413/*---------------------------------------------------------------------------*/
414
415template <class AMatrix, class BMatrix, class CMatrix>
416void spgemm_rmerge(const AMatrix& A, const BMatrix& B, CMatrix& C)
417{
418 typedef typename backend::value_type<CMatrix>::type Val;
419 typedef typename backend::col_type<CMatrix>::type Col;
420 typedef ptrdiff_t Idx;
421
422 Idx max_row_width = 0;
423
424 arccoreParallelFor(0, A.nbRow(), ForLoopRunInfo{}, [&](Int32 begin, Int32 size) {
425 Idx my_max = 0;
426 for (Idx i = begin; i < (begin + size); ++i) {
427 Idx row_beg = A.ptr[i];
428 Idx row_end = A.ptr[i + 1];
429 Idx row_width = 0;
430 for (Idx j = row_beg; j < row_end; ++j) {
431 Idx a_col = A.col[j];
432 row_width += B.ptr[a_col + 1] - B.ptr[a_col];
433 }
434 my_max = std::max(my_max, row_width);
435 }
436
437 Accelerator::doAtomic<Accelerator::eAtomicOperation::Max>(&max_row_width,my_max);
438 });
439
440 const int nthreads = ConcurrencyBase::maxAllowedThread();
441
442 // TODO: keep these values instead of rebuilding them
443 SmallArray<std::vector<Col>,16> tmp_col(nthreads);
444 SmallArray<std::vector<Val>,16> tmp_val(nthreads);
445
446 for (int i = 0; i < nthreads; ++i) {
447 tmp_col[i].resize(3 * max_row_width);
448 tmp_val[i].resize(2 * max_row_width);
449 }
450
451 C.set_size(A.nbRow(), B.ncols);
452 C.ptr[0] = 0;
453
454 arccoreParallelFor(0, A.nbRow(), ForLoopRunInfo{}, [&](Int32 begin, Int32 size) {
455 const int tid = TaskFactory::currentTaskThreadIndex();
456
457 Col* t_col = &tmp_col[tid][0];
458
459 for (Idx i = begin; i < (begin + size); ++i) {
460 Idx row_beg = A.ptr[i];
461 Idx row_end = A.ptr[i + 1];
462
463 C.ptr[i + 1] = prod_row_width(A.col.data() + row_beg, A.col.data() + row_end,
464 B.ptr.data(), B.col.data(),
465 t_col, t_col + max_row_width, t_col + 2 * max_row_width);
466 }
467 });
468
469 C.set_nonzeros(C.scan_row_sizes());
470
471 arccoreParallelFor(0, A.nbRow(), ForLoopRunInfo{}, [&](Int32 begin, Int32 size) {
472 const int tid = TaskFactory::currentTaskThreadIndex();
473
474 Col* t_col = tmp_col[tid].data();
475 Val* t_val = tmp_val[tid].data();
476
477 for (Idx i = begin; i < (begin + size); ++i) {
478 Idx row_beg = A.ptr[i];
479 Idx row_end = A.ptr[i + 1];
480
481 prod_row(A.col.data() + row_beg, A.col.data() + row_end, A.val.data() + row_beg,
482 B.ptr.data(), B.col.data(), B.val.data(),
483 C.col.data() + C.ptr[i], C.val.data() + C.ptr[i],
484 t_col, t_val, t_col + max_row_width, t_val + max_row_width);
485 }
486 });
487}
488
489/*---------------------------------------------------------------------------*/
490/*---------------------------------------------------------------------------*/
491
492} // namespace Arcane::Alina
493
494/*---------------------------------------------------------------------------*/
495/*---------------------------------------------------------------------------*/
496
497#endif
void arccoreParallelFor(const ComplexForLoopRanges< RankValue > &loop_ranges, const ForLoopRunInfo &run_info, const LambdaType &lambda_function, const ReducerArgs &... reducer_args)
Applique en concurrence la fonction lambda lambda_function sur l'intervalle d'itération donné par loo...
Definition ParallelFor.h:85
std::int32_t Int32
Type entier signé sur 32 bits.