17#ifndef O2_ZDC_FAST_SIMULATIONS_H
18#define O2_ZDC_FAST_SIMULATIONS_H
20#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
21#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
23#include <onnxruntime_cxx_api.h>
38 OrtAllocatorType allocatorType,
39 OrtMemType memoryType,
56 virtual bool setInput(std::vector<std::vector<float>>& input) = 0;
61 virtual void run() = 0;
64 virtual const std::vector<Ort::Value>&
getResult() = 0;
72 void setTensors(std::vector<std::vector<float>>& input);
113 bool setInput(std::vector<std::vector<float>>& input)
override;
125 const std::vector<Ort::Value>&
getResult()
override;
128 std::vector<Ort::Value> mModelOutput;
139 std::optional<std::vector<std::vector<float>>>
getBatch(
const std::vector<float>& input);
149 std::vector<std::vector<float>> mBatch;
159std::optional<std::pair<std::vector<float>, std::vector<float>>>
loadScales(
const std::string&
path);
Meyers Singleton thread safe singleton. Responsible for collecting particle data for batch processing...
BatchHandler & operator=(const BatchHandler &)=delete
std::optional< std::vector< std::vector< float > > > getBatch(const std::vector< float > &input)
static BatchHandler & getInstance(size_t batchSize)
BatchHandler(const BatchHandler &)=delete
Derived class implementing interface for specific types of models.
bool setInput(std::vector< std::vector< float > > &input) override
Implements setInput.
const std::vector< Ort::Value > & getResult() override
Returns single model output as const&. Returned vector is of size 1.
void run() override
Implements run().
~ConditionalModelSimulation() override=default
Abstract class providing interface for various specialized implementations.
std::vector< Ort::Value > mInputTensors
Container for input tensors.
std::vector< std::string > mOutputNames
std::string mModelPath
model path (where to find the ONNX model)
std::vector< std::vector< int64_t > > mInputShapes
virtual ~NeuralFastSimulation()=default
std::vector< std::string > mInputNames
Input/Output names and input shape.
void setInputOutputData()
Sets models metadata (input/output layers names, inputs shape) in onnx session.
Ort::MemoryInfo mMemoryInfo
void setTensors(std::vector< std::vector< float > > &input)
Converts flattend input data to Ort::Value. Tensor shapes are taken from loaded model metadata.
Ort::AllocatorWithDefaultOptions mAllocator
virtual bool setInput(std::vector< std::vector< float > > &input)=0
Wrapper for converting raw input to Ort::Value.
virtual void run()=0
Wraps Session.Run() Result should be stored as private member.
void initRunSession()
(late) init session and provide mechanism to customize ONNX session with external options
virtual const std::vector< Ort::Value > & getResult()=0
returns model output as const &.
size_t getBatchSize() const
GLsizei const GLchar *const * path
std::optional< std::pair< std::vector< float >, std::vector< float > > > loadScales(const std::string &path)
loads and parse model scales from file at path