Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
SimpleCSRMatrix.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7#pragma once
8
11
15#include <alien/data/ISpace.h>
16#include <alien/kernels/simple_csr/CSRStructInfo.h>
17#include <alien/kernels/simple_csr/DistStructInfo.h>
18#include <alien/kernels/simple_csr/SendRecvOp.h>
19#include <alien/kernels/simple_csr/SimpleCSRBackEnd.h>
20#include <alien/kernels/simple_csr/SimpleCSRInternal.h>
21#include <alien/kernels/simple_csr/SimpleCSRPrecomp.h>
22
23#include <alien/utils/StdTimer.h>
24/*---------------------------------------------------------------------------*/
25
26namespace Alien::SimpleCSRInternal
27{
28
29template <typename ValueT>
31
32}
33
34/*---------------------------------------------------------------------------*/
35
36namespace Alien
37{
38
39/*---------------------------------------------------------------------------*/
40
41template <typename ValueT>
43{
44 public:
45 // clang-format off
46 static const bool on_host_only = true ;
47 typedef BackEnd::tag::simplecsr TagType ;
48 typedef ValueT ValueType;
49 typedef SimpleCSRInternal::CSRStructInfo CSRStructInfo;
50 typedef SimpleCSRInternal::CSRStructInfo ProfileType;
51 typedef SimpleCSRInternal::DistStructInfo DistStructInfo;
53 typedef typename ProfileType::IndexType IndexType ;
54 typedef Alien::StdTimer TimerType ;
55 typedef TimerType::Sentry SentryType ;
56 // clang-format on
57
58 public:
61 : IMatrixImpl(nullptr, AlgebraTraits<BackEnd::tag::simplecsr>::name())
62 , m_send_policy(SimpleCSRInternal::CommProperty::ASynch)
63 , m_recv_policy(SimpleCSRInternal::CommProperty::ASynch)
64 {}
65
68 : IMatrixImpl(multi_impl, AlgebraTraits<BackEnd::tag::simplecsr>::name())
69 , m_matrix(multi_impl ? multi_impl->vblock() != nullptr : false)
70 , m_send_policy(SimpleCSRInternal::CommProperty::ASynch)
71 , m_recv_policy(SimpleCSRInternal::CommProperty::ASynch)
72 {}
73
76 {
77#ifdef ALIEN_USE_PERF_TIMER
78 m_timer.printInfo("SimpleCSR-MATRIX");
79#endif
80 }
81
82 void setTraceMng(ITraceMng* trace_mng) { m_trace = trace_mng; }
83
84 public:
85 void free()
86 { /* TODO */
87 }
88 void freeData()
89 { /* TODO */
90 }
91 void clear() {}
92
93 void allocate()
94 {
95 if (block()) {
96 const Integer size = block()->size();
97 m_matrix.getValues().resize((getCSRProfile().getNnz() + 1) * size * size);
98 }
99 else if (vblock()) {
100 m_matrix.getValues().resize(getCSRProfile().getBlockNnz() + 1);
101 }
102 else {
103 m_matrix.getValues().resize(getCSRProfile().getNnz() + 1);
104 }
105 }
106
107 void scal(ValueType const* values)
108 {
109 m_matrix.scal(values) ;
110 }
111
112 CSRStructInfo& getCSRProfile() { return m_matrix.getCSRProfile(); }
113
114 const CSRStructInfo& getCSRProfile() const { return m_matrix.getCSRProfile(); }
115
116 const CSRStructInfo& getProfile() const { return m_matrix.getCSRProfile(); }
117
118 const DistStructInfo& getDistStructInfo() const { return m_matrix_dist_info; }
119
120 SimpleCSRInternal::CommProperty::ePolicyType getSendPolicy() const
121 {
122 return m_send_policy;
123 }
124
125 SimpleCSRInternal::CommProperty::ePolicyType getRecvPolicy() const
126 {
127 return m_recv_policy;
128 }
129
130 ValueType* getAddressData() { return m_matrix.getDataPtr(); }
131 ValueType* data() { return m_matrix.getDataPtr(); }
132
133 ValueType const* getAddressData() const { return m_matrix.getDataPtr(); }
134 ValueType const* data() const { return m_matrix.getDataPtr(); }
135
136 MatrixInternal* internal() { return &m_matrix; }
137
138 MatrixInternal const* internal() const { return &m_matrix; }
139
140 bool isParallel() const { return m_is_parallel; }
141
142 Integer getLocalSize() const { return m_local_size; }
143
144 Integer getLocalOffset() const { return m_local_offset; }
145
146 Integer getGlobalSize() const { return m_global_size; }
147
148 Integer getGhostSize() const { return m_ghost_size; }
149
150 Integer getAllocSize() const
151 {
152 auto total_size = m_local_size + m_ghost_size;
153 if (block())
154 return total_size * block()->size();
155 else if (vblock()) {
156 return m_matrix_dist_info.m_block_offsets[total_size];
157 }
158 else
159 return total_size;
160 }
161
162 Integer blockSize() const
163 {
164 if (block())
165 {
166 return block()->size();
167 }
168 else if (vblock()) {
169 return 1 ;
170 }
171 else {
172 return m_own_block_size ;
173 }
174 }
175
176 void setBlockSize(Integer block_size)
177 {
178 if(this->m_multi_impl)
179 const_cast<MultiMatrixImpl*>(this->m_multi_impl)->setBlockInfos(block_size) ;
180 else
181 m_own_block_size = block_size ;
182 }
183
184 IMessagePassingMng* getParallelMng()
185 {
186 return m_parallel_mng;
187 }
188
189 void sequentialStart()
190 {
191 m_local_offset = 0;
192 m_local_size = getCSRProfile().getNRows();
193 m_global_size = m_local_size;
194 m_myrank = 0;
195 m_nproc = 1;
196 m_is_parallel = false;
197 m_matrix_dist_info.m_local_row_size.resize(m_local_size);
198 auto& profile = internal()->getCSRProfile();
199 ConstArrayView<Integer> offset = profile.getRowOffset();
200 for (Integer i = 0; i < m_local_size; ++i)
201 m_matrix_dist_info.m_local_row_size[i] = offset[i + 1] - offset[i];
202 }
203
204 void parallelStart(ConstArrayView<Integer> offset, IMessagePassingMng* parallel_mng,
205 bool need_sort_ghost_col = false)
206 {
207 m_local_size = getCSRProfile().getNRows();
208 m_parallel_mng = parallel_mng;
209 // m_trace = parallel_mng->traceMng();
210 if (m_parallel_mng == NULL) {
211 m_local_offset = 0;
212 m_global_size = m_local_size;
213 m_myrank = 0;
214 m_nproc = 1;
215 m_is_parallel = false;
216 }
217 else {
218 m_myrank = m_parallel_mng->commRank();
219 m_nproc = m_parallel_mng->commSize();
220 m_local_offset = offset[m_myrank];
221 m_global_size = offset[m_nproc];
222 m_is_parallel = (m_nproc > 1);
223 }
224 if (m_is_parallel) {
225 if (need_sort_ghost_col)
226 sortGhostCols(offset);
227 if (block()) {
228 m_matrix_dist_info.compute(
229 m_nproc, offset, m_myrank, m_parallel_mng, getCSRProfile(), m_trace);
230 }
231 else if (vblock()) {
232 m_matrix_dist_info.compute(m_nproc, offset, m_myrank, m_parallel_mng,
233 getCSRProfile(), vblock(), distribution(), m_trace);
234 }
235 else {
236 m_matrix_dist_info.compute(
237 m_nproc, offset, m_myrank, m_parallel_mng, getCSRProfile(), m_trace);
238 }
239 m_ghost_size = m_matrix_dist_info.m_ghost_nrow;
240 }
241 }
242
243 void sortGhostCols(ConstArrayView<Integer> offset)
244 {
245 IsLocal isLocal(offset, m_myrank);
246 UniqueArray<ValueType>& values = m_matrix.getValues();
247 ProfileType& profile = m_matrix.getCSRProfile();
248 UniqueArray<Integer>& cols = profile.getCols();
249 ConstArrayView<Integer> kcol = profile.getRowOffset();
250 Integer next = 0;
251 UniqueArray<Integer> gcols;
252 UniqueArray<ValueType> gvalues;
253 for (Integer irow = 0; irow < m_local_size; ++irow) {
254 bool need_sort = false;
255 Integer first = next;
256 next = kcol[irow + 1];
257 Integer row_size = next - first;
258 for (Integer k = first; k < next; ++k) {
259 if (!isLocal(cols[k])) {
260 need_sort = true;
261 break;
262 }
263 }
264 if (need_sort) {
265 gvalues.resize(row_size);
266 gcols.resize(row_size);
267 Integer local_count = 0;
268 Integer ghost_count = 0;
269 for (Integer k = first; k < next; ++k) {
270 Integer col = cols[k];
271 if (isLocal(col)) {
272 cols[first + local_count] = col;
273 values[first + local_count] = values[k];
274 ++local_count;
275 }
276 else {
277 gcols[ghost_count] = col;
278 gvalues[ghost_count] = values[k];
279 ++ghost_count;
280 }
281 }
282 for (Integer k = 0; k < ghost_count; ++k) {
283 cols[first + local_count] = gcols[k];
284 values[first + local_count] = gvalues[k];
285 ++local_count;
286 }
287 }
288 }
289 }
290
291 /*
292 DistStructInfo m_matrix_dist_info;
293 */
294
295 virtual SimpleCSRMatrix* cloneTo(const MultiMatrixImpl* multi) const
296 {
297 SimpleCSRMatrix* matrix = new SimpleCSRMatrix(multi);
298 matrix->m_is_parallel = m_is_parallel;
299 matrix->m_local_size = m_local_size;
300 matrix->m_local_offset = m_local_offset;
301 matrix->m_global_size = m_global_size;
302 matrix->m_ghost_size = m_ghost_size;
303 matrix->m_send_policy = m_send_policy;
304 matrix->m_recv_policy = m_recv_policy;
305 matrix->m_nproc = m_nproc;
306 matrix->m_myrank = m_myrank;
307 matrix->m_parallel_mng = m_parallel_mng;
308 matrix->m_trace = m_trace;
309 matrix->setBlockSize(blockSize()) ;
310 matrix->m_matrix.copy(m_matrix);
311 matrix->m_matrix_dist_info.copy(m_matrix_dist_info);
312 return matrix;
313 }
314
315 void copy(SimpleCSRMatrix const& matrix)
316 {
317 m_is_parallel = matrix.m_is_parallel;
318 m_local_size = matrix.m_local_size;
319 m_local_offset = matrix.m_local_offset;
320 m_global_size = matrix.m_global_size;
321 m_ghost_size = matrix.m_ghost_size;
322 m_send_policy = matrix.m_send_policy;
323 m_recv_policy = matrix.m_recv_policy;
324 m_nproc = matrix.m_nproc;
325 m_myrank = matrix.m_myrank;
326 m_parallel_mng = matrix.m_parallel_mng;
327 m_trace = matrix.m_trace;
328 if(blockSize()==matrix.blockSize())
329 m_matrix.copy(matrix.m_matrix);
330 else
331 {
332 auto nb_blocks = matrix.getCSRProfile().getNnz() + 1 ;
333 m_matrix.copy(matrix.m_matrix,blockSize(),matrix.blockSize(),nb_blocks) ;
334 }
335 m_matrix_dist_info.copy(matrix.m_matrix_dist_info);
336 }
337
338 void copyProfile(SimpleCSRMatrix const& matrix)
339 {
340 m_is_parallel = matrix.m_is_parallel;
341 m_local_size = matrix.m_local_size;
342 m_local_offset = matrix.m_local_offset;
343 m_global_size = matrix.m_global_size;
344 m_ghost_size = matrix.m_ghost_size;
345 m_send_policy = matrix.m_send_policy;
346 m_recv_policy = matrix.m_recv_policy;
347 m_nproc = matrix.m_nproc;
348 m_myrank = matrix.m_myrank;
349 m_parallel_mng = matrix.m_parallel_mng;
350 m_trace = matrix.m_trace;
351 m_matrix.getCSRProfile().copy(matrix.m_matrix.getCSRProfile());
352 m_matrix_dist_info.copy(matrix.m_matrix_dist_info);
353 if (vblock()) {
354 auto& profile = m_matrix.getCSRProfile();
355 const VBlock* block_sizes = vblock();
356 auto& block_row_offset = profile.getBlockRowOffset();
357 auto& block_cols = profile.getBlockCols();
358 auto kcol = profile.kcol();
359 auto cols = profile.cols();
360 Integer offset = 0;
361 for (Integer irow = 0; irow < m_local_size; ++irow) {
362 block_row_offset[irow] = offset;
363 auto row_blk_size = block_sizes->size(m_local_offset + irow);
364 for (auto k = kcol[irow]; k < kcol[irow + 1]; ++k) {
365 block_cols[k] = offset;
366 auto jcol = cols[k];
367 auto col_blk_size = block_sizes->size(jcol);
368 offset += row_blk_size * col_blk_size;
369 }
370 }
371 block_row_offset[m_local_size] = offset;
372 block_cols[kcol[m_local_size]] = offset;
373
374 const Integer total_size = m_local_size + m_ghost_size;
375
376 m_matrix_dist_info.m_block_sizes.resize(total_size);
377 m_matrix_dist_info.m_block_offsets.resize(total_size + 1);
378
379 offset = 0;
380 for (Integer i = 0; i < m_local_size; ++i) {
381 auto blk_size = block_sizes->size(m_local_offset + i);
382 m_matrix_dist_info.m_block_sizes[i] = blk_size;
383 m_matrix_dist_info.m_block_offsets[i] = offset;
384 offset += blk_size;
385 }
386 for (Integer i = m_local_size; i < total_size; ++i) {
387 auto blk_size = block_sizes->size(m_matrix_dist_info.m_recv_info.m_uids[i - m_local_size]);
388 m_matrix_dist_info.m_block_sizes[i] = blk_size;
389 m_matrix_dist_info.m_block_offsets[i] = offset;
390 offset += blk_size;
391 }
392 m_matrix_dist_info.m_block_offsets[total_size] = offset;
393 }
394 }
395
396 void notifyChanges()
397 {
398 m_matrix.notifyChanges();
399 }
400
401 void endUpdate()
402 {
403 if (m_matrix.needUpdate()) {
404 m_matrix.endUpdate();
405 this->updateTimestamp();
406 }
407 }
408
409 private:
410 class IsLocal
411 {
412 public:
413 IsLocal(const ConstArrayView<Integer> offset, const Integer myrank)
414 : m_offset(offset)
415 , m_myrank(myrank)
416 {}
417 bool operator()(Arccore::Integer col) const
418 {
419 return (col >= m_offset[m_myrank]) && (col < m_offset[m_myrank + 1]);
420 }
421
422 private:
423 const ConstArrayView<Integer> m_offset;
424 const Integer m_myrank;
425 };
426
427 MatrixInternal m_matrix;
428 bool m_is_parallel = 0;
429 Integer m_local_size = 0;
430 Integer m_local_offset = 0;
431 Integer m_global_size = 0;
432 Integer m_ghost_size = 0;
433 DistStructInfo m_matrix_dist_info;
434 SimpleCSRInternal::CommProperty::ePolicyType m_send_policy;
435 SimpleCSRInternal::CommProperty::ePolicyType m_recv_policy;
436 IMessagePassingMng* m_parallel_mng = nullptr;
437 Integer m_own_block_size = 1;
438 Integer m_nproc = 1;
439 Integer m_myrank = 0;
440 ITraceMng* m_trace = nullptr;
441
442 friend class SimpleCSRInternal::SimpleCSRMatrixMultT<ValueType>;
443
444 private:
445 mutable TimerType m_timer;
446
447 public:
448 TimerType& timer() const
449 {
450 return m_timer;
451 }
452};
453
454/*---------------------------------------------------------------------------*/
455
456} // namespace Alien
457
458/*---------------------------------------------------------------------------*/
Block.h.
IMatrixImpl.h.
ISpace.h.
MultiMatrixImpl.h.
VBlockOffsets.h.
VBlock.h.
Arccore::Integer size() const
Get square block size.
Definition Block.cc:90
virtual const VBlock * vblock() const
Get block datas of the matrix.
virtual const MatrixDistribution & distribution() const
Get the distribution of the matrix.
const MultiMatrixImpl * m_multi_impl
Pointer on matrices implementation.
virtual const Block * block() const
Get block datas of the matrix.
IMatrixImpl(const MultiMatrixImpl *multi_impl, BackEndId backend="")
Constructor.
Multi matrices representation container.
SimpleCSRMatrix(const MultiMatrixImpl *multi_impl)
void clear()
Wipe out internal data.
void updateTimestamp()
Met à jour le timestamp.
Definition Timestamp.cc:50
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
Definition BackEnd.h:17