Project
Loading...
Searching...
No Matches
ML.h
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15
16#ifndef O2_TRD_ML_H
17#define O2_TRD_ML_H
18
19#include "Rtypes.h"
20#include "TRDPID/PIDBase.h"
21#include "DataFormatsTRD/PID.h"
24#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
25#include <onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>
26#else
27#include <onnxruntime_cxx_api.h>
28#endif
29#include <memory>
30#include <vector>
31#include <array>
32#include <string>
33
34namespace o2::trd
35{
36
39class ML : public PIDBase
40{
41 using PIDBase::PIDBase;
42
43 public:
45 float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final;
46
47 private:
50 virtual inline float getELikelihood(const std::vector<Ort::Value>& tensorData) const noexcept = 0;
51
53 std::string fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const noexcept;
54
57 std::vector<float> prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) const noexcept;
58
60 std::string printShape(const std::vector<int64_t>& v) const noexcept;
61
63 virtual inline std::string getName() const noexcept = 0;
64
65 // ONNX runtime
66 Ort::Env mEnv{ORT_LOGGING_LEVEL_WARNING, "TRD-PID",
67 // Integrate ORT logging into Fairlogger this way we can have
68 // all the nice logging while taking advantage of ORT telling us
69 // what to do.
70 [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
71 LOG(warn) << "Ort " << severity << ": [" << logid << "|" << category << "|" << code_location << "]: " << message << ((intptr_t)param == 3 ? " [valid]" : " [error]");
72 },
73 (void*)3};
74 const OrtApi& mApi{Ort::GetApi()};
75#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
76 std::unique_ptr<Ort::Experimental::Session> mSession;
77#else
78 std::unique_ptr<Ort::Session> mSession;
79#endif
80 Ort::SessionOptions mSessionOptions;
81 Ort::AllocatorWithDefaultOptions mAllocator;
82
83 // Input/Output
84 std::vector<std::string> mInputNames;
85 std::vector<std::vector<int64_t>> mInputShapes;
86 std::vector<std::string> mOutputNames;
87 std::vector<std::vector<int64_t>> mOutputShapes;
88
89 ClassDefOverride(ML, 1);
90};
91
93class XGB final : public ML
94{
95 using ML::ML;
96
97 public:
98 ~XGB() = default;
99
100 private:
103 inline float getELikelihood(const std::vector<Ort::Value>& tensorData) const noexcept
104 {
105 return tensorData[1].GetTensorData<float>()[1];
106 }
107
108 inline std::string getName() const noexcept { return "xgb"; }
109
110 ClassDefNV(XGB, 1);
111};
112
114class PY final : public ML
115{
116 using ML::ML;
117
118 public:
119 ~PY() = default;
120
121 private:
122 inline float getELikelihood(const std::vector<Ort::Value>& tensorData) const noexcept
123 {
124 return tensorData[0].GetTensorData<float>()[0];
125 }
126
127 inline std::string getName() const noexcept { return "py"; }
128
129 ClassDefNV(PY, 1);
130};
131
132} // namespace o2::trd
133
134#endif
This file provides the base interface for pid policies.
void init(o2::framework::ProcessingContext &pc) final
Initialize the policy.
Definition ML.cxx:51
PIDBase(PIDPolicy policy)
Definition PIDBase.h:52
PyTorch Model.
Definition ML.h:115
~PY()=default
XGBoost Model.
Definition ML.h:94
~XGB()=default
const GLdouble * v
Definition glcorearb.h:832
GLuint GLsizei const GLchar * message
Definition glcorearb.h:2517
GLenum GLfloat param
Definition glcorearb.h:271
GLenum GLenum severity
Definition glcorearb.h:2513
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"