50struct ilu0<backend::cuda<real>>
52 typedef real value_type;
63 : ARCCORE_ALINA_PARAMS_IMPORT_VALUE(p,
damping)
65 p.check_params({
"damping" });
70 ARCCORE_ALINA_PARAMS_EXPORT_VALUE(p, path, damping);
74 template <
class Matrix>
75 ilu0(
const Matrix& A,
const params& prm,
const typename Backend::params& bprm)
77 , handle(bprm.cusparse_handle)
78 , n(backend::nbRow(A))
79 , nnz(backend::nonzeros(A))
80 , ptr(A.ptr, A.ptr + n + 1)
81 , col(A.col, A.col + nnz)
82 , val(A.val, A.val + nnz)
86 std::shared_ptr<std::remove_pointer<cusparseMatDescr_t>::type> descr_M;
87 std::shared_ptr<std::remove_pointer<csrilu02Info_t>::type> info_M;
90 cusparseMatDescr_t descr;
93 ARCCORE_ALINA_CALL_CUDA(cusparseCreateMatDescr(&descr));
94 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
95 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
97 ARCCORE_ALINA_CALL_CUDA(cusparseCreateCsrilu02Info(&info));
104 ARCCORE_ALINA_CALL_CUDA(
105 cusparseXcsrilu02_bufferSize(handle, n, nnz, descr_M.get(),
106 thrust::raw_pointer_cast(&val[0]),
107 thrust::raw_pointer_cast(&ptr[0]),
108 thrust::raw_pointer_cast(&col[0]),
109 info_M.get(), &buf_size));
111 thrust::device_vector<char> bufLU(buf_size);
117 ARCCORE_ALINA_CALL_CUDA(
118 cusparseXcsrilu02_analysis(handle,
122 thrust::raw_pointer_cast(&val[0]),
123 thrust::raw_pointer_cast(&ptr[0]),
124 thrust::raw_pointer_cast(&col[0]),
126 CUSPARSE_SOLVE_POLICY_USE_LEVEL,
127 thrust::raw_pointer_cast(&bufLU[0])));
130 CUSPARSE_STATUS_ZERO_PIVOT != cusparseXcsrilu02_zeroPivot(handle, info_M.get(), &structural_zero),
131 "Zero pivot in cuSPARSE ILU0");
133 ARCCORE_ALINA_CALL_CUDA(
134 cusparseXcsrilu02(handle,
138 thrust::raw_pointer_cast(&val[0]),
139 thrust::raw_pointer_cast(&ptr[0]),
140 thrust::raw_pointer_cast(&col[0]),
142 CUSPARSE_SOLVE_POLICY_USE_LEVEL,
143 thrust::raw_pointer_cast(&bufLU[0])));
145 CUSPARSE_STATUS_ZERO_PIVOT != cusparseXcsrilu02_zeroPivot(handle, info_M.get(), &numerical_zero),
146 "Zero pivot in cuSPARSE ILU0");
150#if CUDART_VERSION >= 11000
151 const real alpha = 1;
152 thrust::device_vector<value_type> t(n);
155 backend::detail::cuda_vector_description(y),
156 backend::detail::cuda_deleter());
158 std::shared_ptr<std::remove_pointer<cusparseDnVecDescr_t>::type> descr_t(
159 backend::detail::cuda_vector_description(t),
160 backend::detail::cuda_deleter());
162 cusparseFillMode_t fill_lower = CUSPARSE_FILL_MODE_LOWER;
163 cusparseFillMode_t fill_upper = CUSPARSE_FILL_MODE_UPPER;
164 cusparseDiagType_t diag_unit = CUSPARSE_DIAG_TYPE_UNIT;
165 cusparseDiagType_t diag_non_unit = CUSPARSE_DIAG_TYPE_NON_UNIT;
170 backend::detail::cuda_matrix_description(n, n, nnz, ptr, col, val),
171 backend::detail::cuda_deleter());
173 ARCCORE_ALINA_CALL_CUDA(
174 cusparseSpMatSetAttribute(descr_L.get(),
175 CUSPARSE_SPMAT_FILL_MODE,
177 sizeof(fill_lower)));
179 ARCCORE_ALINA_CALL_CUDA(
180 cusparseSpMatSetAttribute(descr_L.get(),
181 CUSPARSE_SPMAT_DIAG_TYPE,
187 cusparseSpSVDescr_t desc;
188 ARCCORE_ALINA_CALL_CUDA(cusparseSpSV_createDescr(&desc));
189 descr_SL.reset(desc, backend::detail::cuda_deleter());
191 ARCCORE_ALINA_CALL_CUDA(
192 cusparseSpSV_bufferSize(handle,
193 CUSPARSE_OPERATION_NON_TRANSPOSE,
198 backend::detail::cuda_datatype<real>(),
199 CUSPARSE_SPSV_ALG_DEFAULT,
203 bufL.resize(buf_size);
205 ARCCORE_ALINA_CALL_CUDA(
206 cusparseSpSV_analysis(handle,
207 CUSPARSE_OPERATION_NON_TRANSPOSE,
212 backend::detail::cuda_datatype<real>(),
213 CUSPARSE_SPSV_ALG_DEFAULT,
215 thrust::raw_pointer_cast(&bufL[0])));
221 backend::detail::cuda_matrix_description(n, n, nnz, ptr, col, val),
222 backend::detail::cuda_deleter());
224 ARCCORE_ALINA_CALL_CUDA(
225 cusparseSpMatSetAttribute(descr_U.get(),
226 CUSPARSE_SPMAT_FILL_MODE,
228 sizeof(fill_upper)));
230 ARCCORE_ALINA_CALL_CUDA(
231 cusparseSpMatSetAttribute(descr_U.get(),
232 CUSPARSE_SPMAT_DIAG_TYPE,
234 sizeof(diag_non_unit)));
238 cusparseSpSVDescr_t desc;
239 ARCCORE_ALINA_CALL_CUDA(cusparseSpSV_createDescr(&desc));
240 descr_SU.reset(desc, backend::detail::cuda_deleter());
242 ARCCORE_ALINA_CALL_CUDA(
243 cusparseSpSV_bufferSize(handle,
244 CUSPARSE_OPERATION_NON_TRANSPOSE,
249 backend::detail::cuda_datatype<real>(),
250 CUSPARSE_SPSV_ALG_DEFAULT,
254 bufU.resize(buf_size);
256 ARCCORE_ALINA_CALL_CUDA(
257 cusparseSpSV_analysis(handle,
258 CUSPARSE_OPERATION_NON_TRANSPOSE,
263 backend::detail::cuda_datatype<real>(),
264 CUSPARSE_SPSV_ALG_DEFAULT,
266 thrust::raw_pointer_cast(&bufU[0])));
270 cusparseMatDescr_t descr;
272 ARCCORE_ALINA_CALL_CUDA(cusparseCreateMatDescr(&descr));
273 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
274 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
275 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatFillMode(descr, CUSPARSE_FILL_MODE_LOWER));
276 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatDiagType(descr, CUSPARSE_DIAG_TYPE_UNIT));
278 descr_L.reset(descr, backend::detail::cuda_deleter());
281 cusparseMatDescr_t descr;
283 ARCCORE_ALINA_CALL_CUDA(cusparseCreateMatDescr(&descr));
284 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
285 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
286 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatFillMode(descr, CUSPARSE_FILL_MODE_UPPER));
287 ARCCORE_ALINA_CALL_CUDA(cusparseSetMatDiagType(descr, CUSPARSE_DIAG_TYPE_NON_UNIT));
289 descr_U.reset(descr, backend::detail::cuda_deleter());
295 ARCCORE_ALINA_CALL_CUDA(cusparseCreateCsrsv2Info(&info));
296 info_L.reset(info, backend::detail::cuda_deleter());
300 ARCCORE_ALINA_CALL_CUDA(cusparseCreateCsrsv2Info(&info));
301 info_U.reset(info, backend::detail::cuda_deleter());
309 ARCCORE_ALINA_CALL_CUDA(
310 cusparseXcsrsv2_bufferSize(handle,
311 CUSPARSE_OPERATION_NON_TRANSPOSE,
315 thrust::raw_pointer_cast(&val[0]),
316 thrust::raw_pointer_cast(&ptr[0]),
317 thrust::raw_pointer_cast(&col[0]),
318 info_L.get(), &buf_size_L));
320 ARCCORE_ALINA_CALL_CUDA(
321 cusparseXcsrsv2_bufferSize(handle,
322 CUSPARSE_OPERATION_NON_TRANSPOSE,
326 thrust::raw_pointer_cast(&val[0]),
327 thrust::raw_pointer_cast(&ptr[0]),
328 thrust::raw_pointer_cast(&col[0]),
329 info_U.get(), &buf_size_U));
331 buf.resize(std::max(buf_size_L, buf_size_U));
334 ARCCORE_ALINA_CALL_CUDA(
335 cusparseXcsrsv2_analysis(handle,
336 CUSPARSE_OPERATION_NON_TRANSPOSE,
340 thrust::raw_pointer_cast(&val[0]),
341 thrust::raw_pointer_cast(&ptr[0]),
342 thrust::raw_pointer_cast(&col[0]),
343 info_L.get(), CUSPARSE_SOLVE_POLICY_USE_LEVEL,
344 thrust::raw_pointer_cast(&buf[0])));
346 ARCCORE_ALINA_CALL_CUDA(
347 cusparseXcsrsv2_analysis(handle,
348 CUSPARSE_OPERATION_NON_TRANSPOSE,
352 thrust::raw_pointer_cast(&val[0]),
353 thrust::raw_pointer_cast(&ptr[0]),
354 thrust::raw_pointer_cast(&col[0]),
355 info_U.get(), CUSPARSE_SOLVE_POLICY_USE_LEVEL,
356 thrust::raw_pointer_cast(&buf[0])));
360 template <
class Matrix,
class VectorRHS,
class VectorX,
class VectorTMP>
361 void apply_pre(
const Matrix& A,
const VectorRHS& rhs, VectorX& x, VectorTMP& tmp)
const
363 backend::residual(rhs, A, x, tmp);
365 backend::axpby(prm.damping, tmp, 1, x);
368 template <
class Matrix,
class VectorRHS,
class VectorX,
class VectorTMP>
369 void apply_post(
const Matrix& A,
const VectorRHS& rhs, VectorX& x, VectorTMP& tmp)
const
371 backend::residual(rhs, A, x, tmp);
373 backend::axpby(prm.damping, tmp, 1, x);
376 template <
class Matrix,
class VectorRHS,
class VectorX>
377 void apply(
const Matrix& A,
const VectorRHS& rhs, VectorX& x)
const
379 backend::copy(rhs, x);
386 return backend::bytes(ptr) +
387 backend::bytes(col) +
388 backend::bytes(val) +
390#if CUDART_VERSION >= 11000
391 backend::bytes(bufL) +
401 cusparseHandle_t handle;
404 thrust::device_vector<int> ptr, col;
405 thrust::device_vector<value_type> val;
406 mutable thrust::device_vector<value_type> y;
408#if CUDART_VERSION >= 11000
409 std::shared_ptr<std::remove_pointer<cusparseSpMatDescr_t>::type> descr_L, descr_U;
410 std::shared_ptr<std::remove_pointer<cusparseSpSVDescr_t>::type> descr_SL, descr_SU;
411 std::shared_ptr<std::remove_pointer<cusparseDnVecDescr_t>::type> descr_y;
412 mutable thrust::device_vector<char> bufL, bufU;
414 std::shared_ptr<std::remove_pointer<cusparseMatDescr_t>::type> descr_L, descr_U;
415 std::shared_ptr<std::remove_pointer<csrsv2Info_t>::type> info_L, info_U;
416 mutable thrust::device_vector<char> buf;
419 template <
class VectorX>
420 void solve(VectorX& x)
const
422 value_type alpha = 1;
424#if CUDART_VERSION >= 11000
425 std::shared_ptr<std::remove_pointer<cusparseDnVecDescr_t>::type> descr_x(
426 backend::detail::cuda_vector_description(x),
427 backend::detail::cuda_deleter());
430 ARCCORE_ALINA_CALL_CUDA(
431 cusparseSpSV_solve(handle,
432 CUSPARSE_OPERATION_NON_TRANSPOSE,
437 backend::detail::cuda_datatype<real>(),
438 CUSPARSE_SPSV_ALG_DEFAULT,
442 ARCCORE_ALINA_CALL_CUDA(
443 cusparseSpSV_solve(handle,
444 CUSPARSE_OPERATION_NON_TRANSPOSE,
449 backend::detail::cuda_datatype<real>(),
450 CUSPARSE_SPSV_ALG_DEFAULT,
454 ARCCORE_ALINA_CALL_CUDA(
455 cusparseXcsrsv2_solve(handle,
456 CUSPARSE_OPERATION_NON_TRANSPOSE,
461 thrust::raw_pointer_cast(&val[0]),
462 thrust::raw_pointer_cast(&ptr[0]),
463 thrust::raw_pointer_cast(&col[0]),
465 thrust::raw_pointer_cast(&x[0]),
466 thrust::raw_pointer_cast(&y[0]),
467 CUSPARSE_SOLVE_POLICY_USE_LEVEL,
468 thrust::raw_pointer_cast(&buf[0])));
471 ARCCORE_ALINA_CALL_CUDA(
472 cusparseXcsrsv2_solve(handle,
473 CUSPARSE_OPERATION_NON_TRANSPOSE,
478 thrust::raw_pointer_cast(&val[0]),
479 thrust::raw_pointer_cast(&ptr[0]),
480 thrust::raw_pointer_cast(&col[0]),
482 thrust::raw_pointer_cast(&y[0]),
483 thrust::raw_pointer_cast(&x[0]),
484 CUSPARSE_SOLVE_POLICY_USE_LEVEL,
485 thrust::raw_pointer_cast(&buf[0])));
489 static cusparseStatus_t
490 cusparseXcsrilu02_bufferSize(cusparseHandle_t handle,
493 const cusparseMatDescr_t descrA,
494 double* csrSortedValA,
495 const int* csrSortedRowPtrA,
496 const int* csrSortedColIndA,
498 int* pBufferSizeInBytes)
500 return cusparseDcsrilu02_bufferSize(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA,
501 csrSortedColIndA, info, pBufferSizeInBytes);
504 static cusparseStatus_t
505 cusparseXcsrilu02_bufferSize(cusparseHandle_t handle,
508 const cusparseMatDescr_t descrA,
509 float* csrSortedValA,
510 const int* csrSortedRowPtrA,
511 const int* csrSortedColIndA,
513 int* pBufferSizeInBytes)
515 return cusparseScsrilu02_bufferSize(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA,
516 csrSortedColIndA, info, pBufferSizeInBytes);
519 static cusparseStatus_t
520 cusparseXcsrilu02_analysis(cusparseHandle_t handle,
523 const cusparseMatDescr_t descrA,
524 const double* csrSortedValA,
525 const int* csrSortedRowPtrA,
526 const int* csrSortedColIndA,
528 cusparseSolvePolicy_t policy,
531 return cusparseDcsrilu02_analysis(handle, m, nnz, descrA, csrSortedValA,
532 csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer);
535 static cusparseStatus_t
536 cusparseXcsrilu02_analysis(cusparseHandle_t handle,
539 const cusparseMatDescr_t descrA,
540 const float* csrSortedValA,
541 const int* csrSortedRowPtrA,
542 const int* csrSortedColIndA,
544 cusparseSolvePolicy_t policy,
547 return cusparseScsrilu02_analysis(handle, m, nnz, descrA, csrSortedValA,
548 csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer);
551 static cusparseStatus_t
552 cusparseXcsrilu02(cusparseHandle_t handle,
555 const cusparseMatDescr_t descrA,
556 double* csrSortedValA_valM,
557 const int* csrSortedRowPtrA,
558 const int* csrSortedColIndA,
560 cusparseSolvePolicy_t policy,
563 return cusparseDcsrilu02(handle, m, nnz, descrA,
564 csrSortedValA_valM, csrSortedRowPtrA, csrSortedColIndA,
565 info, policy, pBuffer);
568 static cusparseStatus_t
569 cusparseXcsrilu02(cusparseHandle_t handle,
572 const cusparseMatDescr_t descrA,
573 float* csrSortedValA_valM,
574 const int* csrSortedRowPtrA,
575 const int* csrSortedColIndA,
577 cusparseSolvePolicy_t policy,
580 return cusparseScsrilu02(handle, m, nnz, descrA,
581 csrSortedValA_valM, csrSortedRowPtrA, csrSortedColIndA,
582 info, policy, pBuffer);
585#if CUDART_VERSION < 11000
586 static cusparseStatus_t
587 cusparseXcsrsv2_bufferSize(cusparseHandle_t handle,
588 cusparseOperation_t transA,
591 const cusparseMatDescr_t descrA,
592 double* csrSortedValA,
593 const int* csrSortedRowPtrA,
594 const int* csrSortedColIndA,
596 int* pBufferSizeInBytes)
598 return cusparseDcsrsv2_bufferSize(handle, transA, m, nnz, descrA, csrSortedValA,
599 csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes);
602 static cusparseStatus_t
603 cusparseXcsrsv2_bufferSize(cusparseHandle_t handle,
604 cusparseOperation_t transA,
607 const cusparseMatDescr_t descrA,
608 float* csrSortedValA,
609 const int* csrSortedRowPtrA,
610 const int* csrSortedColIndA,
612 int* pBufferSizeInBytes)
614 return cusparseScsrsv2_bufferSize(handle, transA, m, nnz, descrA, csrSortedValA,
615 csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes);
618 static cusparseStatus_t
619 cusparseXcsrsv2_analysis(cusparseHandle_t handle,
620 cusparseOperation_t transA,
623 const cusparseMatDescr_t descrA,
624 const double* csrSortedValA,
625 const int* csrSortedRowPtrA,
626 const int* csrSortedColIndA,
628 cusparseSolvePolicy_t policy,
631 return cusparseDcsrsv2_analysis(handle, transA, m, nnz, descrA, csrSortedValA,
632 csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer);
635 static cusparseStatus_t
636 cusparseXcsrsv2_analysis(cusparseHandle_t handle,
637 cusparseOperation_t transA,
640 const cusparseMatDescr_t descrA,
641 const float* csrSortedValA,
642 const int* csrSortedRowPtrA,
643 const int* csrSortedColIndA,
645 cusparseSolvePolicy_t policy,
648 return cusparseScsrsv2_analysis(handle, transA, m, nnz, descrA, csrSortedValA,
649 csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer);
652 static cusparseStatus_t
653 cusparseXcsrsv2_solve(cusparseHandle_t handle,
654 cusparseOperation_t transA,
658 const cusparseMatDescr_t descrA,
659 const double* csrSortedValA,
660 const int* csrSortedRowPtrA,
661 const int* csrSortedColIndA,
665 cusparseSolvePolicy_t policy,
668 return cusparseDcsrsv2_solve(handle, transA, m,
669 nnz, alpha, descrA, csrSortedValA, csrSortedRowPtrA,
670 csrSortedColIndA, info, f, x, policy, pBuffer);
673 static cusparseStatus_t
674 cusparseXcsrsv2_solve(cusparseHandle_t handle,
675 cusparseOperation_t transA,
679 const cusparseMatDescr_t descrA,
680 const float* csrSortedValA,
681 const int* csrSortedRowPtrA,
682 const int* csrSortedColIndA,
686 cusparseSolvePolicy_t policy,
689 return cusparseScsrsv2_solve(handle, transA, m,
690 nnz, alpha, descrA, csrSortedValA, csrSortedRowPtrA,
691 csrSortedColIndA, info, f, x, policy, pBuffer);