Project
Loading...
Searching...
No Matches
test_Expressions.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
16#include <catch_amalgamated.hpp>
17#include <arrow/util/config.h>
18
19using namespace o2::framework;
20using namespace o2::framework::expressions;
21
22namespace nodes
23{
24static BindingNode pt{"pt", 1, atype::FLOAT};
25static BindingNode phi{"phi", 2, atype::FLOAT};
26static BindingNode eta{"eta", 3, atype::FLOAT};
27
28static BindingNode tgl{"tgl", 4, atype::FLOAT};
29static BindingNode signed1Pt{"signed1Pt", 5, atype::FLOAT};
30static BindingNode testInt{"testInt", 6, atype::INT32};
31} // namespace nodes
32
33namespace o2::aod::track
34{
35DECLARE_SOA_EXPRESSION_COLUMN(Pze, pz, float, o2::aod::track::tgl * (1.f / o2::aod::track::signed1Pt));
36} // namespace o2::aod::track
37
38TEST_CASE("TestTreeParsing")
39{
40 expressions::Filter f = ((nodes::phi > 1) && (nodes::phi < 2)) && (nodes::eta < 1);
41 auto specs = createOperations(f);
42 REQUIRE(specs[0].left == (DatumSpec{1u, atype::BOOL}));
43 REQUIRE(specs[0].right == (DatumSpec{2u, atype::BOOL}));
44 REQUIRE(specs[0].result == (DatumSpec{0u, atype::BOOL}));
45
46 REQUIRE(specs[1].left == (DatumSpec{std::string{"eta"}, 3, atype::FLOAT}));
47 REQUIRE(specs[1].right == (DatumSpec{LiteralNode::var_t{1}, atype::INT32}));
48 REQUIRE(specs[1].result == (DatumSpec{2u, atype::BOOL}));
49
50 REQUIRE(specs[2].left == (DatumSpec{3u, atype::BOOL}));
51 REQUIRE(specs[2].right == (DatumSpec{4u, atype::BOOL}));
52 REQUIRE(specs[2].result == (DatumSpec{1u, atype::BOOL}));
53
54 REQUIRE(specs[3].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
55 REQUIRE(specs[3].right == (DatumSpec{LiteralNode::var_t{2}, atype::INT32}));
56 REQUIRE(specs[3].result == (DatumSpec{4u, atype::BOOL}));
57
58 REQUIRE(specs[4].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
59 REQUIRE(specs[4].right == (DatumSpec{LiteralNode::var_t{1}, atype::INT32}));
60 REQUIRE(specs[4].result == (DatumSpec{3u, atype::BOOL}));
61
62 expressions::Filter g = ((nodes::eta + 2.f) > 0.5) || ((nodes::phi - M_PI) < 3);
63 auto gspecs = createOperations(g);
64 REQUIRE(gspecs[0].left == (DatumSpec{1u, atype::BOOL}));
65 REQUIRE(gspecs[0].right == (DatumSpec{2u, atype::BOOL}));
66 REQUIRE(gspecs[0].result == (DatumSpec{0u, atype::BOOL}));
67
68 REQUIRE(gspecs[1].left == (DatumSpec{3u, atype::DOUBLE}));
69 REQUIRE(gspecs[1].right == (DatumSpec{LiteralNode::var_t{3}, atype::INT32}));
70 REQUIRE(gspecs[1].result == (DatumSpec{2u, atype::BOOL}));
71
72 REQUIRE(gspecs[2].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
73 REQUIRE(gspecs[2].right == (DatumSpec{LiteralNode::var_t{M_PI}, atype::DOUBLE}));
74 REQUIRE(gspecs[2].result == (DatumSpec{3u, atype::DOUBLE}));
75
76 REQUIRE(gspecs[3].left == (DatumSpec{4u, atype::FLOAT}));
77 REQUIRE(gspecs[3].right == (DatumSpec{LiteralNode::var_t{0.5}, atype::DOUBLE}));
78 REQUIRE(gspecs[3].result == (DatumSpec{1u, atype::BOOL}));
79
80 REQUIRE(gspecs[4].left == (DatumSpec{std::string{"eta"}, 3, atype::FLOAT}));
81 REQUIRE(gspecs[4].right == (DatumSpec{LiteralNode::var_t{2.f}, atype::FLOAT}));
82 REQUIRE(gspecs[4].result == (DatumSpec{4u, atype::FLOAT}));
83
84 expressions::Filter h = (nodes::phi == 0) || (nodes::phi == 3);
85 auto hspecs = createOperations(h);
86
87 REQUIRE(hspecs[0].left == (DatumSpec{1u, atype::BOOL}));
88 REQUIRE(hspecs[0].right == (DatumSpec{2u, atype::BOOL}));
89 REQUIRE(hspecs[0].result == (DatumSpec{0u, atype::BOOL}));
90
91 REQUIRE(hspecs[1].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
92 REQUIRE(hspecs[1].right == (DatumSpec{LiteralNode::var_t{3}, atype::INT32}));
93 REQUIRE(hspecs[1].result == (DatumSpec{2u, atype::BOOL}));
94
95 REQUIRE(hspecs[2].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
96 REQUIRE(hspecs[2].right == (DatumSpec{LiteralNode::var_t{0}, atype::INT32}));
97 REQUIRE(hspecs[2].result == (DatumSpec{1u, atype::BOOL}));
98
99 expressions::Filter u = nabs(nodes::eta) < 1.0 && nexp(nodes::phi + 2.0 * M_PI) > 3.0;
100 auto uspecs = createOperations(std::move(u));
101 REQUIRE(uspecs[0].left == (DatumSpec{1u, atype::BOOL}));
102 REQUIRE(uspecs[0].right == (DatumSpec{2u, atype::BOOL}));
103 REQUIRE(uspecs[0].result == (DatumSpec{0u, atype::BOOL}));
104
105 REQUIRE(uspecs[1].left == (DatumSpec{3u, atype::DOUBLE}));
106 REQUIRE(uspecs[1].right == (DatumSpec{LiteralNode::var_t{3.0}, atype::DOUBLE}));
107 REQUIRE(uspecs[1].result == (DatumSpec{2u, atype::BOOL}));
108
109 REQUIRE(uspecs[2].left == (DatumSpec{4u, atype::DOUBLE}));
110 REQUIRE(uspecs[2].right == (DatumSpec{}));
111 REQUIRE(uspecs[2].result == (DatumSpec{3u, atype::DOUBLE}));
112
113 REQUIRE(uspecs[3].left == (DatumSpec{std::string{"phi"}, 2, atype::FLOAT}));
114 REQUIRE(uspecs[3].right == (DatumSpec{LiteralNode::var_t{2.0 * M_PI}, atype::DOUBLE}));
115 REQUIRE(uspecs[3].result == (DatumSpec{4u, atype::DOUBLE}));
116
117 REQUIRE(uspecs[4].left == (DatumSpec{5u, atype::FLOAT}));
118 REQUIRE(uspecs[4].right == (DatumSpec{LiteralNode::var_t{1.0}, atype::DOUBLE}));
119 REQUIRE(uspecs[4].result == (DatumSpec{1u, atype::BOOL}));
120
121 REQUIRE(uspecs[5].left == (DatumSpec{std::string{"eta"}, 3, atype::FLOAT}));
122 REQUIRE(uspecs[5].right == (DatumSpec{}));
123 REQUIRE(uspecs[5].result == (DatumSpec{5u, atype::FLOAT}));
124
125 Configurable<float> pTCut{"pTCut", 0.5f, "Lower pT limit"};
126 Filter ptfilter = o2::aod::track::pt > pTCut;
127 REQUIRE(ptfilter.node->self.index() == 2);
128 REQUIRE(ptfilter.node->left->self.index() == 1);
129 REQUIRE(ptfilter.node->right->self.index() == 3);
130 auto ptfilterspecs = createOperations(ptfilter);
131 REQUIRE(ptfilterspecs[0].left == (DatumSpec{std::string{"fPt"}, typeid(o2::aod::track::Pt).hash_code(), atype::FLOAT}));
132 REQUIRE(ptfilterspecs[0].right == (DatumSpec{LiteralNode::var_t{0.5f}, atype::FLOAT}));
133 REQUIRE(ptfilterspecs[0].result == (DatumSpec{0u, atype::BOOL}));
134
135 struct : ConfigurableGroup {
136 std::string prefix = "prefix";
137 Configurable<float> pTCut{"pTCut", 1.0f, "Lower pT limit"};
138 } group;
139 Filter ptfilter2 = o2::aod::track::pt > group.pTCut;
140 group.pTCut.name.insert(0, 1, '.');
141 group.pTCut.name.insert(0, group.prefix);
142 REQUIRE(ptfilter2.node->self.index() == 2);
143 REQUIRE(ptfilter2.node->left->self.index() == 1);
144 REQUIRE(ptfilter2.node->right->self.index() == 3);
145 REQUIRE(std::get<PlaceholderNode>(ptfilter2.node->right->self).name == "prefix.pTCut");
146 auto ptfilterspecs2 = createOperations(ptfilter2);
147 REQUIRE(ptfilterspecs2[0].left == (DatumSpec{std::string{"fPt"}, typeid(o2::aod::track::Pt).hash_code(), atype::FLOAT}));
148 REQUIRE(ptfilterspecs2[0].right == (DatumSpec{LiteralNode::var_t{1.0f}, atype::FLOAT}));
149 REQUIRE(ptfilterspecs2[0].result == (DatumSpec{0u, atype::BOOL}));
150}
151
152TEST_CASE("TestGandivaTreeCreation")
153{
154 Projector pze = o2::aod::track::Pze::Projector();
155 auto pzspecs = createOperations(pze);
156 REQUIRE(pzspecs[0].left == (DatumSpec{std::string{"fTgl"}, typeid(o2::aod::track::Tgl).hash_code(), atype::FLOAT}));
157 REQUIRE(pzspecs[0].right == (DatumSpec{1u, atype::FLOAT}));
158 REQUIRE(pzspecs[0].result == (DatumSpec{0u, atype::FLOAT}));
159
160 REQUIRE(pzspecs[1].left == (DatumSpec{LiteralNode::var_t{1.f}, atype::FLOAT}));
161 REQUIRE(pzspecs[1].right == (DatumSpec{std::string{"fSigned1Pt"}, typeid(o2::aod::track::Signed1Pt).hash_code(), atype::FLOAT}));
162 REQUIRE(pzspecs[1].result == (DatumSpec{1u, atype::FLOAT}));
163 auto infield1 = o2::aod::track::Signed1Pt::asArrowField();
164 auto infield2 = o2::aod::track::Tgl::asArrowField();
165 auto resfield = o2::aod::track::Pze::asArrowField();
166 auto schema = std::make_shared<arrow::Schema>(std::vector{infield1, infield2, resfield});
167 auto gandiva_tree = createExpressionTree(pzspecs, schema);
168
169 auto gandiva_expression = makeExpression(gandiva_tree, resfield);
170 REQUIRE(std::string(gandiva_expression->ToString()) == std::string("float multiply((float) fTgl, float divide((const float) 1 raw(3f800000), (float) fSigned1Pt))"));
171 auto projector = createProjector(schema, pzspecs, resfield);
172
173 Projector pte = o2::aod::track::Pt::Projector();
174 auto ptespecs = createOperations(pte);
175
176 auto infield3 = o2::aod::track::Signed1Pt::asArrowField();
177 auto resfield2 = o2::aod::track::Pt::asArrowField();
178 auto schema2 = std::make_shared<arrow::Schema>(std::vector{infield3, resfield2});
179 auto gandiva_tree2 = createExpressionTree(ptespecs, schema2);
180
181 auto gandiva_expression2 = makeExpression(gandiva_tree2, resfield2);
182 REQUIRE(gandiva_expression2->ToString() == "if (bool less_than_or_equal_to(float absf((float) fSigned1Pt), (const float) 1.17549e-38 raw(800000))) { (const float) 8.50706e+37 raw(7e800000) } else { float absf(float divide((const float) 1 raw(3f800000), (float) fSigned1Pt)) }");
183
184 auto projector_b = createProjector(schema2, ptespecs, resfield2);
185 auto fields = o2::soa::createFieldsFromColumns(o2::aod::Tracks::persistent_columns_t{});
186 auto schema_p = std::make_shared<arrow::Schema>(fields);
188
189 Filter bitwiseFilter = (o2::aod::track::flags & static_cast<uint32_t>(o2::aod::track::TPCrefit)) != 0u;
190 auto bwf = createOperations(bitwiseFilter);
191 REQUIRE(bwf[0].left == (DatumSpec{1u, atype::UINT32}));
192 REQUIRE(bwf[0].right == (DatumSpec{LiteralNode::var_t{0u}, atype::UINT32}));
193 REQUIRE(bwf[0].result == (DatumSpec{0u, atype::BOOL}));
194
195 REQUIRE(bwf[1].left == (DatumSpec{std::string{"fFlags"}, typeid(o2::aod::track::Flags).hash_code(), atype::UINT32}));
196 REQUIRE(bwf[1].right == (DatumSpec{LiteralNode::var_t{static_cast<uint32_t>(o2::aod::track::TPCrefit)}, atype::UINT32}));
197 REQUIRE(bwf[1].result == (DatumSpec{1u, atype::UINT32}));
198
199 auto infield4 = o2::aod::track::Flags::asArrowField();
200 auto resfield3 = std::make_shared<arrow::Field>("out", arrow::boolean());
201 auto schema_b = std::make_shared<arrow::Schema>(std::vector{infield4, resfield3});
202 auto gandiva_tree3 = createExpressionTree(bwf, schema_b);
203 REQUIRE(gandiva_tree3->ToString() == "bool not_equal(uint32 bitwise_and((uint32) fFlags, (const uint32) 2), (const uint32) 0)");
204 auto condition = expressions::makeCondition(gandiva_tree3);
205 std::shared_ptr<gandiva::Filter> flt;
206 auto s = gandiva::Filter::Make(schema_b, condition, &flt);
207 REQUIRE(s.ok());
208
209 Filter rounding = nround(o2::aod::track::pt) > 0.1f;
210 auto rf = createOperations(rounding);
211 REQUIRE(rf[0].left == (DatumSpec{1u, atype::FLOAT}));
212 REQUIRE(rf[0].right == (DatumSpec{LiteralNode::var_t{0.1f}, atype::FLOAT}));
213 REQUIRE(rf[0].result == (DatumSpec{0u, atype::BOOL}));
214
215 REQUIRE(rf[1].left == (DatumSpec{std::string{"fPt"}, typeid(o2::aod::track::Pt).hash_code(), atype::FLOAT}));
216 REQUIRE(rf[1].right == (DatumSpec{}));
217 REQUIRE(rf[1].result == (DatumSpec{1u, atype::FLOAT}));
218
219 auto infield5 = o2::aod::track::Pt::asArrowField();
220 auto resfield4 = std::make_shared<arrow::Field>("out", arrow::boolean());
221 auto schema_c = std::make_shared<arrow::Schema>(std::vector{infield5, resfield4});
222 auto gandiva_tree4 = createExpressionTree(rf, schema_c);
223 REQUIRE(gandiva_tree4->ToString() == "bool greater_than(float round((float) fPt), (const float) 0.1 raw(3dcccccd))");
224 auto condition2 = expressions::makeCondition(gandiva_tree4);
225 std::shared_ptr<gandiva::Filter> flt2;
226 auto s2 = gandiva::Filter::Make(schema_c, condition2, &flt2);
227 REQUIRE(s2.ok());
228}
229
230TEST_CASE("TestConditionalExpressions")
231{
232 // simple conditional
233 Filter cf = nabs(o2::aod::track::eta) < 1.0f && ifnode((o2::aod::track::pt < 1.0f), (o2::aod::track::phi > (float)(M_PI / 2.)), (o2::aod::track::phi < (float)(M_PI / 2.)));
234 auto cfspecs = createOperations(cf);
235 REQUIRE(cfspecs[0].left == (DatumSpec{1u, atype::BOOL}));
236 REQUIRE(cfspecs[0].right == (DatumSpec{2u, atype::BOOL}));
237 REQUIRE(cfspecs[0].result == (DatumSpec{0u, atype::BOOL}));
238
239 REQUIRE(cfspecs[1].left == (DatumSpec{3u, atype::BOOL}));
240 REQUIRE(cfspecs[1].right == (DatumSpec{4u, atype::BOOL}));
241 REQUIRE(cfspecs[1].condition == (DatumSpec{5u, atype::BOOL}));
242 REQUIRE(cfspecs[1].result == (DatumSpec{2u, atype::BOOL}));
243
244 REQUIRE(cfspecs[2].left == (DatumSpec{std::string{"fPt"}, typeid(o2::aod::track::Pt).hash_code(), atype::FLOAT}));
245 REQUIRE(cfspecs[2].right == (DatumSpec{LiteralNode::var_t{1.0f}, atype::FLOAT}));
246 REQUIRE(cfspecs[2].result == (DatumSpec{5u, atype::BOOL}));
247
248 REQUIRE(cfspecs[3].left == (DatumSpec{std::string{"fPhi"}, typeid(o2::aod::track::Phi).hash_code(), atype::FLOAT}));
249 REQUIRE(cfspecs[3].right == (DatumSpec{LiteralNode::var_t{(float)(M_PI / 2.)}, atype::FLOAT}));
250 REQUIRE(cfspecs[3].result == (DatumSpec{4u, atype::BOOL}));
251
252 REQUIRE(cfspecs[4].left == (DatumSpec{std::string{"fPhi"}, typeid(o2::aod::track::Phi).hash_code(), atype::FLOAT}));
253 REQUIRE(cfspecs[4].right == (DatumSpec{LiteralNode::var_t{(float)(M_PI / 2.)}, atype::FLOAT}));
254 REQUIRE(cfspecs[4].result == (DatumSpec{3u, atype::BOOL}));
255
256 REQUIRE(cfspecs[5].left == (DatumSpec{6u, atype::FLOAT}));
257 REQUIRE(cfspecs[5].right == (DatumSpec{LiteralNode::var_t{1.0f}, atype::FLOAT}));
258 REQUIRE(cfspecs[5].result == (DatumSpec{1u, atype::BOOL}));
259
260 REQUIRE(cfspecs[6].left == (DatumSpec{std::string{"fEta"}, typeid(o2::aod::track::Eta).hash_code(), atype::FLOAT}));
261 REQUIRE(cfspecs[6].right == (DatumSpec{}));
262 REQUIRE(cfspecs[6].result == (DatumSpec{6u, atype::FLOAT}));
263
264 auto infield1 = o2::aod::track::Pt::asArrowField();
265 auto infield2 = o2::aod::track::Eta::asArrowField();
266 auto infield3 = o2::aod::track::Phi::asArrowField();
267 auto schema = std::make_shared<arrow::Schema>(std::vector{infield1, infield2, infield3});
268 auto gandiva_tree = createExpressionTree(cfspecs, schema);
269 auto gandiva_condition = makeCondition(gandiva_tree);
270 auto gandiva_filter = createFilter(schema, gandiva_condition);
271
272 REQUIRE(gandiva_tree->ToString() == "bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }");
273
274 // nested conditional
275 Filter cfn = o2::aod::track::signed1Pt > 0.f && ifnode(std::move(*cf.node), nabs(o2::aod::track::x) > 1.0f, nabs(o2::aod::track::y) > 1.0f);
276 auto cfnspecs = createOperations(cfn);
277 auto infield4 = o2::aod::track::Signed1Pt::asArrowField();
278 auto infield5 = o2::aod::track::X::asArrowField();
279 auto infield6 = o2::aod::track::Y::asArrowField();
280 auto schema2 = std::make_shared<arrow::Schema>(std::vector{infield1, infield2, infield3, infield4, infield5, infield6});
281 auto gandiva_tree2 = createExpressionTree(cfnspecs, schema2);
282 auto gandiva_condition2 = makeCondition(gandiva_tree2);
283 auto gandiva_filter2 = createFilter(schema2, gandiva_condition2);
284 REQUIRE(gandiva_tree2->ToString() == "bool greater_than((float) fSigned1Pt, (const float) 0 raw(0)) && if (bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }) { bool greater_than(float absf((float) fX), (const float) 1 raw(3f800000)) } else { bool greater_than(float absf((float) fY), (const float) 1 raw(3f800000)) }");
285}
#define DECLARE_SOA_EXPRESSION_COLUMN(_Name_, _Getter_, _Type_, _Expression_)
Definition ASoA.h:2388
Class for time synchronization of RawReader instances.
GLuint64EXT * result
Definition glcorearb.h:5662
GLdouble GLdouble right
Definition glcorearb.h:4077
GLdouble f
Definition glcorearb.h:310
GLboolean GLuint group
Definition glcorearb.h:3991
GLboolean GLboolean g
Definition glcorearb.h:1233
std::shared_ptr< gandiva::Filter > createFilter(gandiva::SchemaPtr const &Schema, gandiva::ConditionPtr condition)
Function to create gandiva filter from gandiva condition.
gandiva::ExpressionPtr makeExpression(gandiva::NodePtr node, gandiva::FieldPtr result)
Function to create gandiva projecting expression from generic gandiva expression tree.
Operations createOperations(Filter const &expression)
Function to create an internal operation sequence from a filter tree.
gandiva::ConditionPtr makeCondition(gandiva::NodePtr node)
Function to create gandiva condition expression from generic gandiva expression tree.
Node ifnode(Node &&condition_, Node &&then_, Node &&else_)
conditionals
gandiva::NodePtr createExpressionTree(Operations const &opSpecs, gandiva::SchemaPtr const &Schema)
Function to create gandiva expression tree from operation sequence.
std::shared_ptr< gandiva::Projector > createProjector(gandiva::SchemaPtr const &Schema, Operations const &opSpecs, gandiva::FieldPtr result)
Function to create gandiva projector from operation sequence.
std::shared_ptr< gandiva::Projector > createProjectors(framework::pack< C... >, std::vector< std::shared_ptr< arrow::Field > > const &fields, gandiva::SchemaPtr schema)
Defining PrimaryVertex explicitly as messageable.
Definition TFIDInfo.h:20
TEST_CASE("test_prepareArguments")
auto createFieldsFromColumns(framework::pack< C... >)
Definition ASoA.h:401
An expression tree node corresponding to a column binding.
A struct, containing the root of the expression tree.
std::unique_ptr< Node > node
LiteralValue::stored_type var_t