Project
Loading...
Searching...
No Matches
GPUTPCNNClusterizer.h
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
14
15#ifndef O2_GPUTPCNNCLUSTERIZER_H
16#define O2_GPUTPCNNCLUSTERIZER_H
17
18#include "CfChargePos.h"
19#include "GPUProcessor.h"
20
21namespace o2::OrtDataType
22{
23struct Float16_t;
24}
25
26namespace o2::gpu
27{
28
30{
31 public:
33 void* setIOPointers(void*);
37
38 // Neural network clusterization
39
45 float mNnClassThreshold = 0.01;
55 int mNnInferenceInputDType = 0; // 0: float16, 1: float32
56 int mNnInferenceOutputDType = 0; // 0: float16, 1: float32
57 int mISector = -1;
58 int mDeviceId = -1;
59
60 // Memory allocation for neural network
61
62 bool* mClusterFlags = nullptr; // mSplitInTime, mSplitInPad. Techincally both flags are set in the same way -> ClusterAccumulator.cx=nullptr
63 int* mOutputDataClass = nullptr;
64
65 // FP32
66 float* mInputData_32 = nullptr;
67 float* mModelProbabilities_32 = nullptr;
68 float* mOutputDataReg1_32 = nullptr;
69 float* mOutputDataReg2_32 = nullptr;
70
71 // FP16
72 OrtDataType::Float16_t* mInputData_16 = nullptr;
73 OrtDataType::Float16_t* mModelProbabilities_16 = nullptr;
74 OrtDataType::Float16_t* mOutputDataReg1_16 = nullptr;
75 OrtDataType::Float16_t* mOutputDataReg2_16 = nullptr;
76
77 int16_t mMemoryId = -1;
78}; // class GPUTPCNNClusterizer
79
80} // namespace o2::gpu
81
82#endif
OrtDataType::Float16_t * mInputData_16
OrtDataType::Float16_t * mOutputDataReg2_16
void SetMaxData(const GPUTrackingInOutPointers &)
OrtDataType::Float16_t * mModelProbabilities_16
OrtDataType::Float16_t * mOutputDataReg1_16