Project
Loading...
Searching...
No Matches
OrtInterface.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15
16#include "ML/OrtInterface.h"
18
19// ONNX includes
20#include <onnxruntime_cxx_api.h>
21
22namespace o2
23{
24
25namespace ml
26{
27
28struct OrtModel::OrtVariables { // The actual implementation is hidden in the .cxx file
29 // ORT runtime objects
30 Ort::RunOptions runOptions;
31 std::shared_ptr<Ort::Env> env = nullptr;
32 std::shared_ptr<Ort::Session> session = nullptr;
33 Ort::SessionOptions sessionOptions;
34 Ort::AllocatorWithDefaultOptions allocator;
35 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo("Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
36};
37
38void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
39{
40
41 pImplOrt = new OrtVariables();
42
43 // Load from options map
44 if (!optionsMap.contains("model-path")) {
45 LOG(fatal) << "(ORT) Model path cannot be empty!";
46 }
47
48 if (!optionsMap["model-path"].empty()) {
49 modelPath = optionsMap["model-path"];
50 device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
51 dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float");
52 deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
53 allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
54 intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
55 interOpNumThreads = (optionsMap.contains("inter-op-num-threads") ? std::stoi(optionsMap["inter-op-num-threads"]) : 0);
56 loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
57 enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
58 enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
59
60 std::string dev_mem_str = "Hip";
61#if defined(ORT_ROCM_BUILD)
62#if ORT_ROCM_BUILD == 1
63 if (device == "ROCM") {
64 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
65 LOG(info) << "(ORT) ROCM execution provider set";
66 }
67#endif
68#endif
69#if defined(ORT_MIGRAPHX_BUILD)
70#if ORT_MIGRAPHX_BUILD == 1
71 if (device == "MIGRAPHX") {
72 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
73 LOG(info) << "(ORT) MIGraphX execution provider set";
74 }
75#endif
76#endif
77#if defined(ORT_CUDA_BUILD)
78#if ORT_CUDA_BUILD == 1
79 if (device == "CUDA") {
80 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
81 LOG(info) << "(ORT) CUDA execution provider set";
82 dev_mem_str = "Cuda";
83 }
84#endif
85#endif
86
87 if (allocateDeviceMemory) {
88 pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
89 LOG(info) << "(ORT) Memory info set to on-device memory";
90 }
91
92 if (device == "CPU") {
93 (pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
94 (pImplOrt->sessionOptions).SetInterOpNumThreads(interOpNumThreads);
95 if (intraOpNumThreads > 1 || interOpNumThreads > 1) {
96 (pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
97 } else if (intraOpNumThreads == 1) {
98 (pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
99 }
100 if (loggingLevel < 2) {
101 LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " (intraOpNumThreads) and " << interOpNumThreads << " (interOpNumThreads) threads";
102 }
103 }
104
105 (pImplOrt->sessionOptions).DisableMemPattern();
106 (pImplOrt->sessionOptions).DisableCpuMemArena();
107
108 if (enableProfiling) {
109 if (optionsMap.contains("profiling-output-path")) {
110 (pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
111 } else {
112 LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
113 (pImplOrt->sessionOptions).DisableProfiling();
114 }
115 } else {
116 (pImplOrt->sessionOptions).DisableProfiling();
117 }
118
119 mInitialized = true;
120
121 (pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
122 (pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
123
124 pImplOrt->env = std::make_shared<Ort::Env>(
125 OrtLoggingLevel(loggingLevel),
126 (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
127 // Integrate ORT logging into Fairlogger
128 [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
129 if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
130 LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
131 } else if (severity == ORT_LOGGING_LEVEL_INFO) {
132 LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
133 } else if (severity == ORT_LOGGING_LEVEL_WARNING) {
134 LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
135 } else if (severity == ORT_LOGGING_LEVEL_ERROR) {
136 LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
137 } else if (severity == ORT_LOGGING_LEVEL_FATAL) {
138 LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
139 } else {
140 LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
141 }
142 },
143 (void*)3);
144 (pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
145 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
146
147 for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
148 mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
149 }
150 for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
151 mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
152 }
153 for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
154 mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get());
155 }
156 for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
157 mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
158 }
159
160 inputNamesChar.resize(mInputNames.size(), nullptr);
161 std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
162 [&](const std::string& str) { return str.c_str(); });
163 outputNamesChar.resize(mOutputNames.size(), nullptr);
164 std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
165 [&](const std::string& str) { return str.c_str(); });
166 }
167 if (loggingLevel < 2) {
168 LOG(info) << "(ORT) Model loaded successfully! (input: " << printShape(mInputShapes[0]) << ", output: " << printShape(mOutputShapes[0]) << ")";
169 }
170}
171
173{
174 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
175}
176
177template <class I, class O>
178std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput)
179{
180 if constexpr (std::is_same_v<I, O>) {
181 return input;
182 } else {
183 std::vector<O> output(input.size());
184 std::transform(std::begin(input), std::end(input), std::begin(output), [](I f) { return O(f); });
185 if (clearInput) {
186 input.clear();
187 }
188 return output;
189 }
190}
191
192std::string OrtModel::printShape(const std::vector<int64_t>& v)
193{
194 std::stringstream ss("");
195 for (size_t i = 0; i < v.size() - 1; i++) {
196 ss << v[i] << "x";
197 }
198 ss << v[v.size() - 1];
199 return ss.str();
200}
201
202template <class I, class O>
203std::vector<O> OrtModel::inference(std::vector<I>& input)
204{
205 std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
206 std::vector<Ort::Value> inputTensor;
207 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
208 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
209 } else {
210 inputTensor.emplace_back(Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size()));
211 }
212 // input.clear();
213 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
214 O* outputValues = outputTensors[0].template GetTensorMutableData<O>();
215 std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
216 outputTensors.clear();
217 return outputValuesVec;
218}
219
220template std::vector<float> OrtModel::inference<float, float>(std::vector<float>&);
221
222template std::vector<float> OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
223
224template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
225
226template <class I, class O>
227void OrtModel::inference(I* input, size_t input_size, O* output)
228{
229 std::vector<int64_t> inputShape{(int64_t)(input_size / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
230 Ort::Value inputTensor = Ort::Value(nullptr);
231 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
232 inputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input), input_size, inputShape.data(), inputShape.size());
233 } else {
234 inputTensor = Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, input, input_size, inputShape.data(), inputShape.size());
235 }
236
237 std::vector<int64_t> outputShape{inputShape[0], mOutputShapes[0][1]};
238 size_t outputSize = (int64_t)(input_size * mOutputShapes[0][1] / mInputShapes[0][1]);
239 Ort::Value outputTensor = Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, output, outputSize, outputShape.data(), outputShape.size());
240
241 (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size()); // TODO: Not sure if 1 is correct here
242}
243
244template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, size_t, float*);
245
246template void OrtModel::inference<float, float>(float*, size_t, float*);
247
248template <class I, class O>
249std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input)
250{
251 std::vector<Ort::Value> inputTensor;
252 for (auto i : input) {
253 std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
254 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
255 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(i.data()), i.size(), inputShape.data(), inputShape.size()));
256 } else {
257 inputTensor.emplace_back(Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo, i.data(), i.size(), inputShape.data(), inputShape.size()));
258 }
259 }
260 // input.clear();
261 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
262 O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>());
263 std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]};
264 outputTensors.clear();
265 return outputValuesVec;
266}
267
268} // namespace ml
269
270} // namespace o2
int32_t i
void output(const std::map< std::string, ChannelStat > &channels)
Definition rawdump.cxx:197
A header library for loading ONNX models and inferencing them on CPU and GPU.
std::ostringstream debug
void reset(std::unordered_map< std::string, std::string >)
std::vector< O > v2v(std::vector< I > &, bool=true)
std::vector< O > inference(std::vector< I > &)
const GLdouble * v
Definition glcorearb.h:832
GLdouble f
Definition glcorearb.h:310
GLuint GLsizei const GLchar * message
Definition glcorearb.h:2517
GLenum GLfloat param
Definition glcorearb.h:271
GLenum GLenum severity
Definition glcorearb.h:2513
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
void empty(int)
Ort::AllocatorWithDefaultOptions allocator
std::shared_ptr< Ort::Env > env
std::shared_ptr< Ort::Session > session
ONNX session.
Ort::SessionOptions sessionOptions
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"
const std::string str