Project
Loading...
Searching...
No Matches
OrtInterface.h
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15
16#ifndef O2_ML_ORTINTERFACE_H
17#define O2_ML_ORTINTERFACE_H
18
19// C++ and system includes
20#include <vector>
21#include <string>
22#include <memory>
23#include <map>
24#include <thread>
25#include <unordered_map>
26
27// O2 includes
28#include "GPUCommonLogger.h"
29
30namespace Ort
31{
32struct SessionOptions;
33struct MemoryInfo;
34struct Env;
35} // namespace Ort
36
37namespace o2
38{
39
40namespace ml
41{
42
44{
45
46 public:
47 // Constructors & destructors
48 OrtModel() = default;
49 OrtModel(std::unordered_map<std::string, std::string> optionsMap) { init(optionsMap); }
50 void init(std::unordered_map<std::string, std::string> optionsMap)
51 {
52 initOptions(optionsMap);
54 }
55 virtual ~OrtModel() = default;
56
57 // General purpose
58 void initOptions(std::unordered_map<std::string, std::string> optionsMap);
59 void initEnvironment();
60 void initSession();
61 void memoryOnDevice(int32_t = 0);
62 bool isInitialized() { return mInitialized; }
63 void resetSession();
64
65 // Getters
66 std::vector<std::vector<int64_t>> getNumInputNodes() const { return mInputShapes; }
67 std::vector<std::vector<int64_t>> getNumOutputNodes() const { return mOutputShapes; }
68 std::vector<std::string> getInputNames() const { return mInputNames; }
69 std::vector<std::string> getOutputNames() const { return mOutputNames; }
70 Ort::SessionOptions* getSessionOptions();
71 Ort::MemoryInfo* getMemoryInfo();
72 Ort::Env* getEnv();
73 int32_t getIntraOpNumThreads() const { return mIntraOpNumThreads; }
74 int32_t getInterOpNumThreads() const { return mInterOpNumThreads; }
75
76 // Setters
77 void setDeviceId(int32_t id) { mDeviceId = id; }
78 void setIO();
79 void setActiveThreads(int threads) { mIntraOpNumThreads = threads; }
80 void setIntraOpNumThreads(int threads)
81 {
82 if (mDeviceType == "CPU") {
83 mIntraOpNumThreads = threads;
84 }
85 }
86 void setInterOpNumThreads(int threads)
87 {
88 if (mDeviceType == "CPU") {
89 mInterOpNumThreads = threads;
90 }
91 }
92 void setEnv(Ort::Env*);
93
94 // Conversion
95 template <class I, class O>
96 std::vector<O> v2v(std::vector<I>&, bool = true);
97
98 // Inferencing
99 template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
100 std::vector<O> inference(std::vector<I>&);
101
102 template <class I, class O>
103 std::vector<O> inference(std::vector<std::vector<I>>&);
104
105 template <class I, class O>
106 void inference(I*, int64_t, O*);
107
108 template <class I, class O>
109 void inference(I**, int64_t, O*);
110
111 void release(bool = false);
112
113 private:
114 // ORT variables -> need to be hidden as pImpl
115 struct OrtVariables;
116 OrtVariables* mPImplOrt;
117
118 // Input & Output specifications of the loaded network
119 std::vector<const char*> mInputNamesChar, mOutputNamesChar;
120 std::vector<std::string> mInputNames, mOutputNames;
121 std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes, mInputShapesCopy, mOutputShapesCopy; // Input shapes
122 std::vector<int64_t> mInputSizePerNode, mOutputSizePerNode; // Output shapes
123 int32_t mInputsTotal = 0, mOutputsTotal = 0; // Total number of inputs and outputs
124
125 // Environment settings
126 bool mInitialized = false;
127 std::string mModelPath, mEnvName = "", mDeviceType = "CPU", mThreadAffinity = ""; // device options should be cpu, rocm, migraphx, cuda
128 int32_t mIntraOpNumThreads = 1, mInterOpNumThreads = 1, mDeviceId = -1, mEnableProfiling = 0, mLoggingLevel = 0, mAllocateDeviceMemory = 0, mEnableOptimizations = 0;
129
130 std::string printShape(const std::vector<int64_t>&);
131 std::string printShape(const std::vector<std::vector<int64_t>>&, std::vector<std::string>&);
132};
133
134} // namespace ml
135
136} // namespace o2
137
138#endif // O2_ML_ORTINTERFACE_H
std::vector< std::vector< int64_t > > getNumInputNodes() const
void initOptions(std::unordered_map< std::string, std::string > optionsMap)
void memoryOnDevice(int32_t=0)
void setActiveThreads(int threads)
Ort::Env * getEnv()
OrtModel()=default
void release(bool=false)
void setEnv(Ort::Env *)
int32_t getInterOpNumThreads() const
std::vector< O > v2v(std::vector< I > &, bool=true)
void setInterOpNumThreads(int threads)
void setIntraOpNumThreads(int threads)
std::vector< std::string > getOutputNames() const
std::vector< std::vector< int64_t > > getNumOutputNodes() const
Ort::MemoryInfo * getMemoryInfo()
void setDeviceId(int32_t id)
std::vector< std::string > getInputNames() const
std::vector< O > inference(std::vector< I > &)
int32_t getIntraOpNumThreads() const
void init(std::unordered_map< std::string, std::string > optionsMap)
Ort::SessionOptions * getSessionOptions()
OrtModel(std::unordered_map< std::string, std::string > optionsMap)
virtual ~OrtModel()=default
GLuint id
Definition glcorearb.h:650
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...