20#include <onnxruntime_cxx_api.h>
42 std::unique_ptr<Ort::Env>
env =
nullptr;
43 std::unique_ptr<Ort::Session>
session =
nullptr;
46 Ort::MemoryInfo
memoryInfo = Ort::MemoryInfo(
"Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
47 std::unique_ptr<Ort::IoBinding>
ioBinding =
nullptr;
53 mPImplOrt = std::make_unique<OrtVariables>();
56 if (!optionsMap.contains(
"model-path")) {
57 LOG(fatal) <<
"(ORT) Model path must be contained in options map!";
60 if (!optionsMap[
"model-path"].
empty()) {
61 mModelPath = optionsMap[
"model-path"];
62 mDeviceType = (optionsMap.contains(
"device-type") ? optionsMap[
"device-type"] :
"CPU");
63 mDeviceId = (optionsMap.contains(
"device-id") ? std::stoi(optionsMap[
"device-id"]) : -1);
64 mAllocateDeviceMemory = (optionsMap.contains(
"allocate-device-memory") ? std::stoi(optionsMap[
"allocate-device-memory"]) : 0);
65 mIntraOpNumThreads = (optionsMap.contains(
"intra-op-num-threads") ? std::stoi(optionsMap[
"intra-op-num-threads"]) : 0);
66 mInterOpNumThreads = (optionsMap.contains(
"inter-op-num-threads") ? std::stoi(optionsMap[
"inter-op-num-threads"]) : 0);
67 mLoggingLevel = (optionsMap.contains(
"logging-level") ? std::stoi(optionsMap[
"logging-level"]) : 0);
68 mEnableProfiling = (optionsMap.contains(
"enable-profiling") ? std::stoi(optionsMap[
"enable-profiling"]) : 0);
69 mEnableOptimizations = (optionsMap.contains(
"enable-optimizations") ? std::stoi(optionsMap[
"enable-optimizations"]) : 0);
70 mEnvName = (optionsMap.contains(
"onnx-environment-name") ? optionsMap[
"onnx-environment-name"] :
"onnx_model_inference");
71 mDeterministicMode = (optionsMap.contains(
"deterministic-compute") ? std::stoi(optionsMap[
"deterministic-compute"]) : 0);
73 if (mDeviceType ==
"CPU") {
74 (mPImplOrt->sessionOptions).SetIntraOpNumThreads(mIntraOpNumThreads);
75 (mPImplOrt->sessionOptions).SetInterOpNumThreads(mInterOpNumThreads);
76 if (mIntraOpNumThreads > 1 || mInterOpNumThreads > 1) {
77 (mPImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
78 }
else if (mIntraOpNumThreads == 1) {
79 (mPImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
81 if (mLoggingLevel < 2) {
82 LOG(info) <<
"(ORT) CPU execution provider set with " << mIntraOpNumThreads <<
" (mIntraOpNumThreads) and " << mInterOpNumThreads <<
" (mInterOpNumThreads) threads";
89 (mPImplOrt->sessionOptions).DisableMemPattern();
90 (mPImplOrt->sessionOptions).DisableCpuMemArena();
92 if (mEnableProfiling) {
93 if (optionsMap.contains(
"profiling-output-path")) {
94 (mPImplOrt->sessionOptions).EnableProfiling((optionsMap[
"profiling-output-path"] +
"/ORT_LOG_").c_str());
96 LOG(warning) <<
"(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
97 (mPImplOrt->sessionOptions).DisableProfiling();
100 (mPImplOrt->sessionOptions).DisableProfiling();
103 if (mDeterministicMode > 0) {
104 (mPImplOrt->sessionOptions).AddConfigEntry(
"session_options.use_deterministic_compute",
"1");
107 (mPImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(mEnableOptimizations));
108 (mPImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(mLoggingLevel));
112 LOG(fatal) <<
"(ORT) Model path cannot be empty!";
118 mPImplOrt->env = std::make_unique<Ort::Env>(
119 OrtLoggingLevel(mLoggingLevel),
120 (mEnvName.empty() ?
"ORT" : mEnvName.c_str()),
122 [](
void*
param, OrtLoggingLevel
severity,
const char* category,
const char* logid,
const char* code_location,
const char*
message) {
123 if (
severity == ORT_LOGGING_LEVEL_VERBOSE) {
124 LOG(
debug) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
125 }
else if (
severity == ORT_LOGGING_LEVEL_INFO) {
126 LOG(info) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
127 }
else if (
severity == ORT_LOGGING_LEVEL_WARNING) {
128 LOG(warning) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
129 }
else if (
severity == ORT_LOGGING_LEVEL_ERROR) {
130 LOG(error) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
131 }
else if (
severity == ORT_LOGGING_LEVEL_FATAL) {
132 LOG(fatal) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
134 LOG(info) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
138 (mPImplOrt->env)->DisableTelemetryEvents();
143 mPImplOrt->sessionOptions.AddConfigEntry(
"session.load_model_format",
"ONNX");
144 mPImplOrt->sessionOptions.AddConfigEntry(
"session.use_ort_model_bytes_directly",
"1");
146 mPImplOrt->session = std::make_unique<Ort::Session>(*mPImplOrt->env,
149 mPImplOrt->sessionOptions);
150 mPImplOrt->ioBinding = std::make_unique<Ort::IoBinding>(*mPImplOrt->session);
154 if (mLoggingLevel < 2) {
155 LOG(info) <<
"(ORT) Model loaded successfully from buffer! (inputs: " << printShape(mInputShapes, mInputNames) <<
", outputs: " << printShape(mOutputShapes, mInputNames) <<
")";
161 if (mAllocateDeviceMemory) {
164 mPImplOrt->session = std::make_unique<Ort::Session>(*mPImplOrt->env, mModelPath.c_str(), mPImplOrt->sessionOptions);
165 mPImplOrt->ioBinding = std::make_unique<Ort::IoBinding>(*mPImplOrt->session);
169 if (mLoggingLevel < 2) {
170 LOG(info) <<
"(ORT) Model loaded successfully! (inputs: " << printShape(mInputShapes, mInputNames) <<
", outputs: " << printShape(mOutputShapes, mInputNames) <<
")";
176 if (deviceIndex >= 0) {
177 (mPImplOrt->runOptions).AddConfigEntry(
"disable_synchronize_execution_providers",
"1");
178 (mPImplOrt->sessionOptions).AddConfigEntry(
"session.use_device_allocator_for_initializers",
"1");
179 (mPImplOrt->sessionOptions).AddConfigEntry(
"session.use_env_allocators",
"1");
180 (mPImplOrt->sessionOptions).AddConfigEntry(
"session_options.enable_cpu_mem_arena",
"0");
183 (mPImplOrt->runOptions).AddConfigEntry(
"memory.enable_memory_arena_shrinkage", (
"gpu:" +
std::to_string(deviceIndex)).c_str());
185 std::string dev_mem_str =
"";
186 if (mDeviceType ==
"ROCM") {
187 dev_mem_str =
"HipPinned";
189 if (mDeviceType ==
"CUDA") {
190 dev_mem_str =
"Cuda";
192 mPImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
193 if (mLoggingLevel < 2) {
194 LOG(info) <<
"(ORT) Memory info set to on-device memory for device type " << mDeviceType <<
" with ID " << deviceIndex <<
" and mPImplOrt pointer " << mPImplOrt;
201 mPImplOrt->session = std::make_unique<Ort::Session>(*(mPImplOrt->env), mModelPath.c_str(), mPImplOrt->sessionOptions);
207 return &mPImplOrt->sessionOptions;
212 return &mPImplOrt->memoryInfo;
217 return (mPImplOrt->env).get();
220template <
class I,
class O>
223 if constexpr (std::is_same_v<I, O>) {
226 std::vector<O>
output(input.size());
227 std::transform(std::begin(input), std::end(input), std::begin(
output), [](I
f) {
return O(
f); });
237 for (
size_t i = 0;
i < (mPImplOrt->session)->GetInputCount(); ++
i) {
238 mInputNames.push_back((mPImplOrt->session)->GetInputNameAllocated(
i, mPImplOrt->allocator).get());
240 for (
size_t i = 0;
i < (mPImplOrt->session)->GetInputCount(); ++
i) {
241 mInputShapes.emplace_back((mPImplOrt->session)->GetInputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape());
243 for (
size_t i = 0;
i < (mPImplOrt->session)->GetOutputCount(); ++
i) {
244 mOutputNames.push_back((mPImplOrt->session)->GetOutputNameAllocated(
i, mPImplOrt->allocator).get());
246 for (
size_t i = 0;
i < (mPImplOrt->session)->GetOutputCount(); ++
i) {
247 mOutputShapes.emplace_back((mPImplOrt->session)->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape());
250 mInputNamesChar.resize(mInputNames.size(),
nullptr);
251 std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(mInputNamesChar),
252 [&](
const std::string&
str) {
return str.c_str(); });
253 mOutputNamesChar.resize(mOutputNames.size(),
nullptr);
254 std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(mOutputNamesChar),
255 [&](
const std::string&
str) {
return str.c_str(); });
257 mInputShapesCopy = mInputShapes;
258 mOutputShapesCopy = mOutputShapes;
259 mInputSizePerNode.resize(mInputShapes.size(), 1);
260 mOutputSizePerNode.resize(mOutputShapes.size(), 1);
262 for (
size_t i = 0;
i < mInputShapes.size(); ++
i) {
263 if (mInputShapes[
i].
size() > 0) {
264 for (
size_t j = 1;
j < mInputShapes[
i].size(); ++
j) {
265 if (mInputShapes[
i][
j] > 0) {
266 mInputsTotal *= mInputShapes[
i][
j];
267 mInputSizePerNode[
i] *= mInputShapes[
i][
j];
273 for (
size_t i = 0;
i < mOutputShapes.size(); ++
i) {
274 if (mOutputShapes[
i].
size() > 0) {
275 for (
size_t j = 1;
j < mOutputShapes[
i].size(); ++
j) {
276 if (mOutputShapes[
i][
j] > 0) {
277 mOutputsTotal *= mOutputShapes[
i][
j];
278 mOutputSizePerNode[
i] *= mOutputShapes[
i][
j];
287 mPImplOrt->env.reset(env);
291template <
class I,
class O>
294 std::vector<int64_t> inputShape = mInputShapes[0];
295 inputShape[0] = input.size();
296 for (
size_t i = 1;
i < mInputShapes[0].size(); ++
i) {
297 inputShape[0] /= mInputShapes[0][
i];
299 std::vector<Ort::Value> inputTensor;
300 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
301 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(mPImplOrt->memoryInfo,
reinterpret_cast<Ort::Float16_t*
>(input.data()), input.size(), inputShape.data(), inputShape.size()));
303 inputTensor.emplace_back(Ort::Value::CreateTensor<I>(mPImplOrt->memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size()));
306 auto outputTensors = (mPImplOrt->session)->Run(mPImplOrt->runOptions, mInputNamesChar.data(), inputTensor.data(), inputTensor.size(), mOutputNamesChar.data(), mOutputNamesChar.size());
307 O* outputValues = outputTensors[0].template GetTensorMutableData<O>();
308 std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
309 outputTensors.clear();
310 return outputValuesVec;
313template std::vector<float> o2::ml::OrtModel::inference<float, float>(std::vector<float>&);
314template std::vector<float> o2::ml::OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
315template std::vector<OrtDataType::Float16_t> o2::ml::OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
317template <
class I,
class O>
324 std::vector<int64_t> inputShape{input_size, (int64_t)mInputShapes[0][1]};
325 Ort::Value inputTensor = Ort::Value(
nullptr);
326 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
327 inputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(mPImplOrt->memoryInfo,
reinterpret_cast<Ort::Float16_t*
>(input), input_size * mInputShapes[0][1], inputShape.data(), inputShape.size());
329 inputTensor = Ort::Value::CreateTensor<I>(mPImplOrt->memoryInfo, input, input_size * mInputShapes[0][1], inputShape.data(), inputShape.size());
331 (mPImplOrt->ioBinding)->BindInput(mInputNames[0].c_str(), inputTensor);
333 std::vector<int64_t> outputShape{input_size, mOutputShapes[0][1]};
334 Ort::Value outputTensor = Ort::Value(
nullptr);
335 if constexpr (std::is_same_v<O, OrtDataType::Float16_t>) {
336 outputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(mPImplOrt->memoryInfo,
reinterpret_cast<Ort::Float16_t*
>(
output), input_size * mOutputShapes[0][1], outputShape.data(), outputShape.size());
338 outputTensor = Ort::Value::CreateTensor<O>(mPImplOrt->memoryInfo,
output, input_size * mOutputShapes[0][1], outputShape.data(), outputShape.size());
340 (mPImplOrt->ioBinding)->BindOutput(mOutputNames[0].c_str(), outputTensor);
342 (mPImplOrt->session)->Run(mPImplOrt->runOptions, *mPImplOrt->ioBinding);
353template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*);
354template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t,
float*);
355template void OrtModel::inference<float, OrtDataType::Float16_t>(
float*, int64_t, OrtDataType::Float16_t*);
356template void OrtModel::inference<float, float>(
float*, int64_t,
float*);
358template <
class I,
class O>
361 std::vector<Ort::Value> inputTensors(mInputShapesCopy.size());
363 for (
size_t i = 0;
i < mInputShapesCopy.size(); ++
i) {
365 mInputShapesCopy[
i][0] = input_size;
366 mOutputShapesCopy[
i][0] = input_size;
368 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
369 inputTensors[
i] = Ort::Value::CreateTensor<Ort::Float16_t>(
370 mPImplOrt->memoryInfo,
371 reinterpret_cast<Ort::Float16_t*
>(input[
i]),
372 mInputSizePerNode[
i] * input_size,
373 mInputShapesCopy[
i].data(),
374 mInputShapesCopy[
i].size());
376 inputTensors[
i] = Ort::Value::CreateTensor<I>(
377 mPImplOrt->memoryInfo,
379 mInputSizePerNode[
i] * input_size,
380 mInputShapesCopy[
i].data(),
381 mInputShapesCopy[
i].size());
385 Ort::Value outputTensor = Ort::Value(
nullptr);
386 if constexpr (std::is_same_v<O, OrtDataType::Float16_t>) {
387 outputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(
388 mPImplOrt->memoryInfo,
389 reinterpret_cast<Ort::Float16_t*
>(
output),
390 mOutputSizePerNode[0] * input_size,
391 mOutputShapesCopy[0].data(),
392 mOutputShapesCopy[0].size());
394 outputTensor = Ort::Value::CreateTensor<O>(
395 mPImplOrt->memoryInfo,
397 mOutputSizePerNode[0] * input_size,
398 mOutputShapesCopy[0].data(),
399 mOutputShapesCopy[0].size());
403 mPImplOrt->session->Run(
404 mPImplOrt->runOptions,
405 mInputNamesChar.data(),
407 mInputNamesChar.size(),
408 mOutputNamesChar.data(),
410 mOutputNamesChar.size());
413template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, int64_t, OrtDataType::Float16_t*);
414template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t,
float*);
415template void OrtModel::inference<float, OrtDataType::Float16_t>(
float**, int64_t, OrtDataType::Float16_t*);
416template void OrtModel::inference<float, float>(
float**, int64_t,
float*);
418template <
class I,
class O>
421 std::vector<Ort::Value> input_tensors;
423 for (
size_t i = 0;
i < inputs.size(); ++
i) {
425 mInputShapesCopy[
i][0] = inputs[
i].size() / mInputSizePerNode[
i];
427 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
428 input_tensors.emplace_back(
429 Ort::Value::CreateTensor<Ort::Float16_t>(
430 mPImplOrt->memoryInfo,
431 reinterpret_cast<Ort::Float16_t*
>(inputs[
i].data()),
432 mInputSizePerNode[
i] * mInputShapesCopy[
i][0],
433 mInputShapesCopy[
i].data(),
434 mInputShapesCopy[
i].size()));
436 input_tensors.emplace_back(
437 Ort::Value::CreateTensor<I>(
438 mPImplOrt->memoryInfo,
440 mInputSizePerNode[
i] * mInputShapesCopy[
i][0],
441 mInputShapesCopy[
i].data(),
442 mInputShapesCopy[
i].size()));
446 int32_t totalOutputSize = mOutputsTotal * mInputShapesCopy[0][0];
449 auto output_tensors = mPImplOrt->session->Run(
450 mPImplOrt->runOptions,
451 mInputNamesChar.data(),
452 input_tensors.data(),
453 input_tensors.size(),
454 mOutputNamesChar.data(),
455 mOutputNamesChar.size());
458 O* output_data = output_tensors[0].template GetTensorMutableData<O>();
459 std::vector<O> output_vec(output_data, output_data + totalOutputSize);
460 output_tensors.clear();
464template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&);
465template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&);
474std::string OrtModel::printShape(
const std::vector<int64_t>&
v)
476 std::stringstream ss(
"");
477 for (
size_t i = 0;
i <
v.size() - 1;
i++) {
480 ss <<
v[
v.size() - 1];
484std::string OrtModel::printShape(
const std::vector<std::vector<int64_t>>&
v, std::vector<std::string>&
n)
486 std::stringstream ss(
"");
487 for (
size_t i = 0;
i <
v.size();
i++) {
488 ss <<
n[
i] <<
" -> (";
489 for (
size_t j = 0;
j <
v[
i].size() - 1;
j++) {
490 ss <<
v[
i][
j] <<
"x";
492 ss <<
v[
i][
v[
i].size() - 1] <<
"); ";
A header library for loading ONNX models and inferencing them on CPU and GPU.
void initOptions(std::unordered_map< std::string, std::string > optionsMap)
void memoryOnDevice(int32_t=0)
std::vector< O > v2v(std::vector< I > &, bool=true)
void initSessionFromBuffer(const char *buffer, size_t bufferSize)
Ort::MemoryInfo * getMemoryInfo()
std::vector< O > inference(std::vector< I > &)
void init(std::unordered_map< std::string, std::string > optionsMap)
Ort::SessionOptions * getSessionOptions()
GLuint GLsizei const GLchar * message
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
std::string to_string(gsl::span< T, Size > span)
Ort::RunOptions runOptions
Ort::AllocatorWithDefaultOptions allocator
Ort::MemoryInfo memoryInfo
std::unique_ptr< Ort::Session > session
ONNX session.
Ort::SessionOptions sessionOptions
std::unique_ptr< Ort::IoBinding > ioBinding
std::unique_ptr< Ort::Env > env
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"