Project
Loading...
Searching...
No Matches
OrtInterface.h
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15
16#ifndef O2_ML_ORTINTERFACE_H
17#define O2_ML_ORTINTERFACE_H
18
19// C++ and system includes
20#include <vector>
21#include <string>
22#include <memory>
23#include <map>
24#include <thread>
25
26// O2 includes
27#include "Framework/Logger.h"
28
29namespace o2
30{
31
32namespace ml
33{
34
36{
37
38 public:
39 // Constructor
40 OrtModel() = default;
41 OrtModel(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
42 void init(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
43 void reset(std::unordered_map<std::string, std::string>);
44 bool isInitialized() { return mInitialized; }
45
46 virtual ~OrtModel() = default;
47
48 // Conversion
49 template <class I, class O>
50 std::vector<O> v2v(std::vector<I>&, bool = true);
51
52 // Inferencing
53 template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
54 std::vector<O> inference(std::vector<I>&);
55
56 template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
57 std::vector<O> inference(std::vector<std::vector<I>>&);
58
59 template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
60 void inference(I*, size_t, O*);
61
62 // template<class I, class T, class O> // class I is the input data type, e.g. float, class T the throughput data type and class O is the output data type
63 // std::vector<O> inference(std::vector<I>&);
64
65 // Reset session
66 void resetSession();
67
68 std::vector<std::vector<int64_t>> getNumInputNodes() const { return mInputShapes; }
69 std::vector<std::vector<int64_t>> getNumOutputNodes() const { return mOutputShapes; }
70 std::vector<std::string> getInputNames() const { return mInputNames; }
71 std::vector<std::string> getOutputNames() const { return mOutputNames; }
72
73 void setActiveThreads(int threads) { intraOpNumThreads = threads; }
74
75 private:
76 // ORT variables -> need to be hidden as Pimpl
77 struct OrtVariables;
78 OrtVariables* pImplOrt;
79
80 // Input & Output specifications of the loaded network
81 std::vector<const char*> inputNamesChar, outputNamesChar;
82 std::vector<std::string> mInputNames, mOutputNames;
83 std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes;
84
85 // Environment settings
86 bool mInitialized = false;
87 std::string modelPath, device = "cpu", dtype = "float", thread_affinity = ""; // device options should be cpu, rocm, migraphx, cuda
88 int intraOpNumThreads = 1, interOpNumThreads = 1, deviceId = 0, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;
89
90 std::string printShape(const std::vector<int64_t>&);
91};
92
93} // namespace ml
94
95} // namespace o2
96
97#endif // O2_ML_ORTINTERFACE_H
std::vector< std::vector< int64_t > > getNumInputNodes() const
void setActiveThreads(int threads)
void reset(std::unordered_map< std::string, std::string >)
OrtModel()=default
std::vector< O > v2v(std::vector< I > &, bool=true)
std::vector< std::string > getOutputNames() const
std::vector< std::vector< int64_t > > getNumOutputNodes() const
std::vector< std::string > getInputNames() const
std::vector< O > inference(std::vector< I > &)
void init(std::unordered_map< std::string, std::string > optionsMap)
OrtModel(std::unordered_map< std::string, std::string > optionsMap)
virtual ~OrtModel()=default
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...