Project
Loading...
Searching...
No Matches
BandMatrixSolver.h
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
16
17#ifndef ALICEO2_GPUCOMMON_TPCFASTTRANSFORMATION_BANDMATRIXSOLVER_H
18#define ALICEO2_GPUCOMMON_TPCFASTTRANSFORMATION_BANDMATRIXSOLVER_H
19
20#include "GPUCommonDef.h"
21#include "GPUCommonRtypes.h"
22#include <vector>
23#include <cassert>
24#include <cstdlib>
25#include <algorithm>
26#include <limits>
27
28namespace o2::gpu
29{
30
52template <int32_t BandWidthT>
54{
55 public:
57 BandMatrixSolver(int32_t N, int32_t Bdim) : mN(N), mBdim(Bdim)
58 {
59 assert(N > 0 && Bdim > 0);
60 mA.resize(mN * BandWidthT, 0.);
61 mB.resize(mN * mBdim, 0.);
62 }
63
66 {
67 // Assign NaN's to ensure that uninitialized elements (for the matrix type 1) are not used in calculations.
68 mA.assign(mA.size(), std::numeric_limits<double>::signaling_NaN());
69 mB.assign(mB.size(), std::numeric_limits<double>::signaling_NaN());
70 }
71
73 double& A(int32_t i, int32_t j)
74 {
75 auto ij = std::minmax(i, j);
76 assert(ij.first >= 0 && ij.second < mN);
77 int32_t k = ij.second - ij.first;
78 assert(k < BandWidthT);
79 return mA[ij.first * BandWidthT + k];
80 }
81
83 double& B(int32_t i, int32_t j)
84 {
85 assert(i >= 0 && i < mN && j >= 0 && j < mBdim);
86 return mB[i * mBdim + j];
87 }
88
90 void solve();
91
93 void solveType1();
94
96 static int32_t test(bool prn = 0)
97 {
98 return BandMatrixSolver<0>::test(prn);
99 }
100
101 private:
102 template <int32_t nRows>
103 void triangulateBlock(double AA[], double bb[]);
104
105 template <int32_t nCols>
106 void dioganalizeBlock(double A[], double b[]);
107
108 private:
109 int32_t mN = 0;
110 int32_t mBdim = 0;
111 std::vector<double> mA;
112 std::vector<double> mB;
113
114 ClassDefNV(BandMatrixSolver, 0);
115};
116
117template <>
119
120template <int32_t BandWidthT>
121template <int32_t nRows>
122inline void BandMatrixSolver<BandWidthT>::triangulateBlock(double AA[], double bb[])
123{
124 {
125 int32_t m = BandWidthT;
126 double* A = AA;
127 for (int32_t rows = 0; rows < nRows; rows++) {
128 double c = 1. / A[0];
129 A[0] = c; // store 1/a[0][0]
130 double* rowi = A + BandWidthT - 1;
131 for (int32_t i = 1; i < m; i++) { // row 0+i
132 double ai = c * A[i]; // A[0][i]
133 for (int32_t j = i; j < m; j++) {
134 rowi[j] -= ai * A[j]; // A[i][j] -= A[0][j]/A[0][0]*A[i][0]
135 }
136 A[i] = ai; // A[0][i] /= A[0][0]
137 rowi += BandWidthT - 1;
138 }
139 m--;
140 A += BandWidthT;
141 }
142 }
143
144 for (int32_t k = 0; k < mBdim; k++) {
145 int32_t m = BandWidthT;
146 double* A = AA;
147 double* b = bb;
148 for (int32_t rows = 0; rows < nRows; rows++) {
149 double bk = b[k];
150 for (int32_t i = 1; i < m; i++) {
151 b[mBdim * i + k] -= A[i] * bk;
152 }
153 b[k] *= A[0];
154 m--;
155 A += BandWidthT;
156 b += mBdim;
157 }
158 }
159}
160
161template <int32_t BandWidthT>
162template <int32_t nCols>
163inline void BandMatrixSolver<BandWidthT>::dioganalizeBlock(double AA[], double bb[])
164{
165 for (int32_t k = 0; k < mBdim; k++) {
166 int32_t rows = BandWidthT;
167 double* A = AA;
168 double* b = bb;
169 for (int32_t col = 0; col < nCols; col++) {
170 double bk = b[k];
171 for (int32_t i = 1; i < rows; i++) {
172 b[-i * mBdim + k] -= A[BandWidthT * (-i) + i] * bk;
173 }
174 A -= BandWidthT;
175 b -= mBdim;
176 rows--;
177 }
178 }
179}
180
181template <int32_t BandWidthT>
183{
185
186 const int32_t stepA = BandWidthT;
187 const int32_t stepB = mBdim;
188 // Upper Triangulization
189 {
190 int32_t k = 0;
191 double* Ak = &mA[0];
192 double* bk = &mB[0];
193 for (; k < mN - BandWidthT; k += 1, Ak += stepA, bk += stepB) { // for each row k
194 triangulateBlock<1>(Ak, bk);
195 }
196 // last m rows
197 triangulateBlock<BandWidthT>(Ak, bk);
198 }
199
200 // Diagonalization
201 {
202 int32_t k = mN - 1;
203 double* Ak = &mA[BandWidthT * k];
204 double* bk = &mB[mBdim * k];
205 for (; k > BandWidthT - 1; k -= 1, Ak -= stepA, bk -= stepB) { // for each row k
206 dioganalizeBlock<1>(Ak, bk);
207 }
208 // first m rows
209 dioganalizeBlock<BandWidthT>(Ak, bk);
210 }
211}
212
213template <int32_t BandWidthT>
215{
230
231 const int32_t stepA = 2 * BandWidthT;
232 const int32_t stepB = 2 * mBdim;
233 // Upper Triangulization
234 {
235 int32_t k = 0;
236 double* Ak = &mA[0];
237 double* bk = &mB[0];
238 for (; k < mN - BandWidthT; k += 2, Ak += stepA, bk += stepB) { // for each row k
239 triangulateBlock<2>(Ak, bk);
240 }
241 // last m rows
242 triangulateBlock<BandWidthT>(Ak, bk);
243 }
244
245 // Diagonalization
246 {
247 int32_t k = mN - 1;
248 double* Ak = &mA[BandWidthT * k];
249 double* bk = &mB[mBdim * k];
250 for (; k > BandWidthT - 1; k -= 2, Ak -= stepA, bk -= stepB) { // for each row k
251 dioganalizeBlock<2>(Ak, bk);
252 }
253 // first m rows
254 dioganalizeBlock<BandWidthT>(Ak, bk);
255 }
256}
257
258} // namespace o2::gpu
259
260#endif
int32_t i
const int16_t bb
uint32_t j
Definition RawData.h:0
uint32_t col
Definition RawData.h:4
uint32_t c
Definition RawData.h:2
Definition A.h:16
double & A(int32_t i, int32_t j)
access to A elements
double & B(int32_t i, int32_t j)
access to B elements
static int32_t test(bool prn=0)
Test the class functionality. Returns 1 when ok, 0 when not ok.
void initWithNaN()
debug tool: init arrays with NaN's
BandMatrixSolver(int32_t N, int32_t Bdim)
Consructor.
void solve()
solve the equation
void solveType1()
solve an equation of a special type
const GLfloat * m
Definition glcorearb.h:4066
GLboolean GLboolean GLboolean b
Definition glcorearb.h:1233
constexpr std::array< int, nLayers > nRows
Definition Specs.h:57
std::vector< ReadoutWindowData > rows