10#include <arccore/collections/Array2.h>
12#include <alien/kernels/sycl/data/SYCLSendRecvOp.h>
15namespace Alien::SYCLInternal
20template <
typename ValueT>
22: m_matrix_impl(matrix)
27template <
typename ValueT>
30 if (m_matrix_impl.m_is_parallel)
36template <
typename ValueT>
37void SYCLBEllPackMatrixMultT<ValueT>::addLMult(Real alpha,
const VectorType& x,
VectorType& y)
const
39#ifdef ALIEN_USE_PERF_TIMER
40 typename MatrixType::SentryType sentry(m_matrix_impl.timer(),
"SYCL-AddLMult");
42 m_matrix_impl.addLMult(alpha, x, y);
45template <
typename ValueT>
46void SYCLBEllPackMatrixMultT<ValueT>::addUMult(Real alpha,
const VectorType& x,
VectorType& y)
const
48#ifdef ALIEN_USE_PERF_TIMER
49 typename MatrixType::SentryType sentry(m_matrix_impl.timer(),
"SYCL-AddUMult");
51 m_matrix_impl.addUMult(alpha, x, y);
54template <
typename ValueT>
57#ifdef ALIEN_USE_PERF_TIMER
58 typename MatrixType::SentryType sentry(m_matrix_impl.timer(),
"SYCL-SPMV");
60 if (m_matrix_impl.m_is_parallel)
68template <
typename ValueT>
69void SYCLBEllPackMatrixMultT<ValueT>::_parallelMult(
70const VectorType& x_impl, VectorType& y_impl)
const
75 m_matrix_impl.m_matrix_dist_info.m_send_info,
76 m_matrix_impl.internal()->getSendIds(),
77 m_matrix_impl.m_send_policy,
78 x_impl.internal()->ghostValues(m_matrix_impl.getGhostSize()),
79 m_matrix_impl.m_matrix_dist_info.m_recv_info,
80 m_matrix_impl.internal()->getRecvIds(),
81 m_matrix_impl.m_recv_policy,
82 m_matrix_impl.m_parallel_mng,
83 m_matrix_impl.m_trace);
87 m_matrix_impl.mult(x_impl, y_impl);
91 m_matrix_impl.endDistMult(x_impl, y_impl);
97template <
typename ValueT>
98void SYCLBEllPackMatrixMultT<ValueT>::_parallelMult(
99[[maybe_unused]]
const UniqueArray<Real>& x_impl, [[maybe_unused]] UniqueArray<Real>& y_impl)
const
101#ifdef ENABLE_MPI_SYCL
102 Real* y_ptr = dataPtr(y_impl);
103 Real* x_ptr = (Real*)dataPtr(x_impl);
104 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
105 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
106 ConstArrayView<Integer> row_offset =
107 m_matrix_impl.m_matrix.getProfile().getRowOffset();
108 SendRecvOp<Real> op(x_ptr, m_matrix_impl.m_matrix_dist_info.m_send_info,
109 m_matrix_impl.m_send_policy, x_ptr, m_matrix_impl.m_matrix_dist_info.m_recv_info,
110 m_matrix_impl.m_recv_policy, m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace);
116#ifdef ENABLE_MPI_SYCL
119 Integer interface_nrow = m_matrix_impl.m_matrix_dist_info.m_interface_nrow;
120 ConstArrayView<Integer> row_ids = m_matrix_impl.m_matrix_dist_info.m_interface_rows;
121 for (Integer i = 0; i < interface_nrow; ++i) {
122 Integer irow = row_ids[i];
123 Integer off = row_offset[irow] + local_row_size[irow];
124 Integer off2 = row_offset[irow + 1];
126 for (Integer j = off; j < off2; ++j) {
127 tmpy += matrix[j] * x_ptr[cols[j]];
135template <
typename ValueT>
136void SYCLBEllPackMatrixMultT<ValueT>::_seqMult(
const VectorType& x_impl, VectorType& y_impl)
const
138 m_matrix_impl.mult(x_impl, y_impl);
141template <
typename ValueT>
142void SYCLBEllPackMatrixMultT<ValueT>::_seqMult([[maybe_unused]]
const UniqueArray<Real>& x_impl,
143 [[maybe_unused]] UniqueArray<Real>& y_impl)
const
148template <
typename ValueT>
149void SYCLBEllPackMatrixMultT<ValueT>::multDiag(VectorType
const& y, VectorType& z)
const
151 m_matrix_impl.multDiag(y,z);
154template <
typename ValueT>
155void SYCLBEllPackMatrixMultT<ValueT>::multDiag(VectorType& y)
const
157 m_matrix_impl.multDiag(y);
160template <
typename ValueT>
161void SYCLBEllPackMatrixMultT<ValueT>::computeDiag(VectorType& y)
const
163 m_matrix_impl.computeDiag(y);
166template <
typename ValueT>
167void SYCLBEllPackMatrixMultT<ValueT>::multInvDiag(VectorType& y)
const
169 m_matrix_impl.multInvDiag(y);
172template <
typename ValueT>
173void SYCLBEllPackMatrixMultT<ValueT>::computeInvDiag(VectorType& y)
const
175 m_matrix_impl.computeInvDiag(y);
void mult(const VectorType &x, VectorType &y) const
Matrix vector product.
SYCLBEllPackMatrixMultT(const MatrixType &matrix)
Constructeur de la classe.