Alien  1.3.0
Developer documentation
Loading...
Searching...
No Matches
CSRModifierViewT.h
1/*
2 * Copyright 2020 IFPEN-CEA
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * SPDX-License-Identifier: Apache-2.0
17 */
18
19/*
20 * CSRModifierViewT.h
21 *
22 * Created on: Dec 24, 2021
23 * Author: gratienj
24 */
25
26// -*- C++ -*-
27#pragma once
28
29namespace Alien
30{
31
32template <typename ProfileT, typename DistProfileInfoT>
33class CSRProfileConstViewT
34{
35 protected:
36 // clang-format off
37 ProfileT const& m_profile;
38 DistProfileInfoT const& m_dist_info ;
39 bool m_is_parallel = false;
40 // clang-format on
41
42 public:
43 // clang-format off
44 typedef ProfileT ProfileType ;
45 typedef typename ProfileType::IndexType IndexType ;
46 typedef DistProfileInfoT DistInfoType ;
47 // clang-format on
48
49 CSRProfileConstViewT(ProfileT const& profile,
50 DistInfoType const& dist_info,
51 bool is_parallel = false)
52 : m_profile(profile)
53 , m_dist_info(dist_info)
54 , m_is_parallel(is_parallel)
55 {}
56
57 std::size_t nrows()
58 {
59 return m_profile.getNRows();
60 }
61
62 std::size_t nnz()
63 {
64 return m_profile.getNnz();
65 }
66
67 IndexType const* kcol()
68 {
69 return m_profile.kcol();
70 }
71
72 IndexType const* cols()
73 {
74 if (m_is_parallel)
75 return m_dist_info.m_cols.data();
76 else
77 return m_profile.cols();
78 }
79
80 IndexType const* dcol()
81 {
82 if (m_is_parallel)
83 return m_dist_info.dcol(m_profile);
84 else
85 return m_profile.dcol();
86 }
87};
88
89template <typename MatrixT>
90class CSRConstViewT
91: public CSRProfileConstViewT<typename MatrixT::ProfileType,
92 typename MatrixT::DistStructInfo>
93{
94 public:
95 // clang-format off
96 typedef MatrixT MatrixType ;
97 typedef typename MatrixType::ProfileType ProfileType ;
98 typedef typename MatrixType::DistStructInfo DistStructInfo ;
99 typedef
100 CSRProfileConstViewT<ProfileType,DistStructInfo> BaseType ;
101 typedef typename MatrixType::ValueType ValueType ;
102 typedef typename BaseType::IndexType IndexType ;
103 // clang-format on
104
105 CSRConstViewT(MatrixT const& matrix)
106 : BaseType(matrix.getProfile(), matrix.getDistStructInfo(), matrix.isParallel())
107 , m_matrix(matrix)
108 {}
109
110 ValueType const* data()
111 {
112 return this->m_matrix.data();
113 }
114
115 private:
116 MatrixType const& m_matrix;
117};
118
119template <typename MatrixT>
120class CSRModifierViewT
121: public CSRProfileConstViewT<typename MatrixT::ProfileType,
122 typename MatrixT::DistStructInfo>
123{
124 public:
125 // clang-format off
126 typedef MatrixT MatrixType ;
127 typedef typename MatrixType::ProfileType ProfileType ;
128 typedef typename MatrixType::DistStructInfo DistStructInfo ;
129 typedef
130 CSRProfileConstViewT<ProfileType,DistStructInfo> BaseType ;
131 typedef typename MatrixType::ValueType ValueType ;
132 typedef typename BaseType::IndexType IndexType ;
133 // clang-format on
134
135 CSRModifierViewT(MatrixT& matrix)
136 : BaseType(matrix.getProfile(), matrix.getDistStructInfo(), matrix.isParallel())
137 , m_matrix(matrix)
138 {
139 m_matrix.notifyChanges();
140 }
141
142 virtual ~CSRModifierViewT()
143 {
144 m_matrix.endUpdate();
145 }
146
147 ValueType* data()
148 {
149 return this->m_matrix.data();
150 }
151
152 private:
153 MatrixType& m_matrix;
154};
155
156} // end namespace Alien
-- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature --
Definition BackEnd.h:17