Project
Loading...
Searching...
No Matches
FastSimulations.h
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
16
17#ifndef O2_ZDC_FAST_SIMULATIONS_H
18#define O2_ZDC_FAST_SIMULATIONS_H
19
20#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
21#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
22#else
23#include <onnxruntime_cxx_api.h>
24#endif
25#include <optional>
26#include <mutex>
27
28namespace o2::zdc::fastsim
29{
35{
36 public:
37 NeuralFastSimulation(const std::string& modelPath,
38 OrtAllocatorType allocatorType,
39 OrtMemType memoryType,
40 int64_t batchSize);
41 virtual ~NeuralFastSimulation() = default;
42
47 void initRunSession();
48
56 virtual bool setInput(std::vector<std::vector<float>>& input) = 0;
61 virtual void run() = 0;
62
64 virtual const std::vector<Ort::Value>& getResult() = 0;
65
66 [[nodiscard]] size_t getBatchSize() const;
67
68 protected:
70 void setInputOutputData();
72 void setTensors(std::vector<std::vector<float>>& input);
73
75 std::string mModelPath;
76
79 Ort::Env mEnv;
80 Ort::Session* mSession = nullptr; // a pointer so that we can set it up dynamically and independently of constructor
81 Ort::AllocatorWithDefaultOptions mAllocator;
82 Ort::MemoryInfo mMemoryInfo;
83
85 std::vector<std::string> mInputNames;
86 std::vector<std::string> mOutputNames;
87 std::vector<std::vector<int64_t>> mInputShapes;
90 int64_t mBatchSize;
91
93 std::vector<Ort::Value> mInputTensors;
94};
95
101{
102 public:
103 ConditionalModelSimulation(const std::string& modelPath, int64_t batchSize);
104 ~ConditionalModelSimulation() override = default;
105
113 bool setInput(std::vector<std::vector<float>>& input) override;
118 void run() override;
125 const std::vector<Ort::Value>& getResult() override;
126
127 private:
128 std::vector<Ort::Value> mModelOutput;
129};
130
136{
137 public:
138 static BatchHandler& getInstance(size_t batchSize);
139 std::optional<std::vector<std::vector<float>>> getBatch(const std::vector<float>& input);
140
141 BatchHandler(const BatchHandler&) = delete;
143
144 private:
145 explicit BatchHandler(size_t batchSize);
146 ~BatchHandler() = default;
147
148 std::mutex mMutex;
149 std::vector<std::vector<float>> mBatch;
150 size_t mBatchSize;
151};
152
159std::optional<std::pair<std::vector<float>, std::vector<float>>> loadScales(const std::string& path);
160
161} // namespace o2::zdc::fastsim
162#endif // O2_ZDC_FAST_SIMULATIONS_H
Meyers Singleton thread safe singleton. Responsible for collecting particle data for batch processing...
BatchHandler & operator=(const BatchHandler &)=delete
std::optional< std::vector< std::vector< float > > > getBatch(const std::vector< float > &input)
static BatchHandler & getInstance(size_t batchSize)
BatchHandler(const BatchHandler &)=delete
Derived class implementing interface for specific types of models.
bool setInput(std::vector< std::vector< float > > &input) override
Implements setInput.
const std::vector< Ort::Value > & getResult() override
Returns single model output as const&. Returned vector is of size 1.
Abstract class providing interface for various specialized implementations.
std::vector< Ort::Value > mInputTensors
Container for input tensors.
std::vector< std::string > mOutputNames
std::string mModelPath
model path (where to find the ONNX model)
std::vector< std::vector< int64_t > > mInputShapes
std::vector< std::string > mInputNames
Input/Output names and input shape.
void setInputOutputData()
Sets models metadata (input/output layers names, inputs shape) in onnx session.
void setTensors(std::vector< std::vector< float > > &input)
Converts flattend input data to Ort::Value. Tensor shapes are taken from loaded model metadata.
Ort::AllocatorWithDefaultOptions mAllocator
virtual bool setInput(std::vector< std::vector< float > > &input)=0
Wrapper for converting raw input to Ort::Value.
virtual void run()=0
Wraps Session.Run() Result should be stored as private member.
void initRunSession()
(late) init session and provide mechanism to customize ONNX session with external options
virtual const std::vector< Ort::Value > & getResult()=0
returns model output as const &.
GLsizei const GLchar *const * path
Definition glcorearb.h:3591
std::optional< std::pair< std::vector< float >, std::vector< float > > > loadScales(const std::string &path)
loads and parse model scales from file at path