Arcane  v4.1.1.0
Documentation utilisateur
Chargement...
Recherche...
Aucune correspondance
GenericSorter.h
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2025 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* GenericSorter.h (C) 2000-2025 */
9/* */
10/* Algorithme de tri. */
11/*---------------------------------------------------------------------------*/
12#ifndef ARCANE_ACCELERATOR_GENERICSORTER_H
13#define ARCANE_ACCELERATOR_GENERICSORTER_H
14/*---------------------------------------------------------------------------*/
15/*---------------------------------------------------------------------------*/
16
17#include "arcane/utils/ArrayView.h"
18#include "arcane/utils/FatalErrorException.h"
19#include "arcane/utils/NotImplementedException.h"
20#include "arcane/utils/NumArray.h"
21
22#include "arcane/accelerator/AcceleratorGlobal.h"
23#include "arcane/accelerator/core/RunQueue.h"
24#include "arcane/accelerator/CommonUtils.h"
25
26#if defined(ARCANE_COMPILING_SYCL)
28#endif
29
30#include <algorithm>
31
32/*---------------------------------------------------------------------------*/
33/*---------------------------------------------------------------------------*/
34
35namespace Arcane::Accelerator::impl
36{
37
38/*---------------------------------------------------------------------------*/
39/*---------------------------------------------------------------------------*/
40/*!
41 * \internal
42 * \brief Classe de base pour effectuer un tri.
43 *
44 * Contient les arguments nécessaires pour effectuer le tri.
45 */
46class ARCANE_ACCELERATOR_EXPORT GenericSorterBase
47{
48 friend class GenericSorterMergeSort;
49
50 public:
51
52 explicit GenericSorterBase(const RunQueue& queue);
53
54 protected:
55
56 RunQueue m_queue;
57 GenericDeviceStorage m_algo_storage;
58
59 protected:
60
61 void _checkBarrier()
62 {
63 // Les fonctions cub ou rocprim pour le scan sont asynchrones par défaut.
64 // Si on a une RunQueue synchrone, alors on fait une barrière.
65 if (!m_queue.isAsync())
66 m_queue.barrier();
67 }
68};
69
70/*---------------------------------------------------------------------------*/
71/*---------------------------------------------------------------------------*/
72/*!
73 * \internal
74 * \brief Classe pour effectuer le tri d'une liste.
75 *
76 * La classe utilisateur associée est GenericSorter
77 */
79{
80 // TODO: Faire le malloc sur le device associé à la queue.
81 // et aussi regarder si on peut utiliser mallocAsync().
82
83 public:
84
85 template <typename CompareLambda, typename InputIterator, typename OutputIterator>
86 void apply(GenericSorterBase& s, Int32 nb_item, InputIterator input_iter,
87 OutputIterator output_iter, const CompareLambda& compare_lambda)
88 {
89 RunQueue queue = s.m_queue;
90 eExecutionPolicy exec_policy = queue.executionPolicy();
91 switch (exec_policy) {
92#if defined(ARCANE_COMPILING_CUDA)
93 case eExecutionPolicy::CUDA: {
94 size_t temp_storage_size = 0;
95 cudaStream_t stream = Impl::CudaUtils::toNativeStream(&queue);
96 // Premier appel pour connaitre la taille pour l'allocation
97 ARCANE_CHECK_CUDA(::cub::DeviceMergeSort::SortKeysCopy(nullptr, temp_storage_size,
98 input_iter, output_iter, nb_item,
99 compare_lambda, stream));
100
101 s.m_algo_storage.allocate(temp_storage_size);
102 ARCANE_CHECK_CUDA(::cub::DeviceMergeSort::SortKeysCopy(s.m_algo_storage.address(), temp_storage_size,
103 input_iter, output_iter, nb_item,
104 compare_lambda, stream));
105 } break;
106#endif
107#if defined(ARCANE_COMPILING_HIP)
108 case eExecutionPolicy::HIP: {
109 size_t temp_storage_size = 0;
110 hipStream_t stream = Impl::HipUtils::toNativeStream(&queue);
111 // Premier appel pour connaitre la taille pour l'allocation
112 ARCANE_CHECK_HIP(rocprim::merge_sort(nullptr, temp_storage_size, input_iter, output_iter,
113 nb_item, compare_lambda, stream));
114
115 s.m_algo_storage.allocate(temp_storage_size);
116
117 ARCANE_CHECK_HIP(rocprim::merge_sort(s.m_algo_storage.address(), temp_storage_size, input_iter, output_iter,
118 nb_item, compare_lambda, stream));
119 } break;
120#endif
121#if defined(ARCANE_COMPILING_SYCL)
122 case eExecutionPolicy::SYCL: {
123 {
124 // Copie input dans output
125 auto command = makeCommand(queue);
126 command << RUNCOMMAND_LOOP1(iter, nb_item)
127 {
128 auto [i] = iter();
129 *(output_iter + i) = *(input_iter + i);
130 };
131 }
132#if defined(ARCANE_HAS_ONEDPL)
133 sycl::queue true_queue = AcceleratorUtils::toSyclNativeStream(queue);
134 auto policy = oneapi::dpl::execution::make_device_policy(true_queue);
135 oneapi::dpl::sort(policy, output_iter, output_iter + nb_item, compare_lambda);
136#elif defined(__ADAPTIVECPP__)
137 sycl::queue true_queue = AcceleratorUtils::toSyclNativeStream(queue);
138 sycl::event e = acpp::algorithms::sort(true_queue, output_iter, output_iter + nb_item, compare_lambda);
139 e.wait();
140#else
141 ARCANE_THROW(NotImplementedException, "Sort is only implemented for SYCL back-end using oneDPL or AdaptiveCpp");
142#endif
143 } break;
144#endif
145 case eExecutionPolicy::Thread:
146 // Pas encore implémenté en multi-thread
147 [[fallthrough]];
148 case eExecutionPolicy::Sequential: {
149 // Copie input dans output
150 auto output_iter_begin = output_iter;
151 for (Int32 i = 0; i < nb_item; ++i) {
152 *output_iter = *input_iter;
153 ++output_iter;
154 ++input_iter;
155 }
156 std::sort(output_iter_begin, output_iter, compare_lambda);
157 } break;
158 default:
159 ARCANE_FATAL(getBadPolicyMessage(exec_policy));
160 }
161 }
162};
163
164/*---------------------------------------------------------------------------*/
165/*---------------------------------------------------------------------------*/
166
167} // namespace Arcane::Accelerator::impl
168
169namespace Arcane::Accelerator
170{
171
172/*---------------------------------------------------------------------------*/
173/*---------------------------------------------------------------------------*/
174/*!
175 * \brief Algorithme générique de tri sur accélérateur.
176 */
177class GenericSorter
179{
180 public:
181
182 explicit GenericSorter(const RunQueue& queue)
184 {
185 }
186
187 public:
188
189 /*!
190 * \brief Tri les entités.
191 *
192 * Remplit \a output avec les valeurs de \a input triées via le comparateur
193 * par défaut pour le type \a DataType. Le tableau \a input n'est pas modifié.
194 *
195 * \pre output.size() >= input.size()
196 */
197 template <typename DataType>
199 {
200 impl::GenericSorterBase* base_ptr = this;
202 Int32 nb_item = input.size();
203 if (output.size() < nb_item)
204 ARCANE_FATAL("Output size '{0}' is smaller than input size '{1}'",
205 output.size(), nb_item);
206 gf.apply(*base_ptr, nb_item, input.data(), output.data(), std::less<DataType>{});
207 }
208};
209
210/*---------------------------------------------------------------------------*/
211/*---------------------------------------------------------------------------*/
212
213} // namespace Arcane::Accelerator
214
215/*---------------------------------------------------------------------------*/
216/*---------------------------------------------------------------------------*/
217
218#endif
219
220/*---------------------------------------------------------------------------*/
221/*---------------------------------------------------------------------------*/
#define ARCANE_THROW(exception_class,...)
Macro pour envoyer une exception avec formattage.
#define ARCANE_FATAL(...)
Macro envoyant une exception FatalErrorException.
Types et macros pour gérer les boucles sur les accélérateurs.
#define RUNCOMMAND_LOOP1(iter_name, x1,...)
Boucle 1D sur accélérateur avec arguments supplémentaires.
void apply(SmallSpan< const DataType > input, SmallSpan< DataType > output)
Tri les entités.
Vue d'un tableau d'éléments de type T.
Definition Span.h:801
constexpr __host__ __device__ pointer data() const noexcept
Pointeur sur le début de la vue.
Definition Span.h:537
constexpr __host__ __device__ SizeType size() const noexcept
Retourne la taille du tableau.
Definition Span.h:325
std::int32_t Int32
Type entier signé sur 32 bits.