Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
ProfiledMatrixBuilderImplT.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2025 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7
8
9#pragma once
10
11/*---------------------------------------------------------------------------*/
12/*---------------------------------------------------------------------------*/
13
15#include <alien/kernels/sycl/data/HCSRMatrix.h>
16#include <alien/kernels/sycl/data/HCSRMatrixInternal.h>
17
18#include <alien/kernels/sycl/data/SYCLParallelEngine.h>
19#include <alien/kernels/sycl/data/SYCLParallelEngineImplT.h>
20
21#include <alien/handlers/scalar/sycl/ProfiledMatrixBuilderT.h>
22
23#include <span>
24
25/*---------------------------------------------------------------------------*/
26/*---------------------------------------------------------------------------*/
27
28namespace Alien
29{
30
31/*---------------------------------------------------------------------------*/
32/*---------------------------------------------------------------------------*/
33
34namespace SYCL
35{
36
37 /*---------------------------------------------------------------------------*/
38 /*---------------------------------------------------------------------------*/
39 template <typename ValueT,typename IndexT>
40 class ProfiledMatrixBuilderT<ValueT,IndexT>::Impl
41 {
42 public :
43 typedef typename HCSRMatrix<ValueT>::InternalType MatrixInternalType ;
44 typedef typename MatrixInternalType::ValueBufferType ValueBufferType ;
45 typedef typename MatrixInternalType::IndexBufferType IndexBufferType ;
46
47 Impl(ValueBufferType& values_buffer,
48 IndexBufferType& cols_buffer,
49 IndexBufferType& kcol_buffer)
50 : m_values_buffer(values_buffer)
51 , m_cols_buffer(cols_buffer)
52 , m_kcol_buffer(kcol_buffer)
53 {}
54
55 ValueBufferType& m_values_buffer ;
56 IndexBufferType& m_cols_buffer ;
57 IndexBufferType& m_kcol_buffer ;
58 };
59
60 template <typename ValueT,typename IndexT>
61 ProfiledMatrixBuilderT<ValueT,IndexT>::ProfiledMatrixBuilderT(IMatrix& matrix, ResetFlag reset_values)
62 : m_matrix(matrix)
63 , m_finalized(false)
64 {
65 m_matrix.impl()->lock();
66 m_matrix_impl = &m_matrix.impl()->get<BackEnd::tag::hcsr>(true);
67
68 const MatrixDistribution& dist = m_matrix_impl->distribution();
69
70 m_local_size = dist.localRowSize();
71 m_local_offset = dist.rowOffset();
72 m_next_offset = m_local_offset + m_local_size;
73
74 SimpleCSRInternal::CSRStructInfo const& profile =
75 m_matrix_impl->getCSRProfile();
76 m_row_starts = profile.getRowOffset();
77 m_local_row_size = m_matrix_impl->getDistStructInfo().m_local_row_size;
78 m_cols = profile.getCols();
79 m_impl.reset(new Impl(m_matrix_impl->internal()->values(),
80 m_matrix_impl->internal()->cols(),
81 m_matrix_impl->internal()->kcol())) ;
82 }
83
84 template <typename ValueT,typename IndexT>
85 ProfiledMatrixBuilderT<ValueT,IndexT>::~ProfiledMatrixBuilderT()
86 {
87 if (!m_finalized) {
88 finalize();
89 }
90 }
91
92 /*---------------------------------------------------------------------------*/
93
94 template <typename ValueT,typename IndexT>
95 void ProfiledMatrixBuilderT<ValueT,IndexT>::finalize()
96 {
97 if (m_finalized)
98 return;
99 m_matrix.impl()->unlock();
100 m_finalized = true;
101 }
102
103 template <typename ValueT, typename IndexT>
104 class ProfiledMatrixBuilderT<ValueT,IndexT>::View
105 {
106 sycl::handler* m_h = nullptr ;
107 sycl::buffer<ValueT,1>* m_vb = nullptr ;
108 sycl::buffer<IndexT,1>* m_ib = nullptr ;
109 using ValueAccessorType = decltype(m_vb->template get_access<sycl::access::mode::read_write>(*m_h));
110 using IndexAccessorType = decltype(m_ib->template get_access<sycl::access::mode::read>(*m_h));
111
112 public :
113 explicit View(ValueAccessorType values_accessor,
114 IndexAccessorType cols_accessor,
115 IndexAccessorType kcol_accessor)
116 : m_values_accessor(values_accessor)
117 , m_cols_accessor(cols_accessor)
118 , m_kcol_accessor(kcol_accessor)
119 {}
120
121 ValueT& operator[](IndexT index) const {
122 return m_values_accessor[index] ;
123 }
124
125 IndexT entryIndex(IndexT row, IndexT col) const {
126 for(auto k=m_kcol_accessor[row];k<m_kcol_accessor[row+1];++k)
127 if(m_cols_accessor[k]==col)
128 return k ;
129 return -1 ;
130 }
131
132 protected :
133 ValueAccessorType m_values_accessor ;
134 IndexAccessorType m_cols_accessor ;
135 IndexAccessorType m_kcol_accessor ;
136
137 } ;
138
139 template <typename ValueT,typename IndexT>
140 class ProfiledMatrixBuilderT<ValueT,IndexT>::ConstView
141 {
142 sycl::handler* m_h = nullptr ;
143 sycl::buffer<ValueT,1>* m_vb = nullptr ;
144 sycl::buffer<IndexT,1>* m_ib = nullptr ;
145 using ValueAccessorType = decltype(m_vb->template get_access<sycl::access::mode::read>(*m_h));
146 using IndexAccessorType = decltype(m_ib->template get_access<sycl::access::mode::read>(*m_h));
147
148 public :
149 explicit ConstView(ValueAccessorType values_accessor,
150 IndexAccessorType cols_accessor,
151 IndexAccessorType kcol_accessor)
152 : m_values_accessor(values_accessor)
153 , m_cols_accessor(cols_accessor)
154 , m_kcol_accessor(kcol_accessor)
155 {}
156
157 ValueT operator[](IndexT index) const {
158 return m_values_accessor[index] ;
159 }
160
161 IndexT entryIndex(IndexT row,IndexT col) const {
162 for(auto k=m_kcol_accessor[row];k<m_kcol_accessor[row+1];++k)
163 if(m_cols_accessor[k]==col)
164 return k ;
165 return -1 ;
166 }
167
168 protected :
169 ValueAccessorType m_values_accessor ;
170 IndexAccessorType m_cols_accessor ;
171 IndexAccessorType m_kcol_accessor ;
172
173 } ;
174
175
176 template <typename ValueT,typename IndexT>
177 class ProfiledMatrixBuilderT<ValueT,IndexT>::HostView
178 {
179 public :
180 sycl::buffer<ValueT,1>* m_b = nullptr ;
181 using ValueAccessorType = decltype(m_b->get_host_access());
182
183 sycl::buffer<IndexT,1>* m_ib = nullptr ;
184 using IndexAccessorType = decltype(m_ib->get_host_access());
185
186 HostView(ValueAccessorType values,
187 IndexAccessorType cols,
188 IndexAccessorType kcol)
189 : m_values(values)
190 , m_cols(cols)
191 , m_kcol(kcol)
192 {}
193
194 ValueType operator[](IndexT index) const {
195 return m_values[index] ;
196 }
197
198 IndexT entryIndex(IndexT row,IndexT col) const {
199 for(auto k=m_kcol[row];k<m_kcol[row+1];++k)
200 if(m_cols[k]==col)
201 return k ;
202 return -1 ;
203 }
204
205 IndexT kcol(IndexT row) const {
206 return m_kcol[row] ;
207 }
208
209 IndexT col(IndexT index) const {
210 return m_cols[index] ;
211 }
212
213
214 protected:
215 ValueAccessorType m_values ;
216 IndexAccessorType m_cols;
217 IndexAccessorType m_kcol;
218 //std::span<IndexT> m_kcol ;
219 //std::span<IndexT> m_cols ;
220 };
221
222 /*---------------------------------------------------------------------------*/
223 template <typename ValueT,typename IndexT>
224 typename ProfiledMatrixBuilderT<ValueT,IndexT>::View ProfiledMatrixBuilderT<ValueT,IndexT>::view(SYCLControlGroupHandler& cgh)
225 {
226 return View(m_impl->m_values_buffer.template get_access<sycl::access::mode::read_write>(cgh.m_internal),
227 m_impl->m_cols_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal),
228 m_impl->m_kcol_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal)) ;
229 }
230
231 template <typename ValueT,typename IndexT>
232 typename ProfiledMatrixBuilderT<ValueT,IndexT>::ConstView ProfiledMatrixBuilderT<ValueT,IndexT>::constView(SYCLControlGroupHandler& cgh) const
233 {
234 return ProfiledMatrixBuilderT<ValueT,IndexT>::ConstView(m_impl->m_values_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal),
235 m_impl->m_cols_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal),
236 m_impl->m_kcol_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal)) ;
237 }
238
239 template <typename ValueT,typename IndexT>
240 typename ProfiledMatrixBuilderT<ValueT,IndexT>::HostView ProfiledMatrixBuilderT<ValueT,IndexT>::hostView() const
241 {
242 return HostView(m_impl->m_values_buffer.get_host_access(),
243 m_impl->m_cols_buffer.get_host_access(),
244 m_impl->m_kcol_buffer.get_host_access()) ;
245 }
246
247 /*---------------------------------------------------------------------------*/
248 /*---------------------------------------------------------------------------*/
249
250 /*---------------------------------------------------------------------------*/
251 /*---------------------------------------------------------------------------*/
252
253} // namespace SYCL
254
255/*---------------------------------------------------------------------------*/
256/*---------------------------------------------------------------------------*/
257
258} // namespace Alien
259
260/*---------------------------------------------------------------------------*/
261/*---------------------------------------------------------------------------*/
MultiMatrixImpl.h.
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
Definition BackEnd.h:17