12#ifndef ARCCORE_ALINA_SPARSEMATRIXMATRIXPRODUCT_H
13#define ARCCORE_ALINA_SPARSEMATRIXMATRIXPRODUCT_H
43#include "arccore/alina/BackendInterface.h"
44#include "arccore/accelerator/Atomic.h"
45#include "arccore/common/SmallArray.h"
54namespace Arcane::Alina
60template <
class AMatrix,
class BMatrix,
class CMatrix>
61void spgemm_saad(
const AMatrix& A,
const BMatrix& B, CMatrix& C,
bool sort =
true)
63 typedef typename backend::value_type<CMatrix>::type Val;
64 typedef typename backend::col_type<CMatrix>::type Col;
65 typedef ptrdiff_t Idx;
67 C.set_size(A.nbRow(), B.ncols);
71 std::vector<ptrdiff_t> marker(B.ncols, -1);
73 for (Idx ia = begin; ia < (begin + size); ++ia) {
75 for (Idx ja = A.ptr[ia], ea = A.ptr[ia + 1]; ja < ea; ++ja) {
78 for (Idx jb = B.ptr[ca], eb = B.ptr[ca + 1]; jb < eb; ++jb) {
80 if (marker[cb] != ia) {
86 C.ptr[ia + 1] = C_cols;
90 C.set_nonzeros(C.scan_row_sizes());
93 std::vector<ptrdiff_t> marker(B.ncols, -1);
95 for (Idx ia = begin; ia < (begin + size); ++ia) {
96 Idx row_beg = C.ptr[ia];
97 Idx row_end = row_beg;
99 for (Idx ja = A.ptr[ia], ea = A.ptr[ia + 1]; ja < ea; ++ja) {
103 for (Idx jb = B.ptr[ca], eb = B.ptr[ca + 1]; jb < eb; ++jb) {
107 if (marker[cb] < row_beg) {
108 marker[cb] = row_end;
110 C.val[row_end] = va * vb;
114 C.val[marker[cb]] += va * vb;
120 detail::sort_row(C.col + row_beg, C.val + row_beg, row_end - row_beg);
128template <
bool need_out,
class Idx> Idx*
129merge_rows(
const Idx* col1,
const Idx* col1_end,
130 const Idx* col2,
const Idx* col2_end,
133 while (col1 != col1_end && col2 != col2_end) {
157 if (col1 < col1_end) {
158 return std::copy(col1, col1_end, col3);
160 else if (col2 < col2_end) {
161 return std::copy(col2, col2_end, col3);
168 return col3 + (col1_end - col1) + (col2_end - col2);
175template <
class Idx,
class Val> Idx*
176merge_rows(
const Val& alpha1,
const Idx* col1,
const Idx* col1_end,
const Val* val1,
177 const Val& alpha2,
const Idx* col2,
const Idx* col2_end,
const Val* val2,
178 Idx* col3, Val* val3)
180 while (col1 != col1_end && col2 != col2_end) {
188 *val3 = alpha1 * (*val1++);
195 *val3 = alpha1 * (*val1++) + alpha2 * (*val2++);
201 *val3 = alpha2 * (*val2++);
208 while (col1 < col1_end) {
210 *val3++ = alpha1 * (*val1++);
213 while (col2 < col2_end) {
215 *val3++ = alpha2 * (*val2++);
224template <
class Col,
class Ptr> Ptr
225prod_row_width(
const Col* acol,
const Col* acol_end,
226 const Ptr* bptr,
const Col* bcol,
227 Col* tmp_col1, Col* tmp_col2, Col* tmp_col3)
229 const Col nrows = acol_end - acol;
237 return bptr[*acol + 1] - bptr[*acol];
244 return merge_rows<false>(bcol + bptr[a1], bcol + bptr[a1 + 1],
245 bcol + bptr[a2], bcol + bptr[a2 + 1],
259 Col c_col1 = merge_rows<true>(bcol + bptr[a1], bcol + bptr[a1 + 1],
260 bcol + bptr[a2], bcol + bptr[a2 + 1],
265 while (acol + 1 < acol_end) {
269 Col c_col2 = merge_rows<true>(bcol + bptr[a1], bcol + bptr[a1 + 1],
270 bcol + bptr[a2], bcol + bptr[a2 + 1],
274 if (acol == acol_end) {
275 return merge_rows<false>(tmp_col1, tmp_col1 + c_col1,
276 tmp_col2, tmp_col2 + c_col2,
281 c_col1 = merge_rows<true>(tmp_col1, tmp_col1 + c_col1,
282 tmp_col2, tmp_col2 + c_col2,
286 std::swap(tmp_col1, tmp_col3);
292 return merge_rows<false>(tmp_col1, tmp_col1 + c_col1,
293 bcol + bptr[a2], bcol + bptr[a2 + 1],
301template <
class Col,
class Ptr,
class Val>
302void prod_row(
const Col* acol,
const Col* acol_end,
const Val* aval,
303 const Ptr* bptr,
const Col* bcol,
const Val* bval,
304 Col* out_col, Val* out_val,
305 Col* tm2_col, Val* tm2_val,
306 Col* tm3_col, Val* tm3_val)
308 const Col nrows = acol_end - acol;
319 const Val* bv = bval + bptr[ac];
320 const Col* bc = bcol + bptr[ac];
321 const Col* be = bcol + bptr[ac + 1];
325 *out_val++ = av * (*bv++);
340 av1, bcol + bptr[ac1], bcol + bptr[ac1 + 1], bval + bptr[ac1],
341 av2, bcol + bptr[ac2], bcol + bptr[ac2 + 1], bval + bptr[ac2],
360 Col* tm1_col = out_col;
361 Val* tm1_val = out_val;
363 Col c_col1 = merge_rows(av1, bcol + bptr[ac1], bcol + bptr[ac1 + 1], bval + bptr[ac1],
364 av2, bcol + bptr[ac2], bcol + bptr[ac2 + 1], bval + bptr[ac2],
369 while (acol + 1 < acol_end) {
376 Col c_col2 = merge_rows(av1, bcol + bptr[ac1], bcol + bptr[ac1 + 1], bval + bptr[ac1],
377 av2, bcol + bptr[ac2], bcol + bptr[ac2 + 1], bval + bptr[ac2],
381 c_col1 = merge_rows(math::identity<Val>(), tm1_col, tm1_col + c_col1, tm1_val,
382 math::identity<Val>(), tm2_col, tm2_col + c_col2, tm2_val,
386 std::swap(tm3_col, tm1_col);
387 std::swap(tm3_val, tm1_val);
391 if (acol < acol_end) {
395 c_col1 = merge_rows(math::identity<Val>(), tm1_col, tm1_col + c_col1, tm1_val,
396 av2, bcol + bptr[ac2], bcol + bptr[ac2 + 1], bval + bptr[ac2],
400 std::swap(tm3_col, tm1_col);
401 std::swap(tm3_val, tm1_val);
406 if (tm1_col != out_col) {
407 std::copy(tm1_col, tm1_col + c_col1, out_col);
408 std::copy(tm1_val, tm1_val + c_col1, out_val);
415template <
class AMatrix,
class BMatrix,
class CMatrix>
416void spgemm_rmerge(
const AMatrix& A,
const BMatrix& B, CMatrix& C)
418 typedef typename backend::value_type<CMatrix>::type Val;
419 typedef typename backend::col_type<CMatrix>::type Col;
420 typedef ptrdiff_t Idx;
422 Idx max_row_width = 0;
424 arccoreParallelFor(0, A.nbRow(), ForLoopRunInfo{}, [&](Int32 begin, Int32 size) {
426 for (Idx i = begin; i < (begin + size); ++i) {
427 Idx row_beg = A.ptr[i];
428 Idx row_end = A.ptr[i + 1];
430 for (Idx j = row_beg; j < row_end; ++j) {
431 Idx a_col = A.col[j];
432 row_width += B.ptr[a_col + 1] - B.ptr[a_col];
434 my_max = std::max(my_max, row_width);
437 Accelerator::doAtomic<Accelerator::eAtomicOperation::Max>(&max_row_width,my_max);
440 const int nthreads = ConcurrencyBase::maxAllowedThread();
443 SmallArray<std::vector<Col>,16> tmp_col(nthreads);
444 SmallArray<std::vector<Val>,16> tmp_val(nthreads);
446 for (
int i = 0; i < nthreads; ++i) {
447 tmp_col[i].resize(3 * max_row_width);
448 tmp_val[i].resize(2 * max_row_width);
451 C.set_size(A.nbRow(), B.ncols);
454 arccoreParallelFor(0, A.nbRow(), ForLoopRunInfo{}, [&](Int32 begin, Int32 size) {
455 const int tid = TaskFactory::currentTaskThreadIndex();
457 Col* t_col = &tmp_col[tid][0];
459 for (Idx i = begin; i < (begin + size); ++i) {
460 Idx row_beg = A.ptr[i];
461 Idx row_end = A.ptr[i + 1];
463 C.ptr[i + 1] = prod_row_width(A.col.data() + row_beg, A.col.data() + row_end,
464 B.ptr.data(), B.col.data(),
465 t_col, t_col + max_row_width, t_col + 2 * max_row_width);
469 C.set_nonzeros(C.scan_row_sizes());
471 arccoreParallelFor(0, A.nbRow(), ForLoopRunInfo{}, [&](Int32 begin, Int32 size) {
472 const int tid = TaskFactory::currentTaskThreadIndex();
474 Col* t_col = tmp_col[tid].data();
475 Val* t_val = tmp_val[tid].data();
477 for (Idx i = begin; i < (begin + size); ++i) {
478 Idx row_beg = A.ptr[i];
479 Idx row_end = A.ptr[i + 1];
481 prod_row(A.col.data() + row_beg, A.col.data() + row_end, A.val.data() + row_beg,
482 B.ptr.data(), B.col.data(), B.val.data(),
483 C.col.data() + C.ptr[i], C.val.data() + C.ptr[i],
484 t_col, t_val, t_col + max_row_width, t_val + max_row_width);
void arccoreParallelFor(const ComplexForLoopRanges< RankValue > &loop_ranges, const ForLoopRunInfo &run_info, const LambdaType &lambda_function, const ReducerArgs &... reducer_args)
Applique en concurrence la fonction lambda lambda_function sur l'intervalle d'itération donné par loo...
std::int32_t Int32
Type entier signé sur 32 bits.