Project
Loading...
Searching...
No Matches
OrtInterface.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15
16#include "ML/OrtInterface.h"
18
19// ONNX includes
20#include <onnxruntime_cxx_api.h>
21
22namespace o2
23{
24
25namespace ml
26{
27
28struct OrtModel::OrtVariables { // The actual implementation is hidden in the .cxx file
29 // ORT runtime objects
30 Ort::RunOptions runOptions;
31 std::shared_ptr<Ort::Env> env = nullptr;
32 std::shared_ptr<Ort::Session> session = nullptr;
33 Ort::SessionOptions sessionOptions;
34 Ort::AllocatorWithDefaultOptions allocator;
35 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo("Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
36};
37
38void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
39{
40
41 pImplOrt = new OrtVariables();
42
43 // Load from options map
44 if (!optionsMap.contains("model-path")) {
45 LOG(fatal) << "(ORT) Model path cannot be empty!";
46 }
47 modelPath = optionsMap["model-path"];
48 device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
49 dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float");
50 deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
51 allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
52 intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
53 loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
54 enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
55 enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
56
57 std::string dev_mem_str = "Hip";
58#if defined(ORT_ROCM_BUILD)
59#if ORT_ROCM_BUILD == 1
60 if (device == "ROCM") {
61 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
62 LOG(info) << "(ORT) ROCM execution provider set";
63 }
64#endif
65#endif
66#if defined(ORT_MIGRAPHX_BUILD)
67#if ORT_MIGRAPHX_BUILD == 1
68 if (device == "MIGRAPHX") {
69 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
70 LOG(info) << "(ORT) MIGraphX execution provider set";
71 }
72#endif
73#endif
74#if defined(ORT_CUDA_BUILD)
75#if ORT_CUDA_BUILD == 1
76 if (device == "CUDA") {
77 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
78 LOG(info) << "(ORT) CUDA execution provider set";
79 dev_mem_str = "Cuda";
80 }
81#endif
82#endif
83
84 if (allocateDeviceMemory) {
85 pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
86 LOG(info) << "(ORT) Memory info set to on-device memory";
87 }
88
89 if (device == "CPU") {
90 (pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
91 if (intraOpNumThreads > 1) {
92 (pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
93 } else if (intraOpNumThreads == 1) {
94 (pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
95 }
96 LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " threads";
97 }
98
99 (pImplOrt->sessionOptions).DisableMemPattern();
100 (pImplOrt->sessionOptions).DisableCpuMemArena();
101
102 if (enableProfiling) {
103 if (optionsMap.contains("profiling-output-path")) {
104 (pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
105 } else {
106 LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
107 (pImplOrt->sessionOptions).DisableProfiling();
108 }
109 } else {
110 (pImplOrt->sessionOptions).DisableProfiling();
111 }
112 (pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
113 (pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
114
115 pImplOrt->env = std::make_shared<Ort::Env>(
116 OrtLoggingLevel(loggingLevel),
117 (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
118 // Integrate ORT logging into Fairlogger
119 [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
120 if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
121 LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
122 } else if (severity == ORT_LOGGING_LEVEL_INFO) {
123 LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
124 } else if (severity == ORT_LOGGING_LEVEL_WARNING) {
125 LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
126 } else if (severity == ORT_LOGGING_LEVEL_ERROR) {
127 LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
128 } else if (severity == ORT_LOGGING_LEVEL_FATAL) {
129 LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
130 } else {
131 LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
132 }
133 },
134 (void*)3);
135 (pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
136 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
137
138 for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
139 mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
140 }
141 for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
142 mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
143 }
144 for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
145 mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get());
146 }
147 for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
148 mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
149 }
150
151 inputNamesChar.resize(mInputNames.size(), nullptr);
152 std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
153 [&](const std::string& str) { return str.c_str(); });
154 outputNamesChar.resize(mOutputNames.size(), nullptr);
155 std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
156 [&](const std::string& str) { return str.c_str(); });
157
158 // Print names
159 LOG(info) << "\tInput Nodes:";
160 for (size_t i = 0; i < mInputNames.size(); i++) {
161 LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
162 }
163
164 LOG(info) << "\tOutput Nodes:";
165 for (size_t i = 0; i < mOutputNames.size(); i++) {
166 LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
167 }
168}
169
171{
172 pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
173}
174
175template <class I, class O>
176std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput)
177{
178 if constexpr (std::is_same_v<I, O>) {
179 return input;
180 } else {
181 std::vector<O> output(input.size());
182 std::transform(std::begin(input), std::end(input), std::begin(output), [](I f) { return O(f); });
183 if (clearInput) {
184 input.clear();
185 }
186 return output;
187 }
188}
189
190template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
191std::vector<O> OrtModel::inference(std::vector<I>& input)
192{
193 std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
194 std::vector<Ort::Value> inputTensor;
195 inputTensor.emplace_back(Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, reinterpret_cast<O*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
196 // input.clear();
197 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
198 O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>());
199 std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
200 outputTensors.clear();
201 return outputValuesVec;
202}
203
204template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
205std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& input)
206{
207 std::vector<Ort::Value> inputTensor;
208 for (auto i : input) {
209 std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
210 inputTensor.emplace_back(Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo, reinterpret_cast<O*>(i.data()), i.size(), inputShape.data(), inputShape.size()));
211 }
212 // input.clear();
213 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
214 O* outputValues = reinterpret_cast<O*>(outputTensors[0].template GetTensorMutableData<O>());
215 std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]};
216 outputTensors.clear();
217 return outputValuesVec;
218}
219
220std::string OrtModel::printShape(const std::vector<int64_t>& v)
221{
222 std::stringstream ss("");
223 for (size_t i = 0; i < v.size() - 1; i++) {
224 ss << v[i] << "x";
225 }
226 ss << v[v.size() - 1];
227 return ss.str();
228}
229
230template <>
231std::vector<float> OrtModel::inference<float, float>(std::vector<float>& input)
232{
233 std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
234 std::vector<Ort::Value> inputTensor;
235 inputTensor.emplace_back(Ort::Value::CreateTensor<float>(pImplOrt->memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size()));
236 // input.clear();
237 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
238 float* outputValues = outputTensors[0].template GetTensorMutableData<float>();
239 std::vector<float> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
240 outputTensors.clear();
241 return outputValuesVec;
242}
243
244template <>
245std::vector<float> OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>& input)
246{
247 std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
248 std::vector<Ort::Value> inputTensor;
249 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
250 // input.clear();
251 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
252 float* outputValues = outputTensors[0].template GetTensorMutableData<float>();
253 std::vector<float> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
254 outputTensors.clear();
255 return outputValuesVec;
256}
257
258template <>
259std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>& input)
260{
261 std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
262 std::vector<Ort::Value> inputTensor;
263 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
264 // input.clear();
265 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
266 OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>());
267 std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
268 outputTensors.clear();
269 return outputValuesVec;
270}
271
272template <>
273std::vector<OrtDataType::Float16_t> OrtModel::inference<float, OrtDataType::Float16_t>(std::vector<float>& input)
274{
275 std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
276 std::vector<Ort::Value> inputTensor;
277 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
278 // input.clear();
279 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
280 OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>());
281 std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
282 outputTensors.clear();
283 return outputValuesVec;
284}
285
286template <>
287std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>& input)
288{
289 std::vector<Ort::Value> inputTensor;
290 for (auto i : input) {
291 std::vector<int64_t> inputShape{(int64_t)(i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
292 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(i.data()), i.size(), inputShape.data(), inputShape.size()));
293 }
294 // input.clear();
295 auto outputTensors = (pImplOrt->session)->Run(pImplOrt->runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
296 OrtDataType::Float16_t* outputValues = reinterpret_cast<OrtDataType::Float16_t*>(outputTensors[0].template GetTensorMutableData<Ort::Float16_t>());
297 std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]};
298 outputTensors.clear();
299 return outputValuesVec;
300}
301
302} // namespace ml
303
304} // namespace o2
int32_t i
void output(const std::map< std::string, ChannelStat > &channels)
Definition rawdump.cxx:197
A header library for loading ONNX models and inferencing them on CPU and GPU.
std::ostringstream debug
void reset(std::unordered_map< std::string, std::string >)
std::vector< O > v2v(std::vector< I > &, bool=true)
std::vector< O > inference(std::vector< I > &)
const GLdouble * v
Definition glcorearb.h:832
GLdouble f
Definition glcorearb.h:310
GLuint GLsizei const GLchar * message
Definition glcorearb.h:2517
GLenum GLfloat param
Definition glcorearb.h:271
GLenum GLenum severity
Definition glcorearb.h:2513
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
void empty(int)
Ort::AllocatorWithDefaultOptions allocator
std::shared_ptr< Ort::Env > env
std::shared_ptr< Ort::Session > session
ONNX session.
Ort::SessionOptions sessionOptions
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"
const std::string str