30#include <fmt/format.h>
31#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
32#include <onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>
34#include <onnxruntime_cxx_api.h>
36#include <boost/range.hpp>
53 LOG(info) <<
"Initializating pid policy";
56 std::string model_data{fetchModelCCDB(pc, getName().c_str())};
59 mEnv.DisableTelemetryEvents();
60 LOG(info) <<
"Disabled Telemetry Events";
63 mSessionOptions.SetIntraOpNumThreads(
mParams.numOrtThreads);
64 LOG(info) <<
"Set number of threads to " <<
mParams.numOrtThreads;
67 mSessionOptions.SetGraphOptimizationLevel(
static_cast<GraphOptimizationLevel
>(
mParams.graphOptimizationLevel));
68 LOG(info) <<
"Set GraphOptimizationLevel to " <<
mParams.graphOptimizationLevel;
71#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
72 mSession = std::make_unique<Ort::Experimental::Session>(mEnv,
reinterpret_cast<void*
>(model_data.data()), model_data.size(), mSessionOptions);
74 mSession = std::make_unique<Ort::Session>(mEnv,
reinterpret_cast<void*
>(model_data.data()), model_data.size(), mSessionOptions);
76 LOG(info) <<
"ONNX runtime session created";
79 for (
size_t i = 0;
i < mSession->GetInputCount(); ++
i) {
80 mInputNames.push_back(mSession->GetInputNameAllocated(
i, mAllocator).get());
82 for (
size_t i = 0;
i < mSession->GetInputCount(); ++
i) {
83 mInputShapes.emplace_back(mSession->GetInputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape());
85 for (
size_t i = 0;
i < mSession->GetOutputCount(); ++
i) {
86 mOutputNames.push_back(mSession->GetOutputNameAllocated(
i, mAllocator).get());
88 for (
size_t i = 0;
i < mSession->GetOutputCount(); ++
i) {
89 mOutputShapes.emplace_back(mSession->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape());
92 LOG(info) <<
"Input Node Name/Shape (" << mInputNames.size() <<
"):";
93 for (
size_t i = 0;
i < mInputNames.size();
i++) {
94 LOG(info) <<
"\t" << mInputNames[
i] <<
" : " << printShape(mInputShapes[
i]);
98 LOG(info) <<
"Output Node Name/Shape (" << mOutputNames.size() <<
"):";
99 for (
size_t i = 0;
i < mOutputNames.size();
i++) {
100 LOG(info) <<
"\t" << mOutputNames[
i] <<
" : " << printShape(mOutputShapes[
i]);
103 LOG(info) <<
"Finalization done";
109 auto input = prepareModelInput(trk, inputTracks);
111#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
112 auto inputTensor = Ort::Experimental::Value::CreateTensor<float>(input.data(), input.size(),
113 {static_cast<int64_t>(input.size()) / mInputShapes[0][1], mInputShapes[0][1]});
115 Ort::MemoryInfo mem_info =
116 Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
117 auto inputTensor = Ort::Value::CreateTensor<float>(mem_info, input.data(), input.size(),
118 {static_cast<int64_t>(input.size()) / mInputShapes[0][1], mInputShapes[0][1]});
120 std::vector<Ort::Value> ortTensor;
121 ortTensor.push_back(std::move(inputTensor));
122 auto outTensor = mSession->Run(mInputNames, ortTensor, mOutputNames);
124 return getELikelihood(outTensor);
125 }
catch (
const Ort::Exception& e) {
126 LOG(error) <<
"Error running model inference, using defaults: " << e.what();
135 auto ref = pc.inputs().get(binding);
136 if (!
ref.spec || !
ref.payload) {
137 throw std::runtime_error(fmt::format(
"A ML model with '{}' as binding does not exist!", binding));
141 auto model_data = pc.inputs().get<std::string>(binding);
142 if (model_data.empty()) {
143 throw std::runtime_error(fmt::format(
"Did not get any data for {} model from ccdb!", binding));
151 std::vector<float> in(mInputShapes[0][1]);
152 const auto& trackletsRaw = inputTracks.getTRDTracklets();
155 int trkltId = trkTRD.getTrackletIndex(iLayer);
161 in[18 + iLayer] = -1.f;
164 const auto xCalib = input.getTRDCalibratedTracklets()[trkTRD.getTrackletIndex(iLayer)].getX();
166 const auto tgl = trk.getTgl();
168 const auto& trklt = trackletsRaw[trkltId];
169 const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkTRD, input, snp, tgl);
173 in[18 + iLayer] = trk.getP();
180std::string ML::printShape(
const std::vector<int64_t>&
v)
const noexcept
182 std::stringstream ss(
"");
183 for (
size_t i = 0;
i <
v.size() - 1;
i++) {
186 ss <<
v[
v.size() - 1];
Global index for barrel track: provides provenance (detectors combination), index in respective array...
This file provides the base for ML policies.
Result of refitting TPC-ITS matched track.
GPUd() value_type estimateLTFast(o2 static GPUd() float estimateLTIncrement(const o2 PropagatorImpl * Instance(bool uninitialized=false)
void init(o2::framework::ProcessingContext &pc) final
Initialize the policy.
float process(const TrackTRD &trk, const o2::globaltracking::RecoContainer &input, bool isTPCTRD) const final
Calculate a PID for a given track.
const TRDPIDParams & mParams
parameters
float sector2Angle(int sect)
constexpr int NLAYER
the number of layers
constexpr int NCHARGES
the number of charges per tracklet (Q0/1/2)
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
static int getSector(int det)
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"