Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
HCSRMatrix.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7
8
9#pragma once
10
13
14#include <alien/data/ISpace.h>
15
16
17#include <alien/kernels/sycl/SYCLPrecomp.h>
18#include <alien/handlers/accelerator/HCSRViewT.h>
19
20#include <alien/kernels/sycl/data/BEllPackStructInfo.h>
21#include <alien/kernels/sycl/data/SYCLDistStructInfo.h>
22
23#include <alien/kernels/sycl/SYCLBackEnd.h>
24
25#include <alien/utils/StdTimer.h>
26
27
28
29/*---------------------------------------------------------------------------*/
30
31namespace Alien
32{
33
34 namespace HCSRInternal
35 {
36 template <typename ValueT>
37 class MatrixInternal;
38 }
39/*---------------------------------------------------------------------------*/
40
41template <typename ValueT>
42class ALIEN_EXPORT HCSRMatrix : public IMatrixImpl
43{
44 public:
45 // clang-format off
46 static const bool on_host_only = false ;
47 typedef BackEnd::tag::hcsr TagType ;
48 typedef ValueT ValueType;
49 typedef ValueT value_type ;
50
51 typedef SimpleCSRInternal::CSRStructInfo CSRStructInfo;
52 typedef SimpleCSRInternal::CSRStructInfo ProfileType;
53 typedef SYCLInternal::SYCLDistStructInfo DistStructInfo;
54 typedef HCSRInternal::MatrixInternal<ValueType> MatrixInternal;
55 typedef HCSRInternal::MatrixInternal<ValueType> InternalType;
56
57 typedef typename ProfileType::IndexType IndexType ;
58 // clang-format on
59
60 typedef HCSRViewT<HCSRMatrix<ValueType>> HCSRView ;
61
62
63
64 public:
67
69 HCSRMatrix(const MultiMatrixImpl* multi_impl);
70
72 virtual ~HCSRMatrix();
73
74 void setTraceMng(ITraceMng* trace_mng) { m_trace = trace_mng; }
75
76 void allocate() ;
77
78 CSRStructInfo& getCSRProfile() { return *m_profile; }
79
80 const CSRStructInfo& getCSRProfile() const { return *m_profile; }
81
82 const CSRStructInfo& getProfile() const { return *m_profile; }
83
84 const DistStructInfo& getDistStructInfo() const { return m_matrix_dist_info; }
85
86 IMessagePassingMng* getParallelMng()
87 {
88 return m_parallel_mng;
89 }
90
91
92 void sequentialStart()
93 {
94 m_local_offset = 0;
95 m_local_size = getCSRProfile().getNRows();
96 m_global_size = m_local_size;
97 m_myrank = 0;
98 m_nproc = 1;
99 m_is_parallel = false;
100 m_matrix_dist_info.m_local_row_size.resize(m_local_size);
101 auto& profile = internal()->getCSRProfile();
102 ConstArrayView<Integer> offset = profile.getRowOffset();
103 for (Integer i = 0; i < m_local_size; ++i)
104 m_matrix_dist_info.m_local_row_size[i] = offset[i + 1] - offset[i];
105 }
106
107 void parallelStart(ConstArrayView<Integer> offset, IMessagePassingMng* parallel_mng,
108 bool need_sort_ghost_col = false)
109 {
110 m_local_size = getCSRProfile().getNRows();
111 m_parallel_mng = parallel_mng;
112 // m_trace = parallel_mng->traceMng();
113 if (m_parallel_mng == NULL) {
114 m_local_offset = 0;
115 m_global_size = m_local_size;
116 m_myrank = 0;
117 m_nproc = 1;
118 m_is_parallel = false;
119 }
120 else {
121 m_myrank = m_parallel_mng->commRank();
122 m_nproc = m_parallel_mng->commSize();
123 m_local_offset = offset[m_myrank];
124 m_global_size = offset[m_nproc];
125 m_is_parallel = (m_nproc > 1);
126 }
127 if (m_is_parallel) {
128 if (need_sort_ghost_col)
129 sortGhostCols(offset);
130 m_matrix_dist_info.compute(
131 m_nproc, offset, m_myrank, m_parallel_mng, getCSRProfile(), m_trace);
132
133 m_ghost_size = m_matrix_dist_info.m_ghost_nrow;
134 }
135 }
136
137 public:
138 bool initMatrix(Arccore::MessagePassing::IMessagePassingMng* parallel_mng,
139 Integer local_offset,
140 Integer global_size,
141 std::size_t nrows,
142 int const* kcol,
143 int const* cols,
144 SimpleCSRInternal::DistStructInfo const& matrix_dist_info);
145
146 HCSRMatrix* cloneTo(const MultiMatrixImpl* multi) const;
147
148 bool isParallel() const { return m_is_parallel; }
149
150 Integer getLocalSize() const { return m_local_size; }
151
152 Integer getLocalOffset() const { return m_local_offset; }
153
154 Integer getGlobalSize() const { return m_global_size; }
155
156 Integer getGhostSize() const { return m_ghost_size; }
157
158 Integer getAllocSize() const { return m_local_size + m_ghost_size; }
159
160 bool setMatrixValues(Arccore::Real const* values, bool only_host);
161
162 void notifyChanges();
163 void endUpdate();
164
165 MatrixInternal* internal() { return m_internal.get(); }
166
167 MatrixInternal const* internal() const { return m_internal.get(); }
168
169 void allocateDevicePointers(std::size_t nrows,
170 std::size_t nnz,
171 IndexType** rows,
172 IndexType** ncols,
173 IndexType** cols,
174 ValueType** values) const ;
175
176 void initDevicePointers(IndexType** ncols,
177 IndexType** rows,
178 IndexType** cols,
179 ValueType** values) const ;
180
181 void freeDevicePointers(IndexType* ncols,
182 IndexType* rows,
183 IndexType* cols,
184 ValueType* values) const ;
185
186 void copyDevicePointers(std::size_t nrows,
187 std::size_t nnz,
188 IndexType* rows,
189 IndexType* ncols,
190 IndexType* cols,
191 ValueType* values) const ;
192
193 HCSRView hcsrView(BackEnd::Memory::eType memory, int nrows, int nnz) const;
194
195 void initCOODevicePointers(int** dof_uids, int** rows, int** cols, ValueType** values) const ;
196 void freeCOODevicePointers(int* dof_uids, int* rows, int* cols, ValueType* values) const ;
197
198 private:
199 class IsLocal
200 {
201 public:
202 IsLocal(const ConstArrayView<Integer> offset, const Integer myrank)
203 : m_offset(offset)
204 , m_myrank(myrank)
205 {}
206 bool operator()(Arccore::Integer col) const
207 {
208 return (col >= m_offset[m_myrank]) && (col < m_offset[m_myrank + 1]);
209 }
210
211 private:
212 const ConstArrayView<Integer> m_offset;
213 const Integer m_myrank;
214 };
215
216
217 void sortGhostCols([[maybe_unused]] ConstArrayView<Integer> offset)
218 {
219 //TODO
220 /*
221 IsLocal isLocal(offset, m_myrank);
222 //UniqueArray<ValueType>& values = m_internal->getValues();
223 auto& values = m_internal->getHostValues() ;;
224 ProfileType& profile = getCSRProfile();
225 UniqueArray<Integer>& cols = profile.getCols();
226 ConstArrayView<Integer> kcol = profile.getRowOffset();
227 Integer next = 0;
228 UniqueArray<Integer> gcols;
229 UniqueArray<ValueType> gvalues;
230 for (Integer irow = 0; irow < m_local_size; ++irow) {
231 bool need_sort = false;
232 Integer first = next;
233 next = kcol[irow + 1];
234 Integer row_size = next - first;
235 for (Integer k = first; k < next; ++k) {
236 if (!isLocal(cols[k])) {
237 need_sort = true;
238 break;
239 }
240 }
241 if (need_sort) {
242 gvalues.resize(row_size);
243 gcols.resize(row_size);
244 Integer local_count = 0;
245 Integer ghost_count = 0;
246 for (Integer k = first; k < next; ++k) {
247 Integer col = cols[k];
248 if (isLocal(col)) {
249 cols[first + local_count] = col;
250 values[first + local_count] = values[k];
251 ++local_count;
252 }
253 else {
254 gcols[ghost_count] = col;
255 gvalues[ghost_count] = values[k];
256 ++ghost_count;
257 }
258 }
259 for (Integer k = 0; k < ghost_count; ++k) {
260 cols[first + local_count] = gcols[k];
261 values[first + local_count] = gvalues[k];
262 ++local_count;
263 }
264 }
265 }
266 */
267 }
268
269
270 // clang-format off
271 Alien::BackEnd::Memory::eType m_mem_kind = Alien::BackEnd::Memory::Device;
272 std::unique_ptr<ProfileType> m_profile;
273 std::unique_ptr<InternalType> m_internal;
274 //InternalType* m_internal = nullptr ;
275
276
277 bool m_is_parallel = false;
278 IMessagePassingMng* m_parallel_mng = nullptr;
279 Integer m_nproc = 1;
280 Integer m_myrank = 0;
281
282 Integer m_local_size = 0;
283 Integer m_local_offset = 0;
284 Integer m_global_size = 0;
285 Integer m_ghost_size = 0;
286
287 //SimpleCSRInternal::DistStructInfo m_matrix_dist_info;
288 DistStructInfo m_matrix_dist_info;
289 ITraceMng* m_trace = nullptr;
290 // clang-format on
291
292
293};
294
295//extern template class SYCLBEllPackMatrix<double>;
296} // namespace Alien
IMatrixImpl.h.
ISpace.h.
MultiMatrixImpl.h.
HCSRMatrix(const MultiMatrixImpl *multi_impl)
Constructeur avec association ? un MultiImpl.
IMatrixImpl(const MultiMatrixImpl *multi_impl, BackEndId backend="")
Constructor.
Multi matrices representation container.
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
Definition BackEnd.h:17