Project
Loading...
Searching...
No Matches
OrtInterface.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15
16#include "ML/OrtInterface.h"
18
19// ONNX includes
20#include <onnxruntime_cxx_api.h>
21
22#include <sstream>
23
24namespace o2
25{
26
27namespace ml
28{
29
30OrtModel::OrtModel() = default;
31OrtModel::OrtModel(std::unordered_map<std::string, std::string> optionsMap) { init(optionsMap); }
32OrtModel::~OrtModel() = default;
33void OrtModel::init(std::unordered_map<std::string, std::string> optionsMap)
34{
35 initOptions(optionsMap);
37}
38
39struct OrtModel::OrtVariables { // The actual implementation is hidden in the .cxx file
40 // ORT runtime objects
41 Ort::RunOptions runOptions;
42 std::unique_ptr<Ort::Env> env = nullptr;
43 std::unique_ptr<Ort::Session> session = nullptr;
44 Ort::SessionOptions sessionOptions;
45 Ort::AllocatorWithDefaultOptions allocator;
46 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo("Cpu", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
47 std::unique_ptr<Ort::IoBinding> ioBinding = nullptr;
48};
49
50// General purpose
51void OrtModel::initOptions(std::unordered_map<std::string, std::string> optionsMap)
52{
53 mPImplOrt = std::make_unique<OrtVariables>();
54
55 // Load from options map
56 if (!optionsMap.contains("model-path")) {
57 LOG(fatal) << "(ORT) Model path must be contained in options map!";
58 }
59
60 if (!optionsMap["model-path"].empty()) {
61 mModelPath = optionsMap["model-path"];
62 mDeviceType = (optionsMap.contains("device-type") ? optionsMap["device-type"] : "CPU");
63 mDeviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : -1);
64 mAllocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
65 mIntraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
66 mInterOpNumThreads = (optionsMap.contains("inter-op-num-threads") ? std::stoi(optionsMap["inter-op-num-threads"]) : 0);
67 mLoggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
68 mEnableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
69 mEnableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
70 mEnvName = (optionsMap.contains("onnx-environment-name") ? optionsMap["onnx-environment-name"] : "onnx_model_inference");
71 mDeterministicMode = (optionsMap.contains("deterministic-compute") ? std::stoi(optionsMap["deterministic-compute"]) : 0);
72
73 if (mDeviceType == "CPU") {
74 (mPImplOrt->sessionOptions).SetIntraOpNumThreads(mIntraOpNumThreads);
75 (mPImplOrt->sessionOptions).SetInterOpNumThreads(mInterOpNumThreads);
76 if (mIntraOpNumThreads > 1 || mInterOpNumThreads > 1) {
77 (mPImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
78 } else if (mIntraOpNumThreads == 1) {
79 (mPImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
80 }
81 if (mLoggingLevel < 2) {
82 LOG(info) << "(ORT) CPU execution provider set with " << mIntraOpNumThreads << " (mIntraOpNumThreads) and " << mInterOpNumThreads << " (mInterOpNumThreads) threads";
83 }
84 }
85
86 // OrtROCMProviderOptions rocm_options{};
87 // (mPImplOrt->sessionOptions).AppendExecutionProvider_ROCM(rocm_options);
88
89 (mPImplOrt->sessionOptions).DisableMemPattern();
90 (mPImplOrt->sessionOptions).DisableCpuMemArena();
91
92 if (mEnableProfiling) {
93 if (optionsMap.contains("profiling-output-path")) {
94 (mPImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
95 } else {
96 LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
97 (mPImplOrt->sessionOptions).DisableProfiling();
98 }
99 } else {
100 (mPImplOrt->sessionOptions).DisableProfiling();
101 }
102
103 if (mDeterministicMode > 0) {
104 (mPImplOrt->sessionOptions).AddConfigEntry("session_options.use_deterministic_compute", "1");
105 }
106
107 (mPImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(mEnableOptimizations));
108 (mPImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(mLoggingLevel));
109
110 mInitialized = true;
111 } else {
112 LOG(fatal) << "(ORT) Model path cannot be empty!";
113 }
114}
115
117{
118 mPImplOrt->env = std::make_unique<Ort::Env>(
119 OrtLoggingLevel(mLoggingLevel),
120 (mEnvName.empty() ? "ORT" : mEnvName.c_str()),
121 // Integrate ORT logging into Fairlogger
122 [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
123 if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
124 LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
125 } else if (severity == ORT_LOGGING_LEVEL_INFO) {
126 LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
127 } else if (severity == ORT_LOGGING_LEVEL_WARNING) {
128 LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
129 } else if (severity == ORT_LOGGING_LEVEL_ERROR) {
130 LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
131 } else if (severity == ORT_LOGGING_LEVEL_FATAL) {
132 LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
133 } else {
134 LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
135 }
136 },
137 (void*)3);
138 (mPImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
139}
140
141void OrtModel::initSessionFromBuffer(const char* buffer, size_t bufferSize)
142{
143 mPImplOrt->sessionOptions.AddConfigEntry("session.load_model_format", "ONNX");
144 mPImplOrt->sessionOptions.AddConfigEntry("session.use_ort_model_bytes_directly", "1");
145
146 mPImplOrt->session = std::make_unique<Ort::Session>(*mPImplOrt->env,
147 buffer,
148 bufferSize,
149 mPImplOrt->sessionOptions);
150 mPImplOrt->ioBinding = std::make_unique<Ort::IoBinding>(*mPImplOrt->session);
151
152 setIO();
153
154 if (mLoggingLevel < 2) {
155 LOG(info) << "(ORT) Model loaded successfully from buffer! (inputs: " << printShape(mInputShapes, mInputNames) << ", outputs: " << printShape(mOutputShapes, mInputNames) << ")";
156 }
157}
158
160{
161 if (mAllocateDeviceMemory) {
162 memoryOnDevice(mDeviceId);
163 }
164 mPImplOrt->session = std::make_unique<Ort::Session>(*mPImplOrt->env, mModelPath.c_str(), mPImplOrt->sessionOptions);
165 mPImplOrt->ioBinding = std::make_unique<Ort::IoBinding>(*mPImplOrt->session);
166
167 setIO();
168
169 if (mLoggingLevel < 2) {
170 LOG(info) << "(ORT) Model loaded successfully! (inputs: " << printShape(mInputShapes, mInputNames) << ", outputs: " << printShape(mOutputShapes, mInputNames) << ")";
171 }
172}
173
174void OrtModel::memoryOnDevice(int32_t deviceIndex)
175{
176 if (deviceIndex >= 0) {
177 (mPImplOrt->runOptions).AddConfigEntry("disable_synchronize_execution_providers", "1");
178 (mPImplOrt->sessionOptions).AddConfigEntry("session.use_device_allocator_for_initializers", "1"); // See kOrtSessionOptionsUseDeviceAllocatorForInitializers, https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
179 (mPImplOrt->sessionOptions).AddConfigEntry("session.use_env_allocators", "1"); // This should enable to use the volatile memory allocation defined in O2/GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerHost.cxx; not working yet: ONNX still assigns new memory at init time
180 (mPImplOrt->sessionOptions).AddConfigEntry("session_options.enable_cpu_mem_arena", "0"); // This should enable to use the volatile memory allocation defined in O2/GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerHost.cxx; not working yet: ONNX still assigns new memory at init time
181 // Arena memory shrinkage comes at performance cost
182 // For now prefer to use single allocation, enabled by O2/GPU/GPUTracking/Base/cuda/GPUReconstructionCUDA.cu -> SetONNXGPUStream -> rocm_options.arena_extend_strategy = 0;
183 (mPImplOrt->runOptions).AddConfigEntry("memory.enable_memory_arena_shrinkage", ("gpu:" + std::to_string(deviceIndex)).c_str()); // See kOrtRunOptionsConfigEnableMemoryArenaShrinkage, https://github.com/microsoft/onnxruntime/blob/90c263f471bbce724e77d8e62831d3a9fa838b2f/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h#L27
184
185 std::string dev_mem_str = "";
186 if (mDeviceType == "ROCM") {
187 dev_mem_str = "HipPinned";
188 }
189 if (mDeviceType == "CUDA") {
190 dev_mem_str = "Cuda";
191 }
192 mPImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceIndex, OrtMemType::OrtMemTypeDefault);
193 if (mLoggingLevel < 2) {
194 LOG(info) << "(ORT) Memory info set to on-device memory for device type " << mDeviceType << " with ID " << deviceIndex << " and mPImplOrt pointer " << mPImplOrt;
195 }
196 }
197}
198
200{
201 mPImplOrt->session = std::make_unique<Ort::Session>(*(mPImplOrt->env), mModelPath.c_str(), mPImplOrt->sessionOptions);
202}
203
204// Getters
205Ort::SessionOptions* OrtModel::getSessionOptions()
206{
207 return &mPImplOrt->sessionOptions;
208}
209
210Ort::MemoryInfo* OrtModel::getMemoryInfo()
211{
212 return &mPImplOrt->memoryInfo;
213}
214
216{
217 return (mPImplOrt->env).get();
218}
219
220template <class I, class O>
221std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput)
222{
223 if constexpr (std::is_same_v<I, O>) {
224 return input;
225 } else {
226 std::vector<O> output(input.size());
227 std::transform(std::begin(input), std::end(input), std::begin(output), [](I f) { return O(f); });
228 if (clearInput) {
229 input.clear();
230 }
231 return output;
232 }
233}
234
236{
237 for (size_t i = 0; i < (mPImplOrt->session)->GetInputCount(); ++i) {
238 mInputNames.push_back((mPImplOrt->session)->GetInputNameAllocated(i, mPImplOrt->allocator).get());
239 }
240 for (size_t i = 0; i < (mPImplOrt->session)->GetInputCount(); ++i) {
241 mInputShapes.emplace_back((mPImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
242 }
243 for (size_t i = 0; i < (mPImplOrt->session)->GetOutputCount(); ++i) {
244 mOutputNames.push_back((mPImplOrt->session)->GetOutputNameAllocated(i, mPImplOrt->allocator).get());
245 }
246 for (size_t i = 0; i < (mPImplOrt->session)->GetOutputCount(); ++i) {
247 mOutputShapes.emplace_back((mPImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
248 }
249
250 mInputNamesChar.resize(mInputNames.size(), nullptr);
251 std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(mInputNamesChar),
252 [&](const std::string& str) { return str.c_str(); });
253 mOutputNamesChar.resize(mOutputNames.size(), nullptr);
254 std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(mOutputNamesChar),
255 [&](const std::string& str) { return str.c_str(); });
256
257 mInputShapesCopy = mInputShapes;
258 mOutputShapesCopy = mOutputShapes;
259 mInputSizePerNode.resize(mInputShapes.size(), 1);
260 mOutputSizePerNode.resize(mOutputShapes.size(), 1);
261 mInputsTotal = 1;
262 for (size_t i = 0; i < mInputShapes.size(); ++i) {
263 if (mInputShapes[i].size() > 0) {
264 for (size_t j = 1; j < mInputShapes[i].size(); ++j) {
265 if (mInputShapes[i][j] > 0) {
266 mInputsTotal *= mInputShapes[i][j];
267 mInputSizePerNode[i] *= mInputShapes[i][j];
268 }
269 }
270 }
271 }
272 mOutputsTotal = 1;
273 for (size_t i = 0; i < mOutputShapes.size(); ++i) {
274 if (mOutputShapes[i].size() > 0) {
275 for (size_t j = 1; j < mOutputShapes[i].size(); ++j) {
276 if (mOutputShapes[i][j] > 0) {
277 mOutputsTotal *= mOutputShapes[i][j];
278 mOutputSizePerNode[i] *= mOutputShapes[i][j];
279 }
280 }
281 }
282 }
283}
284
285void OrtModel::setEnv(Ort::Env* env)
286{
287 mPImplOrt->env.reset(env);
288}
289
290// Inference
291template <class I, class O>
292std::vector<O> OrtModel::inference(std::vector<I>& input)
293{
294 std::vector<int64_t> inputShape = mInputShapes[0];
295 inputShape[0] = input.size();
296 for (size_t i = 1; i < mInputShapes[0].size(); ++i) {
297 inputShape[0] /= mInputShapes[0][i];
298 }
299 std::vector<Ort::Value> inputTensor;
300 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
301 inputTensor.emplace_back(Ort::Value::CreateTensor<Ort::Float16_t>(mPImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input.data()), input.size(), inputShape.data(), inputShape.size()));
302 } else {
303 inputTensor.emplace_back(Ort::Value::CreateTensor<I>(mPImplOrt->memoryInfo, input.data(), input.size(), inputShape.data(), inputShape.size()));
304 }
305 // input.clear();
306 auto outputTensors = (mPImplOrt->session)->Run(mPImplOrt->runOptions, mInputNamesChar.data(), inputTensor.data(), inputTensor.size(), mOutputNamesChar.data(), mOutputNamesChar.size());
307 O* outputValues = outputTensors[0].template GetTensorMutableData<O>();
308 std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0] * mOutputShapes[0][1]};
309 outputTensors.clear();
310 return outputValuesVec;
311}
312
313template std::vector<float> o2::ml::OrtModel::inference<float, float>(std::vector<float>&);
314template std::vector<float> o2::ml::OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
315template std::vector<OrtDataType::Float16_t> o2::ml::OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
316
317template <class I, class O>
318void OrtModel::inference(I* input, int64_t input_size, O* output)
319{
320 // std::vector<std::string> providers = Ort::GetAvailableProviders();
321 // for (const auto& provider : providers) {
322 // LOG(info) << "Available Execution Provider: " << provider;
323 // }
324 std::vector<int64_t> inputShape{input_size, (int64_t)mInputShapes[0][1]};
325 Ort::Value inputTensor = Ort::Value(nullptr);
326 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
327 inputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(mPImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(input), input_size * mInputShapes[0][1], inputShape.data(), inputShape.size());
328 } else {
329 inputTensor = Ort::Value::CreateTensor<I>(mPImplOrt->memoryInfo, input, input_size * mInputShapes[0][1], inputShape.data(), inputShape.size());
330 }
331 (mPImplOrt->ioBinding)->BindInput(mInputNames[0].c_str(), inputTensor);
332
333 std::vector<int64_t> outputShape{input_size, mOutputShapes[0][1]};
334 Ort::Value outputTensor = Ort::Value(nullptr);
335 if constexpr (std::is_same_v<O, OrtDataType::Float16_t>) {
336 outputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(mPImplOrt->memoryInfo, reinterpret_cast<Ort::Float16_t*>(output), input_size * mOutputShapes[0][1], outputShape.data(), outputShape.size());
337 } else {
338 outputTensor = Ort::Value::CreateTensor<O>(mPImplOrt->memoryInfo, output, input_size * mOutputShapes[0][1], outputShape.data(), outputShape.size());
339 }
340 (mPImplOrt->ioBinding)->BindOutput(mOutputNames[0].c_str(), outputTensor);
341
342 (mPImplOrt->session)->Run(mPImplOrt->runOptions, *mPImplOrt->ioBinding);
343 // mPImplOrt->session->Run(
344 // mPImplOrt->runOptions,
345 // mInputNamesChar.data(),
346 // &inputTensor,
347 // mInputNamesChar.size(),
348 // mOutputNamesChar.data(),
349 // &outputTensor,
350 // mOutputNamesChar.size());
351}
352
353template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*);
354template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*);
355template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*);
356template void OrtModel::inference<float, float>(float*, int64_t, float*);
357
358template <class I, class O>
359void OrtModel::inference(I** input, int64_t input_size, O* output)
360{
361 std::vector<Ort::Value> inputTensors(mInputShapesCopy.size());
362
363 for (size_t i = 0; i < mInputShapesCopy.size(); ++i) {
364
365 mInputShapesCopy[i][0] = input_size; // batch-size
366 mOutputShapesCopy[i][0] = input_size; // batch-size
367
368 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
369 inputTensors[i] = Ort::Value::CreateTensor<Ort::Float16_t>(
370 mPImplOrt->memoryInfo,
371 reinterpret_cast<Ort::Float16_t*>(input[i]),
372 mInputSizePerNode[i] * input_size,
373 mInputShapesCopy[i].data(),
374 mInputShapesCopy[i].size());
375 } else {
376 inputTensors[i] = Ort::Value::CreateTensor<I>(
377 mPImplOrt->memoryInfo,
378 input[i],
379 mInputSizePerNode[i] * input_size,
380 mInputShapesCopy[i].data(),
381 mInputShapesCopy[i].size());
382 }
383 }
384
385 Ort::Value outputTensor = Ort::Value(nullptr);
386 if constexpr (std::is_same_v<O, OrtDataType::Float16_t>) {
387 outputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(
388 mPImplOrt->memoryInfo,
389 reinterpret_cast<Ort::Float16_t*>(output),
390 mOutputSizePerNode[0] * input_size, // assumes that there is only one output node
391 mOutputShapesCopy[0].data(),
392 mOutputShapesCopy[0].size());
393 } else {
394 outputTensor = Ort::Value::CreateTensor<O>(
395 mPImplOrt->memoryInfo,
396 output,
397 mOutputSizePerNode[0] * input_size, // assumes that there is only one output node
398 mOutputShapesCopy[0].data(),
399 mOutputShapesCopy[0].size());
400 }
401
402 // === Run inference ===
403 mPImplOrt->session->Run(
404 mPImplOrt->runOptions,
405 mInputNamesChar.data(),
406 inputTensors.data(),
407 mInputNamesChar.size(),
408 mOutputNamesChar.data(),
409 &outputTensor,
410 mOutputNamesChar.size());
411}
412
413template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, int64_t, OrtDataType::Float16_t*);
414template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*);
415template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*);
416template void OrtModel::inference<float, float>(float**, int64_t, float*);
417
418template <class I, class O>
419std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)
420{
421 std::vector<Ort::Value> input_tensors;
422
423 for (size_t i = 0; i < inputs.size(); ++i) {
424
425 mInputShapesCopy[i][0] = inputs[i].size() / mInputSizePerNode[i]; // batch-size
426
427 if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
428 input_tensors.emplace_back(
429 Ort::Value::CreateTensor<Ort::Float16_t>(
430 mPImplOrt->memoryInfo,
431 reinterpret_cast<Ort::Float16_t*>(inputs[i].data()),
432 mInputSizePerNode[i] * mInputShapesCopy[i][0],
433 mInputShapesCopy[i].data(),
434 mInputShapesCopy[i].size()));
435 } else {
436 input_tensors.emplace_back(
437 Ort::Value::CreateTensor<I>(
438 mPImplOrt->memoryInfo,
439 inputs[i].data(),
440 mInputSizePerNode[i] * mInputShapesCopy[i][0],
441 mInputShapesCopy[i].data(),
442 mInputShapesCopy[i].size()));
443 }
444 }
445
446 int32_t totalOutputSize = mOutputsTotal * mInputShapesCopy[0][0];
447
448 // === Run inference ===
449 auto output_tensors = mPImplOrt->session->Run(
450 mPImplOrt->runOptions,
451 mInputNamesChar.data(),
452 input_tensors.data(),
453 input_tensors.size(),
454 mOutputNamesChar.data(),
455 mOutputNamesChar.size());
456
457 // === Extract output values ===
458 O* output_data = output_tensors[0].template GetTensorMutableData<O>();
459 std::vector<O> output_vec(output_data, output_data + totalOutputSize);
460 output_tensors.clear();
461 return output_vec;
462}
463
464template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&);
465template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&);
466
467// Release session
468void OrtModel::release(bool profilingEnabled)
469{
470 mPImplOrt.reset();
471}
472
473// private
474std::string OrtModel::printShape(const std::vector<int64_t>& v)
475{
476 std::stringstream ss("");
477 for (size_t i = 0; i < v.size() - 1; i++) {
478 ss << v[i] << "x";
479 }
480 ss << v[v.size() - 1];
481 return ss.str();
482}
483
484std::string OrtModel::printShape(const std::vector<std::vector<int64_t>>& v, std::vector<std::string>& n)
485{
486 std::stringstream ss("");
487 for (size_t i = 0; i < v.size(); i++) {
488 ss << n[i] << " -> (";
489 for (size_t j = 0; j < v[i].size() - 1; j++) {
490 ss << v[i][j] << "x";
491 }
492 ss << v[i][v[i].size() - 1] << "); ";
493 }
494 return ss.str();
495}
496
497} // namespace ml
498
499} // namespace o2
std::ostringstream debug
int32_t i
void output(const std::map< std::string, ChannelStat > &channels)
Definition rawdump.cxx:197
A header library for loading ONNX models and inferencing them on CPU and GPU.
uint32_t j
Definition RawData.h:0
void initOptions(std::unordered_map< std::string, std::string > optionsMap)
void memoryOnDevice(int32_t=0)
Ort::Env * getEnv()
void release(bool=false)
void setEnv(Ort::Env *)
std::vector< O > v2v(std::vector< I > &, bool=true)
void initSessionFromBuffer(const char *buffer, size_t bufferSize)
Ort::MemoryInfo * getMemoryInfo()
std::vector< O > inference(std::vector< I > &)
virtual ~OrtModel()
void init(std::unordered_map< std::string, std::string > optionsMap)
Ort::SessionOptions * getSessionOptions()
GLdouble n
Definition glcorearb.h:1982
GLuint buffer
Definition glcorearb.h:655
GLsizeiptr size
Definition glcorearb.h:659
const GLdouble * v
Definition glcorearb.h:832
GLdouble f
Definition glcorearb.h:310
GLuint GLsizei const GLchar * message
Definition glcorearb.h:2517
GLenum GLfloat param
Definition glcorearb.h:271
GLenum GLenum severity
Definition glcorearb.h:2513
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
std::string to_string(gsl::span< T, Size > span)
Definition common.h:52
void empty(int)
Ort::AllocatorWithDefaultOptions allocator
std::unique_ptr< Ort::Session > session
ONNX session.
Ort::SessionOptions sessionOptions
std::unique_ptr< Ort::IoBinding > ioBinding
std::unique_ptr< Ort::Env > env
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"
const std::string str