Arcane  v4.1.2.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
ScanImpl.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2025 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* ScanImpl.h (C) 2000-2025 */
9/* */
10/* Implémentation spécifique de l'opération de scan pour les accélérateurs. */
11/*---------------------------------------------------------------------------*/
12#ifndef ARCCORE_ACCELERATOR_SCANIMPL_H
13#define ARCCORE_ACCELERATOR_SCANIMPL_H
14/*---------------------------------------------------------------------------*/
15/*---------------------------------------------------------------------------*/
16
18#include "arccore/base/FatalErrorException.h"
19
20#include "arccore/common/NumArray.h"
21#include "arccore/common/accelerator/RunQueue.h"
22#include "arccore/common/accelerator/RunCommandLaunchInfo.h"
23
24#include "arccore/accelerator/CommonUtils.h"
26
27/*---------------------------------------------------------------------------*/
28/*---------------------------------------------------------------------------*/
29
30namespace Arcane::Accelerator::impl
31{
32
33/*---------------------------------------------------------------------------*/
34/*---------------------------------------------------------------------------*/
35
36#if defined(ARCCORE_COMPILING_SYCL)
37
46template <bool IsExclusive, typename DataType, typename Operator>
47class SyclScanner
48{
49 class InputInfo
50 {
51 public:
52
53 DataType _getInputValue(Int32 index) const
54 {
55 DataType local_value = identity;
56 if constexpr (IsExclusive) {
57 if (index == 0)
58 local_value = init_value;
59 else
60 local_value = ((index - 1) < nb_value) ? input_values[index - 1] : identity;
61 }
62 else
63 local_value = (index < nb_value) ? input_values[index] : identity;
64 return local_value;
65 }
66
67 public:
68
69 SmallSpan<const DataType> input_values;
70 DataType identity = {};
71 DataType init_value = {};
72 Int32 nb_value = 0;
73 };
74
75 public:
76
77 void doScan(RunQueue& rq, SmallSpan<const DataType> input, SmallSpan<DataType> output, DataType init_value)
78 {
79 DataType identity = Operator::defaultValue();
80 sycl::queue q = Impl::SyclUtils::toNativeStream(&rq);
81 // Contient l'application partielle de Operator pour chaque bloc de thread
82 NumArray<DataType, MDDim1> tmp;
83 // Contient l'application partielle de Operator cumulée avec les blocs précédents
84 NumArray<DataType, MDDim1> tmp2;
85 Int32 nb_item = input.size();
86 Int32 block_size = 256;
87 Int32 nb_block = (nb_item / block_size);
88 if ((nb_item % block_size) != 0)
89 ++nb_block;
90 tmp.resize(nb_block);
91 tmp2.resize(nb_block);
92 InputInfo input_info;
93 input_info.nb_value = nb_item;
94 input_info.init_value = init_value;
95 input_info.identity = identity;
96 input_info.input_values = input;
97 if (m_is_verbose)
98 std::cout << "DO_SCAN nb_item=" << nb_item << " nb_block=" << nb_block << "\n";
99 doscan1(q, input_info, tmp.to1DSpan(), nb_item, block_size);
100 if (m_is_verbose)
101 for (int i = 0; i < nb_block; ++i)
102 std::cout << "DO_SCAN_X1 i=" << i << " tmp[i]=" << tmp[i] << "\n";
103 doscan2(q, tmp.to1DSpan(), nb_block, block_size, identity);
104 if (m_is_verbose)
105 for (int i = 0; i < nb_block; ++i)
106 std::cout << "DO_SCAN_X2 i=" << i << " tmp[i]=" << tmp[i] << "\n";
107 doscan2_bis(q, tmp.to1DSpan(), tmp2.to1DSpan(), nb_block, block_size, identity);
108 if (m_is_verbose)
109 for (int i = 0; i < nb_block; ++i)
110 std::cout << "DO_SCAN_X2_BIS i=" << i << " tmp[i]=" << tmp[i] << " tmp2[i]=" << tmp2[i] << "\n";
111 doscan3(q, input_info, output, tmp2, nb_item, block_size);
112 }
113
114 private:
115
116 void doscan1(sycl::queue& q, const InputInfo& input_info, Span<DataType> tmp,
117 int nb_value, int block_size)
118 {
119 if (m_is_verbose)
120 std::cout << "DO_SCAN1 nb_value=" << nb_value << " L=" << block_size << "\n";
121 // Phase 1: Compute local scans over input blocks
122 Operator scan_op;
123 q.submit([&](sycl::handler& h) {
124 auto local = sycl::local_accessor<DataType, 1>(block_size, h);
125 h.parallel_for(_getNDRange(nb_value, block_size), [=](sycl::nd_item<1> it) {
126 const int i = static_cast<int>(it.get_global_id(0));
127 const int li = static_cast<int>(it.get_local_id(0));
128 const int gid = static_cast<int>(it.get_group(0));
129 const int local_range0 = static_cast<int>(it.get_local_range()[0]);
130 // Effectue le scan sur le groupe.
131 DataType local_value = input_info._getInputValue(i);
132 local[li] = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
133 // Le dernier élément sauve la valeur dans le tableau du groupe.
134 if (li == local_range0 - 1)
135 tmp[gid] = local[li];
136 });
137 })
138 .wait();
139 }
140
141 void doscan2(sycl::queue& q, Span<DataType> tmp, int nb_block, const int block_size, DataType identity)
142 {
143 if (m_is_verbose)
144 std::cout << "DO_SCAN2 nb_block=" << nb_block << " block_size=" << block_size << "\n";
145 // Phase 2: Compute scan over partial results
146 Operator scan_op;
147 q.submit([&](sycl::handler& h) {
148 auto local = sycl::local_accessor<DataType, 1>(block_size, h);
149 h.parallel_for(_getNDRange(nb_block, block_size), [=](sycl::nd_item<1> it) {
150 int i = static_cast<int>(it.get_global_id(0));
151 int li = static_cast<int>(it.get_local_id(0));
152 // Copy input to local memory
153 DataType local_value = (i < nb_block) ? tmp[i] : identity;
154 local[li] = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
155 // Overwrite result from each work-item in the temporary buffer
156 if (i < nb_block)
157 tmp[i] = local[li];
158 });
159 })
160 .wait();
161 }
162
163 void doscan2_bis(sycl::queue& q, Span<const DataType> tmp, Span<DataType> tmp2, int nb_block, int block_size, DataType identity)
164 {
165 if (m_is_verbose)
166 std::cout << "DO_SCAN2_bis nb_block=" << nb_block << " L=" << block_size << "\n";
167 Operator scan_op;
168 q.parallel_for(_getNDRange(nb_block, block_size), [=](sycl::nd_item<1> it) {
169 const int g = static_cast<int>(it.get_group(0));
170 const int i = static_cast<int>(it.get_global_id(0));
171 if (i < nb_block) {
172 DataType init_value = identity;
173 for (int j = 1; j <= g; ++j)
174 init_value = scan_op(init_value, tmp[(j * block_size) - 1]);
175 tmp2[i] = scan_op(init_value, tmp[i]);
176 }
177 })
178 .wait();
179 }
180
181 void doscan3(sycl::queue& q, const InputInfo& input_info, SmallSpan<DataType> output, SmallSpan<DataType> tmp2, int nb_value, int block_size)
182 {
183 if (m_is_verbose)
184 std::cout << "DO_SCAN3 nb_value=" << nb_value << " L=" << block_size << "\n";
185 Operator scan_op;
186 // Phase 3: Update local scans using partial results
187 q.parallel_for(_getNDRange(nb_value, block_size), [=](sycl::nd_item<1> it) {
188 const int i = static_cast<int>(it.get_global_id(0));
189 const int g = static_cast<int>(it.get_group(0));
190 DataType local_value = input_info._getInputValue(i);
191 DataType output_value = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
192 if (i < nb_value) {
193 output[i] = (g > 0) ? scan_op(output_value, tmp2[g - 1]) : output_value;
194 }
195 })
196 .wait();
197 }
198
199 private:
200
201 bool m_is_verbose = false;
202
203 private:
204
206 sycl::nd_range<1> _getNDRange(Int32 nb_value, Int32 block_size)
207 {
208 int x = nb_value / block_size;
209 if ((nb_value % block_size) != 0)
210 ++x;
211 x *= block_size;
212 return sycl::nd_range<1>(x, block_size);
213 }
214};
215#endif
216
217/*---------------------------------------------------------------------------*/
218/*---------------------------------------------------------------------------*/
219
220} // namespace Arcane::Accelerator::impl
221
222/*---------------------------------------------------------------------------*/
223/*---------------------------------------------------------------------------*/
224
225#endif
226
227/*---------------------------------------------------------------------------*/
228/*---------------------------------------------------------------------------*/
Types et macros pour gérer les boucles sur les accélérateurs.
Types et fonctions associés aux classes ArrayView et ConstArrayView.