Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
CBLASMPIKernel.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7
8#pragma once
9
10#include <cmath>
11
12#include <arccore/message_passing/Messages.h>
13
14#include <alien/utils/Precomp.h>
15#include <alien/kernels/simple_csr/algebra/alien_cblas.h>
16
17namespace Alien
18{
19
21{
22 public:
23 // static const tag::eType type = tag::CPU;
24 static const bool is_hybrid = false;
25 static const bool is_mpi = true;
26
27 template <typename Distribution, typename VectorT>
28 static void copy(
29 Distribution const& dist ALIEN_UNUSED_PARAM, const VectorT& x, VectorT& y)
30 {
31 typedef typename VectorT::ValueType ValueType;
32 cblas::copy(
33 x.scalarizedLocalSize(), (ValueType*)x.getDataPtr(), 1, y.getDataPtr(), 1);
34 }
35
36 template <typename Distribution, typename VectorT>
37 static void copy(
38 Distribution const& dist, const VectorT& x, Integer stride_x, VectorT& y, Integer stride_y)
39 {
40 typedef typename VectorT::ValueType ValueType;
41 cblas::copy(dist.localSize(), (ValueType*)x.getDataPtr(), stride_x, y.getDataPtr(), stride_y);
42 }
43
44 template <typename Distribution, typename VectorT>
45 static void axpy(Distribution const& dist ALIEN_UNUSED_PARAM,
46 typename VectorT::ValueType alpha, const VectorT& x, VectorT& y)
47 {
48 cblas::axpy(x.scalarizedLocalSize(), alpha, x.getDataPtr(), 1, y.getDataPtr(), 1);
49 }
50
51 template <typename Distribution, typename VectorT>
52 static void axpy(Distribution const& dist,
53 typename VectorT::ValueType alpha,
54 const VectorT& x,
55 Integer stride_x,
56 VectorT& y,
57 Integer stride_y)
58 {
59 cblas::axpy(x.scalarizedLocalSize(), alpha, x.getDataPtr(), stride_x, y.getDataPtr(), stride_y);
60 }
61 template <typename Distribution, typename VectorT>
62 static void scal(Distribution const& dist ALIEN_UNUSED_PARAM,
63 typename VectorT::ValueType alpha, VectorT& x)
64 {
65 cblas::scal(x.scalarizedLocalSize(), alpha, x.getDataPtr(), 1);
66 }
67
68 template <typename Distribution, typename VectorT>
69 static void pointwiseMult(Distribution const& dist,
70 VectorT const& x,
71 VectorT const& y,
72 VectorT& z)
73 {
74 auto local_size = x.scalarizedLocalSize();
75 auto x_ptr = x.getDataPtr();
76 auto y_ptr = y.getDataPtr();
77 auto z_ptr = z.getDataPtr();
78 for (std::size_t i = 0; i < local_size; ++i) {
79 z_ptr[i] = x_ptr[i] * y_ptr[i];
80#ifdef PRINT_DEBUG_INFO
81 std::cout<<"X Y Z ["<<i<<"] : "<<x_ptr[i]<<"*"<<y_ptr[i]<<"="<<z_ptr[i]<<std::endl ;
82#endif
83 }
84 }
85
86 template <typename Distribution, typename VectorT>
87 static void assign(Distribution const& dist,
88 typename VectorT::ValueType alpha,
89 VectorT& y)
90 {
91 auto local_size = y.scalarizedLocalSize();
92 auto y_ptr = y.getDataPtr();
93 for (std::size_t i = 0; i < local_size; ++i) {
94 y_ptr[i] = alpha;
95 }
96 }
97
98 template <typename Distribution, typename VectorT>
99 static typename VectorT::ValueType dot(
100 Distribution const& dist, const VectorT& x, const VectorT& y)
101 {
102 typedef typename VectorT::ValueType ValueType;
103 ValueType value = cblas::dot(x.scalarizedLocalSize(), (ValueType*)x.getDataPtr(), 1,
104 (ValueType*)y.getDataPtr(), 1);
105 if (dist.isParallel()) {
106 return Arccore::MessagePassing::mpAllReduce(
107 dist.parallelMng(), Arccore::MessagePassing::ReduceSum, value);
108 }
109 return value;
110 }
111
112 template <typename Distribution, typename VectorT>
113 static typename VectorT::ValueType nrm0(Distribution const& dist, const VectorT& x)
114 {
115 typedef typename VectorT::ValueType ValueType;
116 auto local_size = x.scalarizedLocalSize();
117 auto x_ptr = x.getDataPtr();
118 ValueType value = ValueType() ;
119 for(std::size_t i = 0; i < local_size; ++i)
120 value += (std::abs(x_ptr[i])>0?1:0) ;
121
122 if (dist.isParallel()) {
123 value = Arccore::MessagePassing::mpAllReduce(
124 dist.parallelMng(), Arccore::MessagePassing::ReduceSum, value);
125 }
126 return value;
127 }
128
129 template <typename Distribution, typename VectorT>
130 static typename VectorT::ValueType nrm1(Distribution const& dist, const VectorT& x)
131 {
132 typedef typename VectorT::ValueType ValueType;
133 typename VectorT::ValueType value = cblas::nrm1(x.scalarizedLocalSize(),
134 (ValueType*)x.getDataPtr(), 1);
135 if (dist.isParallel()) {
136 value = Arccore::MessagePassing::mpAllReduce(
137 dist.parallelMng(), Arccore::MessagePassing::ReduceSum, value);
138 }
139 return value;
140 }
141
142 template <typename Distribution, typename VectorT>
143 static typename VectorT::ValueType nrm2(Distribution const& dist, const VectorT& x)
144 {
145 typedef typename VectorT::ValueType ValueType;
146 typename VectorT::ValueType value = cblas::dot(x.scalarizedLocalSize(),
147 (ValueType*)x.getDataPtr(), 1, (ValueType*)x.getDataPtr(), 1);
148 if (dist.isParallel()) {
149 value = Arccore::MessagePassing::mpAllReduce(
150 dist.parallelMng(), Arccore::MessagePassing::ReduceSum, value);
151 }
152 return std::sqrt(value);
153 }
154
155 template <typename Distribution, typename VectorT>
156 static typename VectorT::ValueType nrmInf(Distribution const& dist, const VectorT& x)
157 {
158 typedef typename VectorT::ValueType ValueType;
159 auto local_size = x.scalarizedLocalSize();
160 auto x_ptr = x.getDataPtr();
161 ValueType value = ValueType() ;
162 for(std::size_t i = 0; i < local_size; ++i)
163 value = std::max(value,std::abs(x_ptr[i])) ;
164
165 if (dist.isParallel()) {
166 value = Arccore::MessagePassing::mpAllReduce(
167 dist.parallelMng(), Arccore::MessagePassing::ReduceMax, value);
168 }
169 return value;
170 }
171
172 template <typename Distribution, typename MatrixT>
173 static typename MatrixT::ValueType matrix_nrm2(Distribution const& dist, const MatrixT& x)
174 {
175 typedef typename MatrixT::ValueType ValueType;
176 typename MatrixT::ValueType value = cblas::dot(x.getProfile().getNnz(),
177 (ValueType*)x.data(), 1, (ValueType*)x.data(), 1);
178 if (dist.isParallel()) {
179 value = Arccore::MessagePassing::mpAllReduce(
180 dist.parallelMng(), Arccore::MessagePassing::ReduceSum, value);
181 }
182 return std::sqrt(value);
183 }
184};
185
186} // namespace Alien
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
Definition BackEnd.h:17