Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
BiCGStab.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2026 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7
8#pragma once
9
10#include <ostream>
11#include <vector>
12
13namespace Alien
14{
15
16template <typename AlgebraT>
17class BiCGStab
18{
19 public:
20 // clang-format off
21 typedef AlgebraT AlgebraType;
22 typedef typename AlgebraType::Matrix MatrixType;
23 typedef typename AlgebraType::Vector VectorType;
24 typedef typename MatrixType::ValueType ValueType;
25 typedef typename AlgebraType::FutureType FutureType;
26 // clang-format on
27
28 class Iteration
29 {
30 public:
31 Iteration(AlgebraType& algebra,
32 VectorType const& b,
33 ValueType tol,
34 int max_iter,
35 ITraceMng* trace_mng = nullptr)
36 : m_algebra(algebra)
37 , m_value(0)
38 , m_f_value(m_value)
39 , m_max_iteration(max_iter)
40 , m_tol(tol)
41 , m_iter(0)
42 , m_trace_mng(trace_mng)
43 {
44 m_algebra.dot(b, b, m_f_value);
45 m_nrm2_b = m_f_value.get();
46 if (m_trace_mng)
47 m_trace_mng->info() << "STOP CRITERIA NORME B = " << m_nrm2_b;
48 m_criteria_value = m_tol * m_tol * m_nrm2_b;
49 m_sqrt_nrm2_b = std::sqrt(m_nrm2_b);
50 m_value = m_criteria_value + 1;
51 if (m_nrm2_b == 0)
52 m_status = true;
53 else
54 m_status = false;
55 }
56
57 virtual ~Iteration()
58 {}
59
60 bool nullRhs() const
61 {
62 return m_nrm2_b == 0.;
63 }
64
65 bool first() const
66 {
67 return m_iter == 0;
68 }
69
70 bool stop(VectorType const& r)
71 {
72 if (m_iter >= m_max_iteration)
73 return true;
74 m_algebra.dot(r, r, m_f_value);
75 m_status = m_f_value.get() < m_criteria_value;
76 return m_status;
77 }
78
79 void operator++()
80 {
81 if (m_trace_mng)
82 m_trace_mng->info() << "iteration (" << m_iter << ") criteria = " << getValue();
83 ++m_iter;
84 }
85
86 ValueType getValue() const
87 {
88 if (m_sqrt_nrm2_b == 0)
89 return 0.;
90 else
91 return std::sqrt(m_value) / m_sqrt_nrm2_b;
92 }
93
94 int operator()() const
95 {
96 return m_iter;
97 }
98
99 bool getStatus() const
100 {
101 return m_status;
102 }
103
104 private:
105 // clang-format off
106 AlgebraType& m_algebra;
107 int m_max_iteration = 0;
108 ValueType m_tol = 0.;
109 int m_iter = 0;
110 ValueType m_value = 0.;
111 FutureType m_f_value;
112 ValueType m_criteria_value = 0.;
113 ValueType m_value_init = 0.;
114 ValueType m_nrm2_b = 0.;
115 ValueType m_sqrt_nrm2_b = 0.;
116 bool m_status = false;
117 ITraceMng* m_trace_mng = nullptr;
118 // clang-format on
119 };
120
121 BiCGStab(AlgebraType& algebra, ITraceMng* trace_mng = nullptr)
122 : m_algebra(algebra)
123 , m_trace_mng(trace_mng)
124 {}
125
126 virtual ~BiCGStab()
127 {}
128
129 void setOutputLevel(int level)
130 {
131 m_output_level = level;
132 }
133
134 template <typename PrecondT, typename iterT>
135 int solve(PrecondT& precond, iterT& iter, MatrixType const& A,
136 VectorType const& b, VectorType& x)
137 {
138 if (iter.nullRhs())
139 return 0;
140 ValueType rho(0), rho1(0), alpha(0), beta(0), omega(0);
141 VectorType p, phat, s, shat, t, v, r, r0;
142
143 m_algebra.allocate(AlgebraType::resource(A), p, phat, s, shat, t, v, r, r0);
144
145 // SEQ0
146 // r = b - A * x;
147 m_algebra.copy(b, r);
148 m_algebra.mult(A, x, r0);
149 m_algebra.axpy(-1., r0, r);
150
151 // rtilde = r
152 m_algebra.copy(r, r0);
153 m_algebra.copy(r, p);
154 rho1 = m_algebra.dot(r, r0);
155 if (m_output_level > 1)
156 _print(0, "Seq 0", "rho1", rho1);
157
158 /*
159 phat = solve(M, p);
160 v = A * phat;
161 gamma = dot(r0, v);
162 alpha = rho_1 / gamma;
163 s = r - alpha * v;
164 */
165 // SEQ1
166 m_algebra.exec(precond, p, phat);
167 m_algebra.mult(A, phat, v);
168 alpha = m_algebra.dot(v, r0);
169 if (alpha == 0)
170 throw typename AlgebraType::NullValueException("alpha");
171 alpha = rho1 / alpha;
172 m_algebra.copy(r, s);
173 m_algebra.axpy(-alpha, v, s);
174 if (m_output_level > 1)
175 _print(0, "Seq 1", "alpha", alpha);
176
177 if (iter.stop(s)) {
178 ++iter;
179 m_algebra.axpy(alpha, phat, x);
180 m_algebra.free(p, phat, s, shat, t, v, r, r0);
181 return 0;
182 }
183
184 // SEQ 2
185 m_algebra.exec(precond, s, shat);
186 m_algebra.mult(A, shat, t);
187 omega = m_algebra.dot(t, s);
188 beta = m_algebra.dot(t, t);
189
190 if (beta == 0) {
191 if (iter.stop(r)) {
192 ++iter;
193 m_algebra.axpy(alpha, phat, x);
194 m_algebra.free(p, phat, s, shat, t, v, r, r0);
195 return 0;
196 }
197 else
198 throw typename AlgebraType::NullValueException("beta");
199 }
200 omega = omega / beta;
201 if (m_output_level > 1)
202 _print(iter(), "Seq 2", "beta", beta, "alpha", alpha, "rho1", rho1);
203
204 // SEQ 3
205 m_algebra.axpy(omega, shat, x);
206 m_algebra.axpy(alpha, phat, x);
207 m_algebra.copy(s, r);
208 m_algebra.axpy(-omega, t, r);
209
210 rho = rho1;
211 ++iter;
212 if (m_output_level > 1)
213 _print(iter(), "Seq 3", "beta", beta, "alpha", alpha, "rho1", rho1);
214
215 while (!iter.stop(r)) {
216 //SEQ4
217 rho1 = m_algebra.dot(r, r0);
218 beta = (rho1 / rho) * (alpha / omega);
219 m_algebra.axpy(-omega, v, p);
220 m_algebra.scal(beta, p);
221 m_algebra.axpy(1., r, p);
222 if (m_output_level > 1)
223 _print(iter(), "Seq 4", "beta", beta, "alpha", alpha, "rho1", rho1);
224
225 if (rho == 0)
226 throw typename AlgebraType::NullValueException("rho");
227
228 //m_algebra.process (seq1);
229 m_algebra.exec(precond, p, phat);
230 m_algebra.mult(A, phat, v);
231 alpha = m_algebra.dot(v, r0);
232 if (alpha == 0)
233 throw typename AlgebraType::NullValueException("alpha");
234 else
235 alpha = rho1 / alpha;
236
237 m_algebra.copy(r, s);
238 m_algebra.axpy(-alpha, v, s);
239
240 if (m_output_level > 1)
241 _print(iter(), "Seq 1", "alpha", alpha);
242 if (iter.stop(s)) {
243 m_algebra.axpy(alpha, phat, x);
244 m_algebra.free(p, phat, s, shat, t, v, r, r0);
245 return 0;
246 }
247
248 //m_algebra.process (seq2);
249 m_algebra.exec(precond, s, shat);
250 m_algebra.mult(A, shat, t);
251 omega = m_algebra.dot(t, s);
252 beta = m_algebra.dot(t, t);
253
254 if (m_output_level > 1)
255 _print(iter(), "Seq 2", "beta", beta, "alpha", alpha, "rho1", rho1, "omega", omega);
256 if (beta == 0) {
257 if (iter.stop(s)) {
258 m_algebra.axpy(alpha, phat, x);
259 m_algebra.free(p, phat, s, shat, t, v, r, r0);
260 return 0;
261 }
262 throw typename AlgebraType::NullValueException("beta");
263 }
264 else
265 omega = omega / beta;
266
267 //m_algebra.process (seq3);
268 m_algebra.axpy(omega, shat, x);
269 m_algebra.axpy(alpha, phat, x);
270 m_algebra.copy(s, r);
271 m_algebra.axpy(-omega, t, r);
272
273 rho = rho1;
274
275 ++iter;
276 if (m_output_level > 1)
277 _print(iter(), "end loop", "beta", beta, "alpha", alpha, "rho", rho);
278 }
279
280 m_algebra.free(p, phat, s, shat, t, v, r, r0);
281
282 return 0;
283 }
284
285 template <typename PrecondT, typename iterT>
286 int solve2(PrecondT& precond, iterT& iter, MatrixType const& A,
287 VectorType const& b, VectorType& x)
288 {
289 if (iter.nullRhs())
290 return 0;
291 // clang-format off
292 ValueType rho (0), rho1 (0), alpha (0), beta (0), gamma (0), omega (0);
293 FutureType frho(rho), frho1(rho1), falpha(alpha), fbeta(beta), fgamma(gamma), fomega(omega) ;
294 VectorType p, phat, s, shat, t, v, r, r0;
295 // clang-format on
296
297 m_algebra.allocate(AlgebraType::resource(A), p, phat, s, shat, t, v, r, r0);
298
299 // SEQ0
300 // r = b - A * x;
301 m_algebra.copy(b, r);
302 m_algebra.mult(A, x, r0);
303 m_algebra.axpy(-1., r0, r);
304
305 // rtilde = r
306 m_algebra.copy(r, r0);
307 m_algebra.copy(r, p);
308 m_algebra.dot(r, r0, frho1);
309 if (m_output_level > 1)
310 _print(0, "Seq 0", "rho1", frho1.get());
311
312 /*
313 phat = solve(M, p);
314 v = A * phat;
315 gamma = dot(r0, v);
316 alpha = rho_1 / gamma;
317 s = r - alpha * v;
318 */
319 // SEQ1
320 m_algebra.exec(precond, p, phat);
321 m_algebra.mult(A, phat, v);
322 m_algebra.dot(v, r0, falpha);
323 if (falpha.get() == 0)
324 throw typename AlgebraType::NullValueException("alpha");
325 alpha = frho1.get() / alpha;
326
327 m_algebra.copy(r, s);
328 m_algebra.axpy(-alpha, v, s);
329 if (m_output_level > 1)
330 _print(0, "Seq 1", "alpha", alpha);
331
332 if (iter.stop(s)) {
333 ++iter;
334 m_algebra.axpy(alpha, phat, x);
335 m_algebra.free(p, phat, s, shat, t, v, r, r0);
336 return 0;
337 }
338
339 // SEQ 2
340 m_algebra.exec(precond, s, shat);
341 m_algebra.mult(A, shat, t);
342 m_algebra.dot(t, s, fomega);
343 m_algebra.dot(t, t, fbeta);
344 if (fbeta.get() == 0) {
345 if (iter.stop(r)) {
346 ++iter;
347 m_algebra.axpy(alpha, phat, x);
348 m_algebra.free(p, phat, s, shat, t, v, r, r0);
349 return 0;
350 }
351 else
352 throw typename AlgebraType::NullValueException("beta");
353 }
354 omega = fomega.get() / beta;
355 if (m_output_level > 1)
356 _print(iter(), "Seq 2", "beta", beta, "alpha", alpha, "rho1", rho1);
357
358 // SEQ 3
359 m_algebra.axpy(omega, shat, x);
360 m_algebra.axpy(alpha, phat, x);
361 m_algebra.copy(s, r);
362 m_algebra.axpy(-omega, t, r);
363
364 rho = rho1;
365 ++iter;
366 if (m_output_level > 1)
367 _print(iter(), "Seq 3", "beta", beta, "alpha", alpha, "rho1", rho1);
368
369 while (!iter.stop(r)) {
370 //SEQ4
371 /*
372 beta = (rho_1 / rho_2) * (alpha / omega);
373 p = r + beta * (p - omega * v);
374 */
375 m_algebra.dot(r, r0, frho1);
376 beta = (frho1.get() / rho) * (alpha / omega);
377 m_algebra.axpy(-omega, v, p);
378 m_algebra.scal(beta, p);
379 m_algebra.axpy(1., r, p);
380 if (m_output_level > 1)
381 _print(iter(), "Seq 4", "beta", beta, "alpha", alpha, "rho1", rho1);
382
383 if (rho == 0)
384 throw typename AlgebraType::NullValueException("rho");
385
386 //m_algebra.process (seq1);
387 m_algebra.exec(precond, p, phat);
388 m_algebra.mult(A, phat, v);
389 m_algebra.dot(v, r0, falpha);
390 if (falpha.get() == 0)
391 throw typename AlgebraType::NullValueException("alpha");
392 else
393 alpha = rho1 / alpha;
394
395 m_algebra.copy(r, s);
396 m_algebra.axpy(-alpha, v, s);
397 if (m_output_level > 1)
398 _print(iter(), "Seq 1", "alpha", alpha);
399
400 if (iter.stop(s)) {
401 m_algebra.axpy(alpha, phat, x);
402 m_algebra.free(p, phat, s, shat, t, v, r, r0);
403 return 0;
404 }
405
406 //m_algebra.process (seq2);
407 m_algebra.exec(precond, s, shat);
408 m_algebra.mult(A, shat, t);
409 m_algebra.dot(t, s, fomega);
410 m_algebra.dot(t, t, fbeta);
411 if (m_output_level > 1)
412 _print(iter(), "Seq 2", "beta", beta, "alpha", alpha, "rho1", rho1, "omega", omega);
413 if (fbeta.get() == 0) {
414 if (iter.stop(s)) {
415 m_algebra.axpy(alpha, phat, x);
416 m_algebra.free(p, phat, s, shat, t, v, r, r0);
417 return 0;
418 }
419 throw typename AlgebraType::NullValueException("beta");
420 }
421 else
422 omega = fomega.get() / beta;
423
424 //m_algebra.process (seq3);
425 m_algebra.axpy(omega, shat, x);
426 m_algebra.axpy(alpha, phat, x);
427 m_algebra.copy(s, r);
428 m_algebra.axpy(-omega, t, r);
429
430 rho = rho1;
431
432 ++iter;
433 if (m_output_level > 1)
434 _print(iter(), "end loop", "beta", beta, "alpha", alpha, "rho", rho);
435 }
436
437 m_algebra.free(p, phat, s, shat, t, v, r, r0);
438
439 return 0;
440 }
441
442 private:
443 void
444 _print(int iter, std::string const& msg, std::string const& label0,
445 ValueType value0)
446 {
447 if (m_trace_mng) {
448 m_trace_mng->info() << msg;
449 m_trace_mng->info() << "Iterate: " << iter << " " << label0 << " "
450 << value0;
451 }
452 }
453
454 void
455 _print(int iter, std::string const& msg, std::string const& label0,
456 ValueType value0, std::string const& label1, ValueType value1)
457 {
458 if (m_trace_mng) {
459 _print(iter, msg, label0, value0);
460 m_trace_mng->info() << "Iterate: " << iter << " " << label1 << " "
461 << value1;
462 }
463 }
464
465 void
466 _print(int iter, std::string const& msg, std::string const& label0,
467 ValueType value0, std::string const& label1, ValueType value1,
468 std::string const& label2, ValueType value2)
469 {
470 if (m_trace_mng) {
471 _print(iter, msg, label0, value0, label1, value1);
472 m_trace_mng->info() << "Iterate: " << iter << " " << label2 << " "
473 << value2;
474 }
475 }
476
477 void
478 _print(int iter, std::string const& msg, std::string const& label0,
479 ValueType value0, std::string const& label1, ValueType value1,
480 std::string const& label2, ValueType value2,
481 std::string const& label3, ValueType value3)
482 {
483 if (m_trace_mng) {
484 _print(iter, msg, label0, value0, label1, value1, label2, value2);
485 m_trace_mng->info() << "Iterate: " << iter << " " << label3 << " "
486 << value3;
487 }
488 }
489
490 AlgebraType& m_algebra;
491 ITraceMng* m_trace_mng = nullptr;
492 int m_output_level = 0;
493};
494} // namespace Alien
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
Definition BackEnd.h:17