Project
Loading...
Searching...
No Matches
test_Expressions.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15#include "../src/ExpressionJSONHelpers.h"
16#include <catch_amalgamated.hpp>
17#include <arrow/util/config.h>
18#include <iostream>
19
20using namespace o2::framework;
21using namespace o2::framework::expressions;
22
23namespace nodes
24{
25static BindingNode pt{"pt", 1, atype::FLOAT};
26static BindingNode phi{"phi", 2, atype::FLOAT};
27static BindingNode eta{"eta", 3, atype::FLOAT};
28
29static BindingNode tgl{"tgl", 4, atype::FLOAT};
30static BindingNode signed1Pt{"signed1Pt", 5, atype::FLOAT};
31static BindingNode testInt{"testInt", 6, atype::INT32};
32} // namespace nodes
33
34namespace o2::aod::track
35{
36DECLARE_SOA_EXPRESSION_COLUMN(Pze, pz, float, o2::aod::track::tgl * (1.f / o2::aod::track::signed1Pt));
37} // namespace o2::aod::track
38
39TEST_CASE("TestTreeParsing")
40{
41 expressions::Filter f = ((nodes::phi > 1) && (nodes::phi < 2)) && (nodes::eta < 1);
42 auto specs = createOperations(f);
43 REQUIRE(specs[0].left == (DatumSpec{1u, atype::BOOL}));
44 REQUIRE(specs[0].right == (DatumSpec{2u, atype::BOOL}));
45 REQUIRE(specs[0].result == (DatumSpec{0u, atype::BOOL}));
46
47 REQUIRE(specs[1].left == (DatumSpec{std::string{"eta"}, 3, atype::FLOAT}));
48 REQUIRE(specs[1].right == (DatumSpec{LiteralNode::var_t{1}, atype::INT32}));
49 REQUIRE(specs[1].result == (DatumSpec{2u, atype::BOOL}));
50
51 REQUIRE(specs[2].left == (DatumSpec{3u, atype::BOOL}));
52 REQUIRE(specs[2].right == (DatumSpec{4u, atype::BOOL}));
53 REQUIRE(specs[2].result == (DatumSpec{1u, atype::BOOL}));
54
55 REQUIRE(specs[3].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
56 REQUIRE(specs[3].right == (DatumSpec{LiteralNode::var_t{2}, atype::INT32}));
57 REQUIRE(specs[3].result == (DatumSpec{4u, atype::BOOL}));
58
59 REQUIRE(specs[4].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
60 REQUIRE(specs[4].right == (DatumSpec{LiteralNode::var_t{1}, atype::INT32}));
61 REQUIRE(specs[4].result == (DatumSpec{3u, atype::BOOL}));
62
63 expressions::Filter g = ((nodes::eta + 2.f) > 0.5) || ((nodes::phi - M_PI) < 3);
64 auto gspecs = createOperations(g);
65 REQUIRE(gspecs[0].left == (DatumSpec{1u, atype::BOOL}));
66 REQUIRE(gspecs[0].right == (DatumSpec{2u, atype::BOOL}));
67 REQUIRE(gspecs[0].result == (DatumSpec{0u, atype::BOOL}));
68
69 REQUIRE(gspecs[1].left == (DatumSpec{3u, atype::DOUBLE}));
70 REQUIRE(gspecs[1].right == (DatumSpec{LiteralNode::var_t{3}, atype::INT32}));
71 REQUIRE(gspecs[1].result == (DatumSpec{2u, atype::BOOL}));
72
73 REQUIRE(gspecs[2].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
74 REQUIRE(gspecs[2].right == (DatumSpec{LiteralNode::var_t{M_PI}, atype::DOUBLE}));
75 REQUIRE(gspecs[2].result == (DatumSpec{3u, atype::DOUBLE}));
76
77 REQUIRE(gspecs[3].left == (DatumSpec{4u, atype::FLOAT}));
78 REQUIRE(gspecs[3].right == (DatumSpec{LiteralNode::var_t{0.5}, atype::DOUBLE}));
79 REQUIRE(gspecs[3].result == (DatumSpec{1u, atype::BOOL}));
80
81 REQUIRE(gspecs[4].left == (DatumSpec{std::string{"eta"}, 3, atype::FLOAT}));
82 REQUIRE(gspecs[4].right == (DatumSpec{LiteralNode::var_t{2.f}, atype::FLOAT}));
83 REQUIRE(gspecs[4].result == (DatumSpec{4u, atype::FLOAT}));
84
85 expressions::Filter h = (nodes::phi == 0) || (nodes::phi == 3);
86 auto hspecs = createOperations(h);
87
88 REQUIRE(hspecs[0].left == (DatumSpec{1u, atype::BOOL}));
89 REQUIRE(hspecs[0].right == (DatumSpec{2u, atype::BOOL}));
90 REQUIRE(hspecs[0].result == (DatumSpec{0u, atype::BOOL}));
91
92 REQUIRE(hspecs[1].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
93 REQUIRE(hspecs[1].right == (DatumSpec{LiteralNode::var_t{3}, atype::INT32}));
94 REQUIRE(hspecs[1].result == (DatumSpec{2u, atype::BOOL}));
95
96 REQUIRE(hspecs[2].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
97 REQUIRE(hspecs[2].right == (DatumSpec{LiteralNode::var_t{0}, atype::INT32}));
98 REQUIRE(hspecs[2].result == (DatumSpec{1u, atype::BOOL}));
99
100 expressions::Filter u = nabs(nodes::eta) < 1.0 && nexp(nodes::phi + 2.0 * M_PI) > 3.0;
101 auto uspecs = createOperations(std::move(u));
102 REQUIRE(uspecs[0].left == (DatumSpec{1u, atype::BOOL}));
103 REQUIRE(uspecs[0].right == (DatumSpec{2u, atype::BOOL}));
104 REQUIRE(uspecs[0].result == (DatumSpec{0u, atype::BOOL}));
105
106 REQUIRE(uspecs[1].left == (DatumSpec{3u, atype::DOUBLE}));
107 REQUIRE(uspecs[1].right == (DatumSpec{LiteralNode::var_t{3.0}, atype::DOUBLE}));
108 REQUIRE(uspecs[1].result == (DatumSpec{2u, atype::BOOL}));
109
110 REQUIRE(uspecs[2].left == (DatumSpec{4u, atype::DOUBLE}));
111 REQUIRE(uspecs[2].right == (DatumSpec{}));
112 REQUIRE(uspecs[2].result == (DatumSpec{3u, atype::DOUBLE}));
113
114 REQUIRE(uspecs[3].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
115 REQUIRE(uspecs[3].right == (DatumSpec{LiteralNode::var_t{2.0 * M_PI}, atype::DOUBLE}));
116 REQUIRE(uspecs[3].result == (DatumSpec{4u, atype::DOUBLE}));
117
118 REQUIRE(uspecs[4].left == (DatumSpec{5u, atype::FLOAT}));
119 REQUIRE(uspecs[4].right == (DatumSpec{LiteralNode::var_t{1.0}, atype::DOUBLE}));
120 REQUIRE(uspecs[4].result == (DatumSpec{1u, atype::BOOL}));
121
122 REQUIRE(uspecs[5].left == (DatumSpec{std::string{"eta"}, 3, atype::FLOAT}));
123 REQUIRE(uspecs[5].right == (DatumSpec{}));
124 REQUIRE(uspecs[5].result == (DatumSpec{5u, atype::FLOAT}));
125
126 Configurable<float> pTCut{"pTCut", 0.5f, "Lower pT limit"};
127 Filter ptfilter = o2::aod::track::pt > pTCut;
128 REQUIRE(ptfilter.node->self.index() == 2);
129 REQUIRE(ptfilter.node->left->self.index() == 1);
130 REQUIRE(ptfilter.node->right->self.index() == 3);
131 auto ptfilterspecs = createOperations(ptfilter);
132 REQUIRE(ptfilterspecs[0].left == (DatumSpec{std::string{"fPt"}, "o2::aod::track::pt"_h, atype::FLOAT}));
133 REQUIRE(ptfilterspecs[0].right == (DatumSpec{LiteralNode::var_t{0.5f}, atype::FLOAT}));
134 REQUIRE(ptfilterspecs[0].result == (DatumSpec{0u, atype::BOOL}));
135
136 struct : ConfigurableGroup {
137 std::string prefix = "prefix";
138 Configurable<float> pTCut{"pTCut", 1.0f, "Lower pT limit"};
139 } group;
140 Filter ptfilter2 = o2::aod::track::pt > group.pTCut;
141 group.pTCut.name.insert(0, 1, '.');
142 group.pTCut.name.insert(0, group.prefix);
143 REQUIRE(ptfilter2.node->self.index() == 2);
144 REQUIRE(ptfilter2.node->left->self.index() == 1);
145 REQUIRE(ptfilter2.node->right->self.index() == 3);
146 REQUIRE(std::get<PlaceholderNode>(ptfilter2.node->right->self).name == "prefix.pTCut");
147 auto ptfilterspecs2 = createOperations(ptfilter2);
148 REQUIRE(ptfilterspecs2[0].left == (DatumSpec{std::string{"fPt"}, "o2::aod::track::pt"_h, atype::FLOAT}));
149 REQUIRE(ptfilterspecs2[0].right == (DatumSpec{LiteralNode::var_t{1.0f}, atype::FLOAT}));
150 REQUIRE(ptfilterspecs2[0].result == (DatumSpec{0u, atype::BOOL}));
151
152 Configurable<int> cvalue{"cvalue", 1, "test value"};
153 Filter testFilter = o2::aod::track::tpcNClsShared < as<uint8_t>(cvalue);
154 REQUIRE(testFilter.node->self.index() == 2);
155 REQUIRE(testFilter.node->left->self.index() == 1);
156 REQUIRE(testFilter.node->right->self.index() == 3);
157 REQUIRE(std::get<PlaceholderNode>(testFilter.node->right->self).name == "cvalue");
158 auto testSpecs = createOperations(testFilter);
159 REQUIRE(testSpecs[0].right == (DatumSpec{LiteralNode::var_t{(uint8_t)1}, atype::UINT8}));
160}
161
162TEST_CASE("TestGandivaTreeCreation")
163{
164 Projector pze = o2::aod::track::Pze::Projector();
165 auto pzspecs = createOperations(pze);
166 REQUIRE(pzspecs[0].left == (DatumSpec{std::string{"fTgl"}, "o2::aod::track::tgl"_h, atype::FLOAT}));
167 REQUIRE(pzspecs[0].right == (DatumSpec{1u, atype::FLOAT}));
168 REQUIRE(pzspecs[0].result == (DatumSpec{0u, atype::FLOAT}));
169
170 REQUIRE(pzspecs[1].left == (DatumSpec{LiteralNode::var_t{1.f}, atype::FLOAT}));
171 REQUIRE(pzspecs[1].right == (DatumSpec{std::string{"fSigned1Pt"}, "o2::aod::track::signed1Pt"_h, atype::FLOAT}));
172 REQUIRE(pzspecs[1].result == (DatumSpec{1u, atype::FLOAT}));
173 auto infield1 = o2::aod::track::Signed1Pt::asArrowField();
174 auto infield2 = o2::aod::track::Tgl::asArrowField();
175 auto resfield = o2::aod::track::Pze::asArrowField();
176 auto schema = std::make_shared<arrow::Schema>(std::vector{infield1, infield2, resfield});
177 auto gandiva_tree = createExpressionTree(pzspecs, schema);
178
179 auto gandiva_expression = makeExpression(gandiva_tree, resfield);
180 REQUIRE(std::string(gandiva_expression->ToString()) == std::string("float multiply((float) fTgl, float divide((const float) 1 raw(3f800000), (float) fSigned1Pt))"));
181 auto projector = createProjector(schema, pzspecs, resfield);
182
183 Projector pte = o2::aod::track::Pt::Projector();
184 auto ptespecs = createOperations(pte);
185
186 auto infield3 = o2::aod::track::Signed1Pt::asArrowField();
187 auto resfield2 = o2::aod::track::Pt::asArrowField();
188 auto schema2 = std::make_shared<arrow::Schema>(std::vector{infield3, resfield2});
189 auto gandiva_tree2 = createExpressionTree(ptespecs, schema2);
190
191 auto gandiva_expression2 = makeExpression(gandiva_tree2, resfield2);
192 REQUIRE(gandiva_expression2->ToString() == "if (bool less_than_or_equal_to(float absf((float) fSigned1Pt), (const float) 1.17549e-38 raw(800000))) { (const float) 8.50706e+37 raw(7e800000) } else { float absf(float divide((const float) 1 raw(3f800000), (float) fSigned1Pt)) }");
193
194 auto projector_b = createProjector(schema2, ptespecs, resfield2);
195 auto fields = o2::soa::createFieldsFromColumns(o2::aod::Tracks::persistent_columns_t{});
196 auto schema_p = std::make_shared<arrow::Schema>(fields);
198
199 Filter bitwiseFilter = (o2::aod::track::flags & static_cast<uint32_t>(o2::aod::track::TPCrefit)) != 0u;
200 auto bwf = createOperations(bitwiseFilter);
201 REQUIRE(bwf[0].left == (DatumSpec{1u, atype::UINT32}));
202 REQUIRE(bwf[0].right == (DatumSpec{LiteralNode::var_t{0u}, atype::UINT32}));
203 REQUIRE(bwf[0].result == (DatumSpec{0u, atype::BOOL}));
204
205 REQUIRE(bwf[1].left == (DatumSpec{std::string{"fFlags"}, "o2::aod::track::flags"_h, atype::UINT32}));
206 REQUIRE(bwf[1].right == (DatumSpec{LiteralNode::var_t{static_cast<uint32_t>(o2::aod::track::TPCrefit)}, atype::UINT32}));
207 REQUIRE(bwf[1].result == (DatumSpec{1u, atype::UINT32}));
208
209 auto infield4 = o2::aod::track::Flags::asArrowField();
210 auto resfield3 = std::make_shared<arrow::Field>("out", arrow::boolean());
211 auto schema_b = std::make_shared<arrow::Schema>(std::vector{infield4, resfield3});
212 auto gandiva_tree3 = createExpressionTree(bwf, schema_b);
213 REQUIRE(gandiva_tree3->ToString() == "bool not_equal(uint32 bitwise_and((uint32) fFlags, (const uint32) 2), (const uint32) 0)");
214 auto condition = expressions::makeCondition(gandiva_tree3);
215 std::shared_ptr<gandiva::Filter> flt;
216 auto s = gandiva::Filter::Make(schema_b, condition, &flt);
217 REQUIRE(s.ok());
218
219 Filter rounding = nround(o2::aod::track::pt) > 0.1f;
220 auto rf = createOperations(rounding);
221 REQUIRE(rf[0].left == (DatumSpec{1u, atype::FLOAT}));
222 REQUIRE(rf[0].right == (DatumSpec{LiteralNode::var_t{0.1f}, atype::FLOAT}));
223 REQUIRE(rf[0].result == (DatumSpec{0u, atype::BOOL}));
224
225 REQUIRE(rf[1].left == (DatumSpec{std::string{"fPt"}, "o2::aod::track::pt"_h, atype::FLOAT}));
226 REQUIRE(rf[1].right == (DatumSpec{}));
227 REQUIRE(rf[1].result == (DatumSpec{1u, atype::FLOAT}));
228
229 auto infield5 = o2::aod::track::Pt::asArrowField();
230 auto resfield4 = std::make_shared<arrow::Field>("out", arrow::boolean());
231 auto schema_c = std::make_shared<arrow::Schema>(std::vector{infield5, resfield4});
232 auto gandiva_tree4 = createExpressionTree(rf, schema_c);
233 REQUIRE(gandiva_tree4->ToString() == "bool greater_than(float round((float) fPt), (const float) 0.1 raw(3dcccccd))");
234 auto condition2 = expressions::makeCondition(gandiva_tree4);
235 std::shared_ptr<gandiva::Filter> flt2;
236 auto s2 = gandiva::Filter::Make(schema_c, condition2, &flt2);
237 REQUIRE(s2.ok());
238}
239
240TEST_CASE("TestConditionalExpressions")
241{
242 // simple conditional
243 Filter cf = nabs(o2::aod::track::eta) < 1.0f && ifnode((o2::aod::track::pt < 1.0f), (o2::aod::track::phi > (float)(M_PI / 2.)), (o2::aod::track::phi < (float)(M_PI / 2.)));
244 auto cfspecs = createOperations(cf);
245 REQUIRE(cfspecs[0].left == (DatumSpec{1u, atype::BOOL}));
246 REQUIRE(cfspecs[0].right == (DatumSpec{2u, atype::BOOL}));
247 REQUIRE(cfspecs[0].result == (DatumSpec{0u, atype::BOOL}));
248
249 REQUIRE(cfspecs[1].left == (DatumSpec{3u, atype::BOOL}));
250 REQUIRE(cfspecs[1].right == (DatumSpec{4u, atype::BOOL}));
251 REQUIRE(cfspecs[1].condition == (DatumSpec{5u, atype::BOOL}));
252 REQUIRE(cfspecs[1].result == (DatumSpec{2u, atype::BOOL}));
253
254 REQUIRE(cfspecs[2].left == (DatumSpec{std::string{"fPt"}, "o2::aod::track::pt"_h, atype::FLOAT}));
255 REQUIRE(cfspecs[2].right == (DatumSpec{LiteralNode::var_t{1.0f}, atype::FLOAT}));
256 REQUIRE(cfspecs[2].result == (DatumSpec{5u, atype::BOOL}));
257
258 REQUIRE(cfspecs[3].left == (DatumSpec{std::string{"fPhi"}, "o2::aod::track::phi"_h, atype::FLOAT}));
259 REQUIRE(cfspecs[3].right == (DatumSpec{LiteralNode::var_t{(float)(M_PI / 2.)}, atype::FLOAT}));
260 REQUIRE(cfspecs[3].result == (DatumSpec{4u, atype::BOOL}));
261
262 REQUIRE(cfspecs[4].left == (DatumSpec{std::string{"fPhi"}, "o2::aod::track::phi"_h, atype::FLOAT}));
263 REQUIRE(cfspecs[4].right == (DatumSpec{LiteralNode::var_t{(float)(M_PI / 2.)}, atype::FLOAT}));
264 REQUIRE(cfspecs[4].result == (DatumSpec{3u, atype::BOOL}));
265
266 REQUIRE(cfspecs[5].left == (DatumSpec{6u, atype::FLOAT}));
267 REQUIRE(cfspecs[5].right == (DatumSpec{LiteralNode::var_t{1.0f}, atype::FLOAT}));
268 REQUIRE(cfspecs[5].result == (DatumSpec{1u, atype::BOOL}));
269
270 REQUIRE(cfspecs[6].left == (DatumSpec{std::string{"fEta"}, "o2::aod::track::eta"_h, atype::FLOAT}));
271 REQUIRE(cfspecs[6].right == (DatumSpec{}));
272 REQUIRE(cfspecs[6].result == (DatumSpec{6u, atype::FLOAT}));
273
274 auto infield1 = o2::aod::track::Pt::asArrowField();
275 auto infield2 = o2::aod::track::Eta::asArrowField();
276 auto infield3 = o2::aod::track::Phi::asArrowField();
277 auto schema = std::make_shared<arrow::Schema>(std::vector{infield1, infield2, infield3});
278 auto gandiva_tree = createExpressionTree(cfspecs, schema);
279 auto gandiva_condition = makeCondition(gandiva_tree);
280 auto gandiva_filter = createFilter(schema, gandiva_condition);
281
282 REQUIRE(gandiva_tree->ToString() == "bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }");
283
284 // nested conditional
285 Filter cfn = o2::aod::track::signed1Pt > 0.f && ifnode(std::move(*cf.node), nabs(o2::aod::track::x) > 1.0f, nabs(o2::aod::track::y) > 1.0f);
286 auto cfnspecs = createOperations(cfn);
287 auto infield4 = o2::aod::track::Signed1Pt::asArrowField();
288 auto infield5 = o2::aod::track::X::asArrowField();
289 auto infield6 = o2::aod::track::Y::asArrowField();
290 auto schema2 = std::make_shared<arrow::Schema>(std::vector{infield1, infield2, infield3, infield4, infield5, infield6});
291 auto gandiva_tree2 = createExpressionTree(cfnspecs, schema2);
292 auto gandiva_condition2 = makeCondition(gandiva_tree2);
293 auto gandiva_filter2 = createFilter(schema2, gandiva_condition2);
294 REQUIRE(gandiva_tree2->ToString() == "bool greater_than((float) fSigned1Pt, (const float) 0 raw(0)) && if (bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }) { bool greater_than(float absf((float) fX), (const float) 1 raw(3f800000)) } else { bool greater_than(float absf((float) fY), (const float) 1 raw(3f800000)) }");
295
296 // clamp
297 Projector clp = clamp(o2::aod::track::pt, 1.0f, 10.f);
298 auto clpspecs = createOperations(clp);
299 auto schemaclp = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Pt::asArrowField()});
300 auto gandiva_tree_clp = createExpressionTree(clpspecs, schemaclp);
301 REQUIRE(gandiva_tree_clp->ToString() == "if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { (const float) 1 raw(3f800000) } else { if (bool greater_than((float) fPt, (const float) 10 raw(41200000))) { (const float) 10 raw(41200000) } else { (float) fPt } }");
302}
303
304TEST_CASE("TestBinnedExpressions")
305{
306 std::vector<float> bins{0.5, 1.5, 2.5, 3.5, 4.5};
307 std::vector<float> params{1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3, 4.0, 4.1, 4.2, 4.3};
308 Projector p = binned(bins, params, o2::aod::track::pt, par(0) * o2::aod::track::x + par(1) * o2::aod::track::y + par(2) * o2::aod::track::z + par(3) * o2::aod::track::phi, LiteralNode{0.f});
309 auto pspecs = createOperations(p);
310 auto schema = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField(), o2::aod::track::Phi::asArrowField()});
311 auto tree = createExpressionTree(pspecs, schema);
312 REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool less_than((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 2 raw(40000000), (float) fY)), float multiply((const float) 3 raw(40400000), (float) fZ)), float multiply((const float) 4 raw(40800000), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 3.1 raw(40466666), (float) fZ)), float multiply((const float) 4.1 raw(40833333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 1.2 raw(3f99999a), (float) fX), float multiply((const float) 2.2 raw(400ccccd), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 4.2 raw(40866666), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 1.3 raw(3fa66666), (float) fX), float multiply((const float) 2.3 raw(40133333), (float) fY)), float multiply((const float) 3.3 raw(40533333), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }");
313
315 std::vector<float> parameters{1.0, 1.1, 1.2, 1.3, // par 0
316 2.0, 2.1, 2.2, 2.3, // par 1
317 3.0, 3.1, 3.2, 3.3, // par 2
318 4.0, 4.1, 4.2, 4.3}; // par 3
319
320 Projector p2 = binned((std::vector<float>)binning,
321 (std::vector<float>)parameters,
322 o2::aod::track::phi, par(0) * o2::aod::track::x * o2::aod::track::x + par(1) * o2::aod::track::y * o2::aod::track::y + par(2) * o2::aod::track::z * o2::aod::track::z,
323 LiteralNode{-1.f});
324 auto p2specs = createOperations(p2);
325 auto schema2 = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Phi::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField()});
326 auto tree2 = createExpressionTree(p2specs, schema2);
327 REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 2 raw(40000000), (float) fY), (float) fY)), float multiply(float multiply((const float) 3 raw(40400000), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fX), (float) fX), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 1.3 raw(3fa66666), (float) fX), (float) fX), float multiply(float multiply((const float) 2.3 raw(40133333), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.3 raw(40533333), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }");
328}
329
331{
332 int token;
333 while ((token = t.nextToken()) && (token != Token::EoL)) {
334 std::cout << t.TokenStr << " ";
335 };
336 std::cout << std::endl;
337}
338
339TEST_CASE("TestStringExpressionsParsing")
340{
341 Filter f = (o2::aod::track::flags & 1u) != 0u && (o2::aod::track::pt <= 10.f);
342 std::string input = "(o2::aod::track::flags & 1u) != 0u && (o2::aod::track::pt <= 10.f)";
343
344 auto t1 = createOperations(f);
345 Filter ff = Parser::parse(input);
346 auto t2 = createOperations(ff);
347
348 auto schema = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Flags::asArrowField(), o2::aod::track::Pt::asArrowField()});
349 auto tree1 = createExpressionTree(t1, schema);
350 auto tree2 = createExpressionTree(t2, schema);
351
352 REQUIRE(tree1->ToString() == tree2->ToString());
353
354 Projector p = -1.f * nlog(ntan(o2::constants::math::PIQuarter - 0.5f * natan(o2::aod::fwdtrack::tgl)));
355 input = "-1.f * nlog(ntan(PIQuarter - 0.5f * natan(o2::aod::fwdtrack::tgl)))";
356
357 auto tp1 = createOperations(p);
358 Projector pp = Parser::parse(input);
359 auto tp2 = createOperations(pp);
360
361 schema = std::make_shared<arrow::Schema>(std::vector{o2::aod::fwdtrack::Tgl::asArrowField()});
362 auto treep1 = createExpressionTree(tp1, schema);
363 auto treep2 = createExpressionTree(tp2, schema);
364
365 REQUIRE(treep1->ToString() == treep2->ToString());
366
367 Filter f2 = o2::aod::track::signed1Pt > 0.f && ifnode(nabs(o2::aod::track::eta) < 1.0f, nabs(o2::aod::track::x) > 2.0f, nabs(o2::aod::track::y) > 3.0f);
368 input = "o2::aod::track::signed1Pt > 0.f && ifnode(nabs(o2::aod::track::eta) < 1.0f, nabs(o2::aod::track::x) > 2.0f, nabs(o2::aod::track::y) > 3.0f)";
369
370 auto tf1 = createOperations(f2);
371 Filter ff2 = Parser::parse(input);
372 auto tf2 = createOperations(ff2);
373
374 schema = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Eta::asArrowField(), o2::aod::track::Signed1Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField()});
375 auto treef1 = createExpressionTree(tf1, schema);
376 auto treef2 = createExpressionTree(tf2, schema);
377
378 REQUIRE(treef1->ToString() == treef2->ToString());
379
380 Configurable<float> pTCut{"pTCut", 0.5f, "Lower pT limit"};
381 Filter pcfg1 = o2::aod::track::pt > pTCut;
382 Filter pcfg2 = Parser::parse("o2::aod::track::pt > ncfg(float, 0.5, \"pTCut\")");
383 auto pcfg1specs = createOperations(pcfg1);
384 auto pcfg2specs = createOperations(pcfg2);
385
386 REQUIRE(pcfg2.node->right->self.index() == 3);
387 REQUIRE(pcfg2specs[0].right == (DatumSpec{LiteralNode::var_t{0.5f}, atype::FLOAT}));
388
389 schema = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Pt::asArrowField()});
390 auto tree1c = createExpressionTree(pcfg1specs, schema);
391 auto tree2c = createExpressionTree(pcfg2specs, schema);
392
393 REQUIRE(tree1c->ToString() == tree2c->ToString());
394}
395
396TEST_CASE("TestExpressionSerialization")
397{
398 Filter f = o2::aod::track::signed1Pt > 0.f && ifnode(nabs(o2::aod::track::eta) < 1.0f, nabs(o2::aod::track::x) > 2.0f, nabs(o2::aod::track::y) > 3.0f);
399 Projector p = -1.f * nlog(ntan(o2::constants::math::PIQuarter - 0.5f * natan(o2::aod::fwdtrack::tgl)));
400 Projector p1 = ifnode(o2::aod::track::itsClusterSizes > (uint32_t)0, static_cast<uint8_t>(o2::aod::track::ITS), (uint8_t)0x0) |
401 ifnode(o2::aod::track::tpcNClsFindable > (uint8_t)0, static_cast<uint8_t>(o2::aod::track::TPC), (uint8_t)0x0) |
402 ifnode(o2::aod::track::trdPattern > (uint8_t)0, static_cast<uint8_t>(o2::aod::track::TRD), (uint8_t)0x0) |
403 ifnode((o2::aod::track::tofChi2 >= 0.f) && (o2::aod::track::tofExpMom > 0.f), static_cast<uint8_t>(o2::aod::track::TOF), (uint8_t)0x0);
404
405 std::vector<Projector> projectors;
406 projectors.emplace_back(std::move(f));
407 projectors.emplace_back(std::move(p));
408 projectors.emplace_back(std::move(p1));
409
410 std::stringstream osm;
412
413 std::stringstream ism;
414 ism.str(osm.str());
415 auto ps = ExpressionJSONHelpers::read(ism);
416
417 auto s1 = createOperations(projectors[0]);
418 auto s2 = createOperations(ps[0]);
419 auto schemaf = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Eta::asArrowField(), o2::aod::track::Signed1Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField()});
420 auto t1 = createExpressionTree(s1, schemaf);
421 auto t2 = createExpressionTree(s2, schemaf);
422 REQUIRE(t1->ToString() == t2->ToString());
423
424 auto s12 = createOperations(projectors[1]);
425 auto s22 = createOperations(ps[1]);
426 auto schemap = std::make_shared<arrow::Schema>(std::vector{o2::aod::fwdtrack::Tgl::asArrowField()});
427 auto t12 = createExpressionTree(s12, schemap);
428 auto t22 = createExpressionTree(s22, schemap);
429 REQUIRE(t12->ToString() == t22->ToString());
430
431 auto s13 = createOperations(projectors[2]);
432 auto s23 = createOperations(ps[2]);
433 auto schemap1 = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::ITSClusterSizes::asArrowField(), o2::aod::track::TPCNClsFindable::asArrowField(),
434 o2::aod::track::TRDPattern::asArrowField(), o2::aod::track::TOFChi2::asArrowField(),
435 o2::aod::track::TOFExpMom::asArrowField()});
436 auto t13 = createExpressionTree(s13, schemap1);
437 auto t23 = createExpressionTree(s23, schemap1);
438 REQUIRE(t13->ToString() == t23->ToString());
439
440 osm.clear();
441 osm.str("");
442 ArrowJSONHelpers::write(osm, schemaf);
443
444 ism.clear();
445 ism.str(osm.str());
446 auto newSchemaf = ArrowJSONHelpers::read(ism);
447 REQUIRE(schemaf->ToString() == newSchemaf->ToString());
448
449 osm.clear();
450 osm.str("");
451 ArrowJSONHelpers::write(osm, schemap);
452
453 ism.clear();
454 ism.str(osm.str());
455 auto newSchemap = ArrowJSONHelpers::read(ism);
456 REQUIRE(schemap->ToString() == newSchemap->ToString());
457
458 osm.clear();
459 osm.str("");
460 ArrowJSONHelpers::write(osm, schemap1);
461
462 ism.clear();
463 ism.str(osm.str());
464 auto newSchemap1 = ArrowJSONHelpers::read(ism);
465 REQUIRE(schemap1->ToString() == newSchemap1->ToString());
466
467 osm.clear();
468 osm.str("");
469 auto realisticSchema = std::make_shared<arrow::Schema>(o2::soa::createFieldsFromColumns(o2::aod::MetadataTrait<o2::aod::Hash<"HMPID/1"_h>>::metadata::persistent_columns_t{}));
470 ArrowJSONHelpers::write(osm, realisticSchema);
471
472 ism.clear();
473 ism.str(osm.str());
474 auto restoredSchema = ArrowJSONHelpers::read(ism);
475 REQUIRE(realisticSchema->ToString() == restoredSchema->ToString());
476
477 osm.clear();
478 osm.str("");
479 auto realisticSchema1 = std::make_shared<arrow::Schema>(o2::soa::createFieldsFromColumns(o2::aod::MetadataTrait<o2::aod::Hash<"ZDC/1"_h>>::metadata::persistent_columns_t{}));
480 ArrowJSONHelpers::write(osm, realisticSchema1);
481
482 ism.clear();
483 ism.str(osm.str());
484 auto restoredSchema1 = ArrowJSONHelpers::read(ism);
485 REQUIRE(realisticSchema1->ToString() == restoredSchema1->ToString());
486}
std::vector< expressions::Projector > projectors
#define DECLARE_SOA_EXPRESSION_COLUMN(_Name_, _Getter_, _Type_, _Expression_)
Definition ASoA.h:2497
const auto bins
Definition PID.cxx:49
std::shared_ptr< arrow::Schema > schema
std::vector< std::shared_ptr< arrow::Field > > fields
constexpr int p2()
constexpr int p1()
constexpr to accelerate the coordinates changing
Class for time synchronization of RawReader instances.
GLuint64EXT * result
Definition glcorearb.h:5662
GLuint GLfloat GLfloat GLfloat GLfloat GLfloat GLfloat GLfloat GLfloat s1
Definition glcorearb.h:5034
GLdouble GLdouble right
Definition glcorearb.h:4077
GLdouble f
Definition glcorearb.h:310
GLenum const GLfloat * params
Definition glcorearb.h:272
GLboolean GLuint group
Definition glcorearb.h:3991
GLboolean GLboolean g
Definition glcorearb.h:1233
GLenum clamp
Definition glcorearb.h:1245
GLuint GLfloat GLfloat GLfloat GLfloat GLfloat GLfloat GLfloat GLfloat GLfloat t1
Definition glcorearb.h:5034
constexpr float TwoPI
constexpr float PI
constexpr float PIQuarter
constexpr float PIHalf
std::shared_ptr< gandiva::Filter > createFilter(gandiva::SchemaPtr const &Schema, gandiva::ConditionPtr condition)
Function to create gandiva filter from gandiva condition.
gandiva::ExpressionPtr makeExpression(gandiva::NodePtr node, gandiva::FieldPtr result)
Function to create gandiva projecting expression from generic gandiva expression tree.
Node par(int index)
Parameters.
Operations createOperations(Filter const &expression)
Function to create an internal operation sequence from a filter tree.
gandiva::ConditionPtr makeCondition(gandiva::NodePtr node)
Function to create gandiva condition expression from generic gandiva expression tree.
Node ifnode(Node &&condition_, Node &&then_, Node &&else_)
conditionals
gandiva::NodePtr createExpressionTree(Operations const &opSpecs, gandiva::SchemaPtr const &Schema)
Function to create gandiva expression tree from operation sequence.
Node binned(std::vector< T > const &binning, std::vector< T > const &parameters, Node &&binned, Node &&pexp, Node &&out)
binned functional
std::shared_ptr< gandiva::Projector > createProjector(gandiva::SchemaPtr const &Schema, Operations const &opSpecs, gandiva::FieldPtr result)
Function to create gandiva projector from operation sequence.
std::shared_ptr< gandiva::Projector > createProjectors(framework::pack< C... >, std::vector< std::shared_ptr< arrow::Field > > const &fields, gandiva::SchemaPtr schema)
Defining PrimaryVertex explicitly as messageable.
TEST_CASE("test_prepareArguments")
auto createFieldsFromColumns(framework::pack< C... >)
Definition ASoA.h:57
static void write(std::ostream &o, std::shared_ptr< arrow::Schema > &schema)
static std::shared_ptr< arrow::Schema > read(std::istream &s)
static std::vector< expressions::Projector > read(std::istream &s)
static void write(std::ostream &o, std::vector< expressions::Projector > &projectors)
An expression tree node corresponding to a column binding.
A struct, containing the root of the expression tree.
std::unique_ptr< Node > node
An expression tree node corresponding to a literal value.
LiteralValue::stored_type var_t
static Node parse(std::string const &input)
void printTokens(Tokenizer &t)
std::unique_ptr< TTree > tree((TTree *) flIn.Get(std::string(o2::base::NameConf::CTFTREENAME).c_str()))