Arcane  v3.14.10.0
Documentation développeur
Chargement...
Recherche...
Aucune correspondance
MpiSerializeDispatcher.cc
1// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
2//-----------------------------------------------------------------------------
3// Copyright 2000-2024 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
4// See the top-level COPYRIGHT file for details.
5// SPDX-License-Identifier: Apache-2.0
6//-----------------------------------------------------------------------------
7/*---------------------------------------------------------------------------*/
8/* MpiSerializeDispatcher.cc (C) 2000-2024 */
9/* */
10/* Gestion des messages de sérialisation avec MPI. */
11/*---------------------------------------------------------------------------*/
12/*---------------------------------------------------------------------------*/
13
14#include "arccore/message_passing_mpi/MpiSerializeDispatcher.h"
15
16#include "arccore/message_passing_mpi/MpiAdapter.h"
17#include "arccore/message_passing_mpi/MpiMessagePassingMng.h"
18#include "arccore/message_passing_mpi/MpiSerializeMessageList.h"
19#include "arccore/message_passing_mpi/MpiLock.h"
20#include "arccore/message_passing/Request.h"
21#include "arccore/message_passing/internal/SubRequestCompletionInfo.h"
22#include "arccore/serialize/BasicSerializer.h"
23#include "arccore/base/NotImplementedException.h"
24#include "arccore/base/FatalErrorException.h"
25#include "arccore/base/NotSupportedException.h"
26#include "arccore/base/ArgumentException.h"
27#include "arccore/base/PlatformUtils.h"
28#include "arccore/trace/ITraceMng.h"
29
30/*---------------------------------------------------------------------------*/
31/*---------------------------------------------------------------------------*/
32
33namespace Arccore::MessagePassing::Mpi
34{
35
36/*---------------------------------------------------------------------------*/
37/*---------------------------------------------------------------------------*/
57template<typename SpanType>
59{
60 public:
62 : m_buffer(buffer), m_datatype(byte_serializer_datatype), m_final_size(-1)
63 {
64 Int64 size = buffer.size();
66 if ((size%align_size)!=0)
67 ARCCORE_FATAL("Buffer size '{0}' is not a multiple of '{1}' Invalid size",size,align_size);
68 m_final_size = size / align_size;
69 }
70 SpanType* data() { return m_buffer.data(); }
71 Int64 size() const { return m_final_size; }
72 Int64 messageSize() const { return m_buffer.size() * sizeof(Byte); }
73 Int64 elementSize() const { return BasicSerializer::paddingSize(); }
74 MPI_Datatype datatype() const { return m_datatype; }
75 private:
76 Span<SpanType> m_buffer;
77 MPI_Datatype m_datatype;
78 Int64 m_final_size;
79};
80
81/*---------------------------------------------------------------------------*/
82/*---------------------------------------------------------------------------*/
93: public ISubRequest
94{
95 public:
96
99 : m_dispatcher(pm), m_serialize_buffer(buf), m_rank(rank), m_mpi_tag(mpi_tag) {}
100
101 public:
102
104 {
105 if (!m_is_message_sent)
106 sendMessage();
107 return m_send_request;
108 }
109 public:
110 void sendMessage()
111 {
112 if (m_is_message_sent)
113 ARCCORE_FATAL("Message already sent");
114 bool do_print = m_dispatcher->m_is_trace_serializer;
115 if (do_print){
116 ITraceMng* tm = m_dispatcher->traceMng();
117 tm->info() << " SendSerializerSubRequest::sendMessage()"
118 << " rank=" << m_rank << " tag=" << m_mpi_tag;
119 }
120 Span<Byte> bytes = m_serialize_buffer->globalBuffer();
121 m_send_request = m_dispatcher->_sendSerializerBytes(bytes,m_rank,m_mpi_tag,false);
122 m_is_message_sent = true;
123 }
124 private:
125 MpiSerializeDispatcher* m_dispatcher;
126 BasicSerializer* m_serialize_buffer;
127 MessageRank m_rank;
128 MessageTag m_mpi_tag;
129 Request m_send_request;
130 bool m_is_message_sent = false;
131};
132
133/*---------------------------------------------------------------------------*/
134/*---------------------------------------------------------------------------*/
135
137: public ISubRequest
138{
139 public:
140
142 MessageTag mpi_tag, Integer action)
143 : m_dispatcher(d)
144 , m_serialize_buffer(buf)
145 , m_mpi_tag(mpi_tag)
146 , m_action(action)
147 {}
148
149 public:
150
152 {
153 MessageRank rank = completion_info.sourceRank();
154 bool is_trace = m_dispatcher->m_is_trace_serializer;
155 ITraceMng* tm = m_dispatcher->traceMng();
156 if (is_trace) {
157 tm->info() << " ReceiveSerializerSubRequest::executeOnCompletion()"
158 << " rank=" << rank << " wanted_tag=" << m_mpi_tag << " action=" << m_action;
159 }
160 if (m_action==1){
161 BasicSerializer* sbuf = m_serialize_buffer;
162 Int64 total_recv_size = sbuf->totalSize();
163
164 if (is_trace) {
165 tm->info() << " ReceiveSerializerSubRequest::executeOnCompletion() total_size=" << total_recv_size
166 << BasicSerializer::SizesPrinter(*m_serialize_buffer);
167 }
168 // Si le message est plus petit que le buffer, le désérialise simplement
169 if (total_recv_size<=m_dispatcher->m_serialize_buffer_size){
170 sbuf->setFromSizes();
171 return {};
172 }
173
174 sbuf->preallocate(total_recv_size);
175 auto bytes = sbuf->globalBuffer();
176
177 // La nouvelle requête doit utiliser le même rang source que celui de cette requête
178 // pour être certain qu'il n'y a pas d'incohérence.
179 Request r2 = m_dispatcher->_recvSerializerBytes(bytes, rank, m_mpi_tag, false);
180 ISubRequest* sr = new ReceiveSerializerSubRequest(m_dispatcher, m_serialize_buffer, m_mpi_tag, 2);
181 r2.setSubRequest(makeRef(sr));
182 return r2;
183 }
184 if (m_action==2){
185 m_serialize_buffer->setFromSizes();
186 }
187 return {};
188 }
189
190 private:
191
192 MpiSerializeDispatcher* m_dispatcher = nullptr;
193 BasicSerializer* m_serialize_buffer = nullptr;
194 MessageTag m_mpi_tag;
195 Int32 m_action = 0;
196};
197
198/*---------------------------------------------------------------------------*/
199/*---------------------------------------------------------------------------*/
200
201/*---------------------------------------------------------------------------*/
202/*---------------------------------------------------------------------------*/
203
204MpiSerializeDispatcher::
205MpiSerializeDispatcher(MpiAdapter* adapter)
206: m_adapter(adapter)
207, m_trace(adapter->traceMng())
208, m_serialize_buffer_size(50000)
209//, m_serialize_buffer_size(20000000)
210, m_max_serialize_buffer_size(m_serialize_buffer_size)
211, m_byte_serializer_datatype(MPI_DATATYPE_NULL)
212{
213 _init();
214}
215
216/*---------------------------------------------------------------------------*/
217/*---------------------------------------------------------------------------*/
218
219MpiSerializeDispatcher::
220~MpiSerializeDispatcher()
221{
222 if (m_byte_serializer_datatype!=MPI_DATATYPE_NULL)
223 MPI_Type_free(&m_byte_serializer_datatype);
224}
225
226/*---------------------------------------------------------------------------*/
227/*---------------------------------------------------------------------------*/
228
229MessageTag MpiSerializeDispatcher::
230nextSerializeTag(MessageTag tag)
231{
232 return MessageTag(tag.value()+1);
233}
234
235/*---------------------------------------------------------------------------*/
236/*---------------------------------------------------------------------------*/
237
238void MpiSerializeDispatcher::
239_init()
240{
241 // Type pour la sérialisation en octet.
242 MPI_Datatype mpi_datatype;
243 MPI_Type_contiguous(BasicSerializer::paddingSize(),MPI_CHAR,&mpi_datatype);
244 MPI_Type_commit(&mpi_datatype);
245 m_byte_serializer_datatype = mpi_datatype;
246
247 if (!Platform::getEnvironmentVariable("ARCCORE_TRACE_MESSAGE_PASSING_SERIALIZE").empty())
248 m_is_trace_serializer = true;
249}
250
251/*---------------------------------------------------------------------------*/
252/*---------------------------------------------------------------------------*/
253
254Request MpiSerializeDispatcher::
255legacySendSerializer(ISerializer* values,const PointToPointMessageInfo& message)
256{
257 if (!message.isRankTag())
258 ARCCORE_FATAL("Only message.isRangTag()==true are allowed for legacy mode");
259
260 MessageRank rank = message.destinationRank();
261 MessageTag mpi_tag = message.tag();
262 bool is_blocking = message.isBlocking();
263
264 BasicSerializer* sbuf = _castSerializer(values);
265 ITraceMng* tm = m_trace;
266
267 Span<Byte> bytes = sbuf->globalBuffer();
268
269 Int64 total_size = sbuf->totalSize();
270 _checkBigMessage(total_size);
271
272 if (m_is_trace_serializer)
273 tm->info() << "legacySendSerializer(): sending to "
274 << " rank=" << rank << " bytes " << bytes.size()
275 << BasicSerializer::SizesPrinter(*sbuf)
276 << " tag=" << mpi_tag << " is_blocking=" << is_blocking;
277
278 // Si le message est plus petit que le buffer par défaut de sérialisation,
279 // envoie tout le message
280 if (total_size<=m_serialize_buffer_size){
281 if (m_is_trace_serializer)
282 tm->info() << "Small message size=" << bytes.size();
283 return _sendSerializerBytes(bytes,rank,mpi_tag,is_blocking);
284 }
285
286 {
287 // le message est trop grand pour tenir dans le buffer, envoie d'abord les tailles,
288 // puis le message sérialisé.
289 auto x = sbuf->copyAndGetSizesBuffer();
290 if (m_is_trace_serializer)
291 tm->info() << "Big message first size=" << x.size();
292 Request r = _sendSerializerBytes(x,rank,mpi_tag,is_blocking);
293 if (!is_blocking){
294 SerializeSubRequest* sub_request = new SerializeSubRequest();
295 sub_request->m_request = r;
296 //m_trace->info() << "** ADD SUB REQUEST r=" << r;
297 {
298 MpiLock::Section ls(m_adapter->mpiLock());
299 m_sub_requests.add(sub_request);
300 }
301 }
302 }
303
304 if (m_is_trace_serializer)
305 tm->info() << "Big message second size=" << bytes.size();
306 return _sendSerializerBytes(bytes,rank,nextSerializeTag(mpi_tag),is_blocking);
307}
308
309/*---------------------------------------------------------------------------*/
310/*---------------------------------------------------------------------------*/
311
312Request MpiSerializeDispatcher::
313_recvSerializerBytes(Span<Byte> bytes,MessageId message_id,bool is_blocking)
314{
315 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
316 MPI_Datatype dt = sbc.datatype();
317 if (m_is_trace_serializer)
318 m_trace->info() << "_recvSerializerBytes: size=" << bytes.size()
319 << " message_id=" << message_id << " is_blocking=" << is_blocking;
320 return m_adapter->directRecv(sbc.data(),sbc.size(),message_id,sbc.elementSize(),dt,is_blocking);
321}
322
323/*---------------------------------------------------------------------------*/
324/*---------------------------------------------------------------------------*/
325
326Request MpiSerializeDispatcher::
327_recvSerializerBytes(Span<Byte> bytes,MessageRank rank,MessageTag tag,bool is_blocking)
328{
329 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
330 MPI_Datatype dt = sbc.datatype();
331 if (m_is_trace_serializer)
332 m_trace->info() << "_recvSerializerBytes: size=" << bytes.size()
333 << " rank=" << rank << " tag=" << tag << " is_blocking=" << is_blocking;
334 Request r = m_adapter->directRecv(sbc.data(),sbc.size(),rank.value(),
335 sbc.elementSize(),dt,tag.value(),is_blocking);
336 if (m_is_trace_serializer)
337 m_trace->info() << "_recvSerializerBytes: request=" << r;
338 return r;
339}
340
341/*---------------------------------------------------------------------------*/
342/*---------------------------------------------------------------------------*/
343
344Request MpiSerializeDispatcher::
345_sendSerializerBytes(Span<const Byte> bytes,MessageRank rank,MessageTag tag,
346 bool is_blocking)
347{
348 SerializeByteConverter<const Byte> sbc(bytes,m_byte_serializer_datatype);
349 MPI_Datatype dt = sbc.datatype();
350 if (m_is_trace_serializer)
351 m_trace->info() << "_sendSerializerBytes: orig_size=" << bytes.size()
352 << " rank=" << rank << " tag=" << tag
353 << " second_size=" << sbc.size()
354 << " message_size=" << sbc.messageSize();
355 Request r = m_adapter->directSend(sbc.data(),sbc.size(),rank.value(),
356 sbc.elementSize(),dt,tag.value(),is_blocking);
357 if (m_is_trace_serializer)
358 m_trace->info() << "_sendSerializerBytes: request=" << r;
359 return r;
360}
361
362/*---------------------------------------------------------------------------*/
363/*---------------------------------------------------------------------------*/
364
365void MpiSerializeDispatcher::
366legacyReceiveSerializer(ISerializer* values,MessageRank rank,MessageTag mpi_tag)
367{
368 BasicSerializer* sbuf = _castSerializer(values);
369 ITraceMng* tm = m_trace;
370
371 if (m_is_trace_serializer)
372 tm->info() << "legacyReceiveSerializer() begin receive"
373 << " rank=" << rank << " tag=" << mpi_tag;
374 sbuf->preallocate(m_serialize_buffer_size);
375 Span<Byte> bytes = sbuf->globalBuffer();
376
377 _recvSerializerBytes(bytes,rank,mpi_tag,true);
378 Int64 total_recv_size = sbuf->totalSize();
379
380 if (m_is_trace_serializer)
381 tm->info() << "legacyReceiveSerializer total_size=" << total_recv_size
382 << " from=" << rank
383 << BasicSerializer::SizesPrinter(*sbuf);
384
385
386 // Si le message est plus petit que le buffer, le désérialise simplement
387 if (total_recv_size<=m_serialize_buffer_size){
388 sbuf->setFromSizes();
389 return;
390 }
391
392 if (m_is_trace_serializer)
393 tm->info() << "Receive overflow buffer: " << total_recv_size;
394 sbuf->preallocate(total_recv_size);
395 bytes = sbuf->globalBuffer();
396 _recvSerializerBytes(bytes,rank,nextSerializeTag(mpi_tag),true);
397 sbuf->setFromSizes();
398 if (m_is_trace_serializer)
399 tm->info() << "End receive overflow buffer: " << total_recv_size;
400}
401
402/*---------------------------------------------------------------------------*/
403/*---------------------------------------------------------------------------*/
404
405void MpiSerializeDispatcher::
406checkFinishedSubRequests()
407{
408 // Regarde si les sous-requêtes sont terminées pour les libérer
409 // Cela est uniquement utilisé avec le mode historique où on utilise
410 // la classe 'MpiSerializeMessageList'.
411 UniqueArray<SerializeSubRequest*> new_sub_requests;
412 for( Integer i=0, n=m_sub_requests.size(); i<n; ++i ){
413 SerializeSubRequest* ssr = m_sub_requests[i];
414 bool is_finished = m_adapter->testRequest(ssr->m_request);
415 if (!is_finished){
416 new_sub_requests.add(ssr);
417 }
418 else{
419 delete ssr;
420 }
421 }
422 m_sub_requests = new_sub_requests;
423}
424
425/*---------------------------------------------------------------------------*/
426/*---------------------------------------------------------------------------*/
427
428void MpiSerializeDispatcher::
429_checkBigMessage(Int64 message_size)
430{
431 if (message_size>m_max_serialize_buffer_size){
432 m_max_serialize_buffer_size = message_size;
433 m_trace->info() << "big buffer: " << message_size;
434 }
435}
436
437/*---------------------------------------------------------------------------*/
438/*---------------------------------------------------------------------------*/
439
442{
443 return sendSerializer(s,message,false);
444}
445
446/*---------------------------------------------------------------------------*/
447/*---------------------------------------------------------------------------*/
448
452{
453 BasicSerializer* sbuf = _castSerializer(const_cast<ISerializer*>(s));
454
455 MessageRank rank = message.destinationRank();
456 MessageTag mpi_tag = message.tag();
457 bool is_blocking = message.isBlocking();
458
459 ITraceMng* tm = m_trace;
460
461 Span<const Byte> bytes = sbuf->globalBuffer();
462 Int64 total_size = sbuf->totalSize();
463 _checkBigMessage(total_size);
464
465 if (m_is_trace_serializer)
466 tm->info() << "sendSerializer(): sending to "
467 << " p2p_message=" << message
468 << " rank=" << rank << " bytes " << bytes.size()
470 << " tag=" << mpi_tag
471 << " total_size=" << total_size;
472
473
474 // Si le message est plus petit que le buffer par défaut de sérialisation
475 // ou qu'on choisit de n'envoyer qu'un seul message, envoie tout le message
476 if (total_size<=m_serialize_buffer_size || force_one_message){
477 if (m_is_trace_serializer)
478 tm->info() << "Small message size=" << bytes.size();
479 return _sendSerializerBytes(bytes,rank,mpi_tag,is_blocking);
480 }
481
482 // Sinon, envoie d'abord les tailles puis une autre requête qui
483 // va envoyer tout le message.
484 auto x = sbuf->copyAndGetSizesBuffer();
485 Request r1 = _sendSerializerBytes(x,rank,mpi_tag,is_blocking);
486 auto* x2 = new SendSerializerSubRequest(this,sbuf,rank,nextSerializeTag(mpi_tag));
487 // Envoi directement le message pour des raisons de performance.
488 x2->sendMessage();
489 r1.setSubRequest(makeRef<ISubRequest>(x2));
490 return r1;
491}
492
493/*---------------------------------------------------------------------------*/
494/*---------------------------------------------------------------------------*/
495
498{
499 BasicSerializer* sbuf = _castSerializer(s);
500 MessageRank rank = message.destinationRank();
501 MessageTag tag = message.tag();
502 bool is_blocking = message.isBlocking();
503
504 sbuf->preallocate(m_serialize_buffer_size);
505 Span<Byte> bytes = sbuf->globalBuffer();
506
507 Request r;
508 if (message.isRankTag())
509 r = _recvSerializerBytes(bytes,rank,tag,is_blocking);
510 else if (message.isMessageId())
511 r = _recvSerializerBytes(bytes,message.messageId(),is_blocking);
512 else
513 ARCCORE_THROW(NotSupportedException,"Only message.isRankTag() or message.isMessageId() is supported");
514 auto* sr = new ReceiveSerializerSubRequest(this, sbuf, nextSerializeTag(tag), 1);
515 r.setSubRequest(makeRef<ISubRequest>(sr));
516 return r;
517}
518
519/*---------------------------------------------------------------------------*/
520/*---------------------------------------------------------------------------*/
521
522void MpiSerializeDispatcher::
523broadcastSerializer(ISerializer* values,MessageRank rank)
524{
525 BasicSerializer* sbuf = _castSerializer(values);
526 ITraceMng* tm = m_trace;
527 MessageRank my_rank(m_adapter->commRank());
528 bool is_broadcaster = (rank==my_rank);
529
530 MPI_Datatype int64_datatype = MpiBuiltIn::datatype(Int64());
531 // Effectue l'envoie en deux phases. Envoie d'abord le nombre d'éléments
532 // puis envoie les éléments.
533 // TODO: il serait possible de le faire en une fois pour les messages
534 // ne dépassant pas une certaine taille.
535 if (is_broadcaster){
536 Int64 total_size = sbuf->totalSize();
537 Span<Byte> bytes = sbuf->globalBuffer();
538 _checkBigMessage(total_size);
540 m_adapter->broadcast(total_size_buf.data(),total_size_buf.size(),rank.value(),int64_datatype);
541 if (m_is_trace_serializer)
542 tm->info() << "MpiSerializeDispatcher::broadcastSerializer(): sending "
544 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
545 m_adapter->broadcast(sbc.data(),sbc.size(),rank.value(),sbc.datatype());
546 }
547 else{
548 Int64 total_size = 0;
549 Int64ArrayView total_size_buf(1,&total_size);
550 m_adapter->broadcast(total_size_buf.data(),total_size_buf.size(),rank.value(),int64_datatype);
551 sbuf->preallocate(total_size);
552 Span<Byte> bytes = sbuf->globalBuffer();
553 SerializeByteConverter<Byte> sbc(bytes,m_byte_serializer_datatype);
554 m_adapter->broadcast(sbc.data(),sbc.size(),rank.value(),sbc.datatype());
555 sbuf->setFromSizes();
556 if (m_is_trace_serializer)
557 tm->info() << "MpiSerializeDispatcher::broadcastSerializer(): receiving from "
558 << " rank=" << rank << " bytes " << bytes.size()
560 }
561}
562
563/*---------------------------------------------------------------------------*/
564/*---------------------------------------------------------------------------*/
565
566BasicSerializer* MpiSerializeDispatcher::
567_castSerializer(ISerializer* serializer)
568{
569 BasicSerializer* sbuf = dynamic_cast<BasicSerializer*>(serializer);
570 if (!sbuf)
571 ARCCORE_THROW(ArgumentException,"Can not cast 'ISerializer' to 'BasicSerializer'");
572 return sbuf;
573}
574
575/*---------------------------------------------------------------------------*/
576/*---------------------------------------------------------------------------*/
577
584
585/*---------------------------------------------------------------------------*/
586/*---------------------------------------------------------------------------*/
587
588} // namespace Arccore::MessagePassing::Mpi
589
590/*---------------------------------------------------------------------------*/
591/*---------------------------------------------------------------------------*/
Lecteur des fichiers de maillage via la bibliothèque LIMA.
Definition Lima.cc:120
Vue modifiable d'un tableau d'un type T.
Implémentation basique de 'ISerializer'.
static ARCCORE_CONSTEXPR Integer paddingSize()
Taille du padding et de l'alignement.
Interface du gestionnaire de traces.
virtual TraceMessage info()=0
Flot pour un message d'information.
Sous-requête d'une requête.
Definition Request.h:35
Int32 value() const
Valeur du rang.
Definition MessageRank.h:72
Request executeOnCompletion(const SubRequestCompletionInfo &completion_info) override
Callback appelé lorsque la requête associée est terminée.
Request executeOnCompletion(const SubRequestCompletionInfo &) override
Callback appelé lorsque la requête associée est terminée.
Request receiveSerializer(ISerializer *s, const PointToPointMessageInfo &message) override
Message de réception.
Request sendSerializer(const ISerializer *s, const PointToPointMessageInfo &message) override
Message d'envoi.
Ref< ISerializeMessageList > createSerializeMessageListRef() override
Créé une liste de messages de sérialisation.
Wrappeur pour envoyer un tableau d'octets d'un sérialiseur.
Informations pour envoyer/recevoir un message point à point.
MessageId messageId() const
Identifiant du message.
MessageRank destinationRank() const
Rang de la destination du message.
bool isBlocking() const
Indique si le message est bloquant.
bool isMessageId() const
Vrai si l'instance a été créée avec un MessageId. Dans ce cas messageId() est valide.
bool isRankTag() const
Vrai si l'instance a été créée avec un couple (rank,tag). Dans ce cas rank() et tag() sont valides.
Requête d'un message.
Definition Request.h:77
Informations de complètion d'une sous-requête.
Exception lorsqu'une opération n'est pas supportée.
constexpr ARCCORE_HOST_DEVICE pointer data() const noexcept
Pointeur sur le début de la vue.
Definition Span.h:419
constexpr ARCCORE_HOST_DEVICE SizeType size() const noexcept
Retourne la taille du tableau.
Definition Span.h:209
Vue d'un tableau d'éléments de type T.
Definition Span.h:510
TraceMessage info() const
Flot pour un message d'information.
unsigned char Byte
Type d'un octet.
Definition UtilsTypes.h:142
ARCCORE_BASE_EXPORT String getEnvironmentVariable(const String &name)
Variable d'environnement du nom name.
Int32 Integer
Type représentant un entier.
auto makeRef(InstanceType *t) -> Ref< InstanceType >
Créé une référence sur un pointeur.
std::int64_t Int64
Type entier signé sur 64 bits.