Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
SimpleCSRMatrixMultT.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7
8#pragma once
9
10#include <arccore/collections/Array2.h>
11
12#include <alien/handlers/scalar/CSRModifierViewT.h>
13
14/*---------------------------------------------------------------------------*/
15
16namespace Alien::SimpleCSRInternal
17{
18
19/*---------------------------------------------------------------------------*/
20
21template <typename ValueT>
23: m_matrix_impl(matrix)
24{}
25
26/*---------------------------------------------------------------------------*/
27
28template <typename ValueT>
29void SimpleCSRMatrixMultT<ValueT>::synchronize(VectorType& x) const
30{
31 if (m_matrix_impl.block()) {
32 if (m_matrix_impl.m_is_parallel)
33 _synchronizeBlock(x);
34 }
35 else if (m_matrix_impl.vblock()) {
36 if (m_matrix_impl.m_is_parallel)
37 _synchronizeVariableBlock(x);
38 }
39 else {
40 if (m_matrix_impl.m_is_parallel)
41 _synchronize(x);
42 }
43}
44
45
46
47
48template <typename ValueT>
49void SimpleCSRMatrixMultT<ValueT>::mult(const VectorType& x, VectorType& y) const
50{
51 if (m_matrix_impl.block()) {
52 if (m_matrix_impl.m_is_parallel)
53 _parallelMultBlock(x, y);
54 else
55 _seqMultBlock(x, y);
56 }
57 else if (m_matrix_impl.vblock()) {
58 if (m_matrix_impl.m_is_parallel)
59 _parallelMultVariableBlock(x, y);
60 else
61 _seqMultVariableBlock(x, y);
62 }
63 else {
64 if (m_matrix_impl.m_is_parallel)
65 _parallelMult(x, y);
66 else
67 _seqMult(x, y);
68 }
69}
70
71template <typename ValueT>
72void SimpleCSRMatrixMultT<ValueT>::addLMult(Real alpha, const VectorType& x, VectorType& y) const
73{
74 _seqAddLMult(alpha, x, y);
75}
76
77template <typename ValueT>
78void SimpleCSRMatrixMultT<ValueT>::addUMult(Real alpha, const VectorType& x, VectorType& y) const
79{
80 _seqAddUMult(alpha, x, y);
81}
82
83template <typename ValueT>
84void SimpleCSRMatrixMultT<ValueT>::mult(const UniqueArray<Real>& x, UniqueArray<Real>& y) const
85{
86 if (m_matrix_impl.m_is_parallel)
87 _parallelMult(x, y);
88 else
89 _seqMult(x, y);
90}
91
92/*---------------------------------------------------------------------------*/
93
94template <typename ValueT>
95void SimpleCSRMatrixMultT<ValueT>::_synchronize(VectorType& x_impl) const
96{
97 Integer alloc_size = m_matrix_impl.m_local_size + m_matrix_impl.m_ghost_size;
98 x_impl.resize(alloc_size);
99 Real* x_ptr = (Real*)x_impl.getDataPtr();
100 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
101 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
102 ConstArrayView<Integer> row_offset =
103 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
104 SendRecvOp<Real> op(x_ptr, m_matrix_impl.m_matrix_dist_info.m_send_info,
105 m_matrix_impl.m_send_policy, x_ptr, m_matrix_impl.m_matrix_dist_info.m_recv_info,
106 m_matrix_impl.m_recv_policy, m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace);
107 op.start();
108 op.end();
109}
110
111template <typename ValueT>
112void SimpleCSRMatrixMultT<ValueT>::_parallelMult(
113const VectorType& x_impl, VectorType& y_impl) const
114{
115 Integer alloc_size = m_matrix_impl.m_local_size + m_matrix_impl.m_ghost_size;
116 x_impl.resize(alloc_size);
117 Real* y_ptr = y_impl.getDataPtr();
118 Real* x_ptr = (Real*)x_impl.getDataPtr();
119 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
120 // ConstArrayView<Integer> cols2 =
121 // m_matrix_impl.m_matrix.getCSRProfile().getCols();
122 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
123 ConstArrayView<Integer> row_offset =
124 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
125 SendRecvOp<Real> op(x_ptr, m_matrix_impl.m_matrix_dist_info.m_send_info,
126 m_matrix_impl.m_send_policy, x_ptr, m_matrix_impl.m_matrix_dist_info.m_recv_info,
127 m_matrix_impl.m_recv_policy, m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace);
128 op.start();
129 ConstArrayView<Integer> local_row_size =
130 m_matrix_impl.m_matrix_dist_info.m_local_row_size;
131 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow) {
132 Integer off = row_offset[irow];
133 Integer off2 = off + local_row_size[irow];
134 Real tmpy = 0.;
135 for (Integer j = off; j < off2; ++j) {
136 tmpy += matrix[j] * x_ptr[cols[j]];
137 }
138 y_ptr[irow] = tmpy;
139 }
140 op.end();
141
142 Integer interface_nrow = m_matrix_impl.m_matrix_dist_info.m_interface_nrow;
143 ConstArrayView<Integer> row_ids = m_matrix_impl.m_matrix_dist_info.m_interface_rows;
144 for (Integer i = 0; i < interface_nrow; ++i) {
145 Integer irow = row_ids[i];
146 Integer off = row_offset[irow] + local_row_size[irow];
147 Integer off2 = row_offset[irow + 1];
148 Real tmpy = 0.;
149 for (Integer j = off; j < off2; ++j) {
150 tmpy += matrix[j] * x_ptr[cols[j]];
151 }
152 y_ptr[irow] += tmpy;
153 }
154}
155
156template <typename ValueT>
157void SimpleCSRMatrixMultT<ValueT>::_parallelMult(
158const UniqueArray<Real>& x_impl, UniqueArray<Real>& y_impl) const
159{
160 Real* y_ptr = dataPtr(y_impl);
161 Real* x_ptr = (Real*)dataPtr(x_impl);
162 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
163 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
164 ConstArrayView<Integer> row_offset =
165 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
166 SendRecvOp<Real> op(x_ptr, m_matrix_impl.m_matrix_dist_info.m_send_info,
167 m_matrix_impl.m_send_policy, x_ptr, m_matrix_impl.m_matrix_dist_info.m_recv_info,
168 m_matrix_impl.m_recv_policy, m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace);
169 op.start();
170 ConstArrayView<Integer> local_row_size =
171 m_matrix_impl.m_matrix_dist_info.m_local_row_size;
172 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow) {
173 Integer off = row_offset[irow];
174 Integer off2 = off + local_row_size[irow];
175 Real tmpy = 0.;
176 for (Integer j = off; j < off2; ++j) {
177 tmpy += matrix[j] * x_ptr[cols[j]];
178 }
179 y_ptr[irow] = tmpy;
180 }
181 op.end();
182
183 Integer interface_nrow = m_matrix_impl.m_matrix_dist_info.m_interface_nrow;
184 ConstArrayView<Integer> row_ids = m_matrix_impl.m_matrix_dist_info.m_interface_rows;
185 for (Integer i = 0; i < interface_nrow; ++i) {
186 Integer irow = row_ids[i];
187 Integer off = row_offset[irow] + local_row_size[irow];
188 Integer off2 = row_offset[irow + 1];
189 Real tmpy = 0.;
190 for (Integer j = off; j < off2; ++j) {
191 tmpy += matrix[j] * x_ptr[cols[j]];
192 }
193 y_ptr[irow] += tmpy;
194 }
195}
196/*---------------------------------------------------------------------------*/
197
198template <typename ValueT>
199void SimpleCSRMatrixMultT<ValueT>::_seqMult(const VectorType& x_impl, VectorType& y_impl) const
200{
201#ifdef ALIEN_USE_PERF_TIMER
202 typename MatrixType::SentryType sentry(m_matrix_impl.timer(), "CSR-SPMV");
203#endif
204 Real* y_ptr = y_impl.getDataPtr();
205 Real* x_ptr = (Real*)x_impl.getDataPtr();
206 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
207 ConstArrayView<Integer> cols = m_matrix_impl.m_matrix.getCSRProfile().getCols();
208 ConstArrayView<Integer> row_offset =
209 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
210 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow) {
211 Real tmpy = 0.;
212 for (Integer j = row_offset[irow]; j < row_offset[irow + 1]; ++j) {
213 tmpy += matrix[j] * x_ptr[cols[j]];
214 }
215 y_ptr[irow] = tmpy;
216 }
217}
218
219template <typename ValueT>
220void SimpleCSRMatrixMultT<ValueT>::_seqAddLMult(Real alpha, const VectorType& x_impl, VectorType& y_impl) const
221{
222#ifdef ALIEN_USE_PERF_TIMER
223 typename MatrixType::SentryType sentry(m_matrix_impl.timer(), "CSR-AddLMult");
224#endif
225 Real* y_ptr = y_impl.getDataPtr();
226 Real* x_ptr = (Real*)x_impl.getDataPtr();
227 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
228 ConstArrayView<Integer> cols = m_matrix_impl.m_matrix.getCSRProfile().getCols();
229 ConstArrayView<Integer> row_offset = m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
230 auto diag_offset = m_matrix_impl.m_matrix.getCSRProfile().getUpperDiagOffset();
231 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow) {
232 Real tmpy = y_ptr[irow];
233 for (Integer j = row_offset[irow]; j < diag_offset[irow]; ++j) {
234 tmpy += alpha * matrix[j] * x_ptr[cols[j]];
235 }
236 y_ptr[irow] = tmpy;
237 }
238}
239
240template <typename ValueT>
241void SimpleCSRMatrixMultT<ValueT>::_seqAddUMult(Real alpha, const VectorType& x_impl, VectorType& y_impl) const
242{
243#ifdef ALIEN_USE_PERF_TIMER
244 typename MatrixType::SentryType sentry(m_matrix_impl.timer(), "CSR-AddUMult");
245#endif
246 Real* y_ptr = y_impl.getDataPtr();
247 Real* x_ptr = (Real*)x_impl.getDataPtr();
248 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
249 ConstArrayView<Integer> cols = m_matrix_impl.m_matrix.getCSRProfile().getCols();
250 ConstArrayView<Integer> row_offset = m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
251 auto diag_offset = m_matrix_impl.m_matrix.getCSRProfile().getUpperDiagOffset();
252 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow) {
253 Real tmpy = y_ptr[irow];
254 for (Integer j = diag_offset[irow] + 1; j < row_offset[irow + 1]; ++j) {
255 tmpy += alpha * matrix[j] * x_ptr[cols[j]];
256 }
257 y_ptr[irow] = tmpy;
258 }
259}
260
261template <typename ValueT>
262void SimpleCSRMatrixMultT<ValueT>::_seqMult(
263const UniqueArray<Real>& x_impl, UniqueArray<Real>& y_impl) const
264{
265 Real* y_ptr = dataPtr(y_impl);
266 Real* x_ptr = (Real*)dataPtr(x_impl);
267 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
268 ConstArrayView<Integer> cols = m_matrix_impl.m_matrix.getCSRProfile().getCols();
269 ConstArrayView<Integer> row_offset =
270 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
271 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow) {
272 Real tmpy = 0.;
273 for (Integer j = row_offset[irow]; j < row_offset[irow + 1]; ++j) {
274 tmpy += matrix[j] * x_ptr[cols[j]];
275 }
276 y_ptr[irow] = tmpy;
277 }
278}
279
280/*---------------------------------------------------------------------------*/
281
282template <typename ValueT>
283void SimpleCSRMatrixMultT<ValueT>::_synchronizeBlock(VectorType& x) const
284{
285 Integer alloc_size = m_matrix_impl.m_local_size + m_matrix_impl.m_ghost_size;
286 const Integer block_size = m_matrix_impl.block()->size();
287 x.resize(alloc_size * block_size);
288 ConstArrayView<Real> x_ptr = x.fullValues();
289 Real const* matrix = m_matrix_impl.m_matrix.getDataPtr();
290 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
291 ConstArrayView<Integer> row_offset =
292 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
293 SimpleCSRInternal::SendRecvOp<Real> op(x.getDataPtr(),
294 m_matrix_impl.m_matrix_dist_info.m_send_info, m_matrix_impl.m_send_policy,
295 (Real*)x.getDataPtr(), m_matrix_impl.m_matrix_dist_info.m_recv_info,
296 m_matrix_impl.m_recv_policy, m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace,
297 block_size);
298 op.start();
299 op.end();
300}
301
302
303
304
305
306template <typename ValueT>
307void SimpleCSRMatrixMultT<ValueT>::_parallelMultBlock(const VectorType& x, VectorType& y) const
308{
309 Integer alloc_size = m_matrix_impl.m_local_size + m_matrix_impl.m_ghost_size;
310 const Integer block_size = m_matrix_impl.block()->size();
311 x.resize(alloc_size * block_size);
312 ArrayView<Real> _y = y.fullValues();
313 ConstArrayView<Real> x_ptr = x.fullValues();
314 Real const* matrix = m_matrix_impl.m_matrix.getDataPtr();
315 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
316 // ConstArrayView<Integer> cols2 =
317 // m_matrix_impl.m_matrix.getCSRProfile().getCols();
318 ConstArrayView<Integer> row_offset =
319 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
320 SimpleCSRInternal::SendRecvOp<Real> op(x.getDataPtr(),
321 m_matrix_impl.m_matrix_dist_info.m_send_info, m_matrix_impl.m_send_policy,
322 (Real*)x.getDataPtr(), m_matrix_impl.m_matrix_dist_info.m_recv_info,
323 m_matrix_impl.m_recv_policy, m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace,
324 block_size);
325 op.start();
326 ConstArrayView<Integer> local_row_size =
327 m_matrix_impl.m_matrix_dist_info.m_local_row_size;
328 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow) {
329 ArrayView<Real> y_ptr = _y.subView(irow * block_size, block_size);
330 Integer off = row_offset[irow];
331 Integer off2 = off + local_row_size[irow];
332 Real const* m = matrix + off * block_size * block_size;
333 for (Integer ieq = 0; ieq < block_size; ++ieq)
334 y_ptr[ieq] = 0.;
335 for (Integer jcol = off; jcol < off2; ++jcol) {
336 ConstArrayView<Real> ptr = x_ptr.subView(cols[jcol] * block_size, block_size);
337 for (Integer ieq = 0; ieq < block_size; ++ieq)
338 for (Integer iu = 0; iu < block_size; ++iu)
339 y_ptr[ieq] += m[iu + block_size * ieq] * ptr[iu];
340 m += block_size * block_size;
341 }
342 // y_ptr += block_size;
343 }
344 op.end();
345
346 Integer interface_nrow = m_matrix_impl.m_matrix_dist_info.m_interface_nrow;
347 ConstArrayView<Integer> row_ids = m_matrix_impl.m_matrix_dist_info.m_interface_rows;
348 ArrayView<Real> y_ptr = _y;
349 for (Integer i = 0; i < interface_nrow; ++i) {
350 Integer irow = row_ids[i];
351 ArrayView<Real> yptr = y_ptr.subView(irow * block_size, block_size);
352 Integer off = row_offset[irow] + local_row_size[irow];
353 Integer off2 = row_offset[irow + 1];
354 Real const* m = matrix + off * block_size * block_size;
355 for (Integer jcol = off; jcol < off2; ++jcol) {
356 ConstArrayView<Real> ptr = x_ptr.subView(cols[jcol] * block_size, block_size);
357 for (Integer ieq = 0; ieq < block_size; ++ieq)
358 for (Integer iu = 0; iu < block_size; ++iu)
359 yptr[ieq] += m[iu + block_size * ieq] * ptr[iu];
360 m += block_size * block_size;
361 }
362 }
363}
364
365/*---------------------------------------------------------------------------*/
366
367template <typename ValueT>
368void SimpleCSRMatrixMultT<ValueT>::_seqMultBlock(const VectorType& x, VectorType& y) const
369{
370 Real* y_ptr = y.getDataPtr();
371 Real const* x_ptr = x.getDataPtr();
372 Real const* matrix = m_matrix_impl.m_matrix.getDataPtr();
373 const Integer block_size = m_matrix_impl.block()->size();
374 ConstArrayView<Integer> cols = m_matrix_impl.m_matrix.getCSRProfile().getCols();
375 ConstArrayView<Integer> row_offset = m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
376 Real* yptr = y_ptr ;
377 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow)
378 {
379 Integer off = row_offset[irow];
380 Integer off2 = row_offset[irow + 1];
381 Real const* m = matrix + off * block_size * block_size;
382 for (Integer ieq = 0; ieq < block_size; ++ieq)
383 yptr[ieq] = 0.;
384 for (Integer jcol = off; jcol < off2; ++jcol)
385 {
386 Real const* ptr = x_ptr + cols[jcol] * block_size;
387 for (Integer ieq = 0; ieq < block_size; ++ieq)
388 for (Integer iu = 0; iu < block_size; ++iu)
389 yptr[ieq] += m[iu + block_size * ieq] * ptr[iu];
390 m += block_size * block_size;
391 }
392 yptr += block_size;
393 }
394
395 /*
396 for(int il=0;il<m_matrix_impl.m_local_size;++il)
397 {
398 std::cout<<"X["<<il<<"]:";
399 for(int i=0;i<block_size;++i)
400 std::cout<<x_ptr[il*block_size+i]<<",";
401 std::cout<<std::endl;
402 }
403 for(int il=0;il<m_matrix_impl.m_local_size;++il)
404 {
405 std::cout<<"Y["<<il<<"]:";
406 for(int i=0;i<block_size;++i)
407 std::cout<<y_ptr[il*block_size+i]<<",";
408 std::cout<<std::endl;
409 }
410 for(std::size_t irow=0;irow<m_matrix_impl.m_local_size;++irow)
411 {
412 for(int ieq=0;ieq<block_size;++ieq)
413 {
414 std::cout<<"LINE["<<irow<<","<<ieq<<"] : ";
415 ValueType value = 0. ;
416 Integer off = row_offset[irow];
417 Integer off2 = row_offset[irow + 1];
418 Real const* m = matrix + off * block_size * block_size;
419 for (Integer jcol = off; jcol < off2; ++jcol)
420 {
421 Real const* ptr = x_ptr + cols[jcol] * block_size;
422 for (Integer iu = 0; iu < block_size; ++iu)
423 value += m[iu + block_size * ieq] * ptr[iu];
424 m += block_size * block_size;
425 }
426 std::cout<<"\nY_CPU["<<irow<<","<<ieq<<"]:"<<value<<std::endl;
427 }
428 }*/
429}
430
431/*---------------------------------------------------------------------------*/
432template <typename ValueT>
433void SimpleCSRMatrixMultT<ValueT>::_synchronizeVariableBlock(VectorType& x_impl) const
434{
435 // alien_info([&] { cout()<<"_parallelMultVariableBlock";}) ;
436
437 ConstArrayView<Integer> block_sizes = m_matrix_impl.getDistStructInfo().m_block_sizes;
438 ConstArrayView<Integer> block_offsets = m_matrix_impl.getDistStructInfo().m_block_offsets;
439 {
440 const Integer last = block_offsets.size() - 1;
441 x_impl.resize(block_offsets[last]);
442 }
443
444 const ValueT* x_ptr = x_impl.getDataPtr();
445 const ValueT* matrix_ptr = m_matrix_impl.m_matrix.getDataPtr();
446
447 ConstArrayView<Integer> row_offset =
448 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
449 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
450 ConstArrayView<Integer> block_cols =
451 m_matrix_impl.m_matrix.getCSRProfile().getBlockCols();
452
453 SendRecvOp<Real> op(x_ptr, m_matrix_impl.m_matrix_dist_info.m_send_info,
454 m_matrix_impl.m_send_policy, (ValueT*)x_ptr,
455 m_matrix_impl.m_matrix_dist_info.m_recv_info, m_matrix_impl.m_recv_policy,
456 m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace, block_sizes, block_offsets);
457
458 op.start();
459 op.end();
460}
461
462
463
464
465
466template <typename ValueT>
467void SimpleCSRMatrixMultT<ValueT>::_parallelMultVariableBlock(
468const VectorType& x_impl, VectorType& y_impl) const
469{
470 // alien_info([&] { cout()<<"_parallelMultVariableBlock";}) ;
471
472 ArrayView<ValueT> y = y_impl.fullValues();
473
474 ConstArrayView<Integer> block_sizes = m_matrix_impl.getDistStructInfo().m_block_sizes;
475 ConstArrayView<Integer> block_offsets = m_matrix_impl.getDistStructInfo().m_block_offsets;
476 {
477 const Integer last = block_offsets.size() - 1;
478 x_impl.resize(block_offsets[last]);
479 }
480
481 const ValueT* x_ptr = x_impl.getDataPtr();
482 const ValueT* matrix_ptr = m_matrix_impl.m_matrix.getDataPtr();
483
484 ConstArrayView<Integer> row_offset =
485 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
486 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
487 ConstArrayView<Integer> block_cols =
488 m_matrix_impl.m_matrix.getCSRProfile().getBlockCols();
489
490 SendRecvOp<Real> op(x_ptr, m_matrix_impl.m_matrix_dist_info.m_send_info,
491 m_matrix_impl.m_send_policy, (ValueT*)x_ptr,
492 m_matrix_impl.m_matrix_dist_info.m_recv_info, m_matrix_impl.m_recv_policy,
493 m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace, block_sizes, block_offsets);
494
495 op.start();
496
497 UniqueArray<ValueT> tmpy;
498 tmpy.reserve(m_matrix_impl.vblock()->maxBlockSize());
499
500 ConstArrayView<Integer> local_row_size =
501 m_matrix_impl.m_matrix_dist_info.m_local_row_size;
502 for (Integer irow = 0; irow < m_matrix_impl.m_local_size; ++irow) {
503 Integer off = row_offset[irow];
504 Integer off2 = off + local_row_size[irow];
505 const Integer block_size_row = block_sizes[irow];
506 tmpy.resize(block_size_row);
507 tmpy.fill(ValueT());
508 for (Integer j = off; j < off2; ++j) {
509 const Integer col = cols[j];
510 const Integer block_size_col = block_sizes[col];
511 ConstArrayView<ValueT> x(block_size_col, x_ptr + block_offsets[col]);
512 ConstArray2View<ValueT> matrix(
513 matrix_ptr + block_cols[j], block_size_row, block_size_col);
514 for (Integer krow = 0; krow < block_size_row; ++krow) {
515 for (Integer kcol = 0; kcol < block_size_col; ++kcol) {
516 tmpy[krow] += matrix[krow][kcol] * x[kcol];
517 }
518 }
519 }
520 y.subView(block_offsets[irow], block_size_row).copy(tmpy);
521 }
522
523 op.end();
524
525 Integer interface_nrow = m_matrix_impl.m_matrix_dist_info.m_interface_nrow;
526 ConstArrayView<Integer> row_ids = m_matrix_impl.m_matrix_dist_info.m_interface_rows;
527 for (Integer i = 0; i < interface_nrow; ++i) {
528 Integer irow = row_ids[i];
529 Integer off = row_offset[irow] + local_row_size[irow];
530 Integer off2 = row_offset[irow + 1];
531 const Integer block_size_row = block_sizes[irow];
532 tmpy.resize(block_size_row);
533 tmpy.fill(ValueT());
534 for (Integer j = off; j < off2; ++j) {
535 const Integer col = cols[j];
536 const Integer block_size_col = block_sizes[col];
537 ConstArrayView<ValueT> x(block_size_col, x_ptr + block_offsets[col]);
538 ConstArray2View<ValueT> matrix(
539 matrix_ptr + block_cols[j], block_size_row, block_size_col);
540 for (Integer krow = 0; krow < block_size_row; ++krow) {
541 for (Integer kcol = 0; kcol < block_size_col; ++kcol) {
542 tmpy[krow] += matrix[krow][kcol] * x[kcol];
543 }
544 }
545 }
546 ArrayView<ValueT> y_view = y.subView(block_offsets[irow], block_size_row);
547 for (Integer k = 0; k < block_size_row; ++k)
548 y_view[k] += tmpy[k];
549 }
550}
551
552/*---------------------------------------------------------------------------*/
553
554template <typename ValueT>
555void SimpleCSRMatrixMultT<ValueT>::_seqMultVariableBlock(
556const VectorType& x_impl, VectorType& y_impl) const
557{
558 // alien_info([&] { cout()<<"_seqMultVariableBlock";}) ;
559
560 ArrayView<ValueT> y = y_impl.fullValues();
561
562 const ValueT* x_ptr = x_impl.getDataPtr();
563 const ValueT* matrix_ptr = m_matrix_impl.m_matrix.getDataPtr();
564
565 ConstArrayView<Integer> row_offset =
566 m_matrix_impl.m_matrix.getCSRProfile().getRowOffset();
567 ConstArrayView<Integer> cols = m_matrix_impl.m_matrix.getCSRProfile().getCols();
568 ConstArrayView<Integer> block_cols =
569 m_matrix_impl.m_matrix.getCSRProfile().getBlockCols();
570 const VBlock* block = m_matrix_impl.vblock();
571 const VBlockImpl& block_infos = x_impl.vblockImpl();
572
573 UniqueArray<ValueT> tmpy;
574 tmpy.reserve(block->maxBlockSize());
575 for (Integer irow = 0; irow < m_matrix_impl.m_local_size;
576 ++irow) // Attention, c'est local !!!!
577 {
578 const Integer block_size_row = block->size(irow);
579 tmpy.resize(block_size_row);
580 tmpy.fill(ValueT());
581 for (Integer j = row_offset[irow]; j < row_offset[irow + 1]; ++j) {
582 const Integer col = cols[j];
583 const Integer block_size_col = block->size(col);
584 ConstArrayView<ValueT> x(block_size_col, x_ptr + block_infos.offset(col));
585 ConstArray2View<ValueT> matrix(
586 matrix_ptr + block_cols[j], block_size_row, block_size_col);
587 for (Integer krow = 0; krow < block_size_row; ++krow) {
588 for (Integer kcol = 0; kcol < block_size_col; ++kcol) {
589 tmpy[krow] += matrix[krow][kcol] * x[kcol];
590 }
591 }
592 }
593 y.subView(block_infos.offset(irow), block_size_row).copy(tmpy);
594 }
595}
596
597
598template <typename ValueT>
599void SimpleCSRMatrixMultT<ValueT>::multDiag(VectorType& y) const
600{
601 Real* y_ptr = y.getDataPtr();
602 CSRConstViewT<MatrixType> view(m_matrix_impl);
603 // clang-format off
604 auto nrows = view.nrows() ;
605 auto kcol = view.kcol() ;
606 auto dcol = view.dcol() ;
607 auto cols = view.cols() ;
608 auto values = view.data() ;
609 // clang-format on
610 for (Integer irow = 0; irow < nrows; ++irow)
611 y_ptr[irow] = y_ptr[irow] * values[dcol[irow]];
612}
613
614template <typename ValueT>
615void SimpleCSRMatrixMultT<ValueT>::multInvDiag(VectorType& y) const
616{
617 Real* y_ptr = y.getDataPtr();
618 CSRConstViewT<MatrixType> view(m_matrix_impl);
619 // clang-format off
620 auto nrows = view.nrows() ;
621 auto kcol = view.kcol() ;
622 auto dcol = view.dcol() ;
623 auto cols = view.cols() ;
624 auto values = view.data() ;
625 // clang-format on
626 for (Integer irow = 0; irow < nrows; ++irow)
627 y_ptr[irow] = y_ptr[irow] / values[dcol[irow]];
628}
629
630template <typename ValueT>
631void SimpleCSRMatrixMultT<ValueT>::computeDiag(VectorType& y) const
632{
633 Real* y_ptr = y.getDataPtr();
634
635 CSRConstViewT<MatrixType> view(m_matrix_impl);
636 // clang-format off
637 auto nrows = view.nrows() ;
638 auto kcol = view.kcol() ;
639 auto dcol = view.dcol() ;
640 auto cols = view.cols() ;
641 auto values = view.data() ;
642 // clang-format on
643 if(m_matrix_impl.blockSize()==1)
644 {
645 for (Integer irow = 0; irow < nrows; ++irow)
646 y_ptr[irow] = values[dcol[irow]];
647 }
648 else
649 {
650 Integer block_size = m_matrix_impl.blockSize();
651 Integer block2_size = block_size*block_size ;
652 for (Integer irow = 0; irow < nrows; ++irow)
653 {
654 for(Integer ieq=0;ieq<block_size;++ieq)
655 y_ptr[irow*block_size+ieq] = values[dcol[irow]*block2_size+ieq*block_size+ieq];
656 }
657 }
658}
659
660template <typename ValueT>
661void SimpleCSRMatrixMultT<ValueT>::computeInvDiag(VectorType& y) const
662{
663 Real* y_ptr = y.getDataPtr();
664
665 CSRConstViewT<MatrixType> view(m_matrix_impl);
666 // clang-format off
667 auto nrows = view.nrows() ;
668 auto kcol = view.kcol() ;
669 auto dcol = view.dcol() ;
670 auto cols = view.cols() ;
671 auto values = view.data() ;
672 // clang-format on
673 if(m_matrix_impl.blockSize()==1)
674 {
675 for (Integer irow = 0; irow < nrows; ++irow)
676 y_ptr[irow] = 1. / values[dcol[irow]];
677 }
678 else
679 {
680 Integer block_size = m_matrix_impl.blockSize();
681 Integer block2_size = block_size*block_size ;
682 for (Integer irow = 0; irow < nrows; ++irow)
683 {
684 for(Integer ieq=0;ieq<block_size;++ieq)
685 {
686 y_ptr[irow*block_size+ieq] = 1. / values[dcol[irow]*block2_size+ieq*block_size+ieq];
687 }
688 }
689 /*
690 auto kcol = view.kcol() ;
691 for(int il=0;il<nrows;++il)
692 {
693 std::cout<<"MAT["<<il<<","<<il<<"]:";
694 for(int k=kcol[il];k<kcol[il+1];++k)
695 {
696 std::cout<<"("<<cols[k]<<",";
697 for(int ieq=0;ieq<block_size;++ieq)
698 for(int ju=0;ju<block_size;++ju)
699 std::cout<<values[k*block2_size+ieq*block_size +ju]<<",";
700 std::cout<<")"<<std::endl;
701 }
702 }
703 for(int il=0;il<nrows;++il)
704 {
705 std::cout<<"DIAG["<<il<<"]:";
706 for(int i=0;i<block_size;++i)
707 std::cout<<y_ptr[il*block_size+i]<<",";
708 std::cout<<std::endl;
709 std::cout<<"MAT["<<il<<","<<il<<"]:";
710 auto mat_offset = kcol[il]*block2_size ;
711 for(int ieq=0;ieq<block_size;++ieq)
712 for(int ju=0;ju<block_size;++ju)
713 std::cout<<values[mat_offset+ieq*block_size +ju]<<",";
714 std::cout<<std::endl;
715 }*/
716 }
717}
718
719/*---------------------------------------------------------------------------*/
720
721} // namespace Alien::SimpleCSRInternal
722
723/*---------------------------------------------------------------------------*/
SimpleCSRMatrixMultT(const MatrixType &matrix)
Constructeur de la classe.
void mult(const VectorType &x, VectorType &y) const
Matrix vector product.