Project
Loading...
Searching...
No Matches
ML.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
14
15#include "TRDPID/ML.h"
26#include "Framework/Logger.h"
29
30#include <fmt/format.h>
31#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
32#include <onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>
33#else
34#include <onnxruntime_cxx_api.h>
35#endif
36#include <boost/range.hpp>
37
38#include <array>
39#include <algorithm>
40#include <stdexcept>
41#include <sstream>
42#include <string>
43
44using namespace o2::trd::constants;
45
46namespace o2
47{
48namespace trd
49{
50
52{
53 LOG(info) << "Initializating pid policy";
54
55 // fetch the onnx model from the ccdb
56 std::string model_data{fetchModelCCDB(pc, getName().c_str())};
57
58 // disable telemtry events
59 mEnv.DisableTelemetryEvents();
60 LOG(info) << "Disabled Telemetry Events";
61
62 // create session options
63 mSessionOptions.SetIntraOpNumThreads(mParams.numOrtThreads);
64 LOG(info) << "Set number of threads to " << mParams.numOrtThreads;
65
66 // Sets graph optimization level
67 mSessionOptions.SetGraphOptimizationLevel(static_cast<GraphOptimizationLevel>(mParams.graphOptimizationLevel));
68 LOG(info) << "Set GraphOptimizationLevel to " << mParams.graphOptimizationLevel;
69
70 // create actual session
71#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
72 mSession = std::make_unique<Ort::Experimental::Session>(mEnv, reinterpret_cast<void*>(model_data.data()), model_data.size(), mSessionOptions);
73#else
74 mSession = std::make_unique<Ort::Session>(mEnv, reinterpret_cast<void*>(model_data.data()), model_data.size(), mSessionOptions);
75#endif
76 LOG(info) << "ONNX runtime session created";
77
78 // print name/shape of inputs
79 for (size_t i = 0; i < mSession->GetInputCount(); ++i) {
80 mInputNames.push_back(mSession->GetInputNameAllocated(i, mAllocator).get());
81 }
82 for (size_t i = 0; i < mSession->GetInputCount(); ++i) {
83 mInputShapes.emplace_back(mSession->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
84 }
85 for (size_t i = 0; i < mSession->GetOutputCount(); ++i) {
86 mOutputNames.push_back(mSession->GetOutputNameAllocated(i, mAllocator).get());
87 }
88 for (size_t i = 0; i < mSession->GetOutputCount(); ++i) {
89 mOutputShapes.emplace_back(mSession->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
90 }
91
92 LOG(info) << "Input Node Name/Shape (" << mInputNames.size() << "):";
93 for (size_t i = 0; i < mInputNames.size(); i++) {
94 LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
95 }
96
97 // print name/shape of outputs
98 LOG(info) << "Output Node Name/Shape (" << mOutputNames.size() << "):";
99 for (size_t i = 0; i < mOutputNames.size(); i++) {
100 LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
101 }
102
103 LOG(info) << "Finalization done";
104}
105
106float ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& inputTracks, bool isTPCTRD) const
107{
108 try {
109 auto input = prepareModelInput(trk, inputTracks);
110 // create memory mapping to vector above
111#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
112 auto inputTensor = Ort::Experimental::Value::CreateTensor<float>(input.data(), input.size(),
113 {static_cast<int64_t>(input.size()) / mInputShapes[0][1], mInputShapes[0][1]});
114#else
115 Ort::MemoryInfo mem_info =
116 Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
117 auto inputTensor = Ort::Value::CreateTensor<float>(mem_info, input.data(), input.size(),
118 {static_cast<int64_t>(input.size()) / mInputShapes[0][1], mInputShapes[0][1]});
119#endif
120 std::vector<Ort::Value> ortTensor;
121 ortTensor.push_back(std::move(inputTensor));
122 auto outTensor = mSession->Run(mInputNames, ortTensor, mOutputNames);
123 // every model defines its own output
124 return getELikelihood(outTensor);
125 } catch (const Ort::Exception& e) {
126 LOG(error) << "Error running model inference, using defaults: " << e.what();
127 // fill with negative elikelihood means no information
128 return -1.f;
129 }
130}
131
132std::string ML::fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const noexcept
133{
134 // sanity checks
135 auto ref = pc.inputs().get(binding);
136 if (!ref.spec || !ref.payload) {
137 throw std::runtime_error(fmt::format("A ML model with '{}' as binding does not exist!", binding));
138 }
139
140 // the model is in binary string format
141 auto model_data = pc.inputs().get<std::string>(binding);
142 if (model_data.empty()) {
143 throw std::runtime_error(fmt::format("Did not get any data for {} model from ccdb!", binding));
144 }
145 return model_data;
146}
147
148std::vector<float> ML::prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) const noexcept
149{
150 // input is [charge0.0, charge0.1, charge0.2, charge1.0, ..., charge5.2, p0, ..., p5]
151 std::vector<float> in(mInputShapes[0][1]);
152 const auto& trackletsRaw = inputTracks.getTRDTracklets();
153 auto trk = trdTRD;
154 for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) {
155 int trkltId = trkTRD.getTrackletIndex(iLayer);
156 if (trkltId < 0) {
157 // no tracklet attached, fill with default values e.g. charge=-1.,
158 in[iLayer * NCHARGES + 0] = -1.f;
159 in[iLayer * NCHARGES + 1] = -1.f;
160 in[iLayer * NCHARGES + 2] = -1.f;
161 in[18 + iLayer] = -1.f;
162 continue;
163 } else {
164 const auto xCalib = input.getTRDCalibratedTracklets()[trkTRD.getTrackletIndex(iLayer)].getX();
165 auto bz = o2::base::Propagator::Instance()->getNominalBz();
166 const auto tgl = trk.getTgl();
167 const auto snp = trk.getSnpAt(o2::math_utils::sector2Angle(HelperMethods::getSector(input.getTRDTracklets()[trkIn.getTrackletIndex(iLayer)].getDetector())), xCalib, bz);
168 const auto& trklt = trackletsRaw[trkltId];
169 const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkTRD, input, snp, tgl); // correct charges
170 in[iLayer * NCHARGES + 0] = q0;
171 in[iLayer * NCHARGES + 1] = q1;
172 in[iLayer * NCHARGES + 2] = q2;
173 in[18 + iLayer] = trk.getP();
174 }
175 }
176
177 return in;
178}
179
180std::string ML::printShape(const std::vector<int64_t>& v) const noexcept
181{
182 std::stringstream ss("");
183 for (size_t i = 0; i < v.size() - 1; i++) {
184 ss << v[i] << "x";
185 }
186 ss << v[v.size() - 1];
187 return ss.str();
188}
189
190} // namespace trd
191} // namespace o2
Global TRD definitions and constants.
int32_t i
Global index for barrel track: provides provenance (detectors combination), index in respective array...
This file provides the base for ML policies.
Result of refitting TPC-ITS matched track.
GPUd() value_type estimateLTFast(o2 static GPUd() float estimateLTIncrement(const o2 PropagatorImpl * Instance(bool uninitialized=false)
Definition Propagator.h:143
void init(o2::framework::ProcessingContext &pc) final
Initialize the policy.
Definition ML.cxx:51
float process(const TrackTRD &trk, const o2::globaltracking::RecoContainer &input, bool isTPCTRD) const final
Calculate a PID for a given track.
Definition ML.cxx:106
const TRDPIDParams & mParams
parameters
Definition PIDBase.h:70
const GLdouble * v
Definition glcorearb.h:832
float sector2Angle(int sect)
Definition Utils.h:193
constexpr int NLAYER
the number of layers
Definition Constants.h:27
constexpr int NCHARGES
the number of charges per tracklet (Q0/1/2)
Definition Constants.h:61
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
static int getSector(int det)
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"