Project
Loading...
Searching...
No Matches
TPCLoopers.h
Go to the documentation of this file.
1// Copyright 2024-2025 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
13
14#ifndef ALICEO2_EVENTGEN_TPCLOOPERS_H_
15#define ALICEO2_EVENTGEN_TPCLOOPERS_H_
16
17#ifdef GENERATORS_WITH_TPCLOOPERS
18#include <onnxruntime_cxx_api.h>
19#endif
20#include <vector>
21#include <rapidjson/document.h>
22#include "TRandom3.h"
24#include "TParticle.h"
25
26#ifdef GENERATORS_WITH_TPCLOOPERS
27// Static Ort::Env instance for multiple onnx model loading
28extern Ort::Env global_env;
29
30// This class is responsible for loading the scaler parameters from a JSON file
31// and applying the inverse transformation to the generated data.
32// Inferenced output is scaled (min-max normalization or robust scaling for outlier features) during training,
33// so we need to revert this transformation to get physical values.
34struct Scaler {
35 std::vector<double> normal_min;
36 std::vector<double> normal_max;
37 std::vector<double> outlier_center;
38 std::vector<double> outlier_scale;
39
40 void load(const std::string& filename);
41
42 std::vector<double> inverse_transform(const std::vector<double>& input);
43
44 private:
45 std::vector<double> jsonArrayToVector(const rapidjson::Value& jsonArray);
46};
47
48// This class loads the ONNX model and generates samples using it.
49class ONNXGenerator
50{
51 public:
52 ONNXGenerator(Ort::Env& shared_env, const std::string& model_path);
53
54 std::vector<double> generate_sample();
55
56 private:
57 Ort::Env& env;
58 Ort::Session session;
59 TRandom3 rand_gen;
60};
61#endif // GENERATORS_WITH_TPCLOOPERS
62
63namespace o2
64{
65namespace eventgen
66{
67
68#ifdef GENERATORS_WITH_TPCLOOPERS
83class GenTPCLoopers
84{
85 public:
86 GenTPCLoopers(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
87 std::string poisson = "poisson.csv", std::string gauss = "gauss.csv", std::string scaler_pair = "scaler_pair.json",
88 std::string scaler_compton = "scaler_compton.json");
89
90 Bool_t generateEvent();
91
92 Bool_t generateEvent(double time_limit);
93
94 std::vector<TParticle> importParticles();
95
96 unsigned int PoissonPairs();
97
98 unsigned int GaussianElectrons();
99
100 void SetNLoopers(unsigned int nsig_pair, unsigned int nsig_compton);
101
102 void SetMultiplier(const std::array<float, 2>& mult);
103
104 void setFlatGas(Bool_t flat, Int_t number = -1, Int_t nloopers_orbit = -1);
105
106 void setFractionPairs(float fractionPairs);
107
108 void SetRate(const std::string& rateFile, bool isPbPb, int intRate = 50000);
109
110 void SetAdjust(float adjust = 0.f);
111
112 unsigned int getNLoopers() const { return (mNLoopersPairs + mNLoopersCompton); }
113
114 private:
115 std::unique_ptr<ONNXGenerator> mONNX_pair = nullptr;
116 std::unique_ptr<ONNXGenerator> mONNX_compton = nullptr;
117 std::unique_ptr<Scaler> mScaler_pair = nullptr;
118 std::unique_ptr<Scaler> mScaler_compton = nullptr;
119 double mPoisson[3] = {0.0, 0.0, 0.0}; // Mu, Min and Max of Poissonian
120 double mGauss[4] = {0.0, 0.0, 0.0, 0.0}; // Mean, Std, Min, Max
121 std::vector<std::vector<double>> mGenPairs;
122 std::vector<std::vector<double>> mGenElectrons;
123 unsigned int mNLoopersPairs = -1;
124 unsigned int mNLoopersCompton = -1;
125 std::array<float, 2> mMultiplier = {1., 1.};
126 bool mPoissonSet = false;
127 bool mGaussSet = false;
128 // Random number generator
129 TRandom3 mRandGen;
130 int mCurrentEvent = 0; // Current event number, used for adaptive loopers
131 TFile* mContextFile = nullptr; // Input collision context file
132 o2::steer::DigitizationContext* mCollisionContext = nullptr; // Pointer to the digitization context
133 std::vector<o2::InteractionTimeRecord> mInteractionTimeRecords; // Interaction time records from collision context
134 Bool_t mFlatGas = false; // Flag to indicate if flat gas loopers are used
135 Bool_t mFlatGasOrbit = false; // Flag to indicate if flat gas loopers are per orbit
136 Int_t mFlatGasNumber = -1; // Number of flat gas loopers per event
137 double mIntTimeRecMean = 1.0; // Average interaction time record used for the reference
138 double mTimeLimit = 0.0; // Time limit for the current event
139 double mTimeEnd = 0.0; // Time limit for the last event
140 float mLoopsFractionPairs = 0.08; // Fraction of loopers from Pairs
141 int mInteractionRate = 50000; // Interaction rate in Hz
142};
143#endif // GENERATORS_WITH_TPCLOOPERS
144
145} // namespace eventgen
146} // namespace o2
147
148#endif // ALICEO2_EVENTGEN_TPCLOOPERS_H_
Ort::Env global_env(ORT_LOGGING_LEVEL_WARNING, "GlobalEnv")
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
std::string filename()