12#ifndef ARCANE_ACCELERATOR_SCANIMPL_H
13#define ARCANE_ACCELERATOR_SCANIMPL_H
17#include "arcane/utils/ArrayView.h"
18#include "arcane/utils/FatalErrorException.h"
20#include "arcane/utils/NumArray.h"
22#include "arcane/accelerator/core/RunQueue.h"
24#include "arcane/accelerator/AcceleratorGlobal.h"
25#include "arcane/accelerator/CommonUtils.h"
26#include "arcane/accelerator/RunCommandLaunchInfo.h"
32namespace Arcane::Accelerator::impl
38#if defined(ARCANE_COMPILING_SYCL)
48template <
bool IsExclusive,
typename DataType,
typename Operator>
55 DataType _getInputValue(Int32 index)
const
57 DataType local_value = identity;
58 if constexpr (IsExclusive) {
60 local_value = init_value;
62 local_value = ((index - 1) < nb_value) ? input_values[index - 1] : identity;
65 local_value = (index < nb_value) ? input_values[index] : identity;
71 SmallSpan<const DataType> input_values;
72 DataType identity = {};
73 DataType init_value = {};
79 void doScan(RunQueue& rq, SmallSpan<const DataType> input, SmallSpan<DataType> output, DataType init_value)
81 DataType identity = Operator::defaultValue();
82 sycl::queue q = impl::SyclUtils::toNativeStream(&rq);
84 NumArray<DataType, MDDim1> tmp;
86 NumArray<DataType, MDDim1> tmp2;
87 Int32 nb_item = input.size();
88 Int32 block_size = 256;
89 Int32 nb_block = (nb_item / block_size);
90 if ((nb_item % block_size) != 0)
93 tmp2.resize(nb_block);
95 input_info.nb_value = nb_item;
96 input_info.init_value = init_value;
97 input_info.identity = identity;
98 input_info.input_values = input;
100 std::cout <<
"DO_SCAN nb_item=" << nb_item <<
" nb_block=" << nb_block <<
"\n";
101 doscan1(q, input_info, tmp.to1DSpan(), nb_item, block_size);
103 for (
int i = 0; i < nb_block; ++i)
104 std::cout <<
"DO_SCAN_X1 i=" << i <<
" tmp[i]=" << tmp[i] <<
"\n";
105 doscan2(q, tmp.to1DSpan(), nb_block, block_size, identity);
107 for (
int i = 0; i < nb_block; ++i)
108 std::cout <<
"DO_SCAN_X2 i=" << i <<
" tmp[i]=" << tmp[i] <<
"\n";
109 doscan2_bis(q, tmp.to1DSpan(), tmp2.to1DSpan(), nb_block, block_size, identity);
111 for (
int i = 0; i < nb_block; ++i)
112 std::cout <<
"DO_SCAN_X2_BIS i=" << i <<
" tmp[i]=" << tmp[i] <<
" tmp2[i]=" << tmp2[i] <<
"\n";
113 doscan3(q, input_info, output, tmp2, nb_item, block_size);
118 void doscan1(sycl::queue& q,
const InputInfo& input_info, Span<DataType> tmp,
119 int nb_value,
int block_size)
122 std::cout <<
"DO_SCAN1 nb_value=" << nb_value <<
" L=" << block_size <<
"\n";
125 q.submit([&](sycl::handler& h) {
126 auto local = sycl::local_accessor<DataType, 1>(block_size, h);
127 h.parallel_for(_getNDRange(nb_value, block_size), [=](sycl::nd_item<1> it) {
128 const int i =
static_cast<int>(it.get_global_id(0));
129 const int li =
static_cast<int>(it.get_local_id(0));
130 const int gid =
static_cast<int>(it.get_group(0));
131 const int local_range0 =
static_cast<int>(it.get_local_range()[0]);
133 DataType local_value = input_info._getInputValue(i);
134 local[li] = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
136 if (li == local_range0 - 1)
137 tmp[gid] = local[li];
143 void doscan2(sycl::queue& q, Span<DataType> tmp,
int nb_block,
const int block_size, DataType identity)
146 std::cout <<
"DO_SCAN2 nb_block=" << nb_block <<
" block_size=" << block_size <<
"\n";
149 q.submit([&](sycl::handler& h) {
150 auto local = sycl::local_accessor<DataType, 1>(block_size, h);
151 h.parallel_for(_getNDRange(nb_block, block_size), [=](sycl::nd_item<1> it) {
152 int i =
static_cast<int>(it.get_global_id(0));
153 int li =
static_cast<int>(it.get_local_id(0));
155 DataType local_value = (i < nb_block) ? tmp[i] : identity;
156 local[li] = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
165 void doscan2_bis(sycl::queue& q, Span<const DataType> tmp, Span<DataType> tmp2,
int nb_block,
int block_size, DataType identity)
168 std::cout <<
"DO_SCAN2_bis nb_block=" << nb_block <<
" L=" << block_size <<
"\n";
170 q.parallel_for(_getNDRange(nb_block, block_size), [=](sycl::nd_item<1> it) {
171 const int g =
static_cast<int>(it.get_group(0));
172 const int i =
static_cast<int>(it.get_global_id(0));
174 DataType init_value = identity;
175 for (
int j = 1; j <= g; ++j)
176 init_value = scan_op(init_value, tmp[(j * block_size) - 1]);
177 tmp2[i] = scan_op(init_value, tmp[i]);
183 void doscan3(sycl::queue& q,
const InputInfo& input_info, SmallSpan<DataType> output, SmallSpan<DataType> tmp2,
int nb_value,
int block_size)
186 std::cout <<
"DO_SCAN3 nb_value=" << nb_value <<
" L=" << block_size <<
"\n";
189 q.parallel_for(_getNDRange(nb_value, block_size), [=](sycl::nd_item<1> it) {
190 const int i =
static_cast<int>(it.get_global_id(0));
191 const int g =
static_cast<int>(it.get_group(0));
192 DataType local_value = input_info._getInputValue(i);
193 DataType output_value = sycl::inclusive_scan_over_group(it.get_group(), local_value, scan_op.syclFunctor());
195 output[i] = (g > 0) ? scan_op(output_value, tmp2[g - 1]) : output_value;
203 bool m_is_verbose =
false;
208 sycl::nd_range<1> _getNDRange(Int32 nb_value, Int32 block_size)
210 int x = nb_value / block_size;
211 if ((nb_value % block_size) != 0)
214 return sycl::nd_range<1>(x, block_size);
Types et macros pour gérer les boucles sur les accélérateurs.