12#ifndef ARCCORE_ACCELERATOR_COMMONCUDHIPAREDUCEIMPL_H
13#define ARCCORE_ACCELERATOR_COMMONCUDHIPAREDUCEIMPL_H
20#include "arccore/accelerator/AcceleratorGlobal.h"
30namespace Arcane::Accelerator::Impl
33__device__ __forceinline__
unsigned int getThreadId()
35 int threadId = threadIdx.x;
39__device__ __forceinline__
unsigned int getBlockId()
41 int blockId = blockIdx.x;
45constexpr const Int32 MAX_BLOCK_SIZE = 1024;
47#if defined(__CUDACC__)
48ARCCORE_DEVICE
inline double shfl_xor_sync(
double var,
int laneMask)
50 return ::__shfl_xor_sync(0xffffffffu, var, laneMask);
53ARCCORE_DEVICE
inline int shfl_xor_sync(
int var,
int laneMask)
55 return ::__shfl_xor_sync(0xffffffffu, var, laneMask);
58ARCCORE_DEVICE
inline Int64 shfl_xor_sync(
Int64 var,
int laneMask)
60 return ::__shfl_xor_sync(0xffffffffu, var, laneMask);
63ARCCORE_DEVICE
inline double shfl_sync(
double var,
int laneMask)
65 return ::__shfl_sync(0xffffffffu, var, laneMask);
68ARCCORE_DEVICE
inline int shfl_sync(
int var,
int laneMask)
70 return ::__shfl_sync(0xffffffffu, var, laneMask);
73ARCCORE_DEVICE
inline Int64 shfl_sync(
Int64 var,
int laneMask)
75 return ::__shfl_sync(0xffffffffu, var, laneMask);
79ARCCORE_DEVICE
inline double shfl_xor_sync(
double var,
int laneMask)
81 return ::__shfl_xor(var, laneMask);
84ARCCORE_DEVICE
inline int shfl_xor_sync(
int var,
int laneMask)
86 return ::__shfl_xor(var, laneMask);
89ARCCORE_DEVICE
inline Int64 shfl_xor_sync(
Int64 var,
int laneMask)
91 return ::__shfl_xor(var, laneMask);
94ARCCORE_DEVICE
inline double shfl_sync(
double var,
int laneMask)
96 return ::__shfl(var, laneMask);
99ARCCORE_DEVICE
inline int shfl_sync(
int var,
int laneMask)
101 return ::__shfl(var, laneMask);
104ARCCORE_DEVICE
inline Int64 shfl_sync(
Int64 var,
int laneMask)
106 return ::__shfl(var, laneMask);
114template <
typename ReduceOperator, Int32 WarpSize,
typename T>
115ARCCORE_DEVICE
inline T block_reduce(T val)
117 constexpr Int32 WARP_SIZE = WarpSize;
118 constexpr const Int32 MAX_WARPS = MAX_BLOCK_SIZE / WARP_SIZE;
119 int numThreads = blockDim.x;
121 int threadId = getThreadId();
123 int warpId = threadId % WARP_SIZE;
124 int warpNum = threadId / WARP_SIZE;
128 if (numThreads % WARP_SIZE == 0) {
131 for (
int i = 1; i < WARP_SIZE; i *= 2) {
132 T rhs = Impl::shfl_xor_sync(temp, i);
133 ReduceOperator::combine(temp, rhs);
139 for (
int i = 1; i < WARP_SIZE; i *= 2) {
140 int srcLane = threadId ^ i;
141 T rhs = Impl::shfl_sync(temp, srcLane);
143 if (srcLane < numThreads) {
144 ReduceOperator::combine(temp, rhs);
152 if (numThreads > WARP_SIZE) {
154 __shared__ T sd[MAX_WARPS];
166 if (warpId * WARP_SIZE < numThreads) {
170 temp = ReduceOperator::identity();
172 for (
int i = 1; i < WARP_SIZE; i *= 2) {
173 T rhs = Impl::shfl_xor_sync(temp, i);
174 ReduceOperator::combine(temp, rhs);
187template <
typename ReduceOperator, Int32 WarpSize,
typename T>
188ARCCORE_DEVICE
inline bool
189grid_reduce(T& val, SmallSpan<T> device_mem,
unsigned int* device_count)
191 int numBlocks = gridDim.x;
192 int numThreads = blockDim.x;
193 int wrap_around = numBlocks - 1;
194 int blockId = blockIdx.x;
195 int threadId = threadIdx.x;
197 T temp = block_reduce<ReduceOperator, WarpSize, T>(val);
200 bool lastBlock =
false;
202 device_mem[blockId] = temp;
210 unsigned int old_count = ::atomicInc(device_count, wrap_around);
211 lastBlock = ((int)old_count == wrap_around);
215 lastBlock = __syncthreads_or(lastBlock);
219 temp = ReduceOperator::identity();
221 for (
int i = threadId; i < numBlocks; i += numThreads) {
222 ReduceOperator::combine(temp, device_mem[i]);
225 temp = block_reduce<ReduceOperator, WarpSize, T>(temp);
233 return lastBlock && threadId == 0;
239template <
typename ReduceOperator>
240ARCCORE_INLINE_REDUCE ARCCORE_DEVICE
void
243 using DataType =
typename ReduceOperator::DataType;
244 SmallSpan<DataType> grid_buffer = dev_info.m_grid_buffer;
245 unsigned int* device_count = dev_info.m_device_count;
246 DataType* host_pinned_ptr = dev_info.m_host_pinned_final_ptr;
247 DataType v = dev_info.m_current_value;
248#if HIP_VERSION_MAJOR >= 7
259 const Int32 warp_size = dev_info.m_warp_size;
262 constexpr const Int32 WARP_SIZE = warpSize;
264 constexpr const Int32 WARP_SIZE = 32;
273#if HIP_VERSION_MAJOR >= 7
274 bool is_done =
false;
276 is_done = grid_reduce<ReduceOperator, 64, DataType>(v, grid_buffer, device_count);
277 else if (warp_size == 32)
278 is_done = grid_reduce<ReduceOperator, 32, DataType>(v, grid_buffer, device_count);
280 assert(
"Bad warp size (should be 32 or 64)");
282 bool is_done = grid_reduce<ReduceOperator, WARP_SIZE, DataType>(v, grid_buffer, device_count);
285 *host_pinned_ptr = v;
std::int64_t Int64
Type entier signé sur 64 bits.
std::int32_t Int32
Type entier signé sur 32 bits.