15#include <alien/kernels/sycl/data/HCSRMatrix.h>
16#include <alien/kernels/sycl/data/HCSRMatrixInternal.h>
18#include <alien/kernels/sycl/data/SYCLParallelEngine.h>
19#include <alien/kernels/sycl/data/SYCLParallelEngineImplT.h>
21#include <alien/handlers/scalar/sycl/ProfiledMatrixBuilderT.h>
39 template <
typename ValueT,
typename IndexT>
43 typedef typename HCSRMatrix<ValueT>::InternalType MatrixInternalType ;
44 typedef typename MatrixInternalType::ValueBufferType ValueBufferType ;
45 typedef typename MatrixInternalType::IndexBufferType IndexBufferType ;
47 Impl(ValueBufferType& values_buffer,
48 IndexBufferType& cols_buffer,
49 IndexBufferType& kcol_buffer)
50 : m_values_buffer(values_buffer)
51 , m_cols_buffer(cols_buffer)
52 , m_kcol_buffer(kcol_buffer)
55 ValueBufferType& m_values_buffer ;
56 IndexBufferType& m_cols_buffer ;
57 IndexBufferType& m_kcol_buffer ;
60 template <
typename ValueT,
typename IndexT>
61 ProfiledMatrixBuilderT<ValueT,IndexT>::ProfiledMatrixBuilderT(IMatrix& matrix, ResetFlag reset_values)
65 m_matrix.impl()->lock();
66 m_matrix_impl = &m_matrix.impl()->get<BackEnd::tag::hcsr>(
true);
68 const MatrixDistribution& dist = m_matrix_impl->distribution();
70 m_local_size = dist.localRowSize();
71 m_local_offset = dist.rowOffset();
72 m_next_offset = m_local_offset + m_local_size;
74 SimpleCSRInternal::CSRStructInfo
const& profile =
75 m_matrix_impl->getCSRProfile();
76 m_row_starts = profile.getRowOffset();
77 m_local_row_size = m_matrix_impl->getDistStructInfo().m_local_row_size;
78 m_cols = profile.getCols();
79 m_impl.reset(
new Impl(m_matrix_impl->internal()->values(),
80 m_matrix_impl->internal()->cols(),
81 m_matrix_impl->internal()->kcol())) ;
84 template <
typename ValueT,
typename IndexT>
85 ProfiledMatrixBuilderT<ValueT,IndexT>::~ProfiledMatrixBuilderT()
94 template <
typename ValueT,
typename IndexT>
95 void ProfiledMatrixBuilderT<ValueT,IndexT>::finalize()
99 m_matrix.impl()->unlock();
103 template <
typename ValueT,
typename IndexT>
104 class ProfiledMatrixBuilderT<ValueT,IndexT>::View
106 sycl::handler* m_h = nullptr ;
107 sycl::buffer<ValueT,1>* m_vb = nullptr ;
108 sycl::buffer<IndexT,1>* m_ib = nullptr ;
109 using ValueAccessorType =
decltype(m_vb->template get_access<sycl::access::mode::read_write>(*m_h));
110 using IndexAccessorType =
decltype(m_ib->template get_access<sycl::access::mode::read>(*m_h));
113 explicit View(ValueAccessorType values_accessor,
114 IndexAccessorType cols_accessor,
115 IndexAccessorType kcol_accessor)
116 : m_values_accessor(values_accessor)
117 , m_cols_accessor(cols_accessor)
118 , m_kcol_accessor(kcol_accessor)
121 ValueT& operator[](IndexT index)
const {
122 return m_values_accessor[index] ;
125 IndexT entryIndex(IndexT row, IndexT col)
const {
126 for(
auto k=m_kcol_accessor[row];k<m_kcol_accessor[row+1];++k)
127 if(m_cols_accessor[k]==col)
133 ValueAccessorType m_values_accessor ;
134 IndexAccessorType m_cols_accessor ;
135 IndexAccessorType m_kcol_accessor ;
139 template <
typename ValueT,
typename IndexT>
140 class ProfiledMatrixBuilderT<ValueT,IndexT>::ConstView
142 sycl::handler* m_h = nullptr ;
143 sycl::buffer<ValueT,1>* m_vb = nullptr ;
144 sycl::buffer<IndexT,1>* m_ib = nullptr ;
145 using ValueAccessorType =
decltype(m_vb->template get_access<sycl::access::mode::read>(*m_h));
146 using IndexAccessorType =
decltype(m_ib->template get_access<sycl::access::mode::read>(*m_h));
149 explicit ConstView(ValueAccessorType values_accessor,
150 IndexAccessorType cols_accessor,
151 IndexAccessorType kcol_accessor)
152 : m_values_accessor(values_accessor)
153 , m_cols_accessor(cols_accessor)
154 , m_kcol_accessor(kcol_accessor)
157 ValueT operator[](IndexT index)
const {
158 return m_values_accessor[index] ;
161 IndexT entryIndex(IndexT row,IndexT col)
const {
162 for(
auto k=m_kcol_accessor[row];k<m_kcol_accessor[row+1];++k)
163 if(m_cols_accessor[k]==col)
169 ValueAccessorType m_values_accessor ;
170 IndexAccessorType m_cols_accessor ;
171 IndexAccessorType m_kcol_accessor ;
176 template <
typename ValueT,
typename IndexT>
177 class ProfiledMatrixBuilderT<ValueT,IndexT>::HostView
180 sycl::buffer<ValueT,1>* m_b = nullptr ;
181 using ValueAccessorType =
decltype(m_b->get_host_access());
183 sycl::buffer<IndexT,1>* m_ib = nullptr ;
184 using IndexAccessorType =
decltype(m_ib->get_host_access());
186 HostView(ValueAccessorType values,
187 IndexAccessorType cols,
188 IndexAccessorType kcol)
194 ValueType operator[](IndexT index)
const {
195 return m_values[index] ;
198 IndexT entryIndex(IndexT row,IndexT col)
const {
199 for(
auto k=m_kcol[row];k<m_kcol[row+1];++k)
205 IndexT kcol(IndexT row)
const {
209 IndexT col(IndexT index)
const {
210 return m_cols[index] ;
215 ValueAccessorType m_values ;
216 IndexAccessorType m_cols;
217 IndexAccessorType m_kcol;
223 template <
typename ValueT,
typename IndexT>
224 typename ProfiledMatrixBuilderT<ValueT,IndexT>::View ProfiledMatrixBuilderT<ValueT,IndexT>::view(SYCLControlGroupHandler& cgh)
226 return View(m_impl->m_values_buffer.template get_access<sycl::access::mode::read_write>(cgh.m_internal),
227 m_impl->m_cols_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal),
228 m_impl->m_kcol_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal)) ;
231 template <
typename ValueT,
typename IndexT>
232 typename ProfiledMatrixBuilderT<ValueT,IndexT>::ConstView ProfiledMatrixBuilderT<ValueT,IndexT>::constView(SYCLControlGroupHandler& cgh)
const
234 return ProfiledMatrixBuilderT<ValueT,IndexT>::ConstView(m_impl->m_values_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal),
235 m_impl->m_cols_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal),
236 m_impl->m_kcol_buffer.template get_access<sycl::access::mode::read>(cgh.m_internal)) ;
239 template <
typename ValueT,
typename IndexT>
240 typename ProfiledMatrixBuilderT<ValueT,IndexT>::HostView ProfiledMatrixBuilderT<ValueT,IndexT>::hostView()
const
242 return HostView(m_impl->m_values_buffer.get_host_access(),
243 m_impl->m_cols_buffer.get_host_access(),
244 m_impl->m_kcol_buffer.get_host_access()) ;
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --