Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
SYCLBEllPackMatrixMultT.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7
8#pragma once
9
10#include <arccore/collections/Array2.h>
11
12#include <alien/kernels/sycl/data/SYCLSendRecvOp.h>
13/*---------------------------------------------------------------------------*/
14
15namespace Alien::SYCLInternal
16{
17
18/*---------------------------------------------------------------------------*/
19
20template <typename ValueT>
22: m_matrix_impl(matrix)
23{}
24
25/*---------------------------------------------------------------------------*/
26
27template <typename ValueT>
28void SYCLBEllPackMatrixMultT<ValueT>::mult(const VectorType& x, VectorType& y) const
29{
30 if (m_matrix_impl.m_is_parallel)
31 _parallelMult(x, y);
32 else
33 _seqMult(x, y);
34}
35
36template <typename ValueT>
37void SYCLBEllPackMatrixMultT<ValueT>::addLMult(Real alpha, const VectorType& x, VectorType& y) const
38{
39#ifdef ALIEN_USE_PERF_TIMER
40 typename MatrixType::SentryType sentry(m_matrix_impl.timer(), "SYCL-AddLMult");
41#endif
42 m_matrix_impl.addLMult(alpha, x, y);
43}
44
45template <typename ValueT>
46void SYCLBEllPackMatrixMultT<ValueT>::addUMult(Real alpha, const VectorType& x, VectorType& y) const
47{
48#ifdef ALIEN_USE_PERF_TIMER
49 typename MatrixType::SentryType sentry(m_matrix_impl.timer(), "SYCL-AddUMult");
50#endif
51 m_matrix_impl.addUMult(alpha, x, y);
52}
53
54template <typename ValueT>
55void SYCLBEllPackMatrixMultT<ValueT>::mult(const UniqueArray<Real>& x, UniqueArray<Real>& y) const
56{
57#ifdef ALIEN_USE_PERF_TIMER
58 typename MatrixType::SentryType sentry(m_matrix_impl.timer(), "SYCL-SPMV");
59#endif
60 if (m_matrix_impl.m_is_parallel)
61 _parallelMult(x, y);
62 else
63 _seqMult(x, y);
64}
65
66/*---------------------------------------------------------------------------*/
67
68template <typename ValueT>
69void SYCLBEllPackMatrixMultT<ValueT>::_parallelMult(
70const VectorType& x_impl, VectorType& y_impl) const
71{
72 //Alien::alien_debug([&] {Alien::cout() << "SYCL PARALLEL MULT : "<<m_matrix_impl.getGhostSize();});
73 //Universe().traceMng()->flush() ;
74 SYCLSendRecvOp<ValueT> op(x_impl.internal()->values(),
75 m_matrix_impl.m_matrix_dist_info.m_send_info,
76 m_matrix_impl.internal()->getSendIds(),
77 m_matrix_impl.m_send_policy,
78 x_impl.internal()->ghostValues(m_matrix_impl.getGhostSize()),
79 m_matrix_impl.m_matrix_dist_info.m_recv_info,
80 m_matrix_impl.internal()->getRecvIds(),
81 m_matrix_impl.m_recv_policy,
82 m_matrix_impl.m_parallel_mng,
83 m_matrix_impl.m_trace);
84
85 op.start();
86
87 m_matrix_impl.mult(x_impl, y_impl);
88
89 op.end();
90
91 m_matrix_impl.endDistMult(x_impl, y_impl);
92
93 //Alien::alien_debug([&] {Alien::cout() << "End SYCL PARALLEL MULT";});
94 //Universe().traceMng()->flush() ;
95}
96
97template <typename ValueT>
98void SYCLBEllPackMatrixMultT<ValueT>::_parallelMult(
99[[maybe_unused]] const UniqueArray<Real>& x_impl, [[maybe_unused]] UniqueArray<Real>& y_impl) const
100{
101#ifdef ENABLE_MPI_SYCL
102 Real* y_ptr = dataPtr(y_impl);
103 Real* x_ptr = (Real*)dataPtr(x_impl);
104 ConstArrayView<Real> matrix = m_matrix_impl.m_matrix.getValues();
105 ConstArrayView<Integer> cols = m_matrix_impl.getDistStructInfo().m_cols;
106 ConstArrayView<Integer> row_offset =
107 m_matrix_impl.m_matrix.getProfile().getRowOffset();
108 SendRecvOp<Real> op(x_ptr, m_matrix_impl.m_matrix_dist_info.m_send_info,
109 m_matrix_impl.m_send_policy, x_ptr, m_matrix_impl.m_matrix_dist_info.m_recv_info,
110 m_matrix_impl.m_recv_policy, m_matrix_impl.m_parallel_mng, m_matrix_impl.m_trace);
111 op.start();
112#endif
113
114 //m_matrix_impl.mult(x_impl,y_impl) ;
115
116#ifdef ENABLE_MPI_SYCL
117 op.end();
118
119 Integer interface_nrow = m_matrix_impl.m_matrix_dist_info.m_interface_nrow;
120 ConstArrayView<Integer> row_ids = m_matrix_impl.m_matrix_dist_info.m_interface_rows;
121 for (Integer i = 0; i < interface_nrow; ++i) {
122 Integer irow = row_ids[i];
123 Integer off = row_offset[irow] + local_row_size[irow];
124 Integer off2 = row_offset[irow + 1];
125 Real tmpy = 0.;
126 for (Integer j = off; j < off2; ++j) {
127 tmpy += matrix[j] * x_ptr[cols[j]];
128 }
129 y_ptr[irow] += tmpy;
130 }
131#endif
132}
133/*---------------------------------------------------------------------------*/
134
135template <typename ValueT>
136void SYCLBEllPackMatrixMultT<ValueT>::_seqMult(const VectorType& x_impl, VectorType& y_impl) const
137{
138 m_matrix_impl.mult(x_impl, y_impl);
139}
140
141template <typename ValueT>
142void SYCLBEllPackMatrixMultT<ValueT>::_seqMult([[maybe_unused]] const UniqueArray<Real>& x_impl,
143 [[maybe_unused]] UniqueArray<Real>& y_impl) const
144{
145}
146
147/*---------------------------------------------------------------------------*/
148template <typename ValueT>
149void SYCLBEllPackMatrixMultT<ValueT>::multDiag(VectorType const& y, VectorType& z) const
150{
151 m_matrix_impl.multDiag(y,z);
152}
153
154template <typename ValueT>
155void SYCLBEllPackMatrixMultT<ValueT>::multDiag(VectorType& y) const
156{
157 m_matrix_impl.multDiag(y);
158}
159
160template <typename ValueT>
161void SYCLBEllPackMatrixMultT<ValueT>::computeDiag(VectorType& y) const
162{
163 m_matrix_impl.computeDiag(y);
164}
165
166template <typename ValueT>
167void SYCLBEllPackMatrixMultT<ValueT>::multInvDiag(VectorType& y) const
168{
169 m_matrix_impl.multInvDiag(y);
170}
171
172template <typename ValueT>
173void SYCLBEllPackMatrixMultT<ValueT>::computeInvDiag(VectorType& y) const
174{
175 m_matrix_impl.computeInvDiag(y);
176}
177
178/*---------------------------------------------------------------------------*/
179
180/*---------------------------------------------------------------------------*/
181
182} // namespace Alien::SYCLInternal
183
184/*---------------------------------------------------------------------------*/
void mult(const VectorType &x, VectorType &y) const
Matrix vector product.
SYCLBEllPackMatrixMultT(const MatrixType &matrix)
Constructeur de la classe.