47 explicit ScannerImpl(
const RunQueue& queue)
53 template <
bool IsExclusive,
typename InputIterator,
typename OutputIterator,
54 typename Operator,
typename DataType>
55 void apply(
Int32 nb_item, InputIterator input_data, OutputIterator output_data,
56 DataType init_value, Operator op,
const TraceInfo& trace_info)
59 command << trace_info;
63 switch (exec_policy) {
64#if defined(ARCANE_COMPILING_CUDA)
66 size_t temp_storage_size = 0;
67 cudaStream_t stream = Impl::CudaUtils::toNativeStream(&m_queue);
69 if constexpr (IsExclusive)
70 ARCANE_CHECK_CUDA(::cub::DeviceScan::ExclusiveScan(
nullptr, temp_storage_size,
71 input_data, output_data, op, init_value, nb_item, stream));
73 ARCANE_CHECK_CUDA(::cub::DeviceScan::InclusiveScan(
nullptr, temp_storage_size,
74 input_data, output_data, op, nb_item, stream));
75 void* temp_storage = m_storage.allocate(temp_storage_size);
76 if constexpr (IsExclusive)
77 ARCANE_CHECK_CUDA(::cub::DeviceScan::ExclusiveScan(temp_storage, temp_storage_size,
78 input_data, output_data, op, init_value, nb_item, stream));
80 ARCANE_CHECK_CUDA(::cub::DeviceScan::InclusiveScan(temp_storage, temp_storage_size,
81 input_data, output_data, op, nb_item, stream));
84#if defined(ARCANE_COMPILING_HIP)
86 size_t temp_storage_size = 0;
88 hipStream_t stream = Impl::HipUtils::toNativeStream(&m_queue);
89 if constexpr (IsExclusive)
90 ARCANE_CHECK_HIP(rocprim::exclusive_scan(
nullptr, temp_storage_size, input_data, output_data,
91 init_value, nb_item, op, stream));
93 ARCANE_CHECK_HIP(rocprim::inclusive_scan(
nullptr, temp_storage_size, input_data, output_data,
94 nb_item, op, stream));
95 void* temp_storage = m_storage.allocate(temp_storage_size);
96 if constexpr (IsExclusive)
97 ARCANE_CHECK_HIP(rocprim::exclusive_scan(temp_storage, temp_storage_size, input_data, output_data,
98 init_value, nb_item, op, stream));
100 ARCANE_CHECK_HIP(rocprim::inclusive_scan(temp_storage, temp_storage_size, input_data, output_data,
101 nb_item, op, stream));
104#if defined(ARCANE_COMPILING_SYCL)
106#if defined(ARCANE_USE_SCAN_ONEDPL) && defined(__INTEL_LLVM_COMPILER)
107 sycl::queue queue = impl::SyclUtils::toNativeStream(&m_queue);
108 auto policy = oneapi::dpl::execution::make_device_policy(queue);
109 if constexpr (IsExclusive) {
110 oneapi::dpl::exclusive_scan(policy, input_data, input_data + nb_item, output_data, init_value, op);
113 oneapi::dpl::inclusive_scan(policy, input_data, input_data + nb_item, output_data, op);
117 copy_input_data(nb_item);
126 in_data[i] = input_data[i];
130 SyclScanner<IsExclusive, DataType, Operator> scanner;
131 scanner.doScan(m_queue, in_data, out_data, init_value);
137 output_data[i] = out_data[i];
150 scanner.
doScan<IsExclusive, DataType>(launch_info.
loopRunInfo(), nb_item, input_data, output_data, init_value, op);
155 DataType sum = init_value;
156 for (
Int32 i = 0; i < nb_item; ++i) {
157 DataType v = *input_data;
158 if constexpr (IsExclusive) {
271 template <
typename DataType,
typename SetterLambda>
272 class SetterLambdaIterator
281 ARCCORE_HOST_DEVICE
explicit Setter(
const SetterLambda& s,
Int32 index)
285 ARCCORE_HOST_DEVICE
void operator=(
const DataType& value)
287 m_lambda(m_index, value);
293 SetterLambda m_lambda;
296 using value_type = DataType;
297 using iterator_category = std::random_access_iterator_tag;
299 using difference_type = ptrdiff_t;
300 using pointer = void;
301 using ThatClass = SetterLambdaIterator<DataType, SetterLambda>;
305 ARCCORE_HOST_DEVICE SetterLambdaIterator(
const SetterLambda& s)
315 ARCCORE_HOST_DEVICE ThatClass& operator++()
320 ARCCORE_HOST_DEVICE
friend ThatClass operator+(
const ThatClass& iter,
Int32 x)
322 return ThatClass(iter.m_lambda, iter.m_index + x);
324 ARCCORE_HOST_DEVICE
friend ThatClass operator+(Int32 x,
const ThatClass& iter)
326 return ThatClass(iter.m_lambda, iter.m_index + x);
328 ARCCORE_HOST_DEVICE
friend bool operator<(
const ThatClass& iter1,
const ThatClass& iter2)
330 return iter1.m_index < iter2.m_index;
332 ARCCORE_HOST_DEVICE ThatClass operator-(
Int32 x)
334 return ThatClass(m_lambda, m_index - x);
336 ARCCORE_HOST_DEVICE
Int32 operator-(
const ThatClass& x)
const
338 return m_index - x.m_index;
340 ARCCORE_HOST_DEVICE reference operator*()
const
342 return Setter(m_lambda, m_index);
344 ARCCORE_HOST_DEVICE reference operator[](
Int32 x)
const {
return Setter(m_lambda, m_index + x); }
345 ARCCORE_HOST_DEVICE
friend bool operator!=(
const ThatClass& a,
const ThatClass& b)
347 return a.m_index != b.m_index;
353 SetterLambda m_lambda;
358 explicit GenericScanner(
const RunQueue& queue)
364 template <
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
365 void applyWithIndexExclusive(
Int32 nb_value,
const DataType& initial_value,
366 const GetterLambda& getter_lambda,
367 const SetterLambda& setter_lambda,
368 const Operator& op_lambda,
369 const TraceInfo& trace_info = TraceInfo())
371 _applyWithIndex<true>(nb_value, initial_value, getter_lambda, setter_lambda, op_lambda, trace_info);
374 template <
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
375 void applyWithIndexInclusive(
Int32 nb_value,
const DataType& initial_value,
376 const GetterLambda& getter_lambda,
377 const SetterLambda& setter_lambda,
378 const Operator& op_lambda,
379 const TraceInfo& trace_info = TraceInfo())
381 _applyWithIndex<false>(nb_value, initial_value, getter_lambda, setter_lambda, op_lambda, trace_info);
384 template <
typename InputDataType,
typename OutputDataType,
typename Operator>
385 void applyExclusive(
const OutputDataType& initial_value,
386 SmallSpan<const InputDataType> input,
387 SmallSpan<OutputDataType> output,
388 const Operator& op_lambda,
389 const TraceInfo& trace_info = TraceInfo())
391 _apply<true>(initial_value, input, output, op_lambda, trace_info);
394 template <
typename InputDataType,
typename OutputDataType,
typename Operator>
395 void applyInclusive(
const OutputDataType& initial_value,
396 SmallSpan<const InputDataType> input,
397 SmallSpan<OutputDataType> output,
398 const Operator& op_lambda,
399 const TraceInfo& trace_info = TraceInfo())
401 _apply<false>(initial_value, input, output, op_lambda, trace_info);
406 template <
bool IsExclusive,
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
407 void _applyWithIndex(
Int32 nb_value,
const DataType& initial_value,
408 const GetterLambda& getter_lambda,
409 const SetterLambda& setter_lambda,
410 const Operator& op_lambda,
411 const TraceInfo& trace_info)
413 impl::GetterLambdaIterator<DataType, GetterLambda> input_iter(getter_lambda);
415 impl::ScannerImpl scanner(m_queue);
416 scanner.apply<IsExclusive>(nb_value, input_iter, output_iter, initial_value, op_lambda, trace_info);
420 template <
bool IsExclusive,
typename InputDataType,
typename OutputDataType,
typename Operator>
421 void _apply(
const OutputDataType& initial_value,
422 SmallSpan<const InputDataType> input,
423 SmallSpan<OutputDataType> output,
425 const TraceInfo& trace_info = TraceInfo())
427 const Int32 nb_item = input.size();
428 if (output.size() != nb_item)
429 ARCANE_FATAL(
"Sizes are not equals: input={0} output={1}", nb_item, output.size());
430 auto* input_data = input.data();
431 auto* output_data = output.data();
432 impl::ScannerImpl scanner(m_queue);
433 scanner.apply<IsExclusive>(nb_item, input_data, output_data, initial_value, op, trace_info);
441 if (!m_queue.isAsync())