diff --git a/Framework/Core/CMakeLists.txt b/Framework/Core/CMakeLists.txt index 43571526855cc..32cb44e2619cf 100644 --- a/Framework/Core/CMakeLists.txt +++ b/Framework/Core/CMakeLists.txt @@ -143,6 +143,7 @@ o2_add_library(Framework src/Array2D.cxx src/Variant.cxx src/VariantJSONHelpers.cxx + src/ExpressionJSONHelpers.cxx src/VariantPropertyTreeHelpers.cxx src/WorkflowCustomizationHelpers.cxx src/WorkflowHelpers.cxx diff --git a/Framework/Core/include/Framework/ASoA.h b/Framework/Core/include/Framework/ASoA.h index b9b97bfa5ca9c..5d5408a638a9a 100644 --- a/Framework/Core/include/Framework/ASoA.h +++ b/Framework/Core/include/Framework/ASoA.h @@ -1293,7 +1293,14 @@ concept with_ccdb_urls = requires { }; template -concept with_base_table = not_void>::metadata::base_table_t>; +concept with_base_table = requires { + typename aod::MetadataTrait>::metadata::base_table_t; +}; + +template +concept with_expression_pack = requires { + typename T::expression_pack_t{}; +}; template os1, size_t N2, std::array os2> consteval bool is_compatible() diff --git a/Framework/Core/include/Framework/AnalysisHelpers.h b/Framework/Core/include/Framework/AnalysisHelpers.h index 842263cd75abc..6b9aa957f6d4f 100644 --- a/Framework/Core/include/Framework/AnalysisHelpers.h +++ b/Framework/Core/include/Framework/AnalysisHelpers.h @@ -26,6 +26,11 @@ #include "Framework/Traits.h" #include +namespace o2::framework { +std::string serializeProjectors(std::vector& projectors); +std::string serializeSchema(std::shared_ptr& schema); +} + namespace o2::soa { template @@ -97,6 +102,32 @@ constexpr auto getCCDBMetadata() -> std::vector { return {}; } + +template +constexpr auto getExpressionMetadata() -> std::vector +{ + using expression_pack_t = T::expression_pack_t; + + auto projectors = [](framework::pack) -> std::vector { + std::vector result; + (result.emplace_back(std::move(C::Projector())), ...); + return result; + }(expression_pack_t{}); + + auto schema = std::make_shared(o2::soa::createFieldsFromColumns(expression_pack_t{})); + + auto json = framework::serializeProjectors(projectors); + return {framework::ConfigParamSpec{"projectors", framework::VariantType::String, json, {"\"\""}}, + framework::ConfigParamSpec{"schema", framework::VariantType::String, framework::serializeSchema(schema), {"\"\""}}}; +} + +template + requires(!soa::with_expression_pack) +constexpr auto getExpressionMetadata() -> std::vector +{ + return {}; +} + } // namespace template @@ -107,6 +138,8 @@ constexpr auto tableRef2InputSpec() metadata.insert(metadata.end(), m.begin(), m.end()); auto ccdbMetadata = getCCDBMetadata>::metadata>(); metadata.insert(metadata.end(), ccdbMetadata.begin(), ccdbMetadata.end()); + auto p = getExpressionMetadata>::metadata>(); + metadata.insert(metadata.end(), p.begin(), p.end()); return framework::InputSpec{ o2::aod::label(), diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index 5a889e9ae26ec..e08bf8db52bb4 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -110,6 +110,8 @@ std::string upcastTo(atype::type f); /// An expression tree node corresponding to a literal value struct LiteralNode { + using var_t = LiteralValue::stored_type; + LiteralNode() : value{-1}, type{atype::INT32} @@ -120,7 +122,12 @@ struct LiteralNode { { } - using var_t = LiteralValue::stored_type; + LiteralNode(var_t v, atype::type t) + : value{v}, + type{t} + { + } + var_t value; atype::type type = atype::NA; }; @@ -617,6 +624,12 @@ inline Node ncfg(T defaultValue, std::string path) struct Filter { Filter() = default; + Filter(std::unique_ptr&& ptr) + { + node = std::move(ptr); + (void)designateSubtrees(node.get()); + } + Filter(Node&& node_) : node{std::make_unique(std::forward(node_))} { (void)designateSubtrees(node.get()); @@ -624,7 +637,6 @@ struct Filter { Filter(Filter&& other) : node{std::forward>(other.node)} { - (void)designateSubtrees(node.get()); } Filter(std::string const& input_) : input{input_} {} diff --git a/Framework/Core/src/AODReaderHelpers.cxx b/Framework/Core/src/AODReaderHelpers.cxx index 2587b8e4ca03a..ad5984d65080b 100644 --- a/Framework/Core/src/AODReaderHelpers.cxx +++ b/Framework/Core/src/AODReaderHelpers.cxx @@ -19,6 +19,7 @@ #include "Framework/CallbackService.h" #include "Framework/EndOfStreamContext.h" #include "Framework/DataSpecUtils.h" +#include "Framework/ExpressionJSONHelpers.h" #include @@ -44,28 +45,6 @@ auto setEOSCallback(InitContext& ic) }); } -template -static inline auto doExtractOriginal(framework::pack, ProcessingContext& pc) -{ - if constexpr (sizeof...(Ts) == 1) { - return pc.inputs().get(aod::MetadataTrait>>::metadata::tableLabel())->asArrowTable(); - } else { - return std::vector{pc.inputs().get(aod::MetadataTrait::metadata::tableLabel())->asArrowTable()...}; - } -} - -template -static inline auto extractOriginalsTuple(framework::pack, ProcessingContext& pc) -{ - return std::make_tuple(extractTypedOriginal(pc)...); -} - -template -static inline auto extractOriginalsVector(framework::pack, ProcessingContext& pc) -{ - return std::vector{extractOriginal(pc)...}; -} - template refs> static inline auto extractOriginals(ProcessingContext& pc) { @@ -156,12 +135,32 @@ auto make_spawn(InputSpec const& input, ProcessingContext& pc) (typename metadata_t::expression_pack_t{}); return o2::framework::spawner(extractOriginals(pc), input.binding.c_str(), projectors.data(), projector, schema); } + +struct Spawnable { + std::vector projectors; + std::vector labels; + std::shared_ptr schema; + + Spawnable(InputSpec const& spec) + { + auto loc = std::find_if(spec.metadata.begin(), spec.metadata.end(), [](ConfigParamSpec const& spc){ return spc.name.compare("projectors") == 0; }); + std::stringstream iws(loc->defaultValue.get()); + projectors = ExpressionJSONHelpers::read(iws); + for (auto& i : spec.metadata) { + if (i.name.starts_with("input:")) { + labels.emplace_back(i.name.substr(6)); + } + } + } +}; } // namespace AlgorithmSpec AODReaderHelpers::aodSpawnerCallback(std::vector& requested) { return AlgorithmSpec::InitCallback{[requested](InitContext& /*ic*/) { - return [requested](ProcessingContext& pc) { + std::vector spawnables; + + return [requested, spawnables](ProcessingContext& pc) { auto outputs = pc.outputs(); // spawn tables for (auto& input : requested) { diff --git a/Framework/Core/src/AnalysisHelpers.cxx b/Framework/Core/src/AnalysisHelpers.cxx index c0f804b47f5af..4f78cc42f3f98 100644 --- a/Framework/Core/src/AnalysisHelpers.cxx +++ b/Framework/Core/src/AnalysisHelpers.cxx @@ -9,6 +9,7 @@ // granted to it by virtue of its status as an Intergovernmental Organization // or submit itself to any jurisdiction. #include "Framework/ExpressionHelpers.h" +#include "ExpressionJSONHelpers.h" namespace o2::framework { @@ -26,4 +27,18 @@ void initializePartitionCaches(std::set const& hashes, std::shared_ptr gfilter = framework::expressions::createFilter(schema, framework::expressions::makeCondition(tree)); } } + +std::string serializeProjectors(std::vector& projectors) +{ + std::stringstream osm; + ExpressionJSONHelpers::write(osm, projectors); + return osm.str(); +} + +std::string serializeSchema(std::shared_ptr& schema) +{ + std::stringstream osm; + ArrowJSONHelpers::write(osm, schema); + return osm.str(); +} } // namespace o2::framework diff --git a/Framework/Core/src/ExpressionJSONHelpers.cxx b/Framework/Core/src/ExpressionJSONHelpers.cxx new file mode 100644 index 0000000000000..9f98eaddc56ce --- /dev/null +++ b/Framework/Core/src/ExpressionJSONHelpers.cxx @@ -0,0 +1,824 @@ +// Copyright 2019-2025 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. +#include "ExpressionJSONHelpers.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "Framework/VariantHelpers.h" + +namespace o2::framework +{ +namespace +{ +using nodes = expressions::Node::self_t; +enum struct Nodes : int { + NLITERAL = 0, + NBINDING = 1, + NOP = 2, + NNPH = 3, + NCOND = 4, + NPAR = 5 +}; + +enum struct ToWrite { + FULL, + LEFT, + RIGHT, + COND, + POP +}; + +struct Entry { + expressions::Node* ptr = nullptr; + ToWrite toWrite = ToWrite::FULL; +}; + +std::array validKeys{ + "projectors", + "kind", + "binding", + "index", + "arrow_type", + "value", + "hash", + "operation", + "left", + "right", + "condition"}; + +struct ExpressionReader : public rapidjson::BaseReaderHandler, ExpressionReader> { + using Ch = rapidjson::UTF8<>::Ch; + using SizeType = rapidjson::SizeType; + + enum struct State { + IN_START, // global start + IN_LIST, // opening brace of the list + IN_ROOT, // after encountering the opening of the expression object + IN_LEFT, // in "left" key - subexpression + IN_RIGHT, // in "right" key - subexpression + IN_COND, // in "condition" key - subexpression + IN_NODE_LITERAL, // in literal node + IN_NODE_BINDING, // in binding node + IN_NODE_OP, // in operation node + IN_NODE_CONDITIONAL, // in conditional node + IN_ERROR // generic error state + }; + + std::stack states; + std::stack path; + std::ostringstream debug; + + std::vector result; + + std::unique_ptr rootNode = nullptr; + std::unique_ptr node = nullptr; + expressions::LiteralValue::stored_type value; + atype::type type; + Nodes kind; + std::string binding; + BasicOp operation; + uint32_t hash; + size_t index; + + std::string previousKey; + std::string currentKey; + + ExpressionReader() + { + debug << ">>> Start" << std::endl; + states.push(State::IN_START); + } + + bool StartArray() + { + debug << "StartArray()" << std::endl; + if (states.top() == State::IN_START) { + states.push(State::IN_LIST); + return true; + } + states.push(State::IN_ERROR); + return false; + } + + bool EndArray(SizeType) + { + debug << "EndArray()" << std::endl; + if (states.top() == State::IN_LIST) { + states.pop(); + return true; + } + states.push(State::IN_ERROR); + return false; + } + + bool Key(const Ch* str, SizeType, bool) + { + debug << "Key(" << str << ")" << std::endl; + previousKey = currentKey; + currentKey = str; + if (std::find(validKeys.begin(), validKeys.end(), currentKey) == validKeys.end()) { + states.push(State::IN_ERROR); + return false; + } + + if (states.top() == State::IN_START) { + if (currentKey.compare("projectors") == 0) { + return true; + } + } + + if (states.top() == State::IN_ROOT) { + if (currentKey.compare("kind") == 0) { + return true; + } else { + states.push(State::IN_ERROR); // should start from root node + return false; + } + } + + if (states.top() == State::IN_LEFT || states.top() == State::IN_RIGHT || states.top() == State::IN_COND) { + if (currentKey.compare("kind") == 0) { + return true; + } + } + + if (states.top() == State::IN_NODE_LITERAL || states.top() == State::IN_NODE_OP || states.top() == State::IN_NODE_BINDING || states.top() == State::IN_NODE_CONDITIONAL) { + if (currentKey.compare("index") == 0) { + return true; + } + if (currentKey.compare("left") == 0) { + // this is the point where the node header is parsed and we can create it + // create a new node instance here and set a pointer to it in a parent (current stack top), based on its state + // push the new node into the stack with LEFT state + switch (states.top()) { + case State::IN_NODE_LITERAL: + node = std::make_unique(expressions::LiteralNode{value, type}); + break; + case State::IN_NODE_BINDING: + node = std::make_unique(expressions::BindingNode{hash, type}, binding); + break; + case State::IN_NODE_OP: + node = std::make_unique(expressions::OpNode{operation}, expressions::LiteralNode{-1}); + break; + case State::IN_NODE_CONDITIONAL: + node = std::make_unique(expressions::ConditionalNode{}, expressions::LiteralNode{-1}, expressions::LiteralNode{-1}, expressions::LiteralNode{true}); + break; + default: + states.push(State::IN_ERROR); + return false; + } + + if (path.empty()) { + rootNode = std::move(node); + path.emplace(rootNode.get(), ToWrite::LEFT); + } else { + auto* n = path.top().ptr; + switch (path.top().toWrite) { + case ToWrite::LEFT: + n->left = std::move(node); + path.top().toWrite = ToWrite::RIGHT; + path.emplace(n->left.get(), ToWrite::LEFT); + break; + case ToWrite::RIGHT: + n->right = std::move(node); + path.top().toWrite = ToWrite::COND; + path.emplace(n->right.get(), ToWrite::LEFT); + break; + case ToWrite::COND: + n->condition = std::move(node); + path.pop(); + path.emplace(n->condition.get(), ToWrite::LEFT); + break; + default: + states.push(State::IN_ERROR); + return false; + } + } + + states.push(State::IN_LEFT); + return true; + } + if (currentKey.compare("right") == 0) { + if (states.top() == State::IN_LEFT) { + states.pop(); + } + // move the stack state of the node to RIGHT state + path.top().toWrite = ToWrite::RIGHT; + states.push(State::IN_RIGHT); + return true; + } + if (currentKey.compare("condition") == 0) { + if (states.top() == State::IN_RIGHT) { + states.pop(); + } + // move the stack state of the node to COND state + path.top().toWrite = ToWrite::COND; + states.push(State::IN_COND); + return true; + } + } + + if (states.top() == State::IN_NODE_LITERAL) { + if (currentKey.compare("arrow_type") == 0 || currentKey.compare("value") == 0) { + return true; + } + } + + if (states.top() == State::IN_NODE_BINDING) { + if (currentKey.compare("binding") == 0 || currentKey.compare("hash") == 0 || currentKey.compare("arrow_type") == 0) { + return true; + } + } + + if (states.top() == State::IN_NODE_OP) { + if (currentKey.compare("operation") == 0) { + return true; + } + } + + debug << ">>> Unrecognized" << std::endl; + states.push(State::IN_ERROR); + return false; + } + + bool StartObject() + { + // opening brace encountered + debug << "StartObject()" << std::endl; + // the first opening brace in the input + if (states.top() == State::IN_START) { + return true; + } + // the opening of an expression + if (states.top() == State::IN_LIST) { + states.push(State::IN_ROOT); + return true; + } + // if we are looking at subexpression + if (states.top() == State::IN_LEFT || states.top() == State::IN_RIGHT || states.top() == State::IN_COND) { // ready to start a new node + return true; + } + // no other object starts are expected + states.push(State::IN_ERROR); + return false; + } + + bool EndObject(SizeType) + { + // closing brace encountered + debug << "EndObject()" << std::endl; + // we are closing up an expression + if (states.top() == State::IN_NODE_LITERAL || states.top() == State::IN_NODE_OP || states.top() == State::IN_NODE_BINDING || states.top() == State::IN_NODE_CONDITIONAL) { // finalize node + // finalize the current node and pop it from the stack (the pointers should be already set + states.pop(); + // subexpression + if (states.top() == State::IN_LEFT || states.top() == State::IN_RIGHT || states.top() == State::IN_COND) { + states.pop(); + return true; + } + + // expression + if (states.top() == State::IN_ROOT) { + result.emplace_back(std::move(rootNode)); + states.pop(); + return true; + } + } + + // we are closing the list + if (states.top() == State::IN_START) { + return true; + } + // no other object ends are expectedd + states.push(State::IN_ERROR); + return false; + } + + bool Null() + { + // null value + debug << "Null()" << std::endl; + // the subexpression can be empty + if (states.top() == State::IN_LEFT || states.top() == State::IN_RIGHT || states.top() == State::IN_COND) { + // empty node, nothing to do + // move the path state to the next + if (path.top().toWrite == ToWrite::LEFT) { + path.top().toWrite = ToWrite::RIGHT; + } else if (path.top().toWrite == ToWrite::RIGHT) { + path.top().toWrite = ToWrite::COND; + } else if (path.top().toWrite == ToWrite::COND) { + path.pop(); + } + + states.pop(); + return true; + } + states.push(State::IN_ERROR); // no other contexts allow null + return false; + } + + bool Bool(bool b) + { + debug << "Bool(" << b << ")" << std::endl; + // can be a value in a literal node + if (states.top() == State::IN_NODE_LITERAL && currentKey.compare("value") == 0) { + value = b; + return true; + } + states.push(State::IN_ERROR); // no other contexts allow booleans + return false; + } + + bool Int(int i) + { + debug << "Int(" << i << ")" << std::endl; + // can be a value in a literal node + if (states.top() == State::IN_NODE_LITERAL && currentKey.compare("value") == 0) { // literal + switch (type) { + case atype::INT8: + value = (int8_t)i; + break; + case atype::INT16: + value = (int16_t)i; + break; + case atype::INT32: + value = i; + break; + case atype::UINT8: + value = (uint8_t)i; + break; + case atype::UINT16: + value = (uint16_t)i; + break; + case atype::UINT32: + value = i; + break; + case atype::UINT64: + value = (uint64_t)i; + break; + case atype::INT64: + value = (int64_t)i; + break; + default: + states.push(State::IN_ERROR); + return false; + } + return true; + } + // can be a node kind designator + if (states.top() == State::IN_ROOT || states.top() == State::IN_LEFT || states.top() == State::IN_RIGHT || states.top() == State::IN_COND) { + if (currentKey.compare("kind") == 0) { + kind = (Nodes)i; + switch (kind) { + case Nodes::NLITERAL: + case Nodes::NNPH: + case Nodes::NPAR: { + states.push(State::IN_NODE_LITERAL); + debug << ">>> Literal node" << std::endl; + return true; + } + case Nodes::NBINDING: { + states.push(State::IN_NODE_BINDING); + debug << ">>> Binding node" << std::endl; + return true; + } + case Nodes::NOP: { + states.push(State::IN_NODE_OP); + debug << ">>> Operation node" << std::endl; + return true; + } + case Nodes::NCOND: { + states.push(State::IN_NODE_CONDITIONAL); + debug << ">>> Conditional node" << std::endl; + return true; + } + } + } + } + // can be node index + if (states.top() == State::IN_NODE_BINDING || states.top() == State::IN_NODE_CONDITIONAL || states.top() == State::IN_NODE_LITERAL || states.top() == State::IN_NODE_OP) { + if (currentKey.compare("index") == 0) { + index = (size_t)i; + return true; + } + } + // can be a node type designator + if (states.top() == State::IN_NODE_LITERAL || states.top() == State::IN_NODE_BINDING) { + if (currentKey.compare("arrow_type") == 0) { + type = (atype::type)i; + return true; + } + } + // can be a node operation designato + if (states.top() == State::IN_NODE_OP && currentKey.compare("operation") == 0) { + operation = (BasicOp)i; + return true; + } + states.push(State::IN_ERROR); // no other contexts allow ints + return false; + } + + bool Uint(unsigned i) + { + debug << "Uint(" << i << ")" << std::endl; + // can be node hash + if (states.top() == State::IN_NODE_BINDING && currentKey.compare("hash") == 0) { + hash = i; + return true; + } + // any positive value will be first read as unsigned, however the actual type is determined by node's arrow_type + debug << ">> falling back to Int" << std::endl; + return Int(i); + } + + bool Int64(int64_t i) + { + debug << "Int64(" << i << ")" << std::endl; + // can only be a literal node value + if (states.top() == State::IN_NODE_LITERAL && currentKey.compare("value") == 0) { + value = i; + return true; + } + states.push(State::IN_ERROR); // no other contexts allow int64s + return false; + } + + bool Uint64(uint64_t i) + { + debug << "Uint64(" << i << ")" << std::endl; + // can only be a literal node value + if (states.top() == State::IN_NODE_LITERAL && currentKey.compare("value") == 0) { + value = i; + return true; + } + states.push(State::IN_ERROR); // no other contexts allow uints + return false; + } + + bool Double(double d) + { + debug << "Double(" << d << ")" << std::endl; + // can only be a literal node value + if (states.top() == State::IN_NODE_LITERAL) { + switch (type) { + case atype::FLOAT: + value = (float)d; + break; + case atype::DOUBLE: + value = d; + break; + default: + states.push(State::IN_ERROR); + return false; + } + return true; + } + states.push(State::IN_ERROR); // no other contexts allow doubles + return false; + } + + bool String(const Ch* str, SizeType, bool) + { + debug << "String(" << str << ")" << std::endl; + // can only be a binding node + if (states.top() == State::IN_NODE_BINDING && currentKey.compare("binding") == 0) { + binding = str; + return true; + } + states.push(State::IN_ERROR); // no strings are expected + return false; + } +}; +} // namespace + +std::vector o2::framework::ExpressionJSONHelpers::read(std::istream& s) +{ + rapidjson::Reader reader; + rapidjson::IStreamWrapper isw(s); + ExpressionReader ereader; + bool ok = reader.Parse(isw, ereader); + + if (!ok) { + throw framework::runtime_error_f("Cannot parse serialized Expression, error: %s at offset: %d", rapidjson::GetParseError_En(reader.GetParseErrorCode()), reader.GetErrorOffset()); + } + return std::move(ereader.result); +} + +namespace +{ +void writeNodeHeader(rapidjson::Writer& w, expressions::Node const* node) +{ + w.Key("kind"); + w.Int((int)node->self.index()); + w.Key("index"); + w.Uint64(node->index); + std::visit(overloaded{ + [&w](expressions::LiteralNode const& node) { + w.Key("arrow_type"); + w.Int(node.type); + w.Key("value"); + std::visit(overloaded{ + [&w](bool v) { w.Bool(v); }, + [&w](float v) { w.Double(v); }, + [&w](double v) { w.Double(v); }, + [&w](uint8_t v) { w.Uint(v); }, + [&w](uint16_t v) { w.Uint(v); }, + [&w](uint32_t v) { w.Uint(v); }, + [&w](uint64_t v) { w.Uint64(v); }, + [&w](int8_t v) { w.Int(v); }, + [&w](int16_t v) { w.Int(v); }, + [&w](int v) { w.Int(v); }, + [&w](int64_t v) { w.Int64(v); }}, + node.value); + }, + [&w](expressions::BindingNode const& node) { + w.Key("binding"); + w.String(node.name); + w.Key("hash"); + w.Uint(node.hash); + w.Key("arrow_type"); + w.Int(node.type); + }, + [&w](expressions::OpNode const& node) { + w.Key("operation"); + w.Int(node.op); + }, + [](expressions::ConditionalNode const&) { + }}, + node->self); +} + +void writeExpression(rapidjson::Writer& w, expressions::Node* n) +{ + std::stack path; + path.emplace(n, ToWrite::FULL); + while (!path.empty()) { + auto& top = path.top(); + + if (top.toWrite == ToWrite::FULL) { + w.StartObject(); + writeNodeHeader(w, top.ptr); + top.toWrite = ToWrite::LEFT; + continue; + } + + if (top.toWrite == ToWrite::LEFT) { + w.Key("left"); + top.toWrite = ToWrite::RIGHT; + auto* left = top.ptr->left.get(); + if (left != nullptr) { + path.emplace(left, ToWrite::FULL); + } else { + w.Null(); + } + continue; + } + + if (top.toWrite == ToWrite::RIGHT) { + w.Key("right"); + top.toWrite = ToWrite::COND; + auto* right = top.ptr->right.get(); + if (right != nullptr) { + path.emplace(right, ToWrite::FULL); + } else { + w.Null(); + } + continue; + } + + if (top.toWrite == ToWrite::COND) { + w.Key("condition"); + top.toWrite = ToWrite::POP; + auto* cond = top.ptr->condition.get(); + if (cond != nullptr) { + path.emplace(cond, ToWrite::FULL); + } else { + w.Null(); + } + continue; + } + + if (top.toWrite == ToWrite::POP) { + w.EndObject(); + path.pop(); + continue; + } + } +} +} // namespace + +void o2::framework::ExpressionJSONHelpers::write(std::ostream& o, std::vector& projectors) +{ + rapidjson::OStreamWrapper osw(o); + rapidjson::Writer w(osw); + w.StartObject(); + w.Key("projectors"); + w.StartArray(); + for (auto& p : projectors) { + writeExpression(w, p.node.get()); + } + w.EndArray(); + w.EndObject(); +} + +namespace { +struct SchemaReader : public rapidjson::BaseReaderHandler, SchemaReader> { + using Ch = rapidjson::UTF8<>::Ch; + using SizeType = rapidjson::SizeType; + + enum struct State { + IN_START, + IN_LIST, + IN_FIELD, + IN_ERROR + }; + + std::stack states; + std::ostringstream debug; + + std::shared_ptr schema = nullptr; + std::vector> fields; + + std::string currentKey; + + std::string name; + atype::type type; + + SchemaReader() + { + debug << ">>> Start" << std::endl; + states.push(State::IN_START); + } + + bool StartArray() + { + debug << "Starting array" << std::endl; + if (states.top() == State::IN_START && currentKey.compare("fields") == 0) { + states.push(State::IN_LIST); + return true; + } + states.push(State::IN_ERROR); + return false; + } + + bool EndArray(SizeType) + { + debug << "Ending array" << std::endl; + if (states.top() == State::IN_LIST) { + //finalize schema + schema = std::make_shared(fields); + states.pop(); + return true; + } + states.push(State::IN_ERROR); + return false; + } + + bool Key(const Ch* str, SizeType, bool) + { + debug << "Key(" << str << ")" << std::endl; + currentKey = str; + if (states.top() == State::IN_START) { + if (currentKey.compare("fields") == 0) { + return true; + } + } + + if (states.top() == State::IN_FIELD) { + if (currentKey.compare("name") == 0) { + return true; + } + if (currentKey.compare("type") == 0) { + return true; + } + } + + states.push(State::IN_ERROR); + return false; + } + + bool StartObject() + { + debug << "StartObject()" << std::endl; + if (states.top() == State::IN_START) { + return true; + } + + if (states.top() == State::IN_LIST) { + states.push(State::IN_FIELD); + return true; + } + + states.push(State::IN_ERROR); + return false; + } + + bool EndObject(SizeType) + { + debug << "EndObject()" << std::endl; + if (states.top() == State::IN_FIELD) { + states.pop(); + // add a field + fields.emplace_back(std::make_shared(name, expressions::concreteArrowType(type))); + return true; + } + + if (states.top() == State::IN_START) { + return true; + } + + states.push(State::IN_ERROR); + return false; + } + + bool Uint(unsigned i) + { + debug << "Uint(" << i << ")" << std::endl; + if (states.top() == State::IN_FIELD) { + if (currentKey.compare("type") == 0) { + type = (atype::type)i; + return true; + } + } + + states.push(State::IN_ERROR); + return false; + } + + bool String(const Ch* str, SizeType, bool) + { + debug << "String(" << str << ")" << std::endl; + if (states.top() == State::IN_FIELD) { + if (currentKey.compare("name") == 0) { + name = str; + return true; + } + } + + states.push(State::IN_ERROR); + return false; + } + + bool Int(int i) { + debug << "Int(" << i << ")" << std::endl; + return Uint(i); + } + +}; +} + +std::shared_ptr o2::framework::ArrowJSONHelpers::read(std::istream& s) +{ + rapidjson::Reader reader; + rapidjson::IStreamWrapper isw(s); + SchemaReader sreader; + + bool ok = reader.Parse(isw, sreader); + + if(!ok) { + throw framework::runtime_error_f("Cannot parse serialized Expression, error: %s at offset: %d", rapidjson::GetParseError_En(reader.GetParseErrorCode()), reader.GetErrorOffset()); + } + return sreader.schema; +} + +namespace { +void writeSchema(rapidjson::Writer& w, arrow::Schema* schema) +{ + for (auto& f : schema->fields()) { + w.StartObject(); + w.Key("name"); + w.String(f->name().c_str()); + w.Key("type"); + w.Int(f->type()->id()); + w.EndObject(); + } +} +} + +void o2::framework::ArrowJSONHelpers::write(std::ostream& o, std::shared_ptr& schema) +{ + rapidjson::OStreamWrapper osw(o); + rapidjson::Writer w(osw); + w.StartObject(); + w.Key("fields"); + w.StartArray(); + writeSchema(w, schema.get()); + w.EndArray(); + w.EndObject(); +} + +} // namespace o2::framework diff --git a/Framework/Core/src/ExpressionJSONHelpers.h b/Framework/Core/src/ExpressionJSONHelpers.h new file mode 100644 index 0000000000000..ed4c51c58d5c2 --- /dev/null +++ b/Framework/Core/src/ExpressionJSONHelpers.h @@ -0,0 +1,29 @@ +// Copyright 2019-2025 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. +#ifndef FRAMEWORK_EXPRESSIONJSONHELPERS_H +#define FRAMEWORK_EXPRESSIONJSONHELPERS_H + +#include "Framework/Expressions.h" + +namespace o2::framework +{ +struct ExpressionJSONHelpers { + static std::vector read(std::istream& s); + static void write(std::ostream& o, std::vector& projectors); +}; + +struct ArrowJSONHelpers { + static std::shared_ptr read(std::istream& s); + static void write(std::ostream& o, std::shared_ptr& schema); +}; +} // namespace o2::framework + +#endif // FRAMEWORK_EXPRESSIONJSONHELPERS_H diff --git a/Framework/Core/test/test_Expressions.cxx b/Framework/Core/test/test_Expressions.cxx index 4c6fc51795ca8..5df34e5aac14a 100644 --- a/Framework/Core/test/test_Expressions.cxx +++ b/Framework/Core/test/test_Expressions.cxx @@ -12,6 +12,7 @@ #include "Framework/Configurable.h" #include "Framework/ExpressionHelpers.h" #include "Framework/AnalysisDataModel.h" +#include "../src/ExpressionJSONHelpers.h" #include #include #include @@ -391,3 +392,52 @@ TEST_CASE("TestStringExpressionsParsing") REQUIRE(tree1c->ToString() == tree2c->ToString()); } + +TEST_CASE("TestExpressionSerialization") +{ + Filter f = o2::aod::track::signed1Pt > 0.f && ifnode(nabs(o2::aod::track::eta) < 1.0f, nabs(o2::aod::track::x) > 2.0f, nabs(o2::aod::track::y) > 3.0f); + Projector p = -1.f * nlog(ntan(o2::constants::math::PIQuarter - 0.5f * natan(o2::aod::fwdtrack::tgl))); + + std::vector projectors; + projectors.emplace_back(std::move(f)); + projectors.emplace_back(std::move(p)); + + std::stringstream osm; + ExpressionJSONHelpers::write(osm, projectors); + + std::stringstream ism; + ism.str(osm.str()); + auto ps = ExpressionJSONHelpers::read(ism); + + auto s1 = createOperations(projectors[0]); + auto s2 = createOperations(ps[0]); + auto schemaf = std::make_shared(std::vector{o2::aod::track::Eta::asArrowField(), o2::aod::track::Signed1Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField()}); + auto t1 = createExpressionTree(s1, schemaf); + auto t2 = createExpressionTree(s2, schemaf); + REQUIRE(t1->ToString() == t2->ToString()); + + auto s12 = createOperations(projectors[1]); + auto s22 = createOperations(ps[1]); + auto schemap = std::make_shared(std::vector{o2::aod::fwdtrack::Tgl::asArrowField()}); + auto t12 = createExpressionTree(s12, schemap); + auto t22 = createExpressionTree(s22, schemap); + REQUIRE(t12->ToString() == t22->ToString()); + + osm.clear(); + osm.str(""); + ArrowJSONHelpers::write(osm, schemaf); + + ism.clear(); + ism.str(osm.str()); + auto newSchemaf = ArrowJSONHelpers::read(ism); + REQUIRE(schemaf->ToString() == newSchemaf->ToString()); + + osm.clear(); + osm.str(""); + ArrowJSONHelpers::write(osm, schemap); + + ism.clear(); + ism.str(osm.str()); + auto newSchemap = ArrowJSONHelpers::read(ism); + REQUIRE(schemap->ToString() == newSchemap->ToString()); +}