47 explicit ScannerImpl(
const RunQueue& queue)
53 template <
bool IsExclusive,
typename InputIterator,
typename OutputIterator,
54 typename Operator,
typename DataType>
55 void apply(
Int32 nb_item, InputIterator input_data, OutputIterator output_data,
56 DataType init_value, Operator op,
const TraceInfo& trace_info)
59 command << trace_info;
63 switch (exec_policy) {
64#if defined(ARCANE_COMPILING_CUDA)
66 size_t temp_storage_size = 0;
67 cudaStream_t stream = impl::CudaUtils::toNativeStream(&m_queue);
69 if constexpr (IsExclusive)
70 ARCANE_CHECK_CUDA(::cub::DeviceScan::ExclusiveScan(
nullptr, temp_storage_size,
71 input_data, output_data, op, init_value, nb_item, stream));
73 ARCANE_CHECK_CUDA(::cub::DeviceScan::InclusiveScan(
nullptr, temp_storage_size,
74 input_data, output_data, op, nb_item, stream));
75 void* temp_storage = m_storage.allocate(temp_storage_size);
76 if constexpr (IsExclusive)
77 ARCANE_CHECK_CUDA(::cub::DeviceScan::ExclusiveScan(temp_storage, temp_storage_size,
78 input_data, output_data, op, init_value, nb_item, stream));
80 ARCANE_CHECK_CUDA(::cub::DeviceScan::InclusiveScan(temp_storage, temp_storage_size,
81 input_data, output_data, op, nb_item, stream));
84#if defined(ARCANE_COMPILING_HIP)
86 size_t temp_storage_size = 0;
88 hipStream_t stream = impl::HipUtils::toNativeStream(&m_queue);
89 if constexpr (IsExclusive)
90 ARCANE_CHECK_HIP(rocprim::exclusive_scan(
nullptr, temp_storage_size, input_data, output_data,
91 init_value, nb_item, op, stream));
93 ARCANE_CHECK_HIP(rocprim::inclusive_scan(
nullptr, temp_storage_size, input_data, output_data,
94 nb_item, op, stream));
95 void* temp_storage = m_storage.allocate(temp_storage_size);
96 if constexpr (IsExclusive)
97 ARCANE_CHECK_HIP(rocprim::exclusive_scan(temp_storage, temp_storage_size, input_data, output_data,
98 init_value, nb_item, op, stream));
100 ARCANE_CHECK_HIP(rocprim::inclusive_scan(temp_storage, temp_storage_size, input_data, output_data,
101 nb_item, op, stream));
104#if defined(ARCANE_COMPILING_SYCL)
106#if defined(ARCANE_USE_SCAN_ONEDPL) && defined(__INTEL_LLVM_COMPILER)
107 sycl::queue queue = impl::SyclUtils::toNativeStream(&m_queue);
108 auto policy = oneapi::dpl::execution::make_device_policy(queue);
109 if constexpr (IsExclusive) {
110 oneapi::dpl::exclusive_scan(policy, input_data, input_data + nb_item, output_data, init_value, op);
113 oneapi::dpl::inclusive_scan(policy, input_data, input_data + nb_item, output_data, op);
117 copy_input_data(nb_item);
126 in_data[i] = input_data[i];
130 SyclScanner<IsExclusive, DataType, Operator> scanner;
131 scanner.doScan(m_queue, in_data, out_data, init_value);
137 output_data[i] = out_data[i];
150 scanner.
doScan<IsExclusive, DataType>(launch_info.
loopRunInfo(), nb_item, input_data, output_data, init_value, op);
155 DataType sum = init_value;
156 for (
Int32 i = 0; i < nb_item; ++i) {
157 DataType v = *input_data;
158 if constexpr (IsExclusive) {
269 template <
typename DataType,
typename SetterLambda>
270 class SetterLambdaIterator
279 ARCCORE_HOST_DEVICE
explicit Setter(
const SetterLambda& s,
Int32 index)
283 ARCCORE_HOST_DEVICE
void operator=(
const DataType& value)
285 m_lambda(m_index, value);
291 SetterLambda m_lambda;
294 using value_type = DataType;
295 using iterator_category = std::random_access_iterator_tag;
297 using difference_type = ptrdiff_t;
298 using pointer = void;
299 using ThatClass = SetterLambdaIterator<DataType, SetterLambda>;
303 ARCCORE_HOST_DEVICE SetterLambdaIterator(
const SetterLambda& s)
313 ARCCORE_HOST_DEVICE ThatClass& operator++()
318 ARCCORE_HOST_DEVICE
friend ThatClass operator+(
const ThatClass& iter,
Int32 x)
320 return ThatClass(iter.m_lambda, iter.m_index + x);
322 ARCCORE_HOST_DEVICE
friend ThatClass operator+(Int32 x,
const ThatClass& iter)
324 return ThatClass(iter.m_lambda, iter.m_index + x);
326 ARCCORE_HOST_DEVICE
friend bool operator<(
const ThatClass& iter1,
const ThatClass& iter2)
328 return iter1.m_index < iter2.m_index;
330 ARCCORE_HOST_DEVICE ThatClass operator-(
Int32 x)
332 return ThatClass(m_lambda, m_index - x);
334 ARCCORE_HOST_DEVICE
Int32 operator-(
const ThatClass& x)
const
336 return m_index - x.m_index;
338 ARCCORE_HOST_DEVICE reference operator*()
const
340 return Setter(m_lambda, m_index);
342 ARCCORE_HOST_DEVICE reference operator[](
Int32 x)
const {
return Setter(m_lambda, m_index + x); }
343 ARCCORE_HOST_DEVICE
friend bool operator!=(
const ThatClass& a,
const ThatClass& b)
345 return a.m_index != b.m_index;
351 SetterLambda m_lambda;
356 explicit GenericScanner(
const RunQueue& queue)
362 template <
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
363 void applyWithIndexExclusive(
Int32 nb_value,
const DataType& initial_value,
364 const GetterLambda& getter_lambda,
365 const SetterLambda& setter_lambda,
366 const Operator& op_lambda,
367 const TraceInfo& trace_info = TraceInfo())
369 _applyWithIndex<true>(nb_value, initial_value, getter_lambda, setter_lambda, op_lambda, trace_info);
372 template <
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
373 void applyWithIndexInclusive(
Int32 nb_value,
const DataType& initial_value,
374 const GetterLambda& getter_lambda,
375 const SetterLambda& setter_lambda,
376 const Operator& op_lambda,
377 const TraceInfo& trace_info = TraceInfo())
379 _applyWithIndex<false>(nb_value, initial_value, getter_lambda, setter_lambda, op_lambda, trace_info);
382 template <
typename InputDataType,
typename OutputDataType,
typename Operator>
383 void applyExclusive(
const OutputDataType& initial_value,
384 SmallSpan<const InputDataType> input,
385 SmallSpan<OutputDataType> output,
386 const Operator& op_lambda,
387 const TraceInfo& trace_info = TraceInfo())
389 _apply<true>(initial_value, input, output, op_lambda, trace_info);
392 template <
typename InputDataType,
typename OutputDataType,
typename Operator>
393 void applyInclusive(
const OutputDataType& initial_value,
394 SmallSpan<const InputDataType> input,
395 SmallSpan<OutputDataType> output,
396 const Operator& op_lambda,
397 const TraceInfo& trace_info = TraceInfo())
399 _apply<false>(initial_value, input, output, op_lambda, trace_info);
404 template <
bool IsExclusive,
typename DataType,
typename GetterLambda,
typename SetterLambda,
typename Operator>
405 void _applyWithIndex(
Int32 nb_value,
const DataType& initial_value,
406 const GetterLambda& getter_lambda,
407 const SetterLambda& setter_lambda,
408 const Operator& op_lambda,
409 const TraceInfo& trace_info)
411 impl::GetterLambdaIterator<DataType, GetterLambda> input_iter(getter_lambda);
413 impl::ScannerImpl scanner(m_queue);
414 scanner.apply<IsExclusive>(nb_value, input_iter, output_iter, initial_value, op_lambda, trace_info);
417 template <
bool IsExclusive,
typename InputDataType,
typename OutputDataType,
typename Operator>
418 void _apply(
const OutputDataType& initial_value,
419 SmallSpan<const InputDataType> input,
420 SmallSpan<OutputDataType> output,
422 const TraceInfo& trace_info = TraceInfo())
424 const Int32 nb_item = input.size();
425 if (output.size() != nb_item)
426 ARCANE_FATAL(
"Sizes are not equals: input={0} output={1}", nb_item, output.size());
427 auto* input_data = input.data();
428 auto* output_data = output.data();
429 impl::ScannerImpl scanner(m_queue);
430 scanner.apply<IsExclusive>(nb_item, input_data, output_data, initial_value, op, trace_info);