Arcane  v3.15.0.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
ScanImpl.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2024 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* ScanImpl.h (C) 2000-2024 */
9/* */
10/* Implémentation spécifique de l'opération de scan pour les accélérateurs. */
11/*---------------------------------------------------------------------------*/
12#ifndef ARCANE_ACCELERATOR_SCANIMPL_H
13#define ARCANE_ACCELERATOR_SCANIMPL_H
14/*---------------------------------------------------------------------------*/
15/*---------------------------------------------------------------------------*/
16
17#include "arcane/utils/ArrayView.h"
18#include "arcane/utils/FatalErrorException.h"
19
20#include "arcane/utils/NumArray.h"
21
22#include "arcane/accelerator/core/RunQueue.h"
23
24#include "arcane/accelerator/AcceleratorGlobal.h"
25#include "arcane/accelerator/CommonUtils.h"
26#include "arcane/accelerator/RunCommandLaunchInfo.h"
28
29/*---------------------------------------------------------------------------*/
30/*---------------------------------------------------------------------------*/
31
32namespace Arcane::Accelerator::impl
33{
34
35/*---------------------------------------------------------------------------*/
36/*---------------------------------------------------------------------------*/
37
38#if defined(ARCANE_COMPILING_SYCL)
39
48template <bool IsExclusive, typename DataType, typename Operator>
49class SyclScanner
50{
51 class InputInfo
52 {
53 public:
54
55 DataType _getInputValue(Int32 index) const
56 {
57 DataType local_value = identity;
58 if constexpr (IsExclusive) {
59 if (index == 0)
60 local_value = init_value;
61 else
62 local_value = ((index - 1) < nb_value) ? input_values[index - 1] : identity;
63 }
64 else
65 local_value = (index < nb_value) ? input_values[index] : identity;
66 return local_value;
67 }
68
69 public:
70
71 SmallSpan<const DataType> input_values;
72 DataType identity = {};
73 DataType init_value = {};
74 Int32 nb_value = 0;
75 };
76
77 public:
78
79 void doScan(RunQueue& rq, SmallSpan<const DataType> input, SmallSpan<DataType> output, DataType init_value)
80 {
81 DataType identity = Operator::defaultValue();
82 sycl::queue q = impl::SyclUtils::toNativeStream(&rq);
83 // Contient l'application partielle de Operator pour chaque bloc de thread
84 NumArray<DataType, MDDim1> tmp;
85 // Contient l'application partielle de Operator cumulée avec les blocs précédents
86 NumArray<DataType, MDDim1> tmp2;
87 Int32 nb_item = input.size();
88 Int32 block_size = 256;
89 Int32 nb_block = (nb_item / block_size);
90 if ((nb_item % block_size) != 0)
91 ++nb_block;
92 tmp.resize(nb_block);
93 tmp2.resize(nb_block);
94 InputInfo input_info;
95 input_info.nb_value = nb_item;
96 input_info.init_value = init_value;
97 input_info.identity = identity;
98 input_info.input_values = input;
99 if (m_is_verbose)
100 std::cout << "DO_SCAN nb_item=" << nb_item << " nb_block=" << nb_block << "\n";
101 doscan1(q, input_info, tmp.to1DSpan(), nb_item, block_size);
102 if (m_is_verbose)
103 for (int i = 0; i < nb_block; ++i)
104 std::cout << "DO_SCAN_X1 i=" << i << " tmp[i]=" << tmp[i] << "\n";
105 doscan2(q, tmp.to1DSpan(), nb_block, block_size, identity);
106 if (m_is_verbose)
107 for (int i = 0; i < nb_block; ++i)
108 std::cout << "DO_SCAN_X2 i=" << i << " tmp[i]=" << tmp[i] << "\n";
109 doscan2_bis(q, tmp.to1DSpan(), tmp2.to1DSpan(), nb_block, block_size, identity);
110 if (m_is_verbose)
111 for (int i = 0; i < nb_block; ++i)
112 std::cout << "DO_SCAN_X2_BIS i=" << i << " tmp[i]=" << tmp[i] << " tmp2[i]=" << tmp2[i] << "\n";
113 doscan3(q, input_info, output, tmp2, nb_item, block_size);
114 }
115
116 private:
117
118 void doscan1(sycl::queue& q, const InputInfo& input_info, Span<DataType> tmp,
119 int nb_value, int block_size)
120 {
121 if (m_is_verbose)
122 std::cout << "DO_SCAN1 nb_value=" << nb_value << " L=" << block_size << "\n";
123 // Phase 1: Compute local scans over input blocks
124 Operator scan_op;
125 q.submit([&](sycl::handler& h) {
126 auto local = sycl::local_accessor<DataType, 1>(block_size, h);
127 h.parallel_for(_getNDRange(nb_value, block_size), [=](sycl::nd_item<1> it) {
128 const int i = static_cast<int>(it.get_global_id(0));
129 const int li = static_cast<int>(it.get_local_id(0));
130 const int gid = static_cast<int>(it.get_group(0));
131 const int local_range0 = static_cast<int>(it.get_local_range()[0]);
132 // Effectue le scan sur le groupe.
133 DataType local_value = input_info._getInputValue(i);
134 local[li] = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
135 // Le dernier élément sauve la valeur dans le tableau du groupe.
136 if (li == local_range0 - 1)
137 tmp[gid] = local[li];
138 });
139 })
140 .wait();
141 }
142
143 void doscan2(sycl::queue& q, Span<DataType> tmp, int nb_block, const int block_size, DataType identity)
144 {
145 if (m_is_verbose)
146 std::cout << "DO_SCAN2 nb_block=" << nb_block << " block_size=" << block_size << "\n";
147 // Phase 2: Compute scan over partial results
148 Operator scan_op;
149 q.submit([&](sycl::handler& h) {
150 auto local = sycl::local_accessor<DataType, 1>(block_size, h);
151 h.parallel_for(_getNDRange(nb_block, block_size), [=](sycl::nd_item<1> it) {
152 int i = static_cast<int>(it.get_global_id(0));
153 int li = static_cast<int>(it.get_local_id(0));
154 // Copy input to local memory
155 DataType local_value = (i < nb_block) ? tmp[i] : identity;
156 local[li] = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
157 // Overwrite result from each work-item in the temporary buffer
158 if (i < nb_block)
159 tmp[i] = local[li];
160 });
161 })
162 .wait();
163 }
164
165 void doscan2_bis(sycl::queue& q, Span<const DataType> tmp, Span<DataType> tmp2, int nb_block, int block_size, DataType identity)
166 {
167 if (m_is_verbose)
168 std::cout << "DO_SCAN2_bis nb_block=" << nb_block << " L=" << block_size << "\n";
169 Operator scan_op;
170 q.parallel_for(_getNDRange(nb_block, block_size), [=](sycl::nd_item<1> it) {
171 const int g = static_cast<int>(it.get_group(0));
172 const int i = static_cast<int>(it.get_global_id(0));
173 if (i < nb_block) {
174 DataType init_value = identity;
175 for (int j = 1; j <= g; ++j)
176 init_value = scan_op(init_value, tmp[(j * block_size) - 1]);
177 tmp2[i] = scan_op(init_value, tmp[i]);
178 }
179 })
180 .wait();
181 }
182
183 void doscan3(sycl::queue& q, const InputInfo& input_info, SmallSpan<DataType> output, SmallSpan<DataType> tmp2, int nb_value, int block_size)
184 {
185 if (m_is_verbose)
186 std::cout << "DO_SCAN3 nb_value=" << nb_value << " L=" << block_size << "\n";
187 Operator scan_op;
188 // Phase 3: Update local scans using partial results
189 q.parallel_for(_getNDRange(nb_value, block_size), [=](sycl::nd_item<1> it) {
190 const int i = static_cast<int>(it.get_global_id(0));
191 const int g = static_cast<int>(it.get_group(0));
192 DataType local_value = input_info._getInputValue(i);
193 DataType output_value = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
194 if (i < nb_value) {
195 output[i] = (g > 0) ? scan_op(output_value, tmp2[g - 1]) : output_value;
196 }
197 })
198 .wait();
199 }
200
201 private:
202
203 bool m_is_verbose = false;
204
205 private:
206
208 sycl::nd_range<1> _getNDRange(Int32 nb_value, Int32 block_size)
209 {
210 int x = nb_value / block_size;
211 if ((nb_value % block_size) != 0)
212 ++x;
213 x *= block_size;
214 return sycl::nd_range<1>(x, block_size);
215 }
216};
217#endif
218
219/*---------------------------------------------------------------------------*/
220/*---------------------------------------------------------------------------*/
221
222} // namespace Arcane::Accelerator::impl
223
224/*---------------------------------------------------------------------------*/
225/*---------------------------------------------------------------------------*/
226
227#endif
228
229/*---------------------------------------------------------------------------*/
230/*---------------------------------------------------------------------------*/
Types et macros pour gérer les boucles sur les accélérateurs.
Lecteur des fichiers de maillage via la bibliothèque LIMA.
Definition Lima.cc:149