44 if (!optionsMap.contains(
"model-path")) {
45 LOG(fatal) <<
"(ORT) Model path cannot be empty!";
48 if (!optionsMap[
"model-path"].
empty()) {
49 modelPath = optionsMap[
"model-path"];
50 device = (optionsMap.contains(
"device") ? optionsMap[
"device"] :
"CPU");
51 dtype = (optionsMap.contains(
"dtype") ? optionsMap[
"dtype"] :
"float");
52 deviceId = (optionsMap.contains(
"device-id") ? std::stoi(optionsMap[
"device-id"]) : 0);
53 allocateDeviceMemory = (optionsMap.contains(
"allocate-device-memory") ? std::stoi(optionsMap[
"allocate-device-memory"]) : 0);
54 intraOpNumThreads = (optionsMap.contains(
"intra-op-num-threads") ? std::stoi(optionsMap[
"intra-op-num-threads"]) : 0);
55 interOpNumThreads = (optionsMap.contains(
"inter-op-num-threads") ? std::stoi(optionsMap[
"inter-op-num-threads"]) : 0);
56 loggingLevel = (optionsMap.contains(
"logging-level") ? std::stoi(optionsMap[
"logging-level"]) : 0);
57 enableProfiling = (optionsMap.contains(
"enable-profiling") ? std::stoi(optionsMap[
"enable-profiling"]) : 0);
58 enableOptimizations = (optionsMap.contains(
"enable-optimizations") ? std::stoi(optionsMap[
"enable-optimizations"]) : 0);
60 std::string dev_mem_str =
"Hip";
61#if defined(ORT_ROCM_BUILD)
62#if ORT_ROCM_BUILD == 1
63 if (device ==
"ROCM") {
64 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->
sessionOptions, deviceId));
65 LOG(info) <<
"(ORT) ROCM execution provider set";
69#if defined(ORT_MIGRAPHX_BUILD)
70#if ORT_MIGRAPHX_BUILD == 1
71 if (device ==
"MIGRAPHX") {
72 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->
sessionOptions, deviceId));
73 LOG(info) <<
"(ORT) MIGraphX execution provider set";
77#if defined(ORT_CUDA_BUILD)
78#if ORT_CUDA_BUILD == 1
79 if (device ==
"CUDA") {
80 Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->
sessionOptions, deviceId));
81 LOG(info) <<
"(ORT) CUDA execution provider set";
87 if (allocateDeviceMemory) {
88 pImplOrt->
memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
89 LOG(info) <<
"(ORT) Memory info set to on-device memory";
92 if (device ==
"CPU") {
93 (pImplOrt->
sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
94 (pImplOrt->
sessionOptions).SetInterOpNumThreads(interOpNumThreads);
95 if (intraOpNumThreads > 1 || interOpNumThreads > 1) {
96 (pImplOrt->
sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
97 }
else if (intraOpNumThreads == 1) {
98 (pImplOrt->
sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
100 if (loggingLevel < 2) {
101 LOG(info) <<
"(ORT) CPU execution provider set with " << intraOpNumThreads <<
" (intraOpNumThreads) and " << interOpNumThreads <<
" (interOpNumThreads) threads";
108 if (enableProfiling) {
109 if (optionsMap.contains(
"profiling-output-path")) {
110 (pImplOrt->
sessionOptions).EnableProfiling((optionsMap[
"profiling-output-path"] +
"/ORT_LOG_").c_str());
112 LOG(warning) <<
"(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
121 (pImplOrt->
sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
122 (pImplOrt->
sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
124 pImplOrt->
env = std::make_shared<Ort::Env>(
125 OrtLoggingLevel(loggingLevel),
126 (optionsMap[
"onnx-environment-name"].
empty() ?
"onnx_model_inference" : optionsMap[
"onnx-environment-name"].c_str()),
128 [](
void*
param, OrtLoggingLevel
severity,
const char* category,
const char* logid,
const char* code_location,
const char*
message) {
129 if (
severity == ORT_LOGGING_LEVEL_VERBOSE) {
130 LOG(
debug) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
131 }
else if (
severity == ORT_LOGGING_LEVEL_INFO) {
132 LOG(info) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
133 }
else if (
severity == ORT_LOGGING_LEVEL_WARNING) {
134 LOG(warning) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
135 }
else if (
severity == ORT_LOGGING_LEVEL_ERROR) {
136 LOG(error) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
137 }
else if (
severity == ORT_LOGGING_LEVEL_FATAL) {
138 LOG(fatal) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
140 LOG(info) <<
"(ORT) [" << logid <<
"|" << category <<
"|" << code_location <<
"]: " <<
message;
144 (pImplOrt->
env)->DisableTelemetryEvents();
147 for (
size_t i = 0;
i < (pImplOrt->
session)->GetInputCount(); ++
i) {
148 mInputNames.push_back((pImplOrt->
session)->GetInputNameAllocated(
i, pImplOrt->
allocator).get());
150 for (
size_t i = 0;
i < (pImplOrt->
session)->GetInputCount(); ++
i) {
151 mInputShapes.emplace_back((pImplOrt->
session)->GetInputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape());
153 for (
size_t i = 0;
i < (pImplOrt->
session)->GetOutputCount(); ++
i) {
154 mOutputNames.push_back((pImplOrt->
session)->GetOutputNameAllocated(
i, pImplOrt->
allocator).get());
156 for (
size_t i = 0;
i < (pImplOrt->
session)->GetOutputCount(); ++
i) {
157 mOutputShapes.emplace_back((pImplOrt->
session)->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape());
160 inputNamesChar.resize(mInputNames.size(),
nullptr);
161 std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
162 [&](
const std::string&
str) {
return str.c_str(); });
163 outputNamesChar.resize(mOutputNames.size(),
nullptr);
164 std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
165 [&](
const std::string&
str) {
return str.c_str(); });
167 if (loggingLevel < 2) {
168 LOG(info) <<
"(ORT) Model loaded successfully! (input: " << printShape(mInputShapes[0]) <<
", output: " << printShape(mOutputShapes[0]) <<
")";
205 std::vector<int64_t> inputShape{(int64_t)(input.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
206 std::vector<Ort::Value> inputTensor;
207 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
208 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->
memoryInfo,
reinterpret_cast<Ort::Float16_t*
>(input.data()), input.size(), inputShape.data(), inputShape.size()));
210 inputTensor.emplace_back(Ort::Value::CreateTensor<I>(pImplOrt->
memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size()));
213 auto outputTensors = (pImplOrt->
session)->Run(pImplOrt->
runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
214 O* outputValues = outputTensors[0].template GetTensorMutableData<O>();
215 std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
216 outputTensors.clear();
217 return outputValuesVec;
229 std::vector<int64_t> inputShape{(int64_t)(input_size / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
230 Ort::Value inputTensor = Ort::Value(
nullptr);
231 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
232 inputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->
memoryInfo,
reinterpret_cast<Ort::Float16_t*
>(input), input_size, inputShape.data(), inputShape.size());
234 inputTensor = Ort::Value::CreateTensor<I>(pImplOrt->
memoryInfo, input, input_size, inputShape.data(), inputShape.size());
237 std::vector<int64_t> outputShape{inputShape[0], mOutputShapes[0][1]};
238 size_t outputSize = (int64_t)(input_size * mOutputShapes[0][1] / mInputShapes[0][1]);
239 Ort::Value outputTensor = Ort::Value::CreateTensor<O>(pImplOrt->
memoryInfo,
output, outputSize, outputShape.data(), outputShape.size());
241 (pImplOrt->
session)->Run(pImplOrt->
runOptions, inputNamesChar.data(), &inputTensor, 1, outputNamesChar.data(), &outputTensor, outputNamesChar.size());
251 std::vector<Ort::Value> inputTensor;
252 for (
auto i : input) {
253 std::vector<int64_t> inputShape{(int64_t)(
i.size() / mInputShapes[0][1]), (int64_t)mInputShapes[0][1]};
254 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
255 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->
memoryInfo,
reinterpret_cast<Ort::Float16_t*
>(
i.data()),
i.size(), inputShape.data(), inputShape.size()));
257 inputTensor.emplace_back(Ort::Value::CreateTensor<I>(pImplOrt->
memoryInfo,
i.data(),
i.size(), inputShape.data(), inputShape.size()));
261 auto outputTensors = (pImplOrt->
session)->Run(pImplOrt->
runOptions, inputNamesChar.data(), inputTensor.data(), inputTensor.size(), outputNamesChar.data(), outputNamesChar.size());
262 O* outputValues =
reinterpret_cast<O*
>(outputTensors[0].template GetTensorMutableData<O>());
263 std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size() / mInputShapes[0][1] * mOutputShapes[0][1]};
264 outputTensors.clear();
265 return outputValuesVec;