110 using ThisType = MatrixInternal<ValueT,EllPackSize>;
112 static const int ellpack_size = EllPackSize ;
114 using ValueType = ValueT;
115 using value_type = ValueT;
118 using InternalProfileType =
typename ProfileType::InternalType;
119 using IndexType =
typename InternalProfileType::IndexType;
120 using IndexBufferType =
typename InternalProfileType::IndexBufferType;
121 using IndexBufferPtrType = std::unique_ptr<IndexBufferType>;
123 using value_buffer_type = sycl::buffer<value_type, 1>;
124 using ValueBufferType = sycl::buffer<value_type, 1>;
125 using ValueBufferPtrType = std::unique_ptr<ValueBufferType>;
127 using QueueType = sycl::queue;
133 static constexpr int NxN = N*N ;
134 inline std::size_t ijk(std::size_t k,
int i,
int j)
const
136 return (k*NxN + i*N + j)*ellpack_size;
139 inline std::size_t ij(std::size_t local_id,
int i,
int j)
const
141 return local_id*NxN+ i*N + j;
147 static const int ellpack_size = EllPackSize ;
156 inline std::size_t _ijk(std::size_t k,
int i,
int j)
const
158 return (k*m_NxN + i*m_N + j)*ellpack_size;
161 inline std::size_t _ij(std::size_t local_id,
int i,
int j)
const
163 return local_id*m_NxN+ i*m_N + j;
166 template<
typename MatrixValueAccessorT,
167 typename MatrixColAccessorT,
168 typename VectorAccessorT>
169 ValueType mult(
int ieq,
170 std::size_t local_id,
172 MatrixColAccessorT& cols,
173 MatrixValueAccessorT& matrix,
174 VectorAccessorT& x)
const
176 ValueType value = 0. ;
177 auto x_offset = cols[k*ellpack_size+local_id]*m_N ;
180 for(
int j=0;j<m_N;++j)
182 auto mat_offset = _ijk(k,ieq,j)+local_id ;
183 value += matrix[mat_offset]*x[x_offset+j] ;
190 template<
typename MatrixValueAccessorT,
191 typename MatrixColAccessorT,
192 typename MaskAccessorT,
193 typename VectorAccessorT>
194 ValueType mult(
int ieq,
195 std::size_t local_id,
197 MatrixColAccessorT& cols,
199 MatrixValueAccessorT& matrix,
200 VectorAccessorT& x)
const
202 ValueType value = 0. ;
203 auto x_offset = cols[k*ellpack_size+local_id]*m_N ;
204 auto ma = mask[k*ellpack_size+local_id] ;
205 if(x_offset>=0 && ma==1)
207 for(
int j=0;j<m_N;++j)
209 auto mat_offset = _ijk(k,ieq,j)+local_id ;
210 value += matrix[mat_offset]*x[x_offset+j] ;
218 template<
typename MatrixAccT,
223 static const int ellpack_size = EllPackSize ;
228 LU(
int N, MatrixAccT& matrix)
234 inline std::size_t _ijk(std::size_t k,
int i,
int j)
const
236 return (k*m_NxN + i*m_N + j)*ellpack_size;
239 inline std::size_t _ij(std::size_t local_id,
int i,
int j)
const
241 return local_id*m_NxN+ i*m_N + j;
244 void factorize(std::size_t global_id,
245 std::size_t local_id,
246 std::size_t block_id,
251 for(
int i=0;i<m_N;++i)
252 for(
int j=0;j<m_N;++j)
253 m_LU[_ijk(block_id,i,j)+local_id] = m_matrix[_ijk(kcol,i,j)+local_id] ;
256 for (
int k = 0; k < m_N; ++k)
259 m_LU[_ijk(block_id,k,k)+local_id] = 1 / m_LU[_ijk(block_id,k,k)+local_id];
260 for (
int i = k + 1; i < m_N; ++i) {
261 m_LU[_ijk(block_id,i,k)+local_id] *= m_LU[_ijk(block_id,k,k)+local_id];
263 for (
int i = k + 1; i < m_N; ++i) {
264 for (
int j = k + 1; j < m_N; ++j) {
265 m_LU[_ijk(block_id,i,j)+local_id] -= m_LU[_ijk(block_id,i,k)+local_id] * m_LU[_ijk(block_id,k,j)+local_id];
271 void inverse(std::size_t global_id,
272 std::size_t local_id,
273 std::size_t block_id,
275 VectorAccT m_y)
const
278 for(
int i=0;i<m_N;++i)
279 for(
int j=0;j<m_N;++j)
280 m_y[_ij(global_id,i,j)] = 0. ;
281 for(
int i=0;i<m_N;++i)
282 m_y[_ij(global_id,i,i)] = 1. ;
285 for (
int i = 1; i < m_N; ++i)
287 for (
int j = 0; j < i; ++j)
289 for(
int k=0;k<m_N;++k)
290 m_y[_ij(global_id,i,k)] -= m_LU[_ijk(block_id,i,j)+local_id] * m_y[_ij(global_id,j,k)];
295 for (
int i = m_N - 1; i >= 0; --i)
297 for (
int j = m_N - 1; j > i; --j)
299 for(
int k=0;k<m_N;++k)
300 m_y[_ij(global_id,i,k)] -= m_LU[_ijk(block_id,i,j)+local_id] * m_y[_ij(global_id,j,k)];
302 for(
int k=0;k<m_N;++k)
303 m_y[_ij(global_id,i,k)] *= m_LU[_ijk(block_id,i,i)+local_id];
309 MatrixInternal(ProfileType
const* profile,
int blk_size=1);
313 bool setMatrixValues(ValueType
const* values,
bool only_host);
314 bool setMatrixValuesFromHost();
316 bool setMatrixValues(ValueBufferType& values);
317 bool setMatrixValues(ValueBufferType& values,
318 ValueBufferType& ext_values);
320 bool copy(std::size_t nb_blocks,
322 ValueBufferType& rhs_values,
323 Integer rhs_block_size);
325 bool copy(std::size_t nb_blocks,
327 ValueBufferType& rhs_values,
328 ValueBufferType& rhs_ext_values,
329 Integer rhs_block_size);
332 void notifyChanges();
336 void multN(ValueBufferType& x, ValueBufferType& y, QueueType& queue)
const
338 auto device = queue.get_device();
340 auto num_groups = queue.get_device().get_info<sycl::info::device::max_compute_units>();
342 auto max_work_group_size = queue.get_device().get_info<sycl::info::device::max_work_group_size>();
346 std::size_t pack_size = ellpack_size;
347 auto nrows = m_profile->getNRows();
348 auto nnz = m_profile->getNnz();
350 auto internal_profile = m_profile->internal();
351 auto& kcol = internal_profile->getKCol();
352 auto& block_row_offset = internal_profile->getBlockRowOffset();
353 auto& block_cols = internal_profile->getBlockCols();
355 auto blocks_needed = (nrows + ellpack_size - 1) / ellpack_size;
356 auto blocks_target = std::max(blocks_needed, num_groups * 4UL);
357 auto total_threads = blocks_target * pack_size;
360 [&](sycl::handler& cgh)
362 auto access_block_row_offset = block_row_offset.template get_access<sycl::access::mode::read>(cgh);
363 auto access_cols = block_cols.template get_access<sycl::access::mode::read>(cgh);
364 auto access_values = m_values.template get_access<sycl::access::mode::read>(cgh);
366 auto access_x = x.template get_access<sycl::access::mode::read>(cgh);
367 auto access_y = y.template get_access<sycl::access::mode::discard_write>(cgh);
369 auto tile = TileT<N>() ;
371 sycl::local_accessor<ValueType, 1> lds_x{pack_size*N, cgh};
372 sycl::nd_range<1> r{sycl::range<1>{total_threads},sycl::range<1>{pack_size}};
373 cgh.parallel_for<
class compute_mult>(r,
374 [=](sycl::nd_item<1> item_id)
376 auto local_id = item_id.get_local_id(0);
377 auto global_id = item_id.get_global_id(0);
379 for (
auto i = global_id; i < nrows; i += item_id.get_global_range()[0])
381 auto block_id = i/pack_size ;
383 int begin = access_block_row_offset[block_id] ;
384 int end = access_block_row_offset[block_id+1] ;
387 for(
int ieq=0;ieq<N;++ieq)
389 ValueType value = 0. ;
390 for(
int k=begin;k<end;++k)
393 const int col = access_cols[k * pack_size + local_id];
395 for(
int ju=0;ju<N;++ju)
396 lds_x[N*local_id+ju] = access_x[col*N+ju];
397 item_id.barrier(sycl::access::fence_space::local_space);
401 for(
int ju=0;ju<N;++ju)
402 value += access_values[tile.ijk(k,ieq,ju) + local_id] * lds_x[local_id*N+ju] ;
404 item_id.barrier(sycl::access::fence_space::local_space);
406 access_y[i*N+ieq] = value ;
412 void mult(ValueBufferType& x, ValueBufferType& y)
const;
413 void mult(ValueBufferType& x, ValueBufferType& y, QueueType& queue)
const;
415 void addExtMult(ValueBufferType& x, ValueBufferType& y)
const;
416 void addExtMult(ValueBufferType& x, ValueBufferType& y, QueueType& queue)
const;
419 void addLMultN(ValueType alpha, ValueBufferType& x, ValueBufferType& y, QueueType& queue)
const
421 auto device = queue.get_device();
423 auto num_groups = queue.get_device().get_info<sycl::info::device::max_compute_units>();
425 auto max_work_group_size = queue.get_device().get_info<sycl::info::device::max_work_group_size>();
428 std::size_t pack_size = ellpack_size;
429 auto nrows = m_profile->getNRows();
430 auto nnz = m_profile->getNnz();
432 auto internal_profile = m_profile->internal();
433 auto& kcol = internal_profile->getKCol();
434 auto& block_row_offset = internal_profile->getBlockRowOffset();
435 auto& block_cols = internal_profile->getBlockCols();
437 auto& mask = internal_profile->getLowerMask();
441 [&](sycl::handler& cgh)
443 auto access_block_row_offset = block_row_offset.template get_access<sycl::access::mode::read>(cgh);
444 auto access_cols = block_cols.template get_access<sycl::access::mode::read>(cgh);
445 auto access_mask = mask.template get_access<sycl::access::mode::read>(cgh);
446 auto access_values = m_values.template get_access<sycl::access::mode::read>(cgh);
449 auto access_x = x.template get_access<sycl::access::mode::read>(cgh);
450 auto access_y = y.template get_access<sycl::access::mode::read_write>(cgh);
452 auto blocks_needed = (nrows + ellpack_size - 1) / ellpack_size;
453 auto blocks_target = std::max(blocks_needed, num_groups * 4UL);
454 auto total_threads = blocks_target * ellpack_size;
457 sycl::local_accessor<ValueType, 1> lds_x{pack_size*N, cgh};
458 sycl::nd_range<1> r{sycl::range<1>{total_threads},sycl::range<1>{pack_size}};
459 cgh.parallel_for<
class compute_lmultn>(r,
460 [=](sycl::nd_item<1> item_id)
462 auto local_id = item_id.get_local_id(0);
463 auto global_id = item_id.get_global_id(0);
466 for (
auto i = global_id; i < nrows; i += item_id.get_global_range()[0])
468 auto block_id = i/pack_size ;
470 int begin = access_block_row_offset[block_id] ;
471 int end = access_block_row_offset[block_id+1] ;
473 for(
int ieq=0;ieq<N;++ieq)
475 ValueType value = 0. ;
476 for(
int k=begin;k<end;++k)
479 const int col = access_cols[k * pack_size + local_id];
482 for(
int ju=0;ju<N;++ju)
483 lds_x[N*local_id+ju] = access_x[N*col+ju];
484 item_id.barrier(sycl::access::fence_space::local_space);
485 if(access_mask[k * pack_size + local_id])
487 for(
int ju=0;ju<N;++ju)
488 value += access_values[tile.ijk(k,ieq,ju) + local_id] * lds_x[N*local_id+ju] ;
489 item_id.barrier(sycl::access::fence_space::local_space);
491 access_y[i*N+ieq] += alpha*value ;
499 void addUMultN(ValueType alpha, ValueBufferType& x, ValueBufferType& y, QueueType& queue)
const
501 auto device = queue.get_device();
503 auto num_groups = queue.get_device().get_info<sycl::info::device::max_compute_units>();
505 auto max_work_group_size = queue.get_device().get_info<sycl::info::device::max_work_group_size>();
510 std::size_t pack_size = ellpack_size;
511 auto nrows = m_profile->getNRows() ;
512 auto nnz = m_profile->getNnz() ;
514 auto blocks_needed = (nrows + ellpack_size - 1) / ellpack_size;
515 auto blocks_target = std::max(blocks_needed, num_groups * 4UL);
516 auto total_threads = blocks_target * ellpack_size;
518 auto internal_profile = m_profile->internal() ;
519 auto& kcol = internal_profile->getKCol() ;
520 auto& block_row_offset = internal_profile->getBlockRowOffset() ;
521 auto& block_cols = internal_profile->getBlockCols() ;
522 auto& mask = internal_profile->getUpperMask() ;
525 [&](sycl::handler& cgh)
527 auto access_block_row_offset = block_row_offset.template get_access<sycl::access::mode::read>(cgh);
528 auto access_cols = block_cols.template get_access<sycl::access::mode::read>(cgh);
529 auto access_mask = mask.template get_access<sycl::access::mode::read>(cgh);
530 auto access_values = m_values.template get_access<sycl::access::mode::read>(cgh);
533 auto access_x = x.template get_access<sycl::access::mode::read>(cgh);
534 auto access_y = y.template get_access<sycl::access::mode::read_write>(cgh);
537 sycl::local_accessor<ValueType, 1> lds_x{pack_size*N, cgh};
538 sycl::nd_range<1> r{sycl::range<1>{total_threads},sycl::range<1>{pack_size}};
539 cgh.parallel_for<
class compute_umultn>(r,
540 [=](sycl::nd_item<1> item_id)
542 auto local_id = item_id.get_local_id(0);
543 auto global_id = item_id.get_global_id(0);
544 for (
auto i = global_id; i < nrows; i += item_id.get_global_range()[0])
546 auto block_id = i/pack_size ;
548 auto begin = access_block_row_offset[block_id] ;
549 auto end = access_block_row_offset[block_id+1] ;
551 for(
int ieq=0;ieq<N;++ieq)
553 ValueType value = 0. ;
554 for(
int k=begin;k<end;++k)
556 const int col = access_cols[k * pack_size + local_id];
559 for(
int ju=0;ju<N;++ju)
560 lds_x[local_id*N+ju] = access_x[col*N+ju];
561 item_id.barrier(sycl::access::fence_space::local_space);
562 if(access_mask[k * pack_size + local_id])
564 for(
int ju=0;ju<N;++ju)
565 value += access_values[tile.ijk(k,ieq,ju) + local_id] * lds_x[local_id*N+ju] ;
566 item_id.barrier(sycl::access::fence_space::local_space);
568 access_y[i*N+ieq] += alpha*value ;
575 void addLMult(ValueType alpha, ValueBufferType& x, ValueBufferType& y)
const;
576 void addUMult(ValueType alpha, ValueBufferType& x, ValueBufferType& y)
const;
578 void addLMult(ValueType alpha, ValueBufferType& x, ValueBufferType& y, QueueType& queue)
const;
579 void addUMult(ValueType alpha, ValueBufferType& x, ValueBufferType& y, QueueType& queue)
const;
581 void multDiag(ValueBufferType& x, ValueBufferType& y)
const;
582 void multDiag(ValueBufferType& x, ValueBufferType& y, QueueType& queue)
const;
584 void multDiag(ValueBufferType& y)
const;
585 void multDiag(ValueBufferType& y, QueueType& queue)
const;
587 void computeDiag(ValueBufferType& y)
const;
588 void computeDiag(ValueBufferType& y, QueueType& queue)
const;
590 void computeBlockDiag(ValueBufferType& y)
const;
591 void computeBlockDiag(ValueBufferType& y, QueueType& queue)
const;
593 void multInvDiag(ValueBufferType& y)
const;
594 void multInvDiag(ValueBufferType& y, QueueType& queue)
const;
596 void computeInvDiag(ValueBufferType& y)
const;
597 void computeInvDiag(ValueBufferType& y, QueueType& queue)
const;
599 void computeInvBlockDiag(ValueBufferType& y)
const;
600 void computeInvBlockDiag(ValueBufferType& y, QueueType& queue)
const;
602 void scal(ValueBufferType& y);
604 void scal(ValueBufferType& y, QueueType& queue);
606 void copyDevicePointers(
int local_offset,
612 ValueT* values)
const ;
614 ValueBufferType& getValues() {
return m_values; }
616 ValueBufferType
const getValues()
const {
return m_values; }
620 ProfileType
const* getProfile()
const {
return m_profile; }
622 ValueType
const* getHCsrData()
const
624 return m_h_csr_values.data();
627 ValueType* getHCsrData()
629 return m_h_csr_values.data();
632 IndexBufferType& getSendIds()
const
636 IndexBufferType& getRecvIds()
const
644 ProfileType
const* m_profile =
nullptr;
645 ProfileType
const* m_ext_profile =
nullptr;
647 std::vector<ValueType> m_h_csr_values ;
648 std::vector<ValueType> m_h_values ;
649 mutable ValueBufferType m_values ;
651 std::vector<ValueType> m_h_csr_ext_values ;
652 std::vector<ValueType> m_h_ext_values ;
653 mutable ValueBufferPtrType m_ext_values ;
654 bool m_values_is_update = false ;
656 int const* m_h_interface_row_ids =
nullptr;
657 mutable IndexBufferPtrType m_interface_row_ids ;
658 mutable IndexBufferPtrType m_send_ids ;
659 mutable IndexBufferPtrType m_recv_ids ;
660 mutable IndexBufferPtrType m_recv_uids ;