11#include <arccore/base/Span.h>
12#include "SimpleCSRDistributor.h"
19 const Alien::SimpleCSRInternal::CSRStructInfo* src_profile)
20: m_comm_plan(commPlan)
21, m_src_profile(src_profile)
23 const auto me = m_comm_plan->superParallelMng()->commRank();
24 const auto dst_me = _dstMe(me);
27 const auto tgt_dist = m_comm_plan->tgtDist();
29 for (Integer i = 1; i < tgt_dist.size(); ++i) {
30 if (tgt_dist[i] != tgt_dist[i - 1])
33 if (dst_me.has_value()) {
34 assert(n_offset == m_comm_plan->tgtParallelMng()->commSize());
36 std::vector<Integer> target_offset(n_offset + 1);
40 for (Integer i = 1; i < tgt_dist.size(); ++i) {
41 if (tgt_dist[i] != tgt_dist[i - 1]) {
42 target_offset[tgt_i] = tgt_dist[i];
49 for (
auto global_row = source_distribution.offset();
50 global_row < source_distribution.offset() + source_distribution.localSize();
52 auto target = _owner(target_offset, global_row);
53 auto src_local_row = source_distribution.globalToLocal(global_row);
55 if (target == dst_me) {
56 auto dst_local_row = m_comm_plan->distribution().globalToLocal(global_row);
57 m_src2dst_row_list.push_back({ src_local_row, dst_local_row });
58 dst_n_elems += m_src_profile->getRowSize(src_local_row);
60 else if (target.has_value()) {
61 auto& comm_info = m_send_comm_info[m_comm_plan->procNum(target.value())];
62 comm_info.m_row_list.push_back(src_local_row);
63 comm_info.m_n_item += m_src_profile->getRowSize(src_local_row);
66 FatalErrorException(
"No target found");
70 auto ext_dst_n_rows = 0;
72 if (dst_me.has_value())
74 const auto& target_distribution = m_comm_plan->distribution();
76 for (
auto global_row = target_distribution.offset();
77 global_row < target_distribution.offset() + target_distribution.localSize();
79 auto source = source_distribution.owner(global_row);
82 auto dst_local_row = target_distribution.globalToLocal(global_row);
83 auto& comm_info = m_recv_comm_info[source];
84 comm_info.m_row_list.push_back(dst_local_row);
90 for (
auto& [send_to_id, comm_info] : m_send_comm_info) {
91 Arccore::MessagePassing::PointToPointMessageInfo message_info(MessageRank(me), MessageRank(send_to_id),
92 Arccore::MessagePassing::NonBlocking);
93 comm_info.m_message_info = message_info;
96 for (
auto& [recv_from_id, comm_info] : m_recv_comm_info) {
97 Arccore::MessagePassing::PointToPointMessageInfo message_info(MessageRank(me), MessageRank(recv_from_id),
98 Arccore::MessagePassing::NonBlocking);
100 comm_info.m_message_info = message_info;
103 auto* pm = m_comm_plan->superParallelMng();
106 for (
auto& [recv_from_id, comm_info] : m_recv_comm_info) {
107 comm_info.m_request = Arccore::MessagePassing::mpReceive(pm, Arccore::Span<size_t>(&comm_info.m_n_item, 1), comm_info.m_message_info);
110 for (
auto& [send_to_id, comm_info] : m_send_comm_info) {
111 comm_info.m_request = Arccore::MessagePassing::mpSend(pm, Arccore::Span<size_t>(&comm_info.m_n_item, 1), comm_info.m_message_info);
114 for (
const auto& [recv_from_id, comm_info] : m_recv_comm_info) {
115 Arccore::MessagePassing::mpWait(pm, comm_info.m_request);
116 dst_n_elems += comm_info.m_n_item;
122 if (dst_me.has_value()) {
123 m_dst_profile = std::make_shared<Alien::SimpleCSRInternal::CSRStructInfo>();
124 m_dst_profile->init(ext_dst_n_rows + m_src2dst_row_list.size(), dst_n_elems);
129 _resizeBuffers<Integer>(1);
131 for (
auto& [recv_id, comm_info] : m_recv_comm_info) {
132 comm_info.m_request = Arccore::MessagePassing::mpReceive(pm, Arccore::Span<Integer>((Integer*)comm_info.m_buffer.data(), comm_info.m_row_list.size()), comm_info.m_message_info);
135 for (
auto& [send_id, comm_info] : m_send_comm_info) {
136 auto* buffer = (Integer*)(comm_info.m_buffer.data());
137 std::size_t buffer_idx = 0;
138 assert(comm_info.m_row_list.size() <= comm_info.m_n_item);
139 for (
const auto& src_row : comm_info.m_row_list) {
140 buffer[buffer_idx] = m_src_profile->getRowSize(src_row);
143 comm_info.m_request = Arccore::MessagePassing::mpSend(pm, Arccore::Span<Integer>((Integer*)comm_info.m_buffer.data(), comm_info.m_row_list.size()), comm_info.m_message_info);
146 std::vector<Integer> row_size(ext_dst_n_rows + m_src2dst_row_list.size(), 0);
149 for (
const auto& [src_row, dst_row] : m_src2dst_row_list) {
150 row_size[dst_row] = m_src_profile->getRowSize(src_row);
154 for (
auto const& [recv_id, comm_info] : m_recv_comm_info) {
155 Arccore::MessagePassing::mpWait(pm, comm_info.m_request);
157 const auto* buffer = (
const Integer*)(comm_info.m_buffer.data());
158 std::size_t buffer_idx = 0;
159 for (
const auto& dst_row : m_recv_comm_info[recv_id].m_row_list) {
160 row_size[dst_row] = buffer[buffer_idx];
166 if (dst_me.has_value()) {
167 auto* kcol = m_dst_profile->kcol();
169 for (Integer i = 1; i < m_dst_profile->getNRows() + 1; ++i) {
170 kcol[i] = kcol[i - 1] + row_size[i - 1];
175 _distribute<Integer>(1, m_src_profile->cols(), dst_me.has_value() ? m_dst_profile->cols() :
nullptr);
179void SimpleCSRDistributor::_distribute(
const int bb,
const T* src, T* dst)
184 _resizeBuffers<T>(bb);
186 auto* pm = m_comm_plan->superParallelMng();
188 for (
auto& [recv_id, comm_info] : m_recv_comm_info) {
189 comm_info.m_request =
190 Arccore::MessagePassing::mpReceive(pm, Arccore::Span<T>((T*)comm_info.m_buffer.data(), comm_info.m_n_item),
191 comm_info.m_message_info);
195 for (
auto& [send_id, comm_info] : m_send_comm_info) {
197 auto* buffer = (ItemType*)(comm_info.m_buffer.data());
198 std::size_t buffer_idx = 0;
199 for (
const auto& src_row : comm_info.m_row_list) {
200 for (
auto k = m_src_profile->kcol()[src_row] * bb; k < m_src_profile->kcol()[src_row + 1] * bb; ++k) {
201 buffer[buffer_idx] = src[k];
205 comm_info.m_request =
206 Arccore::MessagePassing::mpSend(pm, Arccore::Span<T>((T*)comm_info.m_buffer.data(), comm_info.m_n_item),
207 comm_info.m_message_info);
210 for (
const auto& [src_row, dst_row] : m_src2dst_row_list) {
211 auto k_src = m_src_profile->kcol()[src_row] * bb;
212 for (
auto k_dst = m_dst_profile->kcol()[dst_row] * bb; k_dst < m_dst_profile->kcol()[dst_row + 1] * bb; ++k_dst) {
213 dst[k_dst] = src[k_src];
216 assert(k_src == m_src_profile->kcol()[src_row + 1] * bb);
221 for (
auto const& [recv_id, comm_info] : m_recv_comm_info) {
222 Arccore::MessagePassing::mpWait(pm, comm_info.m_request);
225 const auto* buffer = (
const ItemType*)(comm_info.m_buffer.data());
226 std::size_t buffer_idx = 0;
227 for (
const auto& dst_row : m_recv_comm_info[recv_id].m_row_list) {
228 for (
auto k = m_dst_profile->kcol()[dst_row] * bb; k < m_dst_profile->kcol()[dst_row + 1] * bb; ++k) {
229 dst[k] = buffer[buffer_idx];
237template <
typename NumT>
238void SimpleCSRDistributor::distribute(
const SimpleCSRMatrix<NumT>& src, SimpleCSRMatrix<NumT>& dst)
240 const auto me = m_comm_plan->superParallelMng()->commRank();
241 const auto dst_me = _dstMe(me);
243 if (dst_me.has_value()) {
246 auto& profile = dst.internal()->getCSRProfile();
247 profile.init(m_dst_profile->getNRows(), m_dst_profile->getNElems());
250 for (Integer i = 0; i < profile.getNRows() + 1; ++i) {
251 profile.kcol()[i] = m_dst_profile->kcol()[i];
253 for (Integer k = 0; k < profile.getNElems(); ++k) {
254 profile.cols()[k] = m_dst_profile->cols()[k];
259 _distribute(src.block()->sizeX() * src.block()->sizeY(), src.data(), dst.data());
261 else if (src.vblock()) {
262 throw Arccore::NotImplementedException(A_FUNCINFO);
265 _distribute(1, src.data(), dst.data());
268 if (dst_me.has_value()) {
269 if (m_comm_plan->tgtParallelMng()->commSize() == 1) {
270 dst.sequentialStart();
273 dst.parallelStart(dst.distribution().rowDistribution().offsets(), m_comm_plan->tgtParallelMng().get(),
true);
278 if(dst_me.value_or(1) == 0)
280 const auto& profile = dst.internal().getCSRProfile();
281 for (Integer i = 0; i < profile.getNRows(); ++i)
284 for (Integer k = profile.kcol()[i]; k < profile.kcol()[i+1]; ++k)
286 std::cout <<
" [" << profile.cols()[k] <<
" " << dst.data()[k] <<
"]";
288 std::cout << std::endl;
294template <
typename NumT>
295void SimpleCSRDistributor::distribute(
const SimpleCSRVector<NumT>& src, SimpleCSRVector<NumT>& dst)
297 throw Arccore::NotImplementedException(A_FUNCINFO);
301void SimpleCSRDistributor::_resizeBuffers(
const int bb)
304 for (
auto& [send_id, comm_info] : m_send_comm_info) {
305 comm_info.m_buffer.resize((comm_info.m_n_item *
sizeof(T) * bb +
sizeof(uint64_t) - 1) /
sizeof(uint64_t));
308 for (
auto& [recv_id, comm_info] : m_recv_comm_info) {
309 comm_info.m_buffer.resize((comm_info.m_n_item *
sizeof(T) * bb +
sizeof(uint64_t) - 1) /
sizeof(uint64_t));
313void SimpleCSRDistributor::_finishExchange()
315 auto* pm = m_comm_plan->superParallelMng();
319 for (
auto const& [send_to_id, comm_info] : m_send_comm_info) {
320 Arccore::MessagePassing::mpWait(pm, comm_info.m_request);
326std::optional<T> SimpleCSRDistributor::_owner(
const std::vector<T>& offset, T global_id)
328 if (global_id >= offset.back())
331 auto min = offset.size() - offset.size();
332 auto max = offset.size();
334 while (max - min > 1) {
335 auto mid = (max - min) / 2 + min;
336 if (global_id < offset[mid]) {
346std::optional<int> SimpleCSRDistributor::_dstMe(
int)
const
348 if (m_comm_plan->tgtParallelMng()) {
349 return m_comm_plan->tgtParallelMng()->commRank();
Computes a vector distribution.
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --