Project
Loading...
Searching...
No Matches
run-example.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
16
17#include "FastSimulations.h"
18#include "Config.h"
19#include "Processors.h"
20#include "Utils.h"
21
22#include <iostream>
23
24// Basic test which check if models can be properly loaded and run.
25// Models result is being displayed at stdout (to check if it looks correct).
26int main()
27{
28 // Sample particle data
29 const std::vector<float> particleData = {5.133179999999999836e+02,
30 1.454299999999999993e-08,
31 3.650509999999999381e-08,
32 -2.731009999999999861e-03,
33 3.545600000000000140e-02,
34 -5.182060000000000138e-02,
35 -5.133179999999999836e+02,
36 0.000000000000000000e+00,
37 0.000000000000000000e+00};
38
39 // Loading VAE scales
40 std::cout << "Loading ONNX model VAE from: " << o2::zdc::fastsim::gZDCModelPath << std::endl;
42
43 // If vaeScale.has_value() != true error occured during loading scales
44 if (!vaeScales.has_value()) {
45 std::cout << "error loading vae model scales" << std::endl;
46 return 0;
47 }
48 // Loading actual model and setting scales
50 std::cout << " ONNX VAE model loaded: " << std::endl;
51
52 // Loading SAE scales
53 std::cout << "Loading ONNX model SAE from: " << o2::zdc::fastsim::gSAEModelPath << std::endl;
55
56 // If saeScale.has_value() != true error occured during loading scales
57 if (!saeScales.has_value()) {
58 std::cout << "error loading sae model scales" << std::endl;
59 return 0;
60 }
61 // Loading actual model and setting scales
63 std::cout << " ONNX SAE model loaded: " << std::endl;
64
65 // Create scaler object, set scales and scale particleData
67 scaler.setScales(vaeScales->first, vaeScales->second);
68 auto scaled = scaler.scale(particleData);
69
70 // Create noise vector
71 auto noise = o2::zdc::fastsim::normal_distribution(0.0, 1.0, 10);
72
73 // Create input vector
74 std::vector<std::vector<float>> input = {noise, std::move(scaled.value())};
75
76 // Running model, will throw only if error in ONNX was encountered
77 onnxVAEDemo.setInput(input);
78 onnxVAEDemo.run();
79 auto vaeResult = o2::zdc::fastsim::processors::calculateChannels(onnxVAEDemo.getResult()[0], 1)[0];
80
81 // Running model, will throw only if error in ONNX was encountered
82 onnxSAEDemo.setInput(input);
83 onnxSAEDemo.run();
84 auto saeResult = o2::zdc::fastsim::processors::calculateChannels(onnxSAEDemo.getResult()[0], 1)[0];
85
86 // Print output
87 for (auto& element : vaeResult) {
88 std::cout << element << ", ";
89 }
90 std::cout << std::endl;
91 for (auto& element : saeResult) {
92 std::cout << element << ", ";
93 }
94
95 return 0;
96}
Derived class implementing interface for specific types of models.
bool setInput(std::vector< std::vector< float > > &input) override
Implements setInput.
const std::vector< Ort::Value > & getResult() override
Returns single model output as const&. Returned vector is of size 1.
std::vector< std::array< long, 5 > > calculateChannels(const Ort::Value &value, size_t batchSize)
Calculate 5 channels values from 44x44 float array (for every batch)
std::vector< float > normal_distribution(double mean, double stddev, size_t size)
Generates a vector of numbers with a given normal distribution and length.
Definition Utils.cxx:19
const std::string gZDCModelPath
Global paths to models and scales files.
Definition Config.h:29
std::optional< std::pair< std::vector< float >, std::vector< float > > > loadScales(const std::string &path)
loads and parse model scales from file at path
const std::string gSAEModelPath
Definition Config.h:31
const std::string gZDCModelConfig
Definition Config.h:30
const std::string gSAEModelConfig
Definition Config.h:32
int main()