Project
Loading...
Searching...
No Matches
TPCLoopers.h
Go to the documentation of this file.
1// Copyright 2024-2025 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
13
14#ifndef ALICEO2_EVENTGEN_TPCLOOPERS_H_
15#define ALICEO2_EVENTGEN_TPCLOOPERS_H_
16
17#ifdef GENERATORS_WITH_TPCLOOPERS
18#include <onnxruntime_cxx_api.h>
19#include <vector>
20#include <rapidjson/document.h>
21#include "TRandom3.h"
23#include "TParticle.h"
24
25// Static Ort::Env instance for multiple onnx model loading
26extern Ort::Env global_env;
27
28// This class is responsible for loading the scaler parameters from a JSON file
29// and applying the inverse transformation to the generated data.
30// Inferenced output is scaled (min-max normalization or robust scaling for outlier features) during training,
31// so we need to revert this transformation to get physical values.
32struct Scaler {
33 std::vector<double> normal_min;
34 std::vector<double> normal_max;
35 std::vector<double> outlier_center;
36 std::vector<double> outlier_scale;
37
38 void load(const std::string& filename);
39
40 std::vector<double> inverse_transform(const std::vector<double>& input);
41
42 private:
43 std::vector<double> jsonArrayToVector(const rapidjson::Value& jsonArray);
44};
45
46// This class loads the ONNX model and generates samples using it.
47class ONNXGenerator
48{
49 public:
50 ONNXGenerator(Ort::Env& shared_env, const std::string& model_path);
51
52 std::vector<double> generate_sample();
53
54 private:
55 Ort::Env& env;
56 Ort::Session session;
57 TRandom3 rand_gen;
58};
59#endif // GENERATORS_WITH_TPCLOOPERS
60
61namespace o2
62{
63namespace eventgen
64{
65
66#ifdef GENERATORS_WITH_TPCLOOPERS
81class GenTPCLoopers
82{
83 public:
84 GenTPCLoopers(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
85 std::string poisson = "poisson.csv", std::string gauss = "gauss.csv", std::string scaler_pair = "scaler_pair.json",
86 std::string scaler_compton = "scaler_compton.json");
87
88 Bool_t generateEvent();
89
90 Bool_t generateEvent(double time_limit);
91
92 std::vector<TParticle> importParticles();
93
94 unsigned int PoissonPairs();
95
96 unsigned int GaussianElectrons();
97
98 void SetNLoopers(unsigned int nsig_pair, unsigned int nsig_compton);
99
100 void SetMultiplier(const std::array<float, 2>& mult);
101
102 void setFlatGas(Bool_t flat, Int_t number = -1, Int_t nloopers_orbit = -1);
103
104 void setFractionPairs(float fractionPairs);
105
106 void SetRate(const std::string& rateFile, bool isPbPb, int intRate = 50000);
107
108 void SetAdjust(float adjust = 0.f);
109
110 unsigned int getNLoopers() const { return (mNLoopersPairs + mNLoopersCompton); }
111
112 private:
113 std::unique_ptr<ONNXGenerator> mONNX_pair = nullptr;
114 std::unique_ptr<ONNXGenerator> mONNX_compton = nullptr;
115 std::unique_ptr<Scaler> mScaler_pair = nullptr;
116 std::unique_ptr<Scaler> mScaler_compton = nullptr;
117 double mPoisson[3] = {0.0, 0.0, 0.0}; // Mu, Min and Max of Poissonian
118 double mGauss[4] = {0.0, 0.0, 0.0, 0.0}; // Mean, Std, Min, Max
119 std::vector<std::vector<double>> mGenPairs;
120 std::vector<std::vector<double>> mGenElectrons;
121 unsigned int mNLoopersPairs = -1;
122 unsigned int mNLoopersCompton = -1;
123 std::array<float, 2> mMultiplier = {1., 1.};
124 bool mPoissonSet = false;
125 bool mGaussSet = false;
126 // Random number generator
127 TRandom3 mRandGen;
128 int mCurrentEvent = 0; // Current event number, used for adaptive loopers
129 TFile* mContextFile = nullptr; // Input collision context file
130 o2::steer::DigitizationContext* mCollisionContext = nullptr; // Pointer to the digitization context
131 std::vector<o2::InteractionTimeRecord> mInteractionTimeRecords; // Interaction time records from collision context
132 Bool_t mFlatGas = false; // Flag to indicate if flat gas loopers are used
133 Bool_t mFlatGasOrbit = false; // Flag to indicate if flat gas loopers are per orbit
134 Int_t mFlatGasNumber = -1; // Number of flat gas loopers per event
135 double mIntTimeRecMean = 1.0; // Average interaction time record used for the reference
136 double mTimeLimit = 0.0; // Time limit for the current event
137 double mTimeEnd = 0.0; // Time limit for the last event
138 float mLoopsFractionPairs = 0.08; // Fraction of loopers from Pairs
139 int mInteractionRate = 50000; // Interaction rate in Hz
140};
141#endif // GENERATORS_WITH_TPCLOOPERS
142
143} // namespace eventgen
144} // namespace o2
145
146#endif // ALICEO2_EVENTGEN_TPCLOOPERS_H_
Ort::Env global_env(ORT_LOGGING_LEVEL_WARNING, "GlobalEnv")
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
std::string filename()