Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
CG.h
1/*
2 * Copyright 2020 IFPEN-CEA
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * SPDX-License-Identifier: Apache-2.0
17 */
18/*
19 * cg.h
20 *
21 * Created on: Dec 1, 2021
22 * Author: gratienj
23 */
24
25#pragma once
26
27#include <ostream>
28#include <vector>
29
30namespace Alien
31{
32
33template <typename AlgebraT>
34class CG
35{
36 public:
37 // clang-format off
38 typedef AlgebraT AlgebraType;
39 typedef typename AlgebraType::Matrix MatrixType;
40 typedef typename AlgebraType::Vector VectorType;
41 typedef typename MatrixType::ValueType ValueType;
42 typedef typename AlgebraType::FutureType FutureType;
43 // clang-format on
44
45 CG(AlgebraType& algebra, ITraceMng* trace_mng = nullptr)
46 : m_algebra(algebra)
47 , m_trace_mng(trace_mng)
48 {}
49
50 virtual ~CG()
51 {}
52
53 void setOutputLevel(int level)
54 {
55 m_output_level = level;
56 }
57
58 template <typename PrecondT, typename iterT>
59 int solve(PrecondT& precond,
60 iterT& iter,
61 MatrixType const& A,
62 VectorType const& b,
63 VectorType& x)
64 {
65 if (iter.nullRhs())
66 return 0;
67 ValueType rho(0), rho1(0), alpha(0);
68 VectorType p, z, q, r;
69
70 m_algebra.allocate(AlgebraType::resource(A), p, z, q, r);
71
72 // SEQ0
73 // r = b - A * x;
74 m_algebra.copy(b, r);
75 m_algebra.mult(A, x, p);
76 m_algebra.axpy(-1., p, r);
77
78 // SEQ1
79 /*
80 * z = solve(M,r)
81 * p = z
82 * q = A p
83 * rho1 = dot(r,z)
84 * alpha = dot(p,q)
85 * alpha = rho1/alpha
86 * x += alpha*p
87 * r += alpha*q
88 */
89 m_algebra.exec(precond, r, z);
90 rho1 = m_algebra.dot(r, z);
91 m_algebra.copy(z, p);
92 m_algebra.mult(A, p, q);
93 alpha = m_algebra.dot(p, q);
94 if (m_output_level > 1)
95 _print(0, "Seq 1", "rho", rho1, "alpha", alpha);
96 if (alpha == 0) {
97 if (iter.stop(r)) {
98 ++iter;
99 m_algebra.free(p, z, q, r);
100 return 0;
101 }
102 else
103 throw typename AlgebraType::NullValueException("alpha");
104 }
105 alpha = rho1 / alpha;
106 m_algebra.axpy(alpha, p, x);
107 m_algebra.axpy(-alpha, q, r);
108 rho = rho1;
109 ++iter;
110
111 while (!iter.stop(r)) {
112 // SEQ2
113 /*
114 * z = solve(M,r)
115 * rho1 = dot(r,z)
116 * alpha = rho1/rho
117 * p = z + alpha*p
118 * q = A*p
119 * alpha = dot(q,p)
120 * alpha = rho1/alpha
121 * x += alpha*p
122 * r -= alpha*q
123 */
124 m_algebra.exec(precond, r, z);
125 rho1 = m_algebra.dot(r, z);
126 alpha = rho1 / rho;
127 m_algebra.axpy(alpha, p, z);
128 m_algebra.copy(z, p);
129 m_algebra.mult(A, p, q);
130 alpha = m_algebra.dot(q, p);
131 if (alpha == 0) {
132 if (iter.stop(r)) {
133 ++iter;
134 m_algebra.free(p, z, q, r);
135 return 0;
136 }
137 else
138 throw typename AlgebraType::NullValueException("alpha");
139 }
140 alpha = rho1 / alpha;
141 m_algebra.axpy(alpha, p, x);
142 m_algebra.axpy(-alpha, q, r);
143 rho = rho1;
144 ++iter;
145 }
146
147 m_algebra.free(p, z, q, r);
148
149 return 0;
150 }
151
152 template <typename PrecondT, typename iterT>
153 int solve2(PrecondT& precond,
154 iterT& iter,
155 MatrixType const& A,
156 VectorType const& b,
157 VectorType& x)
158 {
159
160 if (iter.nullRhs())
161 return 0;
162 ValueType rho(0), rho1(0), alpha(0);
163 FutureType frho(rho), frho1(rho1), falpha(alpha);
164 VectorType p, z, q, r;
165
166 m_algebra.allocate(AlgebraType::resource(A), p, z, q, r);
167
168 // SEQ0
169 // r = b - A * x;
170 m_algebra.copy(b, r);
171 m_algebra.mult(A, x, p);
172 m_algebra.axpy(-1., p, r);
173
174 // SEQ1
175 /*
176 * z = solve(M,r)
177 * p = z
178 * q = A p
179 * rho1 = dot(r,p)
180 * alpha = dot(p,q)
181 * alpha = rho1/alpha
182 * x += alpha*p
183 * r += alpha*q
184 */
185 m_algebra.exec(precond, r, z);
186 m_algebra.copy(z, p);
187 m_algebra.mult(A, p, q);
188 m_algebra.dot(r, p, frho1);
189 m_algebra.dot(p, q, falpha);
190 if (falpha.get() == 0) {
191 if (iter.stop(r)) {
192 ++iter;
193 m_algebra.free(p, z, q, r);
194 return 0;
195 }
196 else
197 throw typename AlgebraType::NullValueException("alpha");
198 }
199 alpha = frho1.get() / alpha;
200 m_algebra.axpy(alpha, p, x);
201 m_algebra.axpy(-alpha, q, r);
202 rho = rho1;
203 ++iter;
204
205 while (!iter.stop(r)) {
206
207 /*
208 * z = solve(M,r)
209 * rho1 = dot(r,z)
210 * alpha = rho1/rho
211 * p = z + alpha*p
212 * q = A*p
213 * alpha = dot(q,p)
214 * alpha = rho1/alpha
215 * x += alpha*p
216 * r -= alpha*q
217 */
218 m_algebra.exec(precond, r, z);
219 m_algebra.dot(r, z, frho1);
220 alpha = frho1.get() / rho;
221 m_algebra.axpy(alpha, p, z);
222 m_algebra.copy(z, p);
223 m_algebra.mult(A, p, q);
224 m_algebra.dot(p, q, falpha);
225 if (falpha.get() == 0) {
226 if (iter.stop(r)) {
227 ++iter;
228 m_algebra.free(p, z, q, r);
229 return 0;
230 }
231 else
232 throw typename AlgebraType::NullValueException("alpha");
233 }
234 alpha = rho1 / alpha;
235 m_algebra.axpy(alpha, p, x);
236 m_algebra.axpy(-alpha, q, r);
237 rho = rho1;
238 ++iter;
239 }
240
241 m_algebra.free(p, z, q, r);
242
243 return 0;
244 }
245
246 private:
247 void
248 _print(int iter, std::string const& msg, std::string const& label0,
249 ValueType value0)
250 {
251 if (m_trace_mng) {
252 m_trace_mng->info() << msg;
253 m_trace_mng->info() << "Iterate: " << iter << " " << label0 << " "
254 << value0;
255 }
256 }
257
258 void
259 _print(int iter, std::string const& msg, std::string const& label0,
260 ValueType value0, std::string const& label1, ValueType value1)
261 {
262 if (m_trace_mng) {
263 _print(iter, msg, label0, value0);
264 m_trace_mng->info() << "Iterate: " << iter << " " << label1 << " "
265 << value1;
266 }
267 }
268
269 void
270 _print(int iter, std::string const& msg, std::string const& label0,
271 ValueType value0, std::string const& label1, ValueType value1,
272 std::string const& label2, ValueType value2)
273 {
274 if (m_trace_mng) {
275 _print(iter, msg, label0, value0, label1, value1);
276 m_trace_mng->info() << "Iterate: " << iter << " " << label2 << " "
277 << value2;
278 }
279 }
280
281 void
282 _print(int iter, std::string const& msg, std::string const& label0,
283 ValueType value0, std::string const& label1, ValueType value1,
284 std::string const& label2, ValueType value2,
285 std::string const& label3, ValueType value3)
286 {
287 if (m_trace_mng) {
288 _print(iter, msg, label0, value0, label1, value1, label2, value2);
289 m_trace_mng->info() << "Iterate: " << iter << " " << label3 << " "
290 << value3;
291 }
292 }
293
294 AlgebraType& m_algebra;
295 ITraceMng* m_trace_mng = nullptr;
296 int m_output_level = 0;
297};
298} // namespace Alien
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
Definition BackEnd.h:17