Project
Loading...
Searching...
No Matches
BandMatrixSolver.h
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
16
17#ifndef ALICEO2_GPUCOMMON_TPCFASTTRANSFORMATION_BANDMATRIXSOLVER_H
18#define ALICEO2_GPUCOMMON_TPCFASTTRANSFORMATION_BANDMATRIXSOLVER_H
19
20#include "GPUCommonDef.h"
21#include "GPUCommonRtypes.h"
22#include <vector>
23#include <cassert>
24#include <cstdlib>
25#include <algorithm>
26#include <limits>
27
28namespace o2
29{
30namespace gpu
31{
32
54template <int32_t BandWidthT>
56{
57 public:
59 BandMatrixSolver(int32_t N, int32_t Bdim) : mN(N), mBdim(Bdim)
60 {
61 assert(N > 0 && Bdim > 0);
62 mA.resize(mN * BandWidthT, 0.);
63 mB.resize(mN * mBdim, 0.);
64 }
65
68 {
69 // Assign NaN's to ensure that uninitialized elements (for the matrix type 1) are not used in calculations.
70 mA.assign(mA.size(), std::numeric_limits<double>::signaling_NaN());
71 mB.assign(mB.size(), std::numeric_limits<double>::signaling_NaN());
72 }
73
75 double& A(int32_t i, int32_t j)
76 {
77 auto ij = std::minmax(i, j);
78 assert(ij.first >= 0 && ij.second < mN);
79 int32_t k = ij.second - ij.first;
80 assert(k < BandWidthT);
81 return mA[ij.first * BandWidthT + k];
82 }
83
85 double& B(int32_t i, int32_t j)
86 {
87 assert(i >= 0 && i < mN && j >= 0 && j < mBdim);
88 return mB[i * mBdim + j];
89 }
90
92 void solve();
93
95 void solveType1();
96
98 static int32_t test(bool prn = 0)
99 {
100 return BandMatrixSolver<0>::test(prn);
101 }
102
103 private:
104 template <int32_t nRows>
105 void triangulateBlock(double AA[], double bb[]);
106
107 template <int32_t nCols>
108 void dioganalizeBlock(double A[], double b[]);
109
110 private:
111 int32_t mN = 0;
112 int32_t mBdim = 0;
113 std::vector<double> mA;
114 std::vector<double> mB;
115
116 ClassDefNV(BandMatrixSolver, 0);
117};
118
119template <>
121
122template <int32_t BandWidthT>
123template <int32_t nRows>
124inline void BandMatrixSolver<BandWidthT>::triangulateBlock(double AA[], double bb[])
125{
126 {
127 int32_t m = BandWidthT;
128 double* A = AA;
129 for (int32_t rows = 0; rows < nRows; rows++) {
130 double c = 1. / A[0];
131 A[0] = c; // store 1/a[0][0]
132 double* rowi = A + BandWidthT - 1;
133 for (int32_t i = 1; i < m; i++) { // row 0+i
134 double ai = c * A[i]; // A[0][i]
135 for (int32_t j = i; j < m; j++) {
136 rowi[j] -= ai * A[j]; // A[i][j] -= A[0][j]/A[0][0]*A[i][0]
137 }
138 A[i] = ai; // A[0][i] /= A[0][0]
139 rowi += BandWidthT - 1;
140 }
141 m--;
142 A += BandWidthT;
143 }
144 }
145
146 for (int32_t k = 0; k < mBdim; k++) {
147 int32_t m = BandWidthT;
148 double* A = AA;
149 double* b = bb;
150 for (int32_t rows = 0; rows < nRows; rows++) {
151 double bk = b[k];
152 for (int32_t i = 1; i < m; i++) {
153 b[mBdim * i + k] -= A[i] * bk;
154 }
155 b[k] *= A[0];
156 m--;
157 A += BandWidthT;
158 b += mBdim;
159 }
160 }
161}
162
163template <int32_t BandWidthT>
164template <int32_t nCols>
165inline void BandMatrixSolver<BandWidthT>::dioganalizeBlock(double AA[], double bb[])
166{
167 for (int32_t k = 0; k < mBdim; k++) {
168 int32_t rows = BandWidthT;
169 double* A = AA;
170 double* b = bb;
171 for (int32_t col = 0; col < nCols; col++) {
172 double bk = b[k];
173 for (int32_t i = 1; i < rows; i++) {
174 b[-i * mBdim + k] -= A[BandWidthT * (-i) + i] * bk;
175 }
176 A -= BandWidthT;
177 b -= mBdim;
178 rows--;
179 }
180 }
181}
182
183template <int32_t BandWidthT>
185{
187
188 const int32_t stepA = BandWidthT;
189 const int32_t stepB = mBdim;
190 // Upper Triangulization
191 {
192 int32_t k = 0;
193 double* Ak = &mA[0];
194 double* bk = &mB[0];
195 for (; k < mN - BandWidthT; k += 1, Ak += stepA, bk += stepB) { // for each row k
196 triangulateBlock<1>(Ak, bk);
197 }
198 // last m rows
199 triangulateBlock<BandWidthT>(Ak, bk);
200 }
201
202 // Diagonalization
203 {
204 int32_t k = mN - 1;
205 double* Ak = &mA[BandWidthT * k];
206 double* bk = &mB[mBdim * k];
207 for (; k > BandWidthT - 1; k -= 1, Ak -= stepA, bk -= stepB) { // for each row k
208 dioganalizeBlock<1>(Ak, bk);
209 }
210 // first m rows
211 dioganalizeBlock<BandWidthT>(Ak, bk);
212 }
213}
214
215template <int32_t BandWidthT>
217{
232
233 const int32_t stepA = 2 * BandWidthT;
234 const int32_t stepB = 2 * mBdim;
235 // Upper Triangulization
236 {
237 int32_t k = 0;
238 double* Ak = &mA[0];
239 double* bk = &mB[0];
240 for (; k < mN - BandWidthT; k += 2, Ak += stepA, bk += stepB) { // for each row k
241 triangulateBlock<2>(Ak, bk);
242 }
243 // last m rows
244 triangulateBlock<BandWidthT>(Ak, bk);
245 }
246
247 // Diagonalization
248 {
249 int32_t k = mN - 1;
250 double* Ak = &mA[BandWidthT * k];
251 double* bk = &mB[mBdim * k];
252 for (; k > BandWidthT - 1; k -= 2, Ak -= stepA, bk -= stepB) { // for each row k
253 dioganalizeBlock<2>(Ak, bk);
254 }
255 // first m rows
256 dioganalizeBlock<BandWidthT>(Ak, bk);
257 }
258}
259
260} // namespace gpu
261} // namespace o2
262
263#endif
int32_t i
const int16_t bb
uint32_t j
Definition RawData.h:0
uint32_t col
Definition RawData.h:4
uint32_t c
Definition RawData.h:2
Definition A.h:16
double & A(int32_t i, int32_t j)
access to A elements
double & B(int32_t i, int32_t j)
access to B elements
static int32_t test(bool prn=0)
Test the class functionality. Returns 1 when ok, 0 when not ok.
void initWithNaN()
debug tool: init arrays with NaN's
BandMatrixSolver(int32_t N, int32_t Bdim)
Consructor.
void solve()
solve the equation
void solveType1()
solve an equation of a special type
const GLfloat * m
Definition glcorearb.h:4066
GLboolean GLboolean GLboolean b
Definition glcorearb.h:1233
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
std::vector< ReadoutWindowData > rows