Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
SimpleCSRDistributorImpl.h
1
2//-----------------------------------------------------------------------------
3// Copyright 2000-2024 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6
7#pragma once
8
9#include <mpi.h>
10
11#include <arccore/base/Span.h>
12#include "SimpleCSRDistributor.h"
13
14namespace Alien
15{
16
17SimpleCSRDistributor::SimpleCSRDistributor(const RedistributorCommPlan* commPlan,
18 const VectorDistribution& source_distribution,
19 const Alien::SimpleCSRInternal::CSRStructInfo* src_profile)
20: m_comm_plan(commPlan)
21, m_src_profile(src_profile)
22{
23 const auto me = m_comm_plan->superParallelMng()->commRank();
24 const auto dst_me = _dstMe(me);
25
26 // build target_offset from comm_plan tgtdist
27 const auto tgt_dist = m_comm_plan->tgtDist();
28 Integer n_offset = 0;
29 for (Integer i = 1; i < tgt_dist.size(); ++i) {
30 if (tgt_dist[i] != tgt_dist[i - 1])
31 n_offset++;
32 }
33 if (dst_me.has_value()) {
34 assert(n_offset == m_comm_plan->tgtParallelMng()->commSize());
35 }
36 std::vector<Integer> target_offset(n_offset + 1);
37
38 target_offset[0] = 0;
39 Integer tgt_i = 1;
40 for (Integer i = 1; i < tgt_dist.size(); ++i) {
41 if (tgt_dist[i] != tgt_dist[i - 1]) {
42 target_offset[tgt_i] = tgt_dist[i];
43 tgt_i++;
44 }
45 }
46
47 auto dst_n_elems = 0;
48
49 for (auto global_row = source_distribution.offset();
50 global_row < source_distribution.offset() + source_distribution.localSize();
51 global_row++) {
52 auto target = _owner(target_offset, global_row); // target is and id in tgtParallelMng
53 auto src_local_row = source_distribution.globalToLocal(global_row);
54
55 if (target == dst_me) {
56 auto dst_local_row = m_comm_plan->distribution().globalToLocal(global_row);
57 m_src2dst_row_list.push_back({ src_local_row, dst_local_row });
58 dst_n_elems += m_src_profile->getRowSize(src_local_row);
59 }
60 else if (target.has_value()) {
61 auto& comm_info = m_send_comm_info[m_comm_plan->procNum(target.value())];
62 comm_info.m_row_list.push_back(src_local_row);
63 comm_info.m_n_item += m_src_profile->getRowSize(src_local_row);
64 }
65 else {
66 FatalErrorException("No target found");
67 }
68 }
69
70 auto ext_dst_n_rows = 0;
71
72 if (dst_me.has_value()) // I am in the target parallel manager
73 {
74 const auto& target_distribution = m_comm_plan->distribution();
75
76 for (auto global_row = target_distribution.offset();
77 global_row < target_distribution.offset() + target_distribution.localSize();
78 global_row++) {
79 auto source = source_distribution.owner(global_row); // source is an id in superParallelMng
80
81 if (source != me) {
82 auto dst_local_row = target_distribution.globalToLocal(global_row);
83 auto& comm_info = m_recv_comm_info[source];
84 comm_info.m_row_list.push_back(dst_local_row);
85 ext_dst_n_rows++;
86 }
87 }
88 }
89
90 for (auto& [send_to_id, comm_info] : m_send_comm_info) {
91 Arccore::MessagePassing::PointToPointMessageInfo message_info(MessageRank(me), MessageRank(send_to_id),
92 Arccore::MessagePassing::NonBlocking);
93 comm_info.m_message_info = message_info;
94 }
95
96 for (auto& [recv_from_id, comm_info] : m_recv_comm_info) {
97 Arccore::MessagePassing::PointToPointMessageInfo message_info(MessageRank(me), MessageRank(recv_from_id),
98 Arccore::MessagePassing::NonBlocking);
99
100 comm_info.m_message_info = message_info;
101 }
102
103 auto* pm = m_comm_plan->superParallelMng();
104
105 // perform an exchange of sizes
106 for (auto& [recv_from_id, comm_info] : m_recv_comm_info) {
107 comm_info.m_request = Arccore::MessagePassing::mpReceive(pm, Arccore::Span<size_t>(&comm_info.m_n_item, 1), comm_info.m_message_info);
108 }
109
110 for (auto& [send_to_id, comm_info] : m_send_comm_info) {
111 comm_info.m_request = Arccore::MessagePassing::mpSend(pm, Arccore::Span<size_t>(&comm_info.m_n_item, 1), comm_info.m_message_info);
112 }
113
114 for (const auto& [recv_from_id, comm_info] : m_recv_comm_info) {
115 Arccore::MessagePassing::mpWait(pm, comm_info.m_request);
116 dst_n_elems += comm_info.m_n_item;
117 }
118
119 _finishExchange();
120
121 // create destination profile
122 if (dst_me.has_value()) {
123 m_dst_profile = std::make_shared<Alien::SimpleCSRInternal::CSRStructInfo>();
124 m_dst_profile->init(ext_dst_n_rows + m_src2dst_row_list.size(), dst_n_elems);
125 }
126
127 // build destination kcol
128 // exchange row sizes
129 _resizeBuffers<Integer>(1); // comm_info.m_n_item are already computed so buffer sizes are sufficient for row size exchange
130
131 for (auto& [recv_id, comm_info] : m_recv_comm_info) {
132 comm_info.m_request = Arccore::MessagePassing::mpReceive(pm, Arccore::Span<Integer>((Integer*)comm_info.m_buffer.data(), comm_info.m_row_list.size()), comm_info.m_message_info);
133 }
134
135 for (auto& [send_id, comm_info] : m_send_comm_info) {
136 auto* buffer = (Integer*)(comm_info.m_buffer.data());
137 std::size_t buffer_idx = 0;
138 assert(comm_info.m_row_list.size() <= comm_info.m_n_item); // check that buffer is large enough
139 for (const auto& src_row : comm_info.m_row_list) {
140 buffer[buffer_idx] = m_src_profile->getRowSize(src_row);
141 buffer_idx++;
142 }
143 comm_info.m_request = Arccore::MessagePassing::mpSend(pm, Arccore::Span<Integer>((Integer*)comm_info.m_buffer.data(), comm_info.m_row_list.size()), comm_info.m_message_info);
144 }
145
146 std::vector<Integer> row_size(ext_dst_n_rows + m_src2dst_row_list.size(), 0);
147
148 // self rows (if exist)
149 for (const auto& [src_row, dst_row] : m_src2dst_row_list) {
150 row_size[dst_row] = m_src_profile->getRowSize(src_row);
151 }
152 // wait for recv messages
153 // mpWaitSome ?
154 for (auto const& [recv_id, comm_info] : m_recv_comm_info) {
155 Arccore::MessagePassing::mpWait(pm, comm_info.m_request);
156
157 const auto* buffer = (const Integer*)(comm_info.m_buffer.data());
158 std::size_t buffer_idx = 0;
159 for (const auto& dst_row : m_recv_comm_info[recv_id].m_row_list) {
160 row_size[dst_row] = buffer[buffer_idx];
161 buffer_idx++;
162 }
163 }
164 _finishExchange();
165
166 if (dst_me.has_value()) {
167 auto* kcol = m_dst_profile->kcol();
168 kcol[0] = 0;
169 for (Integer i = 1; i < m_dst_profile->getNRows() + 1; ++i) {
170 kcol[i] = kcol[i - 1] + row_size[i - 1];
171 }
172 }
173
174 // distribute profile cols
175 _distribute<Integer>(1, m_src_profile->cols(), dst_me.has_value() ? m_dst_profile->cols() : nullptr);
176}
177
178template <typename T>
179void SimpleCSRDistributor::_distribute(const int bb, const T* src, T* dst)
180{
181
182 using ItemType = T;
183
184 _resizeBuffers<T>(bb);
185
186 auto* pm = m_comm_plan->superParallelMng();
187 // post recv
188 for (auto& [recv_id, comm_info] : m_recv_comm_info) {
189 comm_info.m_request =
190 Arccore::MessagePassing::mpReceive(pm, Arccore::Span<T>((T*)comm_info.m_buffer.data(), comm_info.m_n_item),
191 comm_info.m_message_info);
192 }
193
194 // send rows
195 for (auto& [send_id, comm_info] : m_send_comm_info) {
196 // assemble message
197 auto* buffer = (ItemType*)(comm_info.m_buffer.data());
198 std::size_t buffer_idx = 0;
199 for (const auto& src_row : comm_info.m_row_list) {
200 for (auto k = m_src_profile->kcol()[src_row] * bb; k < m_src_profile->kcol()[src_row + 1] * bb; ++k) {
201 buffer[buffer_idx] = src[k];
202 buffer_idx++;
203 }
204 }
205 comm_info.m_request =
206 Arccore::MessagePassing::mpSend(pm, Arccore::Span<T>((T*)comm_info.m_buffer.data(), comm_info.m_n_item),
207 comm_info.m_message_info);
208 }
209 // perform direct transfer
210 for (const auto& [src_row, dst_row] : m_src2dst_row_list) {
211 auto k_src = m_src_profile->kcol()[src_row] * bb;
212 for (auto k_dst = m_dst_profile->kcol()[dst_row] * bb; k_dst < m_dst_profile->kcol()[dst_row + 1] * bb; ++k_dst) {
213 dst[k_dst] = src[k_src];
214 k_src++;
215 }
216 assert(k_src == m_src_profile->kcol()[src_row + 1] * bb);
217 }
218
219 // wait for recv messages
220 // Use mpWaitAny or mpWaitSome
221 for (auto const& [recv_id, comm_info] : m_recv_comm_info) {
222 Arccore::MessagePassing::mpWait(pm, comm_info.m_request);
223
224 // put received matrix values at the right place
225 const auto* buffer = (const ItemType*)(comm_info.m_buffer.data());
226 std::size_t buffer_idx = 0;
227 for (const auto& dst_row : m_recv_comm_info[recv_id].m_row_list) {
228 for (auto k = m_dst_profile->kcol()[dst_row] * bb; k < m_dst_profile->kcol()[dst_row + 1] * bb; ++k) {
229 dst[k] = buffer[buffer_idx];
230 buffer_idx++;
231 }
232 }
233 }
234
235 _finishExchange();
236}
237template <typename NumT>
238void SimpleCSRDistributor::distribute(const SimpleCSRMatrix<NumT>& src, SimpleCSRMatrix<NumT>& dst)
239{
240 const auto me = m_comm_plan->superParallelMng()->commRank();
241 const auto dst_me = _dstMe(me);
242
243 if (dst_me.has_value()) {
244 // I am in the target parallel manager
245 // fill dst profile with a copy of m_dst_profile
246 auto& profile = dst.internal()->getCSRProfile();
247 profile.init(m_dst_profile->getNRows(), m_dst_profile->getNElems());
248 dst.allocate();
249
250 for (Integer i = 0; i < profile.getNRows() + 1; ++i) {
251 profile.kcol()[i] = m_dst_profile->kcol()[i];
252 }
253 for (Integer k = 0; k < profile.getNElems(); ++k) {
254 profile.cols()[k] = m_dst_profile->cols()[k];
255 }
256 }
257
258 if (src.block()) {
259 _distribute(src.block()->sizeX() * src.block()->sizeY(), src.data(), dst.data());
260 }
261 else if (src.vblock()) {
262 throw Arccore::NotImplementedException(A_FUNCINFO);
263 }
264 else {
265 _distribute(1, src.data(), dst.data());
266 }
267
268 if (dst_me.has_value()) {
269 if (m_comm_plan->tgtParallelMng()->commSize() == 1) {
270 dst.sequentialStart();
271 }
272 else {
273 dst.parallelStart(dst.distribution().rowDistribution().offsets(), m_comm_plan->tgtParallelMng().get(), true);
274 }
275 }
276
277#if 0
278 if(dst_me.value_or(1) == 0)
279 {
280 const auto& profile = dst.internal().getCSRProfile();
281 for (Integer i = 0; i < profile.getNRows(); ++i)
282 {
283 std::cout << i ;
284 for (Integer k = profile.kcol()[i]; k < profile.kcol()[i+1]; ++k)
285 {
286 std::cout << " [" << profile.cols()[k] << " " << dst.data()[k] << "]";
287 }
288 std::cout << std::endl;
289 }
290 }
291#endif
292}
293
294template <typename NumT>
295void SimpleCSRDistributor::distribute(const SimpleCSRVector<NumT>& src, SimpleCSRVector<NumT>& dst)
296{
297 throw Arccore::NotImplementedException(A_FUNCINFO);
298}
299
300template <typename T>
301void SimpleCSRDistributor::_resizeBuffers(const int bb)
302{
303 // comm_info should be templated by the type to avoid cast and explicit size computations
304 for (auto& [send_id, comm_info] : m_send_comm_info) {
305 comm_info.m_buffer.resize((comm_info.m_n_item * sizeof(T) * bb + sizeof(uint64_t) - 1) / sizeof(uint64_t));
306 }
307
308 for (auto& [recv_id, comm_info] : m_recv_comm_info) {
309 comm_info.m_buffer.resize((comm_info.m_n_item * sizeof(T) * bb + sizeof(uint64_t) - 1) / sizeof(uint64_t));
310 }
311}
312
313void SimpleCSRDistributor::_finishExchange()
314{
315 auto* pm = m_comm_plan->superParallelMng();
316 // finish properly
317
318 // CC: should be a mpWaitAll and not a loop
319 for (auto const& [send_to_id, comm_info] : m_send_comm_info) {
320 Arccore::MessagePassing::mpWait(pm, comm_info.m_request);
321 }
322}
323
324// T must be an integer signed type
325template <typename T>
326std::optional<T> SimpleCSRDistributor::_owner(const std::vector<T>& offset, T global_id)
327{
328 if (global_id >= offset.back())
329 return {};
330
331 auto min = offset.size() - offset.size(); // just for the right auto type
332 auto max = offset.size();
333
334 while (max - min > 1) {
335 auto mid = (max - min) / 2 + min;
336 if (global_id < offset[mid]) {
337 max = mid;
338 }
339 else {
340 min = mid;
341 }
342 }
343
344 return min;
345}
346std::optional<int> SimpleCSRDistributor::_dstMe(int) const
347{
348 if (m_comm_plan->tgtParallelMng()) {
349 return m_comm_plan->tgtParallelMng()->commRank();
350 }
351
352 return {};
353}
354
355} // namespace Alien
Computes a vector distribution.
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
Definition BackEnd.h:17