12#ifndef ARCANE_ACCELERATOR_SCAN_H
13#define ARCANE_ACCELERATOR_SCAN_H
17#include "arcane/utils/ArrayView.h"
18#include "arcane/utils/FatalErrorException.h"
20#include "arcane/utils/NumArray.h"
22#include "arcane/accelerator/core/RunQueue.h"
24#include "arcane/accelerator/AcceleratorGlobal.h"
25#include "arcane/accelerator/CommonUtils.h"
26#include "arcane/accelerator/RunCommandLaunchInfo.h"
28#include "arcane/accelerator/ScanImpl.h"
33namespace Arcane::Accelerator::impl
52 template <
bool IsExclusive,
typename InputIterator,
typename OutputIterator,
53 typename Operator,
typename DataType>
54 void apply(
Int32 nb_item, InputIterator input_data, OutputIterator output_data,
55 DataType init_value, Operator op,
const TraceInfo& trace_info)
58 command << trace_info;
62 switch (exec_policy) {
63#if defined(ARCANE_COMPILING_CUDA)
65 size_t temp_storage_size = 0;
66 cudaStream_t stream = impl::CudaUtils::toNativeStream(&m_queue);
68 if constexpr (IsExclusive)
69 ARCANE_CHECK_CUDA(::cub::DeviceScan::ExclusiveScan(
nullptr, temp_storage_size,
70 input_data, output_data, op, init_value, nb_item, stream));
72 ARCANE_CHECK_CUDA(::cub::DeviceScan::InclusiveScan(
nullptr, temp_storage_size,
73 input_data, output_data, op, nb_item, stream));
74 void* temp_storage = m_storage.allocate(temp_storage_size);
75 if constexpr (IsExclusive)
76 ARCANE_CHECK_CUDA(::cub::DeviceScan::ExclusiveScan(temp_storage, temp_storage_size,
77 input_data, output_data, op, init_value, nb_item, stream));
79 ARCANE_CHECK_CUDA(::cub::DeviceScan::InclusiveScan(temp_storage, temp_storage_size,
80 input_data, output_data, op, nb_item, stream));
83#if defined(ARCANE_COMPILING_HIP)
85 size_t temp_storage_size = 0;
87 hipStream_t stream = impl::HipUtils::toNativeStream(&m_queue);
88 if constexpr (IsExclusive)
89 ARCANE_CHECK_HIP(rocprim::exclusive_scan(
nullptr, temp_storage_size, input_data, output_data,
90 init_value, nb_item, op, stream));
92 ARCANE_CHECK_HIP(rocprim::inclusive_scan(
nullptr, temp_storage_size, input_data, output_data,
93 nb_item, op, stream));
94 void* temp_storage = m_storage.allocate(temp_storage_size);
95 if constexpr (IsExclusive)
96 ARCANE_CHECK_HIP(rocprim::exclusive_scan(temp_storage, temp_storage_size, input_data, output_data,
97 init_value, nb_item, op, stream));
99 ARCANE_CHECK_HIP(rocprim::inclusive_scan(temp_storage, temp_storage_size, input_data, output_data,
100 nb_item, op, stream));
103#if defined(ARCANE_COMPILING_SYCL)
105#if defined(ARCANE_USE_SCAN_ONEDPL) && defined(__INTEL_LLVM_COMPILER)
106 sycl::queue queue = impl::SyclUtils::toNativeStream(&m_queue);
107 auto policy = oneapi::dpl::execution::make_device_policy(queue);
108 if constexpr (IsExclusive) {
109 oneapi::dpl::exclusive_scan(policy, input_data, input_data + nb_item, output_data, init_value, op);
112 oneapi::dpl::inclusive_scan(policy, input_data, input_data + nb_item, output_data, op);
116 copy_input_data(nb_item);
125 in_data[i] = input_data[i];
129 SyclScanner<IsExclusive, DataType, Operator> scanner;
130 scanner.doScan(m_queue, in_data, out_data, init_value);
136 output_data[i] = out_data[i];
147 DataType sum = init_value;
148 for (
Int32 i = 0; i < nb_item; ++i) {
149 DataType v = *input_data;
150 if constexpr (IsExclusive) {
191template <
typename DataType>
199 _applyArray<true>(queue, input, output, ScannerSumOperator<DataType>{});
202 static void exclusiveMin(RunQueue* queue, SmallSpan<const DataType> input, SmallSpan<DataType> output)
204 _applyArray<true>(queue, input, output, ScannerMinOperator<DataType>{});
207 static void exclusiveMax(RunQueue* queue, SmallSpan<const DataType> input, SmallSpan<DataType> output)
209 _applyArray<true>(queue, input, output, ScannerMaxOperator<DataType>{});
212 static void inclusiveSum(RunQueue* queue, SmallSpan<const DataType> input, SmallSpan<DataType> output)
214 _applyArray<false>(queue, input, output, ScannerSumOperator<DataType>{});
217 static void inclusiveMin(RunQueue* queue, SmallSpan<const DataType> input, SmallSpan<DataType> output)
219 _applyArray<false>(queue, input, output, ScannerMinOperator<DataType>{});
222 static void inclusiveMax(RunQueue* queue, SmallSpan<const DataType> input, SmallSpan<DataType> output)
224 _applyArray<false>(queue, input, output, ScannerMaxOperator<DataType>{});
229 template <
bool IsExclusive,
typename Operator>
230 static void _applyArray(RunQueue* queue, SmallSpan<const DataType> input, SmallSpan<DataType> output,
const Operator& op)
233 impl::ScannerImpl scanner(*queue);
234 const Int32 nb_item = input.size();
235 if (output.size() != nb_item)
236 ARCANE_FATAL(
"Sizes are not equals: input={0} output={1}", nb_item, output.size());
237 const DataType* input_data = input.data();
238 DataType* output_data = output.data();
239 DataType init_value = op.defaultValue();
240 scanner.apply<IsExclusive>(nb_item, input_data, output_data, init_value, op, TraceInfo{});
261 template <
typename DataType,
typename SetterLambda>
271 ARCCORE_HOST_DEVICE
explicit Setter(
const SetterLambda& s, Int32 index)
275 ARCCORE_HOST_DEVICE
void operator=(
const DataType& value)
277 m_lambda(m_index, value);
283 SetterLambda m_lambda;
286 using value_type = DataType;
287 using iterator_category = std::random_access_iterator_tag;
289 using difference_type = ptrdiff_t;
290 using pointer = void;
305 ARCCORE_HOST_DEVICE ThatClass& operator++()
310 ARCCORE_HOST_DEVICE
friend ThatClass operator+(
const ThatClass& iter, Int32 x)
312 return ThatClass(iter.m_lambda, iter.m_index + x);
314 ARCCORE_HOST_DEVICE
friend ThatClass operator+(Int32 x,
const ThatClass& iter)
316 return ThatClass(iter.m_lambda, iter.m_index + x);
318 ARCCORE_HOST_DEVICE
friend bool operator<(
const ThatClass& iter1,
const ThatClass& iter2)
320 return iter1.m_index < iter2.m_index;
322 ARCCORE_HOST_DEVICE ThatClass operator-(Int32 x)
324 return ThatClass(m_lambda, m_index - x);
326 ARCCORE_HOST_DEVICE Int32 operator-(
const ThatClass& x)
const
328 return m_index - x.m_index;
330 ARCCORE_HOST_DEVICE reference operator*()
const
332 return Setter(m_lambda, m_index);
334 ARCCORE_HOST_DEVICE reference operator[](Int32 x)
const {
return Setter(m_lambda, m_index + x); }
335 ARCCORE_HOST_DEVICE
friend bool operator!=(
const ThatClass& a,
const ThatClass& b)
337 return a.m_index != b.m_index;
343 SetterLambda m_lambda;
348 explicit GenericScanner(
const RunQueue& queue)
354 template <
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
355 void applyWithIndexExclusive(Int32 nb_value,
const DataType& initial_value,
356 const GetterLambda& getter_lambda,
357 const SetterLambda& setter_lambda,
358 const Operator& op_lambda,
359 const TraceInfo& trace_info = TraceInfo())
361 _applyWithIndex<true>(nb_value, initial_value, getter_lambda, setter_lambda, op_lambda, trace_info);
364 template <
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
365 void applyWithIndexInclusive(Int32 nb_value,
const DataType& initial_value,
366 const GetterLambda& getter_lambda,
367 const SetterLambda& setter_lambda,
368 const Operator& op_lambda,
369 const TraceInfo& trace_info = TraceInfo())
371 _applyWithIndex<false>(nb_value, initial_value, getter_lambda, setter_lambda, op_lambda, trace_info);
374 template <
typename InputDataType,
typename OutputDataType,
typename Operator>
375 void applyExclusive(
const OutputDataType& initial_value,
376 SmallSpan<const InputDataType> input,
377 SmallSpan<OutputDataType> output,
378 const Operator& op_lambda,
379 const TraceInfo& trace_info = TraceInfo())
381 _apply<true>(initial_value, input, output, op_lambda, trace_info);
384 template <
typename InputDataType,
typename OutputDataType,
typename Operator>
385 void applyInclusive(
const OutputDataType& initial_value,
386 SmallSpan<const InputDataType> input,
387 SmallSpan<OutputDataType> output,
388 const Operator& op_lambda,
389 const TraceInfo& trace_info = TraceInfo())
391 _apply<false>(initial_value, input, output, op_lambda, trace_info);
396 template <
bool IsExclusive,
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
397 void _applyWithIndex(Int32 nb_value,
const DataType& initial_value,
398 const GetterLambda& getter_lambda,
399 const SetterLambda& setter_lambda,
400 const Operator& op_lambda,
401 const TraceInfo& trace_info)
403 impl::GetterLambdaIterator<DataType, GetterLambda> input_iter(getter_lambda);
404 SetterLambdaIterator<DataType, SetterLambda> output_iter(setter_lambda);
405 impl::ScannerImpl scanner(m_queue);
406 scanner.apply<IsExclusive>(nb_value, input_iter, output_iter, initial_value, op_lambda, trace_info);
409 template <
bool IsExclusive,
typename InputDataType,
typename OutputDataType,
typename Operator>
410 void _apply(
const OutputDataType& initial_value,
411 SmallSpan<const InputDataType> input,
412 SmallSpan<OutputDataType> output,
414 const TraceInfo& trace_info = TraceInfo())
416 const Int32 nb_item = input.size();
417 if (output.size() != nb_item)
418 ARCANE_FATAL(
"Sizes are not equals: input={0} output={1}", nb_item, output.size());
419 auto* input_data = input.data();
420 auto* output_data = output.data();
421 impl::ScannerImpl scanner(m_queue);
422 scanner.apply<IsExclusive>(nb_item, input_data, output_data, initial_value, op, trace_info);
#define ARCANE_CHECK_POINTER(ptr)
Macro retournant le pointeur ptr s'il est non nul ou lancant une exception s'il est nul.
#define ARCANE_FATAL(...)
Macro envoyant une exception FatalErrorException.
Types et macros pour gérer les boucles sur les accélérateurs.
#define RUNCOMMAND_LOOP1(iter_name, x1,...)
Boucle sur accélérateur avec arguments supplémentaires pour les réductions.
Permet de positionner un élément de l'itérateur de sortie.
Itérateur sur une lambda pour positionner une valeur via un index.
Gestion d'une commande sur accélérateur.
File d'exécution pour un accélérateur.
void barrier() const
Bloque tant que toutes les commandes associées à la file ne sont pas terminées.
eExecutionPolicy executionPolicy() const
Politique d'exécution de la file.
void beginExecute()
Indique qu'on commence l'exécution de la commande.
void endExecute()
Signale la fin de l'exécution.
Tableaux multi-dimensionnels pour les types numériques accessibles sur accélérateurs.
constexpr SmallSpan< DataType > to1DSmallSpan()
Vue 1D sur l'instance (uniquement si rank == 1)
Vue d'un tableau d'éléments de type T.
Espace de nom pour l'utilisation des accélérateurs.
RunCommand makeCommand(const RunQueue &run_queue)
Créé une commande associée à la file run_queue.
eExecutionPolicy
Politique d'exécution pour un Runner.
@ SYCL
Politique d'exécution utilisant l'environnement SYCL.
@ HIP
Politique d'exécution utilisant l'environnement HIP.
@ CUDA
Politique d'exécution utilisant l'environnement CUDA.
@ Sequential
Politique d'exécution séquentielle.
@ Thread
Politique d'exécution multi-thread.
std::int32_t Int32
Type entier signé sur 32 bits.