Project
Loading...
Searching...
No Matches
RNTuplePlugin.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
14#include "Framework/Plugins.h"
16#include <ROOT/RNTupleModel.hxx>
17#include <ROOT/RNTupleWriteOptions.hxx>
18#include <ROOT/RNTupleWriter.hxx>
19#include <ROOT/RField.hxx>
20#include <ROOT/RNTuple.hxx>
21#include <ROOT/RNTupleReader.hxx>
22#include <ROOT/RFieldVisitor.hxx>
23#include <ROOT/RNTupleInspector.hxx>
24#include <ROOT/RVec.hxx>
25#include <memory>
26#include <TBufferFile.h>
27
28#include <TDirectory.h>
29#include <arrow/array/array_nested.h>
30#include <arrow/array/array_primitive.h>
31#include <arrow/array/builder_nested.h>
32#include <arrow/array/builder_primitive.h>
33#include <arrow/array/util.h>
34#include <arrow/record_batch.h>
35#include <arrow/dataset/file_base.h>
36
37#if __has_include(<ROOT/RFieldBase.hxx>)
38#include <ROOT/RFieldBase.hxx>
39namespace rns = ROOT;
40using DPLFieldToken = rns::RFieldToken;
41using DPLLocalIndex = rns::RNTupleLocalIndex;
42#else
43namespace rns = ROOT::Experimental;
44using DPLFieldToken = rns::REntry::RFieldToken;
45using DPLLocalIndex = rns::RClusterIndex;
46#endif
47
48template class
49 std::unique_ptr<rns::RNTupleReader>;
50
51namespace o2::framework
52{
53
55{
56 public:
57 RNTupleFileWriteOptions(std::shared_ptr<arrow::dataset::FileFormat> format)
58 : FileWriteOptions(format)
59 {
60 }
61};
62
63// A filesystem which allows me to get a RNTuple
65{
66 public:
68
69 virtual rns::RNTuple* GetRNTuple(arrow::dataset::FileSource source) = 0;
70};
71
73{
74 public:
75 SingleRNTupleFileSystem(rns::RNTuple* tuple)
77 mTuple(tuple)
78 {
79 }
80
81 arrow::Result<arrow::fs::FileInfo> GetFileInfo(std::string const& path) override;
82
83 std::string type_name() const override
84 {
85 return "rntuple";
86 }
87
88 rns::RNTuple* GetRNTuple(arrow::dataset::FileSource) override
89 {
90 // Simply return the only TTree we have
91 return mTuple;
92 }
93
94 private:
95 rns::RNTuple* mTuple;
96};
97
98arrow::Result<arrow::fs::FileInfo> SingleRNTupleFileSystem::GetFileInfo(std::string const& path)
99{
100 arrow::dataset::FileSource source(path, shared_from_this());
101 arrow::fs::FileInfo result;
102 result.set_path(path);
103 result.set_type(arrow::fs::FileType::File);
104 return result;
105}
106
108{
109 public:
110 RNTupleFileFragment(arrow::dataset::FileSource source,
111 std::shared_ptr<arrow::dataset::FileFormat> format,
112 arrow::compute::Expression partition_expression,
113 std::shared_ptr<arrow::Schema> physical_schema)
114 : FileFragment(source, format, partition_expression, physical_schema)
115 {
116 auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source.filesystem());
117 if (!fs.get()) {
118 throw runtime_error_f("Do not know how to extract %s from %s", source.path().c_str(), fs->type_name().c_str());
119 }
120 auto handler = fs->GetObjectHandler(source);
121 if (!handler->format->Equals(*format)) {
122 throw runtime_error_f("Format for %s does not match. Found %s, expected %s.", source.path().c_str(),
123 handler->format->type_name().c_str(),
124 format->type_name().c_str());
125 }
126 mNTuple = handler->GetObjectAsOwner<rns::RNTuple>();
127 }
128
129 rns::RNTuple* GetRNTuple()
130 {
131 return mNTuple.get();
132 }
133
134 private:
135 std::unique_ptr<rns::RNTuple> mNTuple;
136};
137
139{
140 size_t& mTotCompressedSize;
141 size_t& mTotUncompressedSize;
142
143 public:
144 RNTupleFileFormat(size_t& totalCompressedSize, size_t& totalUncompressedSize)
145 : FileFormat({}),
146 mTotCompressedSize(totalCompressedSize),
147 mTotUncompressedSize(totalUncompressedSize)
148 {
149 }
150
151 ~RNTupleFileFormat() override = default;
152
153 std::string type_name() const override
154 {
155 return "rntuple";
156 }
157
158 bool Equals(const FileFormat& other) const override
159 {
160 return other.type_name() == this->type_name();
161 }
162
163 arrow::Result<bool> IsSupported(const arrow::dataset::FileSource& source) const override
164 {
165 auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source.filesystem());
166 if (!fs) {
167 return false;
168 }
169 return fs->CheckSupport(source);
170 }
171
172 arrow::Result<std::shared_ptr<arrow::Schema>> Inspect(const arrow::dataset::FileSource& source) const override;
173
174 arrow::Result<arrow::RecordBatchGenerator> ScanBatchesAsync(
175 const std::shared_ptr<arrow::dataset::ScanOptions>& options,
176 const std::shared_ptr<arrow::dataset::FileFragment>& fragment) const override;
177
178 std::shared_ptr<arrow::dataset::FileWriteOptions> DefaultWriteOptions() override;
179
180 arrow::Result<std::shared_ptr<arrow::dataset::FileWriter>> MakeWriter(std::shared_ptr<arrow::io::OutputStream> destination,
181 std::shared_ptr<arrow::Schema> schema,
182 std::shared_ptr<arrow::dataset::FileWriteOptions> options,
183 arrow::fs::FileLocator destination_locator) const override;
184 arrow::Result<std::shared_ptr<arrow::dataset::FileFragment>> MakeFragment(
185 arrow::dataset::FileSource source, arrow::compute::Expression partition_expression,
186 std::shared_ptr<arrow::Schema> physical_schema) override;
187};
188
189template <typename T>
190 requires requires(T&& f) { f.GetSubFields(); }
191auto getSubfields(T const& field)
192{
193 return field.GetSubFields();
194}
195
196template <typename T>
197 requires requires(T&& f) { f.GetConstSubfields(); }
198auto getSubfields(T const& field)
199{
200 return field.GetConstSubfields();
201}
202
203struct RootNTupleVisitor : public rns::Detail::RFieldVisitor {
204 void VisitArrayField(const rns::RArrayField& field) override
205 {
206 int size = field.GetLength();
207 RootNTupleVisitor valueVisitor{};
208 auto valueField = getSubfields(field)[0];
209 valueField->AcceptVisitor(valueVisitor);
210 auto type = valueVisitor.datatype;
211 this->datatype = arrow::fixed_size_list(type, size);
212 }
213
214 void VisitRVecField(const rns::RRVecField& field) override
215 {
216 RootNTupleVisitor valueVisitor{};
217 auto valueField = getSubfields(field)[0];
218 valueField->AcceptVisitor(valueVisitor);
219 auto type = valueVisitor.datatype;
220 this->datatype = arrow::list(type);
221 }
222
223 void VisitField(const rns::RFieldBase& field) override
224 {
225 throw o2::framework::runtime_error_f("Unknown field %s with type %s", field.GetFieldName().c_str(), field.GetTypeName().c_str());
226 }
227
228#if __has_include(<ROOT/RFieldBase.hxx>)
229 void VisitInt32Field(const rns::RIntegralField<int>& field) override
230 {
231 this->datatype = arrow::int32();
232 }
233 void VisitInt8Field(const rns::RIntegralField<std::int8_t>& field) override
234 {
235 this->datatype = arrow::int8();
236 }
237
238 void VisitInt16Field(const rns::RIntegralField<std::int16_t>& field) override
239 {
240 this->datatype = arrow::int16();
241 }
242
243 void VisitUInt32Field(const rns::RIntegralField<std::uint32_t>& field) override
244 {
245 this->datatype = arrow::uint32();
246 }
247
248 void VisitUInt8Field(const rns::RIntegralField<std::uint8_t>& field) override
249 {
250 this->datatype = arrow::uint8();
251 }
252
253 void VisitUInt16Field(const rns::RIntegralField<std::uint16_t>& field) override
254 {
255 this->datatype = arrow::int16();
256 }
257#else
258 void VisitIntField(const rns::RField<int>& field) override
259 {
260 this->datatype = arrow::int32();
261 }
262 void VisitInt8Field(const ROOT::Experimental::RField<std::int8_t>& field) override
263 {
264 this->datatype = arrow::int8();
265 }
266
267 void VisitInt16Field(const ROOT::Experimental::RField<std::int16_t>& field) override
268 {
269 this->datatype = arrow::int16();
270 }
271
272 void VisitUInt32Field(const ROOT::Experimental::RField<std::uint32_t>& field) override
273 {
274 this->datatype = arrow::uint32();
275 }
276
277 void VisitUInt8Field(const ROOT::Experimental::RField<std::uint8_t>& field) override
278 {
279 this->datatype = arrow::uint8();
280 }
281
282 void VisitUInt16Field(const ROOT::Experimental::RField<std::uint16_t>& field) override
283 {
284 this->datatype = arrow::int16();
285 }
286#endif
287
288 void VisitBoolField(const rns::RField<bool>& field) override
289 {
290 this->datatype = arrow::boolean();
291 }
292
293 void VisitFloatField(const rns::RField<float>& field) override
294 {
295 this->datatype = arrow::float32();
296 }
297
298 void VisitDoubleField(const rns::RField<double>& field) override
299 {
300 this->datatype = arrow::float64();
301 }
302 std::shared_ptr<arrow::DataType> datatype;
303};
304} // namespace o2::framework
305
306auto arrowTypeFromRNTuple(rns::RFieldBase const& field, int size)
307{
309 field.AcceptVisitor(visitor);
310 return visitor.datatype;
311}
312
313namespace o2::framework
314{
315std::unique_ptr<rns::RFieldBase> rootFieldFromArrow(std::shared_ptr<arrow::Field> field, std::string name)
316{
317 using namespace ROOT::Experimental;
318 switch (field->type()->id()) {
319 case arrow::Type::BOOL:
320 return std::make_unique<rns::RField<bool>>(name);
321 case arrow::Type::UINT8:
322 return std::make_unique<rns::RField<uint8_t>>(name);
323 case arrow::Type::UINT16:
324 return std::make_unique<rns::RField<uint16_t>>(name);
325 case arrow::Type::UINT32:
326 return std::make_unique<rns::RField<uint32_t>>(name);
327 case arrow::Type::UINT64:
328 return std::make_unique<rns::RField<uint64_t>>(name);
329 case arrow::Type::INT8:
330 return std::make_unique<rns::RField<int8_t>>(name);
331 case arrow::Type::INT16:
332 return std::make_unique<rns::RField<int16_t>>(name);
333 case arrow::Type::INT32:
334 return std::make_unique<rns::RField<int32_t>>(name);
335 case arrow::Type::INT64:
336 return std::make_unique<rns::RField<int64_t>>(name);
337 case arrow::Type::FLOAT:
338 return std::make_unique<rns::RField<float>>(name);
339 case arrow::Type::DOUBLE:
340 return std::make_unique<rns::RField<double>>(name);
341 case arrow::Type::STRING:
342 return std::make_unique<rns::RField<std::string>>(name);
343 default:
344 throw runtime_error("Unsupported arrow column type");
345 }
346}
347
349{
350 std::shared_ptr<rns::RNTupleWriter> mWriter;
351 bool firstBatch = true;
352 std::vector<std::shared_ptr<arrow::Array>> valueArrays;
353 std::vector<std::shared_ptr<arrow::DataType>> valueTypes;
354 std::vector<size_t> valueCount;
355
356 public:
357 RNTupleFileWriter(std::shared_ptr<arrow::Schema> schema, std::shared_ptr<arrow::dataset::FileWriteOptions> options,
358 std::shared_ptr<arrow::io::OutputStream> destination,
359 arrow::fs::FileLocator destination_locator)
360 : FileWriter(schema, options, destination, destination_locator)
361 {
362 using namespace ROOT::Experimental;
363
364 auto model = rns::RNTupleModel::CreateBare();
365 // Let's create a model from the physical schema
366 for (auto i = 0u; i < schema->fields().size(); ++i) {
367 auto& field = schema->field(i);
368
369 // Construct all the needed branches.
370 switch (field->type()->id()) {
371 case arrow::Type::FIXED_SIZE_LIST: {
372 auto list = std::static_pointer_cast<arrow::FixedSizeListType>(field->type());
373 auto valueField = field->type()->field(0);
374 model->AddField(std::make_unique<rns::RArrayField>(field->name(), rootFieldFromArrow(valueField, "_0"), list->list_size()));
375 } break;
376 case arrow::Type::LIST: {
377 auto valueField = field->type()->field(0);
378 model->AddField(std::make_unique<rns::RRVecField>(field->name(), rootFieldFromArrow(valueField, "_0")));
379 } break;
380 default: {
381 model->AddField(rootFieldFromArrow(field, field->name()));
382 } break;
383 }
384 }
385 auto fileStream = std::dynamic_pointer_cast<TDirectoryFileOutputStream>(destination_);
386 auto* file = dynamic_cast<TFile*>(fileStream->GetDirectory());
387 mWriter = rns::RNTupleWriter::Append(std::move(model), destination_locator_.path, *file, {});
388 }
389
390 arrow::Status Write(const std::shared_ptr<arrow::RecordBatch>& batch) override
391 {
392 if (firstBatch) {
393 firstBatch = false;
394 }
395
396 // Support writing empty tables
397 if (batch->columns().empty() || batch->num_rows() == 0) {
398 return arrow::Status::OK();
399 }
400
401 for (auto i = 0u; i < batch->columns().size(); ++i) {
402 auto column = batch->column(i);
403 auto& field = batch->schema()->field(i);
404
405 valueArrays.push_back(nullptr);
406 valueTypes.push_back(nullptr);
407 valueCount.push_back(1);
408
409 switch (field->type()->id()) {
410 case arrow::Type::FIXED_SIZE_LIST: {
411 auto list = std::static_pointer_cast<arrow::FixedSizeListArray>(column);
412 auto listType = std::static_pointer_cast<arrow::FixedSizeListType>(field->type());
413 if (field->type()->field(0)->type()->id() == arrow::Type::BOOL) {
414 auto boolArray = std::static_pointer_cast<arrow::BooleanArray>(list->values());
415 int64_t length = boolArray->length();
416 arrow::UInt8Builder builder;
417 auto ok = builder.Reserve(length);
418
419 for (int64_t i = 0; i < length; ++i) {
420 if (boolArray->IsValid(i)) {
421 // Expand each boolean value (true/false) to uint8 (1/0)
422 uint8_t value = boolArray->Value(i) ? 1 : 0;
423 auto ok = builder.Append(value);
424 } else {
425 // Append null for invalid entries
426 auto ok = builder.AppendNull();
427 }
428 }
429 valueArrays.back() = *builder.Finish();
430 valueTypes.back() = valueArrays.back()->type();
431 } else {
432 valueArrays.back() = list->values();
433 valueTypes.back() = field->type()->field(0)->type();
434 }
435 valueCount.back() = listType->list_size();
436 } break;
437 case arrow::Type::LIST: {
438 auto list = std::static_pointer_cast<arrow::ListArray>(column);
439 valueArrays.back() = list;
440 valueTypes.back() = field->type()->field(0)->type();
441 valueCount.back() = -1;
442 } break;
443 case arrow::Type::BOOL: {
444 // We unpack the array
445 auto boolArray = std::static_pointer_cast<arrow::BooleanArray>(column);
446 int64_t length = boolArray->length();
447 arrow::UInt8Builder builder;
448 auto ok = builder.Reserve(length);
449
450 for (int64_t i = 0; i < length; ++i) {
451 if (boolArray->IsValid(i)) {
452 // Expand each boolean value (true/false) to uint8 (1/0)
453 uint8_t value = boolArray->Value(i) ? 1 : 0;
454 auto ok = builder.Append(value);
455 } else {
456 // Append null for invalid entries
457 auto ok = builder.AppendNull();
458 }
459 }
460 valueArrays.back() = *builder.Finish();
461 valueTypes.back() = valueArrays.back()->type();
462 } break;
463 default:
464 valueArrays.back() = column;
465 valueTypes.back() = field->type();
466 break;
467 }
468 }
469
470 int64_t pos = 0;
471
472 auto entry = mWriter->CreateEntry();
473 std::vector<DPLFieldToken> tokens;
474 tokens.reserve(batch->num_columns());
475 std::vector<size_t> typeIds;
476 typeIds.reserve(batch->num_columns());
477
478 for (size_t ci = 0; ci < batch->num_columns(); ++ci) {
479 auto& field = batch->schema()->field(ci);
480 typeIds.push_back(batch->column(ci)->type()->id());
481 tokens.push_back(entry->GetToken(field->name()));
482 }
483
484 while (pos < batch->num_rows()) {
485 for (size_t ci = 0; ci < batch->num_columns(); ++ci) {
486 auto typeId = typeIds[ci];
487 auto token = tokens[ci];
488
489 switch (typeId) {
490 case arrow::Type::LIST: {
491 auto list = std::static_pointer_cast<arrow::ListArray>(valueArrays[ci]);
492 auto value_slice = list->value_slice(pos);
493
494 valueCount[ci] = value_slice->length();
495 auto bindValue = [&vc = valueCount, ci, token](auto array, std::unique_ptr<rns::REntry>& entry) -> void {
496 using value_type = std::decay_t<decltype(*array.get())>::value_type;
497 auto v = std::make_shared<ROOT::VecOps::RVec<value_type>>((value_type*)array->raw_values(), vc[ci]);
498 entry->BindValue(token, v);
499 };
500 switch (valueTypes[ci]->id()) {
501 case arrow::Type::FLOAT: {
502 bindValue(std::static_pointer_cast<arrow::FloatArray>(value_slice), entry);
503 } break;
504 case arrow::Type::DOUBLE: {
505 bindValue(std::static_pointer_cast<arrow::DoubleArray>(value_slice), entry);
506 } break;
507 case arrow::Type::INT8: {
508 bindValue(std::static_pointer_cast<arrow::Int8Array>(value_slice), entry);
509 } break;
510 case arrow::Type::INT16: {
511 bindValue(std::static_pointer_cast<arrow::Int16Array>(value_slice), entry);
512 } break;
513 case arrow::Type::INT32: {
514 bindValue(std::static_pointer_cast<arrow::Int32Array>(value_slice), entry);
515 } break;
516 case arrow::Type::INT64: {
517 bindValue(std::static_pointer_cast<arrow::Int64Array>(value_slice), entry);
518 } break;
519 case arrow::Type::UINT8: {
520 bindValue(std::static_pointer_cast<arrow::UInt8Array>(value_slice), entry);
521 } break;
522 case arrow::Type::UINT16: {
523 bindValue(std::static_pointer_cast<arrow::UInt16Array>(value_slice), entry);
524 } break;
525 case arrow::Type::UINT32: {
526 bindValue(std::static_pointer_cast<arrow::UInt32Array>(value_slice), entry);
527 } break;
528 case arrow::Type::UINT64: {
529 bindValue(std::static_pointer_cast<arrow::UInt64Array>(value_slice), entry);
530 } break;
531 default: {
532 throw runtime_error("Unsupported kind of VLA");
533 } break;
534 }
535 } break;
536 case arrow::Type::FIXED_SIZE_LIST: {
537 entry->BindRawPtr<void>(token, (void*)(valueArrays[ci]->data()->buffers[1]->data() + pos * valueCount[ci] * valueTypes[ci]->byte_width()));
538 } break;
539 case arrow::Type::BOOL: {
540 // Not sure we actually need this
541 entry->BindRawPtr<bool>(token, (bool*)(valueArrays[ci]->data()->buffers[1]->data() + pos * 1));
542 } break;
543 default:
544 // By default we consider things scalars.
545 entry->BindRawPtr<void>(token, (void*)(valueArrays[ci]->data()->buffers[1]->data() + pos * valueTypes[ci]->byte_width()));
546 break;
547 }
548 }
549 mWriter->Fill(*entry);
550 ++pos;
551 }
552 // mWriter->CommitCluster();
553
554 return arrow::Status::OK();
555 }
556
557 arrow::Future<>
558 FinishInternal() override
559 {
560 return {};
561 };
562};
563
564template <typename T>
565 requires requires(T const& m) { m.GetFieldZero(); }
566auto& getFieldZero(T const& m)
567{
568 return m.GetFieldZero();
569}
570
571template <typename T>
572 requires requires(T const& m) { m.GetConstFieldZero(); }
573auto& getFieldZero(T const& m)
574{
575 return m.GetConstFieldZero();
576}
577
578arrow::Result<std::shared_ptr<arrow::Schema>> RNTupleFileFormat::Inspect(const arrow::dataset::FileSource& source) const
579{
580
581 auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source.filesystem());
582 // Actually get the TTree from the ROOT file.
583 auto objectHandler = fs->GetObjectHandler(source);
584 if (objectHandler->format->type_name() != this->type_name()) {
585 throw runtime_error_f("Unexpected kind of filesystem %s to handle payload %s.\n", source.filesystem()->type_name().c_str(), source.path().c_str());
586 }
587 // We know this is a RNTuple, so we can continue with the inspection.
588 auto rntuple = objectHandler->GetObjectAsOwner<rns::RNTuple>().release();
589
590#if __has_include(<ROOT/RFieldBase.hxx>)
591 auto reader = rns::RNTupleReader::Open(*rntuple);
592#else
593 auto reader = rns::RNTupleReader::Open(rntuple);
594#endif
595
596 auto& tupleField0 = getFieldZero(reader->GetModel());
597 std::vector<std::shared_ptr<arrow::Field>> fields;
598 for (auto& tupleField : getSubfields(tupleField0)) {
599 auto field = std::make_shared<arrow::Field>(tupleField->GetFieldName(), arrowTypeFromRNTuple(*tupleField, tupleField->GetValueSize()));
600 fields.push_back(field);
601 }
602
603 return std::make_shared<arrow::Schema>(fields);
604}
605
606arrow::Result<arrow::RecordBatchGenerator> RNTupleFileFormat::ScanBatchesAsync(
607 const std::shared_ptr<arrow::dataset::ScanOptions>& options,
608 const std::shared_ptr<arrow::dataset::FileFragment>& fragment) const
609{
610 auto dataset_schema = options->dataset_schema;
611 auto ntupleFragment = std::dynamic_pointer_cast<RNTupleFileFragment>(fragment);
612
613 auto generator = [pool = options->pool, ntupleFragment, dataset_schema, &totalCompressedSize = mTotCompressedSize,
614 &totalUncompressedSize = mTotUncompressedSize]() -> arrow::Future<std::shared_ptr<arrow::RecordBatch>> {
615 using namespace ROOT::Experimental;
616 std::vector<std::shared_ptr<arrow::Array>> columns;
617 std::vector<std::shared_ptr<arrow::Field>> fields = dataset_schema->fields();
618
619 int64_t rows = -1;
620 rns::RNTuple* rntuple = ntupleFragment->GetRNTuple();
621#if __has_include(<ROOT/RFieldBase.hxx>)
622 auto reader = rns::RNTupleReader::Open(*rntuple);
623#else
624 auto reader = rns::RNTupleReader::Open(rntuple);
625#endif
626 auto& model = reader->GetModel();
627 for (auto& physicalField : fields) {
628 auto bulk = model.CreateBulk(physicalField->name());
629
630 auto listType = std::dynamic_pointer_cast<arrow::FixedSizeListType>(physicalField->type());
631
632 auto& descriptor = reader->GetDescriptor();
633 auto totalEntries = reader->GetNEntries();
634
635 if (rows == -1) {
636 rows = totalEntries;
637 }
638 if (rows != totalEntries) {
639 throw runtime_error_f("Unmatching number of rows for branch %s", physicalField->name().c_str());
640 }
641 arrow::Status status;
642 int readEntries = 0;
643 std::shared_ptr<arrow::Array> array;
644 if (physicalField->type() == arrow::boolean() ||
645 (listType && physicalField->type()->field(0)->type() == arrow::boolean())) {
646 if (listType) {
647 std::unique_ptr<arrow::ArrayBuilder> builder = nullptr;
648 auto status = arrow::MakeBuilder(pool, physicalField->type()->field(0)->type(), &builder);
649 if (!status.ok()) {
650 throw runtime_error("Cannot create value builder");
651 }
652 auto listBuilder = std::make_unique<arrow::FixedSizeListBuilder>(pool, std::move(builder), listType->list_size());
653 auto valueBuilder = listBuilder.get()->value_builder();
654 // boolean array special case: we need to use builder to create the bitmap
655 status = valueBuilder->Reserve(totalEntries * listType->list_size());
656 status &= listBuilder->Reserve(totalEntries);
657 if (!status.ok()) {
658 throw runtime_error("Failed to reserve memory for array builder");
659 }
660 auto clusterIt = descriptor.FindClusterId(0, 0);
661 // No adoption for now...
662 // bulk.AdoptBuffer(buffer, totalEntries)
663 while (clusterIt != rns::kInvalidDescriptorId) {
664 auto& index = descriptor.GetClusterDescriptor(clusterIt);
665 auto mask = std::make_unique<bool[]>(index.GetNEntries());
666 std::fill(mask.get(), mask.get() + index.GetNEntries(), true);
667 void* ptr = bulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries());
668 int readLast = index.GetNEntries();
669 readEntries += readLast;
670 status &= static_cast<arrow::BooleanBuilder*>(valueBuilder)->AppendValues(reinterpret_cast<uint8_t const*>(ptr), readLast * listType->list_size());
671 clusterIt = descriptor.FindNextClusterId(clusterIt);
672 }
673 status &= static_cast<arrow::FixedSizeListBuilder*>(listBuilder.get())->AppendValues(readEntries);
674 if (!status.ok()) {
675 throw runtime_error("Failed to append values to array");
676 }
677 status &= listBuilder->Finish(&array);
678 if (!status.ok()) {
679 throw runtime_error("Failed to create array");
680 }
681 } else if (listType == nullptr) {
682 std::unique_ptr<arrow::ArrayBuilder> builder = nullptr;
683 auto status = arrow::MakeBuilder(pool, physicalField->type(), &builder);
684 if (!status.ok()) {
685 throw runtime_error("Cannot create builder");
686 }
687 auto valueBuilder = static_cast<arrow::BooleanBuilder*>(builder.get());
688 // boolean array special case: we need to use builder to create the bitmap
689 status = valueBuilder->Reserve(totalEntries);
690 if (!status.ok()) {
691 throw runtime_error("Failed to reserve memory for array builder");
692 }
693 auto clusterIt = descriptor.FindClusterId(0, 0);
694 while (clusterIt != rns::kInvalidDescriptorId) {
695 auto& index = descriptor.GetClusterDescriptor(clusterIt);
696 auto mask = std::make_unique<bool[]>(index.GetNEntries());
697 std::fill(mask.get(), mask.get() + index.GetNEntries(), true);
698 void* ptr = bulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries());
699 int readLast = index.GetNEntries();
700 readEntries += readLast;
701 status &= valueBuilder->AppendValues(reinterpret_cast<uint8_t const*>(ptr), readLast);
702 clusterIt = descriptor.FindNextClusterId(clusterIt);
703 }
704 if (!status.ok()) {
705 throw runtime_error("Failed to append values to array");
706 }
707 status &= valueBuilder->Finish(&array);
708 if (!status.ok()) {
709 throw runtime_error("Failed to create array");
710 }
711 }
712 } else {
713 // other types: use serialized read to build arrays directly.
714 auto typeSize = physicalField->type()->byte_width();
715 // FIXME: for now...
716 auto bytes = 0;
717 auto branchSize = bytes ? bytes : 1000000;
718 auto&& result = arrow::AllocateResizableBuffer(branchSize, pool);
719 if (!result.ok()) {
720 throw runtime_error("Cannot allocate values buffer");
721 }
722 std::shared_ptr<arrow::Buffer> arrowValuesBuffer = std::move(result).ValueUnsafe();
723 auto ptr = arrowValuesBuffer->mutable_data();
724 if (ptr == nullptr) {
725 throw runtime_error("Invalid buffer");
726 }
727
728 std::unique_ptr<TBufferFile> offsetBuffer = nullptr;
729
730 std::shared_ptr<arrow::Buffer> arrowOffsetBuffer;
731 std::span<int> offsets;
732 int size = 0;
733 uint32_t totalSize = 0;
734 int64_t listSize = 1;
735 if (auto fixedSizeList = std::dynamic_pointer_cast<arrow::FixedSizeListType>(physicalField->type())) {
736 listSize = fixedSizeList->list_size();
737 typeSize = fixedSizeList->field(0)->type()->byte_width();
738 auto clusterIt = descriptor.FindClusterId(0, 0);
739 while (clusterIt != rns::kInvalidDescriptorId) {
740 auto& index = descriptor.GetClusterDescriptor(clusterIt);
741 auto mask = std::make_unique<bool[]>(index.GetNEntries());
742 std::fill(mask.get(), mask.get() + index.GetNEntries(), true);
743 void* inPtr = bulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries());
744
745 int readLast = index.GetNEntries();
746 if (listSize == -1) {
747 size = offsets[readEntries + readLast] - offsets[readEntries];
748 } else {
749 size = readLast * listSize;
750 }
751 readEntries += readLast;
752 memcpy(ptr, inPtr, size * typeSize);
753 ptr += (ptrdiff_t)(size * typeSize);
754 clusterIt = descriptor.FindNextClusterId(clusterIt);
755 }
756 } else if (auto vlaListType = std::dynamic_pointer_cast<arrow::ListType>(physicalField->type())) {
757 listSize = -1;
758 typeSize = vlaListType->field(0)->type()->byte_width();
759 offsetBuffer = std::make_unique<TBufferFile>(TBuffer::EMode::kWrite, 4 * 1024 * 1024);
760 result = arrow::AllocateResizableBuffer((totalEntries + 1) * (int64_t)sizeof(int), pool);
761 if (!result.ok()) {
762 throw runtime_error("Cannot allocate offset buffer");
763 }
764 arrowOffsetBuffer = result.MoveValueUnsafe();
765
766 // Offset bulk
767 auto offsetBulk = model.CreateBulk(physicalField->name());
768 // Actual values are in a different place...
769 bulk = model.CreateBulk(physicalField->name());
770 auto clusterIt = descriptor.FindClusterId(0, 0);
771 auto* ptrOffset = reinterpret_cast<int*>(arrowOffsetBuffer->mutable_data());
772 auto* tPtrOffset = reinterpret_cast<int*>(ptrOffset);
773 offsets = std::span<int>{tPtrOffset, tPtrOffset + totalEntries + 1};
774
775 auto copyOffsets = [&arrowValuesBuffer, &pool, &ptrOffset, &ptr, &totalSize](auto inPtr, size_t total) {
776 using value_type = typename std::decay_t<decltype(*inPtr)>::value_type;
777 for (size_t i = 0; i < total; i++) {
778 *ptrOffset++ = totalSize;
779 totalSize += inPtr[i].size();
780 }
781 *ptrOffset = totalSize;
782 auto&& result = arrow::AllocateResizableBuffer(totalSize * sizeof(value_type), pool);
783 if (!result.ok()) {
784 throw runtime_error("Cannot allocate values buffer");
785 }
786 arrowValuesBuffer = result.MoveValueUnsafe();
787 ptr = (uint8_t*)(arrowValuesBuffer->mutable_data());
788 // Calculate the size of the buffer here.
789 for (size_t i = 0; i < total; i++) {
790 int vlaSizeInBytes = inPtr[i].size() * sizeof(value_type);
791 if (vlaSizeInBytes == 0) {
792 continue;
793 }
794 memcpy(ptr, inPtr[i].data(), vlaSizeInBytes);
795 ptr += vlaSizeInBytes;
796 }
797 };
798
799 while (clusterIt != rns::kInvalidDescriptorId) {
800 auto& index = descriptor.GetClusterDescriptor(clusterIt);
801 auto mask = std::make_unique<bool[]>(index.GetNEntries());
802 std::fill(mask.get(), mask.get() + index.GetNEntries(), true);
803 int readLast = index.GetNEntries();
804 switch (vlaListType->field(0)->type()->id()) {
805 case arrow::Type::FLOAT: {
806 copyOffsets((ROOT::Internal::VecOps::RVec<float>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
807 } break;
808 case arrow::Type::DOUBLE: {
809 copyOffsets((ROOT::Internal::VecOps::RVec<double>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
810 } break;
811 case arrow::Type::INT8: {
812 copyOffsets((ROOT::Internal::VecOps::RVec<int8_t>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
813 } break;
814 case arrow::Type::INT16: {
815 copyOffsets((ROOT::Internal::VecOps::RVec<int16_t>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
816 } break;
817 case arrow::Type::INT32: {
818 copyOffsets((ROOT::Internal::VecOps::RVec<int32_t>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
819 } break;
820 case arrow::Type::INT64: {
821 copyOffsets((ROOT::Internal::VecOps::RVec<int64_t>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
822 } break;
823 case arrow::Type::UINT8: {
824 copyOffsets((ROOT::Internal::VecOps::RVec<uint8_t>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
825 } break;
826 case arrow::Type::UINT16: {
827 copyOffsets((ROOT::Internal::VecOps::RVec<uint16_t>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
828 } break;
829 case arrow::Type::UINT32: {
830 copyOffsets((ROOT::Internal::VecOps::RVec<uint32_t>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
831 } break;
832 case arrow::Type::UINT64: {
833 copyOffsets((ROOT::Internal::VecOps::RVec<uint64_t>*)offsetBulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast);
834 } break;
835 default: {
836 throw runtime_error("Unsupported kind of VLA");
837 } break;
838 }
839
840 readEntries += readLast;
841 clusterIt = descriptor.FindNextClusterId(clusterIt);
842 }
843 } else {
844 auto clusterIt = descriptor.FindClusterId(0, 0);
845 while (clusterIt != rns::kInvalidDescriptorId) {
846 auto& index = descriptor.GetClusterDescriptor(clusterIt);
847 auto mask = std::make_unique<bool[]>(index.GetNEntries());
848 std::fill(mask.get(), mask.get() + index.GetNEntries(), true);
849 void* inPtr = bulk.ReadBulk(DPLLocalIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries());
850
851 int readLast = index.GetNEntries();
852 if (listSize == -1) {
853 size = offsets[readEntries + readLast] - offsets[readEntries];
854 } else {
855 size = readLast * listSize;
856 }
857 readEntries += readLast;
858 memcpy(ptr, inPtr, size * typeSize);
859 ptr += (ptrdiff_t)(size * typeSize);
860 clusterIt = descriptor.FindNextClusterId(clusterIt);
861 }
862 }
863 switch (listSize) {
864 case -1: {
865 auto vdata = std::make_shared<arrow::ArrayData>(physicalField->type()->field(0)->type(), totalSize, std::vector<std::shared_ptr<arrow::Buffer>>{nullptr, arrowValuesBuffer});
866 array = std::make_shared<arrow::ListArray>(physicalField->type(), readEntries, arrowOffsetBuffer, arrow::MakeArray(vdata));
867 } break;
868 case 1: {
869 totalSize = readEntries * listSize;
870 auto data = std::make_shared<arrow::ArrayData>(physicalField->type(), readEntries, std::vector<std::shared_ptr<arrow::Buffer>>{nullptr, arrowValuesBuffer});
871 array = arrow::MakeArray(data);
872
873 } break;
874 default: {
875 totalSize = readEntries * listSize;
876 auto vdata = std::make_shared<arrow::ArrayData>(physicalField->type()->field(0)->type(), totalSize, std::vector<std::shared_ptr<arrow::Buffer>>{nullptr, arrowValuesBuffer});
877 array = std::make_shared<arrow::FixedSizeListArray>(physicalField->type(), readEntries, arrow::MakeArray(vdata));
878 }
879 }
880 }
881 columns.push_back(array);
882 }
883
884 auto batch = arrow::RecordBatch::Make(dataset_schema, rows, columns);
885 return batch;
886 };
887
888 return generator;
889}
890
891arrow::Result<std::shared_ptr<arrow::dataset::FileWriter>> RNTupleFileFormat::MakeWriter(std::shared_ptr<arrow::io::OutputStream> destination,
892 std::shared_ptr<arrow::Schema> schema,
893 std::shared_ptr<arrow::dataset::FileWriteOptions> options,
894 arrow::fs::FileLocator destination_locator) const
895{
896 auto writer = std::make_shared<RNTupleFileWriter>(schema, options, destination, destination_locator);
897 return std::dynamic_pointer_cast<arrow::dataset::FileWriter>(writer);
898}
899
900arrow::Result<std::shared_ptr<arrow::dataset::FileFragment>> RNTupleFileFormat::MakeFragment(
901 arrow::dataset::FileSource source, arrow::compute::Expression partition_expression,
902 std::shared_ptr<arrow::Schema> physical_schema)
903{
904 std::shared_ptr<arrow::dataset::FileFormat> format = std::make_shared<RNTupleFileFormat>(mTotCompressedSize, mTotUncompressedSize);
905
906 auto fragment = std::make_shared<RNTupleFileFragment>(source, format,
907 partition_expression,
908 physical_schema);
909 return std::dynamic_pointer_cast<arrow::dataset::FileFragment>(fragment);
910}
911
913
914std::shared_ptr<arrow::dataset::FileWriteOptions>
916{
917 return std::make_shared<RNTupleFileWriteOptions>(shared_from_this());
918}
919
923 std::shared_ptr<o2::framework::RNTupleFileFormat> format = nullptr;
924};
925
928 {
929 auto context = new RNTuplePluginContext;
930 context->format = std::make_shared<o2::framework::RNTupleFileFormat>(context->totalCompressedSize, context->totalUncompressedSize);
931 return new RootArrowFactory{
932 .options = [context]() { return context->format->DefaultWriteOptions(); },
933 .format = [context]() { return context->format; },
934 .deferredOutputStreamer = [](std::shared_ptr<arrow::dataset::FileFragment> fragment, const std::shared_ptr<arrow::ResizableBuffer>& buffer) -> std::shared_ptr<arrow::io::OutputStream> {
935 auto treeFragment = std::dynamic_pointer_cast<RNTupleFileFragment>(fragment);
936 return std::make_shared<FairMQOutputStream>(buffer);
937 }};
938 }
939};
940
944} // namespace o2::framework
int32_t i
#define DEFINE_DPL_PLUGIN_INSTANCE(NAME, KIND)
Definition Plugins.h:112
#define DEFINE_DPL_PLUGINS_END
Definition Plugins.h:115
#define DEFINE_DPL_PLUGINS_BEGIN
Definition Plugins.h:107
auto arrowTypeFromRNTuple(rns::RFieldBase const &field, int size)
rns::REntry::RFieldToken DPLFieldToken
rns::RClusterIndex DPLLocalIndex
uint16_t pos
Definition RawData.h:3
TBranch * ptr
~RNTupleFileFormat() override=default
RNTupleFileFormat(size_t &totalCompressedSize, size_t &totalUncompressedSize)
std::shared_ptr< arrow::dataset::FileWriteOptions > DefaultWriteOptions() override
bool Equals(const FileFormat &other) const override
arrow::Result< std::shared_ptr< arrow::dataset::FileWriter > > MakeWriter(std::shared_ptr< arrow::io::OutputStream > destination, std::shared_ptr< arrow::Schema > schema, std::shared_ptr< arrow::dataset::FileWriteOptions > options, arrow::fs::FileLocator destination_locator) const override
arrow::Result< std::shared_ptr< arrow::dataset::FileFragment > > MakeFragment(arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, std::shared_ptr< arrow::Schema > physical_schema) override
arrow::Result< arrow::RecordBatchGenerator > ScanBatchesAsync(const std::shared_ptr< arrow::dataset::ScanOptions > &options, const std::shared_ptr< arrow::dataset::FileFragment > &fragment) const override
arrow::Result< bool > IsSupported(const arrow::dataset::FileSource &source) const override
arrow::Result< std::shared_ptr< arrow::Schema > > Inspect(const arrow::dataset::FileSource &source) const override
std::string type_name() const override
RNTupleFileFragment(arrow::dataset::FileSource source, std::shared_ptr< arrow::dataset::FileFormat > format, arrow::compute::Expression partition_expression, std::shared_ptr< arrow::Schema > physical_schema)
virtual rns::RNTuple * GetRNTuple(arrow::dataset::FileSource source)=0
RNTupleFileWriteOptions(std::shared_ptr< arrow::dataset::FileFormat > format)
arrow::Future FinishInternal() override
arrow::Status Write(const std::shared_ptr< arrow::RecordBatch > &batch) override
RNTupleFileWriter(std::shared_ptr< arrow::Schema > schema, std::shared_ptr< arrow::dataset::FileWriteOptions > options, std::shared_ptr< arrow::io::OutputStream > destination, arrow::fs::FileLocator destination_locator)
std::string type_name() const override
rns::RNTuple * GetRNTuple(arrow::dataset::FileSource) override
const GLfloat * m
Definition glcorearb.h:4066
GLuint64EXT * result
Definition glcorearb.h:5662
GLuint buffer
Definition glcorearb.h:655
GLuint entry
Definition glcorearb.h:5735
GLsizeiptr size
Definition glcorearb.h:659
GLuint GLsizei const GLuint const GLintptr * offsets
Definition glcorearb.h:2595
const GLdouble * v
Definition glcorearb.h:832
GLenum array
Definition glcorearb.h:4274
GLuint index
Definition glcorearb.h:781
GLuint const GLchar * name
Definition glcorearb.h:781
GLdouble f
Definition glcorearb.h:310
GLsizei GLsizei GLchar * source
Definition glcorearb.h:798
GLsizei const GLfloat * value
Definition glcorearb.h:819
GLint GLint GLsizei GLint GLenum GLenum type
Definition glcorearb.h:275
GLboolean * data
Definition glcorearb.h:298
GLuint GLsizei GLsizei * length
Definition glcorearb.h:790
GLsizei const GLchar *const * path
Definition glcorearb.h:3591
GLint GLuint mask
Definition glcorearb.h:291
GLint GLint GLsizei GLint GLenum format
Definition glcorearb.h:275
Defining PrimaryVertex explicitly as messageable.
Definition TFIDInfo.h:20
RuntimeErrorRef runtime_error(const char *)
auto getSubfields(T const &field)
std::unique_ptr< rns::RFieldBase > rootFieldFromArrow(std::shared_ptr< arrow::Field > field, std::string name)
auto & getFieldZero(T const &m)
RuntimeErrorRef runtime_error_f(const char *,...)
Definition list.h:40
std::shared_ptr< o2::framework::RNTupleFileFormat > format
std::function< std::shared_ptr< arrow::dataset::FileWriteOptions >()> options
void VisitIntField(const rns::RField< int > &field) override
void VisitUInt8Field(const ROOT::Experimental::RField< std::uint8_t > &field) override
void VisitRVecField(const rns::RRVecField &field) override
void VisitBoolField(const rns::RField< bool > &field) override
void VisitInt16Field(const ROOT::Experimental::RField< std::int16_t > &field) override
void VisitUInt32Field(const ROOT::Experimental::RField< std::uint32_t > &field) override
void VisitDoubleField(const rns::RField< double > &field) override
void VisitArrayField(const rns::RArrayField &field) override
void VisitField(const rns::RFieldBase &field) override
void VisitInt8Field(const ROOT::Experimental::RField< std::int8_t > &field) override
std::shared_ptr< arrow::DataType > datatype
void VisitUInt16Field(const ROOT::Experimental::RField< std::uint16_t > &field) override
void VisitFloatField(const rns::RField< float > &field) override
VectorOfTObjectPtrs other
std::vector< ReadoutWindowData > rows