13#ifndef ALIEN_EXPRESSION_MVEXPR_MVEXPR_H_
14#define ALIEN_EXPRESSION_MVEXPR_MVEXPR_H_
16#include <alien/kernels/simple_csr/algebra/SimpleCSRLinearAlgebra.h>
17#include <alien/utils/Precomp.h>
55 template <
class A,
class B>
56 auto add(A&& a, B&& b)
58 return [=](
auto visitor) {
return visitor(
lazy::add_tag{}, a(visitor), b(visitor)); };
61 template <
class A,
class B>
62 auto minus(A&& a, B&& b)
65 [=](
auto visitor) {
return visitor(
lazy::minus_tag{}, a(visitor), b(visitor)); };
68 template <
class A,
class B>
69 auto mul(A&& a, B&& b)
72 [=](
auto visitor) {
return visitor(
lazy::mult_tag{}, a(visitor), b(visitor)); };
75 template <
class A,
class B>
76 auto div(A&& a, B&& b)
78 return [=](
auto visitor) {
return visitor(
lazy::div_tag{}, a(visitor), b(visitor)); };
96 return &r.distribution();
102 return &r.distribution();
105 template <
typename L>
108 return &r.distribution();
120 template <
typename R>
123 return &r.distribution();
126 template <
typename R>
129 return &l.distribution().rowDistribution();
132 template <
typename L>
135 return &r.distribution();
139 template <
class A,
class B>
140 auto scalMul(A&& a, B&& b)
142 return [=](
auto visitor) {
144 lazy::dot_tag{}, a(distribution_evaluator()), a(visitor), b(visitor));
151 return [=](
auto visitor) {
return visitor(
lazy::cst_tag{}, expr); };
155 auto ref(T
const& expr)
157 return [&](
auto visitor) ->
decltype(visitor(
lazy::ref_tag{}, expr)) {
163 struct alloc_size_evaluator;
165 template <
typename T>
166 auto matrixMult(Matrix
const& matrix, UniqueArray<T>
const& x)
169 std::cout <<
"\t\t MatrixVectorMult" << std::endl;
171 std::size_t n = matrix.distribution().localRowSize();
172 UniqueArray<T> y(n, 0.);
173 SimpleCSRLinearAlgebraExpr alg;
174 alg.
mult(matrix, x, y);
178 template <
typename Tag,
typename T>
179 auto matrixMultT(Matrix
const& matrix, UniqueArray<T>
const& x)
181 std::size_t n = matrix.distribution().localRowSize();
182 UniqueArray<T> y(n, 0.);
183 LinearAlgebraExpr<Tag> alg(matrix.distribution().parallelMng());
184 alg.mult(matrix, x, y);
188 template <
typename T>
189 auto vectorAdd(UniqueArray<T>
const& x, UniqueArray<T>
const& y)
192 std::cout <<
"\t\t VectorAdd" << std::endl;
194 std::size_t n = x.size();
195 UniqueArray<T> result(n, 0.);
196 SimpleCSRLinearAlgebraExpr alg;
198 alg.axpy(1., x, result);
199 return std::move(result);
202 template <
typename Tag,
typename T>
203 auto vectorAddT(UniqueArray<T>
const& x, UniqueArray<T>
const& y)
205 std::size_t n = x.size();
206 UniqueArray<T> result(n, 0.);
207 LinearAlgebraExpr<Tag> alg(
nullptr);
209 alg.axpy(1., x, result);
210 return std::move(result);
213 template <
typename Tag>
214 auto matrixAddT(Matrix
const& a, Matrix
const& b)
216 LinearAlgebraExpr<Tag> alg(a.distribution().parallelMng());
217 Matrix c(a.distribution());
224 auto matrixAdd(Matrix
const& a, Matrix
const& b)
227 std::cout <<
"\t\t MatrixAdd" << std::endl;
230 Matrix c(a.distribution());
231 SimpleCSRLinearAlgebraExpr alg;
237 template <
typename T>
238 auto vectorMinus(UniqueArray<T>
const& x, UniqueArray<T>
const& y)
241 std::cout <<
"\t\t VectorMinus" << std::endl;
243 std::size_t n = x.size();
244 UniqueArray<T> result(n, 0.);
246 SimpleCSRLinearAlgebraExpr alg;
248 alg.axpy(-1., y, result);
249 return std::move(result);
252 template <
typename Tag,
typename T>
253 auto vectorMinusT(UniqueArray<T>
const& x, UniqueArray<T>
const& y)
255 std::size_t n = x.size();
256 UniqueArray<T> result(n, 0.);
258 LinearAlgebraExpr<Tag> alg(
nullptr);
260 alg.axpy(-1., y, result);
261 return std::move(result);
264 template <
typename T>
265 auto vectorMult(T
const& lambda, UniqueArray<T>
const& x)
268 std::cout <<
"\t\t VectorScal" << std::endl;
270 std::size_t n = x.size();
271 UniqueArray<T> y(n, 0.);
272 SimpleCSRLinearAlgebraExpr alg;
273 alg.
axpy(lambda, x, y);
277 template <
typename Tag,
typename T>
278 auto vectorMultT(T
const& lambda, UniqueArray<T>
const& x)
280 std::size_t n = x.size();
281 UniqueArray<T> y(n, 0.);
282 LinearAlgebraExpr<Tag> alg(
nullptr);
283 alg.axpy(lambda, x, y);
287 auto vectorScalProduct(Vector
const& a, Vector
const& b)
290 std::cout <<
"\t\t VectorScal" << std::endl;
292 SimpleCSRLinearAlgebraExpr alg;
293 return alg.
dot(a, b);
296 template <
typename Tag>
297 auto vectorScalProductT(Vector
const& a, Vector
const& b)
299 LinearAlgebraExpr<Tag> alg(a.distribution().parallelMng());
300 return alg.dot(a, b);
303 auto vectorScalProduct(
304 VectorDistribution
const* distribution, Vector
const& a, UniqueArray<Real>
const& b)
307 std::cout <<
"\t\t VectorScal" << std::endl;
309 SimpleCSRVector<Real>
const& csr_a = a.impl()->get<BackEnd::tag::simplecsr>();
310 SimpleCSRLinearAlgebraExpr alg;
311 Integer local_size = distribution ? distribution->localSize() : b.size();
312 Real value = alg.
dot(local_size, csr_a.getArrayValues(), b);
313 if (distribution && distribution->isParallel())
314 return Arccore::MessagePassing::mpAllReduce(
315 distribution->parallelMng(), Arccore::MessagePassing::ReduceSum, value);
320 template <
typename Tag>
321 auto vectorScalProductT(
322 VectorDistribution
const* distribution, Vector
const& a, UniqueArray<Real>
const& b)
324 auto const& csr_a = a.impl()->get<Tag>();
325 LinearAlgebraExpr<Tag> alg(
nullptr);
326 Integer local_size = distribution ? distribution->localSize() : b.size();
327 Real value = alg.dot(local_size, csr_a.getArrayValues(), b);
328 if (distribution && distribution->isParallel())
329 return Arccore::MessagePassing::mpAllReduce(
330 distribution->parallelMng(), Arccore::MessagePassing::ReduceSum, value);
335 auto vectorScalProduct(
336 VectorDistribution
const* distribution, UniqueArray<Real>
const& a, Vector
const& b)
339 std::cout <<
"\t\t VectorScal" << std::endl;
341 SimpleCSRVector<Real>
const& csr_b = b.impl()->get<BackEnd::tag::simplecsr>();
342 SimpleCSRLinearAlgebraExpr alg;
343 Integer local_size = distribution ? distribution->localSize() : a.size();
344 Real value = alg.
dot(local_size, a, csr_b.getArrayValues());
345 if (distribution && distribution->isParallel())
346 return Arccore::MessagePassing::mpAllReduce(
347 distribution->parallelMng(), Arccore::MessagePassing::ReduceSum, value);
352 template <
typename Tag>
353 auto vectorScalProductT(
354 VectorDistribution
const* distribution, UniqueArray<Real>
const& a, Vector
const& b)
356 auto const& csr_b = b.impl()->get<Tag>();
357 LinearAlgebraExpr<Tag> alg(
nullptr);
358 Integer local_size = distribution ? distribution->localSize() : a.size();
359 Real value = alg.dot(local_size, a, csr_b.getArrayValues());
360 if (distribution && distribution->isParallel())
361 return Arccore::MessagePassing::mpAllReduce(
362 distribution->parallelMng(), Arccore::MessagePassing::ReduceSum, value);
367 auto vectorScalProduct(VectorDistribution
const* distribution,
368 UniqueArray<Real>
const& a, UniqueArray<Real>
const& b)
371 std::cout <<
"\t\t VectorScal" << std::endl;
373 SimpleCSRLinearAlgebraExpr alg;
374 Integer local_size = distribution ? distribution->localSize() : a.size();
375 Real value = alg.
dot(local_size, a, b);
376 if (distribution && distribution->isParallel())
377 return Arccore::MessagePassing::mpAllReduce(
378 distribution->parallelMng(), Arccore::MessagePassing::ReduceSum, value);
383 template <
typename Tag>
384 auto vectorScalProductT(VectorDistribution
const* distribution,
385 UniqueArray<Real>
const& a, UniqueArray<Real>
const& b)
387 LinearAlgebraExpr<Tag> alg(distribution->parallelMng());
388 Integer local_size = distribution ? distribution->localSize() : a.size();
389 Real value = alg.dot(local_size, a, b);
390 if (distribution && distribution->isParallel())
391 return Arccore::MessagePassing::mpAllReduce(
392 distribution->parallelMng(), Arccore::MessagePassing::ReduceSum, value);
397 template <
typename Tag,
typename T>
398 auto matrixScalT(T
const& lambda, Matrix
const& A)
400 Matrix B(A.distribution());
401 LinearAlgebraExpr<Tag> alg(A.distribution().parallelMng());
407 template <
typename T>
408 auto matrixScal(T
const& lambda, Matrix
const& A)
411 std::cout <<
"\t\t MatrixScal" << std::endl;
413 Matrix B(A.distribution());
414 SimpleCSRLinearAlgebraExpr alg;
423 template <
typename T>
427 std::cout <<
"\t return cst" << std::endl;
432 template <
typename T>
438 template <
typename T>
442 std::cout <<
"\t visit A*b" << std::endl;
444 return matrixMult(a, b);
451 std::cout <<
"\t visit A*b" << std::endl;
455 csr_b.resize(csr_matrix.getAllocSize());
456 return matrixMult(a, csr_b.getArrayValues());
462 std::cout <<
"\t visit lambda*b : " << b.name() << std::endl;
465 return vectorMult(lambda, csr_b.getArrayValues());
470 return matrixScal(lambda, a);
476 std::cout <<
"\t visit a+b" << std::endl;
480 return vectorAdd(csr_a.getArrayValues(), csr_b.getArrayValues());
486 std::cout <<
"\t visit a+b" << std::endl;
489 return vectorAdd(csr_a.getArrayValues(), b);
493 auto operator()(
lazy::add_tag, UniqueArray<T>
const& a, UniqueArray<T>
const& b)
496 std::cout <<
"\t visit a+b" << std::endl;
498 return vectorAdd(a, b);
504 std::cout <<
"\t visit a+b" << std::endl;
506 return matrixAdd(a, b);
515 std::cout <<
"\t visit a-b" << std::endl;
519 return vectorMinus(csr_a.getArrayValues(), csr_b.getArrayValues());
525 std::cout <<
"\t visit a-b" << std::endl;
528 return vectorMinus(csr_a.getArrayValues(), b);
535 std::cout <<
"\t visit dot(a,b)" << std::endl;
537 return vectorScalProduct(a, b);
541 Vector const& a, UniqueArray<Real>
const& b)
544 std::cout <<
"\t visit dot(a,b)" << std::endl;
546 return vectorScalProduct(distribution, a, b);
550 UniqueArray<Real>
const& a,
Vector const& b)
553 std::cout <<
"\t visit dot(a,b)" << std::endl;
555 return vectorScalProduct(distribution, a, b);
559 UniqueArray<Real>
const& a, UniqueArray<Real>
const& b)
562 std::cout <<
"\t visit dot(a,b)" << std::endl;
564 return vectorScalProduct(distribution, a, b);
567 template <
class A,
class B>
574 template <
typename Tag>
578 template <
typename T>
581 template <
typename T>
584 template <
typename T>
587 return matrixMultT<Tag>(a, b);
592 auto const& tag_matrix = a.
impl()->template get<Tag>();
593 auto const& tag_b = b.
impl()->template get<Tag>();
594 tag_b.resize(tag_matrix.getAllocSize());
595 return matrixMultT<Tag>(a, tag_b.getArrayValues());
600 auto const& tag_b = b.
impl()->template get<Tag>();
601 return vectorMultT<Tag>(lambda, tag_b.getArrayValues());
606 return matrixScalT<Tag, Real>(lambda, a);
611 auto const& csr_a = a.
impl()->template get<Tag>();
612 auto const& csr_b = b.
impl()->template get<Tag>();
613 return vectorAddT<Tag>(csr_a.getArrayValues(), csr_b.getArrayValues());
618 auto const& csr_a = a.
impl()->template get<Tag>();
619 return vectorAddT<Tag>(csr_a.getArrayValues(), b);
623 auto operator()(
lazy::add_tag, UniqueArray<T>
const& a, UniqueArray<T>
const& b)
625 return vectorAddT<Tag>(a, b);
630 return matrixAddT<Tag>(a, b);
635 auto const& csr_a = a.
impl()->template get<Tag>();
636 auto const& csr_b = b.
impl()->template get<Tag>();
637 return vectorMinusT<Tag>(csr_a.getArrayValues(), csr_b.getArrayValues());
642 auto const& csr_a = a.
impl()->template get<Tag>();
643 return vectorMinusT<Tag>(csr_a.getArrayValues(), b);
649 return vectorScalProductT<Tag>(a, b);
653 Vector const& a, UniqueArray<Real>
const& b)
655 return vectorScalProductT<Tag>(distribution, a, b);
659 UniqueArray<Real>
const& a,
Vector const& b)
661 return vectorScalProductT<Tag>(distribution, a, b);
665 UniqueArray<Real>
const& a, UniqueArray<Real>
const& b)
667 return vectorScalProductT<Tag>(distribution, a, b);
670 template <
class A,
class B>
682 return std::numeric_limits<size_t>::max();
690 return r.rowSpace().size();
693 template <
class T,
class A,
class B>
694 auto operator()(T, A a, B b)
696 return std::min(a, b);
700 std::size_t allocSize(
Matrix const& A)
703 return csr_matrix.getLocalSize() + csr_matrix.getGhostSize();
706 std::size_t allocSize(
Vector const& x)
709 return csr_x.getAllocSize();
715 auto operator()(
lazy::cst_tag, T c) {
return std::size_t(0); }
721 template <
typename L>
727 template <
typename L>
733 template <
typename L>
742 auto base_eval(E
const& expr) {
return expr(
cpu_evaluator()); }
744 auto eval(E
const& expr) {
return expr(cpu_evaluator()); }
746 auto operator*(Matrix
const& l, Vector
const& r) {
return mul(ref(l), ref(r)); }
748 auto operator*(Real lambda, Vector
const& r)
751 std::cout <<
"lambda*x" << std::endl;
753 return mul(cst(lambda), ref(r));
756 auto operator*(Real lambda, Matrix
const& r)
759 std::cout <<
"lambda*A" << std::endl;
761 return mul(cst(lambda), ref(r));
764 template <
typename R>
765 auto operator*(Matrix
const& l, R
const& r)
767 return mul(ref(l), r);
777 auto operator+(Vector
const& l, Vector
const& r) {
return add(ref(l), ref(r)); }
779 template <
typename R>
780 auto operator+(Vector
const& l, R
const& r)
782 return add(ref(l), r);
785 template <
typename L>
786 auto operator+(L
const& l, Vector
const& r)
788 return add(l, ref(r));
791 auto operator+(Matrix
const& l, Matrix
const& r) {
return add(ref(l), ref(r)); }
800 template <
typename R>
801 auto operator-(Vector& l, R&& r) {
return minus(ref(l), r); }
803 template <
typename L>
804 auto operator-(L&& l, Vector& r) {
return minus(l, ref(r)); }
806 template <
typename L,
typename R>
807 auto operator-(L&& l, R&& r) {
return minus(l, r); }
809 auto dot(Vector
const& x, Vector
const& y) {
return scalMul(ref(x), ref(y)); }
811 template <
typename R>
812 auto dot(Vector
const& x, R
const& y)
814 return scalMul(ref(x), y);
817 template <
typename L>
818 auto dot(L
const& x, Vector
const& y)
820 return scalMul(x, ref(y));
823 template <
typename L,
typename R>
824 auto dot(L
const& x, R
const& y)
826 return scalMul(x, y);
830 void assign(Vector& y, E
const& expr)
832 SimpleCSRVector<Real>& csr_y = y.impl()->get<BackEnd::tag::simplecsr>(
true);
837 template <
typename Tag,
class E>
838 void kassign(Vector& y, E
const& expr)
840 auto& backend_y = y.impl()->template get<Tag>(
true);
841 backend_y.allocate();
845 template <
typename A,
typename B>
846 auto vassign(A&& a, B&& b)
851 template <
typename E>
852 auto veval(
double& value, E&& e)
854 return [&](
auto visitor) {
return visitor(
lazy::eval_tag{}, value, e); };
859 template <
typename T>
863 std::cout <<
"pipeline eval" << std::endl;
867 template <
typename T0,
typename... T>
868 void eval(T0&& expr0, T&&... args)
874 template <
typename E>
878 std::cout <<
"vector assignement" << std::endl;
883 template <
typename E>
886 value = base_eval(expr);
891 template <
typename... T>
892 void pipeline(T... args)
895 evaluator.eval(args...);
901Vector::operator=(E
const& expr)
903 SimpleCSRVector<Real>& csr_y =
impl()->
get<BackEnd::tag::simplecsr>(
true);
905 csr_y.setArrayValues(expr(MVExpr::cpu_evaluator()));
911Matrix::operator=(E
const& expr)
913 *
this = expr(MVExpr::cpu_evaluator());
virtual Arccore::Integer size() const =0
Get space size.
void mult(const IMatrix &a, const IVector &x, IVector &r) const
Compute a matrix vector product.
void copy(const IVector &x, IVector &r) const
Copy a vector in another one.
void axpy(Real alpha, const IVector &x, IVector &y) const
Scale a vector by a factor and adds the result to another vector.
Real dot(const IVector &x, const IVector &y) const
Compute the dot product of two vectors.
const VectorDistribution & rowDistribution() const
Get the row distribution.
const ISpace & rowSpace() const
Get row space associated to the matrix.
MultiMatrixImpl * impl()
Get the multimatrix implementation.
const AlgebraTraits< tag >::matrix_type & get() const
Get a specific matrix implementation.
const AlgebraTraits< tag >::vector_type & get() const
Get a specific vector implementation.
Computes a vector distribution.
MultiVectorImpl * impl()
Get the multivector implementation.
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --