diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index eee63b11ca1c..87d858727b51 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -795,6 +795,7 @@ if(ARROW_COMPUTE) compute/kernels/vector_rank.cc compute/kernels/vector_replace.cc compute/kernels/vector_run_end_encode.cc + compute/kernels/vector_search_sorted.cc compute/kernels/vector_select_k.cc compute/kernels/vector_sort.cc compute/kernels/vector_statistics.cc diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 1bf4de93520c..676b5b89d5a4 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -50,6 +50,7 @@ using compute::FilterOptions; using compute::NullPlacement; using compute::RankOptions; using compute::RankQuantileOptions; +using compute::SearchSortedOptions; template <> struct EnumTraits @@ -96,6 +97,21 @@ struct EnumTraits } }; template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "SearchSortedOptions::Side"; } + static std::string value_name(SearchSortedOptions::Side value) { + switch (value) { + case SearchSortedOptions::Left: + return "Left"; + case SearchSortedOptions::Right: + return "Right"; + } + return ""; + } +}; +template <> struct EnumTraits : BasicEnumTraits { @@ -137,6 +153,8 @@ static auto kRunEndEncodeOptionsType = GetFunctionOptionsType( DataMember("order", &ArraySortOptions::order), DataMember("null_placement", &ArraySortOptions::null_placement)); +static auto kSearchSortedOptionsType = GetFunctionOptionsType( + DataMember("side", &SearchSortedOptions::side)); static auto kSortOptionsType = GetFunctionOptionsType( DataMember("sort_keys", &SortOptions::sort_keys), DataMember("null_placement", &SortOptions::null_placement)); @@ -196,6 +214,10 @@ ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; +SearchSortedOptions::SearchSortedOptions(SearchSortedOptions::Side side) + : FunctionOptions(internal::kSearchSortedOptionsType), side(side) {} +constexpr char SearchSortedOptions::kTypeName[]; + SortOptions::SortOptions(std::vector sort_keys, NullPlacement null_placement) : FunctionOptions(internal::kSortOptionsType), sort_keys(std::move(sort_keys)), @@ -274,6 +296,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kDictionaryEncodeOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kRunEndEncodeOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kArraySortOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSearchSortedOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType)); @@ -315,6 +338,11 @@ Result> SelectKUnstable(const Datum& datum, return result.make_array(); } +Result SearchSorted(const Datum& values, const Datum& needles, + const SearchSortedOptions& options, ExecContext* ctx) { + return CallFunction("search_sorted", {values, needles}, &options, ctx); +} + Result ReplaceWithMask(const Datum& values, const Datum& mask, const Datum& replacements, ExecContext* ctx) { return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 159a787641ee..a662cad94ba4 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -102,6 +102,21 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { NullPlacement null_placement; }; +class ARROW_EXPORT SearchSortedOptions : public FunctionOptions { + public: + enum Side { + Left, + Right, + }; + + explicit SearchSortedOptions(Side side = Side::Left); + static constexpr const char kTypeName[] = "SearchSortedOptions"; + static SearchSortedOptions Defaults() { return SearchSortedOptions(); } + + /// Whether to return the leftmost or rightmost insertion point. + Side side; +}; + class ARROW_EXPORT SortOptions : public FunctionOptions { public: explicit SortOptions(std::vector sort_keys = {}, @@ -515,6 +530,27 @@ Result> SelectKUnstable(const Datum& datum, const SelectKOptions& options, ExecContext* ctx = NULLPTR); +/// \brief Find insertion indices that preserve sorted order. +/// +/// The `values` datum must be a plain array or run-end encoded array sorted in +/// ascending order. `needles` may be a scalar, plain array, or run-end encoded +/// array whose logical value type matches `values`. +/// +/// Nulls in `values` are supported when clustered entirely at the start or the +/// end of the sorted array. Non-null needles are matched only against the +/// non-null portion of `values`. Null needles yield null outputs. +/// +/// \param[in] values sorted array to search within +/// \param[in] needles scalar or array-like values to search for +/// \param[in] options selects left or right insertion semantics +/// \param[in] ctx the function execution context, optional +/// \return insertion indices as uint64 scalar or array +ARROW_EXPORT +Result SearchSorted( + const Datum& values, const Datum& needles, + const SearchSortedOptions& options = SearchSortedOptions::Defaults(), + ExecContext* ctx = NULLPTR); + /// \brief Return the indices that would sort an array. /// /// Perform an indirect sort of array. The output array will contain diff --git a/cpp/src/arrow/compute/initialize.cc b/cpp/src/arrow/compute/initialize.cc index d88835da04ac..ec531e8d490f 100644 --- a/cpp/src/arrow/compute/initialize.cc +++ b/cpp/src/arrow/compute/initialize.cc @@ -48,6 +48,7 @@ Status RegisterComputeKernels() { internal::RegisterVectorNested(registry); internal::RegisterVectorRank(registry); internal::RegisterVectorReplace(registry); + internal::RegisterVectorSearchSorted(registry); internal::RegisterVectorSelectK(registry); internal::RegisterVectorSort(registry); internal::RegisterVectorRunEndEncode(registry); diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 15955b5ef883..d07356aa6300 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -121,6 +121,13 @@ add_arrow_compute_test(vector_sort_test arrow_compute_kernels_testing arrow_compute_testing) +add_arrow_compute_test(vector_search_sorted_test + SOURCES + vector_search_sorted_test.cc + EXTRA_LINK_LIBS + arrow_compute_kernels_testing + arrow_compute_testing) + add_arrow_compute_test(vector_selection_test SOURCES vector_selection_test.cc @@ -141,6 +148,7 @@ add_arrow_compute_benchmark(vector_sort_benchmark) add_arrow_compute_benchmark(vector_partition_benchmark) add_arrow_compute_benchmark(vector_topk_benchmark) add_arrow_compute_benchmark(vector_replace_benchmark) +add_arrow_compute_benchmark(vector_search_sorted_benchmark) add_arrow_compute_benchmark(vector_selection_benchmark) # ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/meson.build b/cpp/src/arrow/compute/kernels/meson.build index fb682443783c..0a8c16f9766b 100644 --- a/cpp/src/arrow/compute/kernels/meson.build +++ b/cpp/src/arrow/compute/kernels/meson.build @@ -132,6 +132,7 @@ vector_kernel_benchmarks = [ 'vector_partition_benchmark', 'vector_topk_benchmark', 'vector_replace_benchmark', + 'vector_search_sorted_benchmark', 'vector_selection_benchmark', ] diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc new file mode 100644 index 000000000000..86c8bdce0c46 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc @@ -0,0 +1,736 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_vector.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_primitive.h" +#include "arrow/array/array_run_end.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/array/util.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/kernels/vector_sort_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/registry_internal.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging_internal.h" +#include "arrow/util/ree_util.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute::internal { +namespace { + +const SearchSortedOptions* GetDefaultSearchSortedOptions() { + static const auto kDefaultSearchSortedOptions = SearchSortedOptions::Defaults(); + return &kDefaultSearchSortedOptions; +} + +const FunctionDoc search_sorted_doc( + "Find insertion indices for sorted input", + ("Return the index where each needle should be inserted in a sorted input array\n" + "to maintain ascending order.\n" + "\n" + "With side='left', returns the first suitable index (lower bound).\n" + "With side='right', returns the last suitable index (upper bound).\n" + "\n" + "The input array must already be sorted in ascending order. Null values in\n" + "the searched array are supported when clustered entirely at the start or\n" + "entirely at the end. Non-null needles are matched only against the non-null\n" + "portion of the searched array. Null needles emit nulls in the output."), + {"values", "needles"}, "SearchSortedOptions"); + +// This file implements search_sorted as a small pipeline that first normalizes +// Arrow input shapes and then runs one typed binary-search core on logical +// values. +// +// Plain arrays, run-end encoded arrays, and scalar needles are all +// adapted into the same accessor and visitor model so the search logic does +// not care about physical layout. +// +// After validation, the kernel isolates the contiguous non-null window of the searched +// values, because nulls are only supported when clustered at one end. +// Needles are then visited either as single values or as logical runs, and each non-null +// needle is resolved with a lower-bound or upper-bound binary search over the sorted +// non-null range. +// +// Output materialization is split by null handling: non-null-only needles write directly +// into a preallocated uint64 buffer, while nullable needles append null and non-null +// spans through a UInt64Builder. That builder path is optimized for repeated runs by +// bulk-filling reserved memory instead of appending one insertion index at a time. +// +// High-level flow: +// +// values datum +// | +// +--> ValidateSortedValuesInput +// | +// +--> LogicalType / FindNonNullValuesRange +// | +// +--> VisitValuesAccessor +// | +// +--> PlainArrayAccessor +// | +// `--> RunEndEncodedValuesAccessor +// +// needles datum +// | +// +--> ValidateNeedleInput +// | +// +--> DatumHasNulls +// | +// `--> VisitNeedles +// | +// +--> scalar needle -> one logical span +// | +// +--> plain array -> one span per element +// | +// `--> REE array -> one span per logical run +// +// normalized values accessor + normalized needle spans +// | +// `--> FindInsertionPoint +// | +// +--> side = left -> lower_bound semantics +// | +// `--> side = right -> upper_bound semantics +// +// result materialization +// | +// +--> no needle nulls +// | `--> MakeMutableUInt64Array +// | `--> fill output buffer directly +// | +// `--> nullable needles +// `--> UInt64Builder +// +--> AppendNulls for null runs +// `--> bulk fill + UnsafeAdvance for repeated indices +// +// A rough map of the file: +// +// [validation + type helpers] +// | +// [value accessors] +// | +// [needle visitors] +// | +// [typed search + output helpers] +// | +// [meta-function dispatch] +// + +#define VISIT_SEARCH_SORTED_TYPES(VISIT) \ + VISIT(BooleanType) \ + VISIT(Int8Type) \ + VISIT(Int16Type) \ + VISIT(Int32Type) \ + VISIT(Int64Type) \ + VISIT(UInt8Type) \ + VISIT(UInt16Type) \ + VISIT(UInt32Type) \ + VISIT(UInt64Type) \ + VISIT(FloatType) \ + VISIT(DoubleType) \ + VISIT(Date32Type) \ + VISIT(Date64Type) \ + VISIT(Time32Type) \ + VISIT(Time64Type) \ + VISIT(TimestampType) \ + VISIT(DurationType) \ + VISIT(BinaryType) \ + VISIT(StringType) \ + VISIT(LargeBinaryType) \ + VISIT(LargeStringType) \ + VISIT(BinaryViewType) \ + VISIT(StringViewType) + +template +using SearchValue = typename GetViewType::T; + +/// Comparator implementing Arrow's ascending-order semantics for supported types. +template +struct SearchSortedCompare { + using ValueType = SearchValue; + + int operator()(const ValueType& left, const ValueType& right) const { + return CompareTypeValues(left, right, SortOrder::Ascending, + NullPlacement::AtEnd); + } +}; + +/// Access logical values from a plain Arrow array. +template +class PlainArrayAccessor { + public: + using ArrayType = typename TypeTraits::ArrayType; + using ValueType = SearchValue; + + /// Build a typed accessor over a plain array payload. + explicit PlainArrayAccessor(const std::shared_ptr& array_data) + : array_(array_data) {} + + /// Return the logical length of the searched values. + int64_t length() const { return array_.length(); } + + /// Return the logical value at the given logical position. + ValueType Value(int64_t index) const { + return GetViewType::LogicalValue(array_.GetView(index)); + } + + private: + ArrayType array_; +}; + +/// Access logical values from a run-end encoded Arrow array. +template +class RunEndEncodedValuesAccessor { + public: + using ArrayType = typename TypeTraits::ArrayType; + using ValueType = SearchValue; + + /// Build a typed accessor over a run-end encoded payload. + explicit RunEndEncodedValuesAccessor(const RunEndEncodedArray& array) + : values_(array.values()->data()), array_span_(*array.data()), span_(array_span_) {} + + /// Return the logical length of the searched values. + int64_t length() const { return span_.length(); } + + /// Return the logical value at the given logical position. + ValueType Value(int64_t index) const { + const auto physical_index = span_.PhysicalIndex(index); + return GetViewType::LogicalValue(values_.GetView(physical_index)); + } + + private: + ArrayType values_; + ArraySpan array_span_; + ::arrow::ree_util::RunEndEncodedArraySpan span_; +}; + +struct NonNullValuesRange { + int64_t offset = 0; + int64_t length = 0; + + /// Return whether the range spans the full searched values input. + bool is_identity(int64_t full_length) const { + return (offset == 0) && (length == full_length); + } +}; + +/// Present a contiguous non-null slice of the searched values through the same +/// accessor interface as the original values container. +template +class NonNullValuesAccessor { + public: + /// Wrap the original accessor with the discovered non-null subrange. + explicit NonNullValuesAccessor(const ValuesAccessor& values, + const NonNullValuesRange& non_null_values_range) + : values_(values), + offset_(non_null_values_range.offset), + length_(non_null_values_range.length) {} + + /// Return the number of accessible non-null values. + int64_t length() const noexcept { return length_; } + + /// Return the value at the given index within the non-null subrange. + auto Value(int64_t index) const { return values_.Value(offset_ + index); } + + private: + const ValuesAccessor& values_; + int64_t offset_; + int64_t length_; +}; + +/// Return the logical type of an array, unwrapping run-end encoding when present. +inline const DataType& LogicalType(const ArrayData& array) { + const auto& type = *array.type; + if (type.id() == Type::RUN_END_ENCODED) { + return *checked_cast(type).value_type(); + } + return type; +} + +/// Return the logical type of a datum, unwrapping run-end encoding when present. +inline const DataType& LogicalType(const Datum& datum) { + if (datum.is_scalar()) { + return *datum.scalar()->type; + } + return LogicalType(*datum.array()); +} + +/// Return whether a scalar or array needle input contains any logical nulls. +inline bool DatumHasNulls(const Datum& datum) { + if (datum.is_scalar()) { + return !datum.scalar()->is_valid; + } + + const auto& array_data = datum.array(); + const bool has_nulls = array_data->GetNullCount() > 0; + if (array_data->type->id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray run_end_encoded(array_data); + return has_nulls || (run_end_encoded.values()->null_count() != 0); + } + return has_nulls; +} + +/// Reject nested run-end encoded values. TODO: Support this case in the future if there +/// is demand for it. +inline Status ValidateRunEndEncodedLogicalValueType(const DataType& type, + const char* name) { + const auto& ree_type = checked_cast(type); + if (ree_type.value_type()->id() == Type::RUN_END_ENCODED) { + return Status::TypeError("Nested run-end encoded ", name, " are not supported"); + } + return Status::OK(); +} + +/// Compute the contiguous non-null window of the searched values. +/// +inline Result FindNonNullValuesRange(const ArrayData& values) { + NonNullValuesRange non_null_values_range{.offset = 0, .length = values.length}; + + const auto null_count = values.GetNullCount(); + if (null_count == 0) { + return non_null_values_range; + } + + const int64_t leading_null_count = [&] { + auto indices = std::ranges::views::iota(int64_t{0}, values.length); + auto it = + std::ranges::find_if_not(indices, [&](int64_t i) { return values.IsNull(i); }); + return it == indices.end() ? values.length : *it; + }(); + + if (leading_null_count == values.length) { + non_null_values_range.offset = values.length; + non_null_values_range.length = 0; + return non_null_values_range; + } + + if (leading_null_count > 0) { + if (leading_null_count != null_count) { + return Status::Invalid( + "search_sorted values with nulls must be clustered at the start or end."); + } + non_null_values_range.offset = leading_null_count; + non_null_values_range.length = values.length - leading_null_count; + return non_null_values_range; + } + + const int64_t trailing_null_count = [&] { + auto indices = std::ranges::views::iota(int64_t{0}, values.length); + auto it = std::ranges::find_if_not( + indices, [&](int64_t i) { return values.IsNull(values.length - 1 - i); }); + return it == indices.end() ? values.length : *it; + }(); + + if (trailing_null_count == 0 || (trailing_null_count != null_count)) { + return Status::Invalid( + "search_sorted values with nulls must be clustered at the start or end."); + } + + non_null_values_range.length = values.length - trailing_null_count; + return non_null_values_range; +} + +/// Validate the searched values input shape and supported encoding. +inline Status ValidateSortedValuesInput(const Datum& datum) { + if (!datum.is_array()) { + return Status::TypeError("search_sorted values must be an array"); + } + + const auto& type = *datum.type(); + if (type.id() == Type::RUN_END_ENCODED) { + return ValidateRunEndEncodedLogicalValueType(type, "values"); + } + + return Status::OK(); +} + +/// Validate the needles input shape and supported encoding. +/// Needles can be either a scalar or an array, but if an array is provided it must not +/// have nested run-end encoding since that is not currently supported. +inline Status ValidateNeedleInput(const Datum& datum) { + if (!(datum.is_array() || datum.is_scalar())) { + return Status::TypeError("search_sorted needles must be a scalar or array"); + } + + if (datum.is_array() && datum.type()->id() == Type::RUN_END_ENCODED) { + return ValidateRunEndEncodedLogicalValueType(*datum.type(), "needles"); + } + return Status::OK(); +} + +/// Perform a lower- or upper-bound binary search over already sorted values. +template +uint64_t FindInsertionPoint(const Accessor& sorted_values, + const SearchValue& needle, + SearchSortedOptions::Side side) { + SearchSortedCompare compare; + int64_t first = 0; + int64_t count = sorted_values.length(); + + // TODO(search_sorted): For fixed-width primitive haystacks, investigate a SIMD-friendly + // batched search path . + while (count > 0) { + const int64_t step = count / 2; + const int64_t it = first + step; + const bool advance = (side == SearchSortedOptions::Left) + ? compare(sorted_values.Value(it), needle) < 0 + : compare(needle, sorted_values.Value(it)) >= 0; + if (advance) { + first = it + 1; + count -= step + 1; + } else { + count = step; + } + } + return static_cast(first); +} + +/// Read a scalar needle without materializing a one-element array. +template +SearchValue ExtractScalarValue(const Scalar& scalar) { + using ScalarType = typename TypeTraits::ScalarType; + const auto& typed_scalar = checked_cast(scalar); + + if constexpr (std::is_base_of_v) { + return GetViewType::LogicalValue(typed_scalar.view()); + } else { + return GetViewType::LogicalValue(typed_scalar.value); + } +} + +/// Append the same insertion index repeatedly for a logical run of needles. +inline void AppendInsertionIndex(UInt64Builder& builder, uint64_t insertion_index, + int64_t count) { + if (count == 0) { + return; + } + DCHECK_LE(builder.length() + count, builder.capacity()); + std::fill_n(builder.GetMutableValue(builder.length()), count, insertion_index); + builder.UnsafeAdvance(count); +} + +/// Dispatch a run-end encoded array to the matching run-end physical type. +template +ReturnType DispatchRunEndEncodedByRunEndType(const RunEndEncodedArray& array, + const char* argument_name, + Visitor&& visitor) { + const auto& ree_type = checked_cast(*array.type()); + switch (ree_type.run_end_type()->id()) { + case Type::INT16: + return std::forward(visitor).template operator()(array); + case Type::INT32: + return std::forward(visitor).template operator()(array); + case Type::INT64: + return std::forward(visitor).template operator()(array); + default: + return ReturnType(Status::TypeError("Unsupported run-end type for search_sorted ", + argument_name, ": ", array.type()->ToString())); + } +} + +template +using VisitedNeedle = std::conditional_t>, + SearchValue>; + +/// Normalize a non-null logical needle into the visitor payload type. +template +VisitedNeedle MakeVisitedNeedle( + const SearchValue& needle) { + if constexpr (EmitNulls) { + return std::optional>(needle); + } else { + return needle; + } +} + +/// Read one logical needle value from a physical array position. +template +VisitedNeedle ReadVisitedNeedle(const ArrayType& array, + int64_t physical_index) { + if constexpr (EmitNulls) { + if (array.IsNull(physical_index)) { + return std::nullopt; + } + } + const auto needle = GetViewType::LogicalValue(array.GetView(physical_index)); + return MakeVisitedNeedle(needle); +} + +/// Visit each plain-array needle as a single-value logical span. +template +Status VisitArrayNeedles(const std::shared_ptr& needles_data, + Visitor&& visitor) { + using ArrayType = typename TypeTraits::ArrayType; + + ArrayType array(needles_data); + for (int64_t index = 0; index < array.length(); ++index) { + RETURN_NOT_OK( + visitor(ReadVisitedNeedle(array, index), index, index + 1)); + } + return Status::OK(); +} + +/// Visit each run of a run-end encoded needle array as one logical span. +template +Status VisitRunEndEncodedNeedleRuns(const RunEndEncodedArray& needles, + Visitor&& visitor) { + using ArrayType = typename TypeTraits::ArrayType; + + ArrayType values(needles.values()->data()); + ArraySpan array_span(*needles.data()); + ::arrow::ree_util::RunEndEncodedArraySpan span(array_span); + + for (auto it = span.begin(); !it.is_end(span); ++it) { + RETURN_NOT_OK( + visitor(ReadVisitedNeedle(values, it.index_into_array()), + it.logical_position(), it.run_end())); + } + return Status::OK(); +} + +/// Visit scalar, plain-array, or run-end encoded needles through a uniform +/// callback interface of [begin, end) logical spans. +template +Status VisitNeedles(const Datum& needles, Visitor&& visitor) { + if (needles.is_scalar()) { + if constexpr (EmitNulls) { + if (!needles.scalar()->is_valid) { + return visitor(std::optional>{}, 0, 1); + } + } + return visitor(MakeVisitedNeedle( + ExtractScalarValue(*needles.scalar())), + 0, 1); + } + + const auto& needle_data = needles.array(); + if (needle_data->type->id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray ree(needle_data); + return DispatchRunEndEncodedByRunEndType( + ree, "needles", + [&](const RunEndEncodedArray& run_end_encoded_needles) { + return VisitRunEndEncodedNeedleRuns( + run_end_encoded_needles, visitor); + }); + } + + return VisitArrayNeedles(needle_data, visitor); +} + +/// Search all needle values and write insertion indices into the preallocated output. +template +Status SearchNeedleValues(const ValuesAccessor& sorted_values, const Datum& needles, + SearchSortedOptions::Side side, uint64_t insertion_offset, + uint64_t* out) { + auto emit_search_result = [&](const SearchValue& needle, int64_t begin, + int64_t end) -> Status { + const auto insertion_index = + FindInsertionPoint(sorted_values, needle, side) + insertion_offset; + std::ranges::fill(std::span(out + begin, static_cast(end - begin)), + insertion_index); + return Status::OK(); + }; + + return VisitNeedles(needles, emit_search_result); +} + +/// Search needle values while emitting nulls for null needles. +template +Status AppendInsertionIndicesWithNulls(const ValuesAccessor& sorted_values, + const Datum& needles, + SearchSortedOptions::Side side, + uint64_t insertion_offset, + UInt64Builder& builder) { + auto emit_search_result = [&](const std::optional>& needle, + int64_t begin, int64_t end) -> Status { + const auto span_length = end - begin; + if (!needle.has_value()) { + return builder.AppendNulls(span_length); + } + const auto insertion_index = + FindInsertionPoint(sorted_values, *needle, side) + insertion_offset; + AppendInsertionIndex(builder, insertion_index, span_length); + return Status::OK(); + }; + + return VisitNeedles(needles, emit_search_result); +} + +/// Materialize output for scalar or array needles. +template +Result ComputeInsertionIndices(const ValuesAccessor& sorted_values, + const Datum& needles, + SearchSortedOptions::Side side, + uint64_t insertion_offset, ExecContext* ctx) { + if (needles.is_scalar()) { + auto scalar = needles.scalar(); + if (!scalar->is_valid) { + return Datum(std::make_shared()); + } + + const auto insertion_index = + FindInsertionPoint(sorted_values, + ExtractScalarValue(*scalar), side) + + insertion_offset; + return Datum(std::make_shared(insertion_index)); + } + + if (DatumHasNulls(needles)) { + UInt64Builder builder(ctx->memory_pool()); + ARROW_RETURN_NOT_OK(builder.Reserve(needles.length())); + RETURN_NOT_OK(AppendInsertionIndicesWithNulls(sorted_values, needles, side, + insertion_offset, builder)); + ARROW_ASSIGN_OR_RAISE(auto out, builder.Finish()); + return Datum(std::move(out)); + } + + ARROW_ASSIGN_OR_RAISE(auto out, + MakeMutableUInt64Array(needles.length(), ctx->memory_pool())); + auto* out_values = out->GetMutableValues(1); + + RETURN_NOT_OK(SearchNeedleValues(sorted_values, needles, side, + insertion_offset, out_values)); + return Datum(MakeArray(std::move(out))); +} + +// Main entry point for search_sorted over a single array of sorted values and scalar or +// array needles. Handles null presence in the needles and dispatches to the appropriate +// search implementation. +template +Result SearchWithAccessor(const ValuesAccessor& values_accessor, + const NonNullValuesRange& non_null_values_range, + const Datum& needles, SearchSortedOptions::Side side, + ExecContext* ctx) { + if (non_null_values_range.is_identity(values_accessor.length())) { + return ComputeInsertionIndices(values_accessor, needles, side, + /*insertion_offset=*/0, ctx); + } + + NonNullValuesAccessor non_null_values(values_accessor, non_null_values_range); + return ComputeInsertionIndices( + non_null_values, needles, side, static_cast(non_null_values_range.offset), + ctx); +} + +// Meta-function implementation for the search_sorted public compute entrypoint. +template +Result VisitValuesAccessor(const std::shared_ptr& values_data, + Visitor&& visitor) { + if (values_data->type->id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray ree(values_data); + return DispatchRunEndEncodedByRunEndType>( + ree, "values", + [&](const RunEndEncodedArray& run_end_encoded_values) { + RunEndEncodedValuesAccessor values_accessor( + run_end_encoded_values); + return visitor(values_accessor); + }); + } + + PlainArrayAccessor values_accessor(values_data); + return visitor(values_accessor); +} + +/// Meta-function implementation for the search_sorted public compute entrypoint. +/// Validates input shapes and types, normalizes to logical value accessors, and +/// dispatches to the typed search implementation. +class SearchSortedMetaFunction : public MetaFunction { + public: + /// Construct the registry entry with default options and documentation. + SearchSortedMetaFunction() + : MetaFunction("search_sorted", Arity::Binary(), search_sorted_doc, + GetDefaultSearchSortedOptions()) {} + + /// Validate inputs, normalize options, and dispatch to the typed search implementation. + Result ExecuteImpl(const std::vector& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + RETURN_NOT_OK(ValidateSortedValuesInput(args[0])); + RETURN_NOT_OK(ValidateNeedleInput(args[1])); + + const auto& values_type = LogicalType(args[0]); + const auto& needles_type = LogicalType(args[1]); + if (!values_type.Equals(needles_type)) { + return Status::TypeError( + "search_sorted arguments must have matching logical types, got ", + values_type.ToString(), " and ", needles_type.ToString()); + } + + const auto& values_array = args[0].array(); + ARROW_ASSIGN_OR_RAISE(auto non_null_values_range, + FindNonNullValuesRange(*values_array)); + auto result = DispatchByType(values_array, non_null_values_range, args[1], + static_cast(*options), ctx); + return result; + } + + private: + /// Dispatch the logical value type to the matching template specialization. + Result DispatchByType(const std::shared_ptr& values, + const NonNullValuesRange& non_null_values_range, + const Datum& needles, const SearchSortedOptions& options, + ExecContext* ctx) const { + switch (LogicalType(values).id()) { +#define VISIT(TYPE) \ + case TYPE::type_id: \ + return DispatchHaystack(values, non_null_values_range, needles, options.side, \ + ctx); + VISIT_SEARCH_SORTED_TYPES(VISIT) +#undef VISIT + default: + break; + } + return Status::NotImplemented("search_sorted is not implemented for type ", + LogicalType(values).ToString()); + } + + /// Dispatch the physical representation of the searched values. + template + Result DispatchHaystack(const std::shared_ptr& values, + const NonNullValuesRange& non_null_values_range, + const Datum& needles, SearchSortedOptions::Side side, + ExecContext* ctx) const { + return VisitValuesAccessor(values, [&](const auto& values_accessor) { + return SearchWithAccessor(values_accessor, non_null_values_range, + needles, side, ctx); + }); + } +}; + +} // namespace + +/// Register the search_sorted vector kernel in the global compute registry. +void RegisterVectorSearchSorted(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunction(std::make_shared())); +} + +} // namespace compute::internal +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc new file mode 100644 index 000000000000..71c78beb338b --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc @@ -0,0 +1,382 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/builder.h" +#include "arrow/compute/api_vector.h" +#include "arrow/datum.h" +#include "arrow/scalar.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" + +namespace arrow { +namespace compute { + +constexpr auto kSeed = 0x5EA4C42; +constexpr int64_t kNeedleToValueRatio = 4; +constexpr int64_t kValuesRunLength = 16; +constexpr int64_t kNeedlesRunLength = 8; +constexpr int32_t kStringMinLength = 8; +constexpr int32_t kStringMaxLength = 24; + +void SetSearchSortedQuickArgs(benchmark::internal::Benchmark* bench) { + bench->Unit(benchmark::kMicrosecond); + for (const auto size : std::vector{kL1Size, kL2Size}) { + bench->Arg(size); + } +} + +void SetSearchSortedArgs(benchmark::internal::Benchmark* bench) { + bench->Unit(benchmark::kMicrosecond); + for (const auto size : kMemorySizes) { + bench->Arg(size); + } +} + +int64_t Int64LengthFromBytes(int64_t size_bytes) { + return std::max(1, size_bytes / static_cast(sizeof(int64_t))); +} + +int64_t NeedleLengthFromBytes(int64_t size_bytes) { + return std::max(1, Int64LengthFromBytes(size_bytes) / kNeedleToValueRatio); +} + +int64_t StringLengthFromBytes(int64_t size_bytes) { + const int64_t average_length = (kStringMinLength + kStringMaxLength) / 2; + return std::max(1, size_bytes / average_length); +} + +std::shared_ptr BuildSortedInt64Values(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed); + const auto length = Int64LengthFromBytes(size_bytes); + const auto max_value = std::max(length / 8, 1); + + auto values = + std::static_pointer_cast(rand.Int64(length, 0, max_value, 0.0)); + std::vector data(values->raw_values(), + values->raw_values() + values->length()); + std::ranges::sort(data); + + Int64Builder builder; + ABORT_NOT_OK(builder.AppendValues(data)); + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildInt64Needles(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 1); + const auto length = NeedleLengthFromBytes(size_bytes); + const auto max_value = std::max(Int64LengthFromBytes(size_bytes) / 8, 1); + return std::static_pointer_cast(rand.Int64(length, 0, max_value, 0.0)); +} + +std::shared_ptr BuildSortedStringValues(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 2); + const auto length = StringLengthFromBytes(size_bytes); + auto values = std::static_pointer_cast( + rand.String(length, kStringMinLength, kStringMaxLength, 0.0)); + + std::vector data; + data.reserve(static_cast(values->length())); + for (int64_t index = 0; index < values->length(); ++index) { + data.push_back(values->GetString(index)); + } + std::ranges::sort(data); + + StringBuilder builder; + ABORT_NOT_OK(builder.AppendValues(data)); + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildStringNeedles(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 3); + const auto length = std::max(1, StringLengthFromBytes(size_bytes) / 4); + return std::static_pointer_cast( + rand.String(length, kStringMinLength, kStringMaxLength, 0.0)); +} + +std::shared_ptr BuildSortedBinaryValues(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 4); + const auto length = StringLengthFromBytes(size_bytes); + const auto unique = std::max(1, length / 8); + auto values = std::static_pointer_cast( + rand.BinaryWithRepeats(length, unique, kStringMinLength, kStringMaxLength, 0.0)); + + std::vector data; + data.reserve(static_cast(values->length())); + for (int64_t index = 0; index < values->length(); ++index) { + data.emplace_back(values->GetView(index)); + } + std::ranges::sort(data); + + BinaryBuilder builder; + ABORT_NOT_OK(builder.AppendValues(data)); + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildBinaryNeedles(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 5); + const auto length = std::max(1, StringLengthFromBytes(size_bytes) / 4); + const auto unique = std::max(1, length / 2); + return std::static_pointer_cast( + rand.BinaryWithRepeats(length, unique, kStringMinLength, kStringMaxLength, 0.0)); +} + +std::shared_ptr BuildRunHeavyInt64Values(int64_t logical_length, + int64_t run_length) { + Int64Builder builder; + ABORT_NOT_OK(builder.Reserve(logical_length)); + for (int64_t index = 0; index < logical_length; ++index) { + builder.UnsafeAppend(index / run_length); + } + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildRunHeavyInt64NeedlesWithNullRuns(int64_t logical_length, + int64_t run_length) { + std::vector data(static_cast(logical_length), 0); + std::vector is_valid(static_cast(logical_length), true); + + for (int64_t index = 0; index < logical_length; ++index) { + const int64_t run_index = index / run_length; + if (run_index % 4 == 0) { + is_valid[static_cast(index)] = false; + continue; + } + data[static_cast(index)] = run_index / 2; + } + + Int64Builder builder; + ABORT_NOT_OK(builder.AppendValues(data, is_valid)); + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildInt64NeedlesWithNullRuns(int64_t size_bytes, + int64_t run_length) { + return BuildRunHeavyInt64NeedlesWithNullRuns(NeedleLengthFromBytes(size_bytes), + run_length); +} + +std::shared_ptr BuildRunEndEncodedInt64Values(int64_t size_bytes, + int64_t run_length) { + auto values = BuildRunHeavyInt64Values(Int64LengthFromBytes(size_bytes), run_length); + return RunEndEncode(Datum(values), RunEndEncodeOptions{int32()}) + .ValueOrDie() + .make_array(); +} + +std::shared_ptr BuildRunEndEncodedInt64Needles(int64_t size_bytes, + int64_t run_length) { + auto needles = BuildRunHeavyInt64Values(NeedleLengthFromBytes(size_bytes), run_length); + return RunEndEncode(Datum(needles), RunEndEncodeOptions{int32()}) + .ValueOrDie() + .make_array(); +} + +std::shared_ptr BuildRunEndEncodedInt64NeedlesWithNullRuns(int64_t size_bytes, + int64_t run_length) { + auto needles = BuildRunHeavyInt64NeedlesWithNullRuns(NeedleLengthFromBytes(size_bytes), + run_length); + return RunEndEncode(Datum(needles), RunEndEncodeOptions{int32()}) + .ValueOrDie() + .make_array(); +} + +void SetBenchmarkCounters(benchmark::State& state, const Datum& values, + const Datum& needles) { + const auto values_length = values.length(); + const auto needles_length = needles.length(); + state.counters["values_length"] = static_cast(values_length); + state.counters["needles_length"] = static_cast(needles_length); + state.SetItemsProcessed(state.iterations() * needles_length); +} + +void RunSearchSortedBenchmark(benchmark::State& state, const Datum& values, + const Datum& needles, SearchSortedOptions::Side side) { + const SearchSortedOptions options(side); + for (auto _ : state) { + auto result = SearchSorted(values, needles, options); + ABORT_NOT_OK(result.status()); + benchmark::DoNotOptimize(result.ValueUnsafe()); + } + SetBenchmarkCounters(state, values, needles); +} + +static void BM_SearchSortedInt64ArrayNeedles(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildSortedInt64Values(state.range(0))); + const Datum needles(BuildInt64Needles(state.range(0))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedInt64ScalarNeedle(benchmark::State& state, + SearchSortedOptions::Side side) { + const auto values_array = BuildSortedInt64Values(state.range(0)); + const auto scalar_index = values_array->length() / 2; + const Datum values(values_array); + const Datum needles(std::make_shared(values_array->Value(scalar_index))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedRunEndEncodedValues(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildRunEndEncodedInt64Values(state.range(0), kValuesRunLength)); + const Datum needles(BuildInt64Needles(state.range(0))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedRunEndEncodedValuesAndNeedles(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildRunEndEncodedInt64Values(state.range(0), kValuesRunLength)); + const Datum needles(BuildRunEndEncodedInt64Needles(state.range(0), kNeedlesRunLength)); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedInt64NeedlesWithNullRuns(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildSortedInt64Values(state.range(0))); + const Datum needles(BuildInt64NeedlesWithNullRuns(state.range(0), kNeedlesRunLength)); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedRunEndEncodedNeedlesWithNullRuns( + benchmark::State& state, SearchSortedOptions::Side side) { + const Datum values(BuildRunEndEncodedInt64Values(state.range(0), kValuesRunLength)); + const Datum needles( + BuildRunEndEncodedInt64NeedlesWithNullRuns(state.range(0), kNeedlesRunLength)); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedStringArrayNeedles(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildSortedStringValues(state.range(0))); + const Datum needles(BuildStringNeedles(state.range(0))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedStringScalarNeedle(benchmark::State& state, + SearchSortedOptions::Side side) { + const auto values_array = BuildSortedStringValues(state.range(0)); + const auto scalar_index = values_array->length() / 2; + const Datum values(values_array); + const Datum needles( + std::make_shared(values_array->GetString(scalar_index))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedBinaryScalarNeedle(benchmark::State& state, + SearchSortedOptions::Side side) { + const auto values_array = BuildSortedBinaryValues(state.range(0)); + const auto scalar_index = values_array->length() / 2; + const Datum values(values_array); + const Datum needles( + std::make_shared(std::string(values_array->GetView(scalar_index)))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedInt64ArrayNeedlesQuick(benchmark::State& state, + SearchSortedOptions::Side side) { + BM_SearchSortedInt64ArrayNeedles(state, side); +} + +static void BM_SearchSortedRunEndEncodedValuesAndNeedlesQuick( + benchmark::State& state, SearchSortedOptions::Side side) { + BM_SearchSortedRunEndEncodedValuesAndNeedles(state, side); +} + +static void BM_SearchSortedInt64NeedlesWithNullRunsQuick(benchmark::State& state, + SearchSortedOptions::Side side) { + BM_SearchSortedInt64NeedlesWithNullRuns(state, side); +} + +static void BM_SearchSortedRunEndEncodedNeedlesWithNullRunsQuick( + benchmark::State& state, SearchSortedOptions::Side side) { + BM_SearchSortedRunEndEncodedNeedlesWithNullRuns(state, side); +} + +// Primitive-array and REE cases are the main baselines for the kernel TODOs around +// SIMD batched search, vectorized REE writeback, and future parallel needle traversal. + +BENCHMARK_CAPTURE(BM_SearchSortedInt64ArrayNeedles, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64ArrayNeedles, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64ScalarNeedle, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64ScalarNeedle, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValues, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValues, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedles, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedles, right, + SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRuns, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRuns, right, + SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRuns, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRuns, right, + SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); + +// String and binary scalar cases specifically exercise the direct scalar fast path that +// avoids boxing a scalar needle into a temporary one-element array. +BENCHMARK_CAPTURE(BM_SearchSortedStringScalarNeedle, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedStringScalarNeedle, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedBinaryScalarNeedle, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedBinaryScalarNeedle, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedQuickArgs); + +// Lightweight L1/L2 regressions keep a fast local loop for future optimization work. +BENCHMARK_CAPTURE(BM_SearchSortedInt64ArrayNeedlesQuick, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRunsQuick, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedlesQuick, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRunsQuick, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc new file mode 100644 index 000000000000..cc41286dc8e1 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc @@ -0,0 +1,409 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +#include "arrow/compute/api.h" +#include "arrow/compute/kernels/test_util_internal.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace { + +Result> REEFromJSON(const std::shared_ptr& ree_type, + const std::string& json) { + auto ree_type_ptr = checked_cast(ree_type.get()); + auto array = ArrayFromJSON(ree_type_ptr->value_type(), json); + ARROW_ASSIGN_OR_RAISE( + auto datum, RunEndEncode(array, RunEndEncodeOptions{ree_type_ptr->run_end_type()})); + return datum.make_array(); +} + +void CheckSimpleSearchSorted(const std::shared_ptr& type, + const std::string& values_json, + const std::string& needles_json, + const std::string& expected_left_json, + const std::string& expected_right_json) { + auto values = ArrayFromJSON(type, values_json); + auto needles = ArrayFromJSON(type, needles_json); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), expected_left_json), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), expected_right_json), *right.make_array()); +} + +void CheckSimpleScalarSearchSorted(const std::shared_ptr& type, + const std::string& values_json, + const std::string& needle_json, uint64_t expected_left, + uint64_t expected_right) { + auto values = ArrayFromJSON(type, values_json); + auto needle = ScalarFromJSON(type, needle_json); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needle), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needle), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(left.is_scalar()); + ASSERT_TRUE(right.is_scalar()); + ASSERT_EQ(checked_cast(*left.scalar()).value, expected_left); + ASSERT_EQ(checked_cast(*right.scalar()).value, expected_right); +} + +struct SearchSortedSmokeCase { + std::string name; + std::shared_ptr type; + std::string values_json; + std::string needles_json; + std::string expected_left_json; + std::string expected_right_json; + std::string scalar_needle_json; + uint64_t expected_scalar_left; + uint64_t expected_scalar_right; +}; + +std::vector SupportedTypeSmokeCases() { + return { + {"Boolean", boolean(), "[false, false, true, true]", "[false, true]", "[0, 2]", + "[2, 4]", "true", 2, 4}, + {"Int8", int8(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"Int16", int16(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"Int32", int32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"Int64", int64(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"UInt8", uint8(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"UInt16", uint16(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"UInt32", uint32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"UInt64", uint64(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"Float32", float32(), "[1.0, 3.0, 3.0, 5.0]", "[0.0, 3.0, 4.0, 6.0]", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3.0", 1, 3}, + {"Float64", float64(), "[1.0, 3.0, 3.0, 5.0]", "[0.0, 3.0, 4.0, 6.0]", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3.0", 1, 3}, + {"Date32", date32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"Date64", date64(), "[86400000, 259200000, 259200000, 432000000]", + "[0, 259200000, 345600000, 518400000]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "259200000", 1, 3}, + {"Time32", time32(TimeUnit::SECOND), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"Time64", time64(TimeUnit::NANO), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"Timestamp", timestamp(TimeUnit::SECOND), + R"(["1970-01-02", "1970-01-04", "1970-01-04", "1970-01-06"])", + R"(["1970-01-01", "1970-01-04", "1970-01-05", "1970-01-07"])", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", R"("1970-01-04")", 1, 3}, + {"Duration", duration(TimeUnit::NANO), "[1, 3, 3, 5]", "[0, 3, 4, 6]", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3", 1, 3}, + {"Binary", binary(), R"(["aa", "bb", "bb", "dd"])", R"(["a", "bb", "bc", "z"])", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, + {"String", utf8(), R"(["aa", "bb", "bb", "dd"])", R"(["a", "bb", "bc", "z"])", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, + {"LargeBinary", large_binary(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, + {"LargeString", large_utf8(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, + {"BinaryView", binary_view(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, + {"StringView", utf8_view(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, + }; +} + +class SearchSortedSupportedTypesTest + : public ::testing::TestWithParam {}; + +TEST(SearchSorted, BasicLeftRight) { + auto values = ArrayFromJSON(int64(), "[100, 200, 200, 300, 300]"); + auto needles = ArrayFromJSON(int64(), "[50, 200, 250, 400]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 1, 3, 5]"), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 3, 3, 5]"), *right.make_array()); +} + +TEST(SearchSorted, ScalarNeedle) { + auto values = ArrayFromJSON(int32(), "[1, 3, 5, 7]"); + + ASSERT_OK_AND_ASSIGN( + auto result, SearchSorted(Datum(values), Datum(std::make_shared(5)), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(result.is_scalar()); + ASSERT_EQ(checked_cast(*result.scalar()).value, 3); +} + +TEST(SearchSorted, ScalarStringNeedle) { + auto values = ArrayFromJSON(utf8(), R"(["aa", "bb", "bb", "cc"])"); + + ASSERT_OK_AND_ASSIGN( + auto result, + SearchSorted(Datum(values), Datum(std::make_shared("bb")), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(result.is_scalar()); + ASSERT_EQ(checked_cast(*result.scalar()).value, 3); +} + +TEST(SearchSorted, EmptyHaystack) { + auto values = ArrayFromJSON(int16(), "[]"); + auto needles = ArrayFromJSON(int16(), "[1, 2, 3]"); + + ASSERT_OK_AND_ASSIGN(auto result, SearchSorted(Datum(values), Datum(needles))); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 0]"), *result.make_array()); +} + +TEST(SearchSorted, ValuesWithLeadingNulls) { + auto values = ArrayFromJSON(int32(), "[null, 200, 300, 300]"); + auto needles = ArrayFromJSON(int32(), "[50, 200, 250, 400]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[1, 1, 2, 4]"), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[1, 2, 2, 4]"), *right.make_array()); +} + +TEST(SearchSorted, ValuesAllNull) { + auto values = ArrayFromJSON(int32(), "[null, null, null]"); + auto needles = ArrayFromJSON(int32(), "[50, 200, null]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[3, 3, null]"), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[3, 3, null]"), *right.make_array()); +} + +TEST(SearchSorted, ValuesWithTrailingNulls) { + auto values = ArrayFromJSON(int32(), "[200, 300, 300, null, null]"); + auto needles = ArrayFromJSON(int32(), "[50, 200, 250, 400]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 1, 3]"), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 1, 1, 3]"), *right.make_array()); +} + +TEST(SearchSorted, NullNeedlesEmitNull) { + auto values = ArrayFromJSON(int32(), "[null, 200, 300, 300]"); + auto needles = ArrayFromJSON(int32(), "[null, 50, 200, null, 400]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[null, 1, 1, null, 4]"), + *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[null, 1, 2, null, 4]"), + *right.make_array()); + + ASSERT_OK_AND_ASSIGN(auto scalar_result, + SearchSorted(Datum(values), Datum(std::make_shared()), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_TRUE(scalar_result.is_scalar()); + ASSERT_FALSE(scalar_result.scalar()->is_valid); + ASSERT_TRUE(scalar_result.scalar()->type->Equals(uint64())); +} + +TEST(SearchSorted, RejectUnclusteredNullValues) { + auto values = ArrayFromJSON(int32(), "[null, 1, null, 3]"); + auto needles = ArrayFromJSON(int32(), "[2]"); + + ASSERT_RAISES(Invalid, SearchSorted(Datum(values), Datum(needles))); +} + +TEST(SearchSorted, RunEndEncodedNulls) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_values, + REEFromJSON(values_type, "[null, null, 2, 4, 4]")); + auto needles_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_needles, + REEFromJSON(needles_type, "[null, null, 1, 4, 4, null, 8]")); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(ree_values), Datum(ree_needles), + SearchSortedOptions(SearchSortedOptions::Left))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[null, null, 2, 3, 3, null, 5]"), + *result.make_array()); +} + +TEST(SearchSorted, RunEndEncodedNeedlesWithNullRuns) { + auto values = ArrayFromJSON(int32(), "[1, 1, 3, 5, 8]"); + auto needles_type = run_end_encoded(int32(), int32()); + ASSERT_OK_AND_ASSIGN( + auto ree_needles, + REEFromJSON(needles_type, "[null, null, 0, 0, 0, 1, 1, 4, 4, 4, null, 9, 9]")); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(ree_needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(ree_needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual( + *ArrayFromJSON(uint64(), "[null, null, 0, 0, 0, 0, 0, 3, 3, 3, null, 5, 5]"), + *left.make_array()); + AssertArraysEqual( + *ArrayFromJSON(uint64(), "[null, null, 0, 0, 0, 2, 2, 3, 3, 3, null, 5, 5]"), + *right.make_array()); +} + +TEST(SearchSorted, RunEndEncodedAllNullValues) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_values, + REEFromJSON(values_type, "[null, null, null, null]")); + auto needles = ArrayFromJSON(int32(), "[null, 1, 8]"); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(ree_values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[null, 4, 4]"), *result.make_array()); +} + +TEST(SearchSorted, RejectMismatchedTypes) { + auto values = ArrayFromJSON(int32(), "[1, 2, 3]"); + auto needles = ArrayFromJSON(int64(), "[2]"); + + ASSERT_RAISES(TypeError, SearchSorted(Datum(values), Datum(needles))); +} + +TEST(SearchSorted, RunEndEncodedValues) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_values, REEFromJSON(values_type, "[1, 1, 1, 3, 3, 5]")); + auto needles = ArrayFromJSON(int32(), "[0, 1, 2, 3, 4, 5, 6]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(ree_values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(ree_values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 3, 3, 5, 5, 6]"), + *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 3, 3, 5, 5, 6, 6]"), + *right.make_array()); +} + +TEST(SearchSorted, RunEndEncodedNeedles) { + auto values = ArrayFromJSON(int32(), "[1, 1, 3, 5, 8]"); + auto needles_type = run_end_encoded(int32(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_needles, + REEFromJSON(needles_type, "[0, 0, 1, 1, 4, 4, 9]")); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(values), Datum(ree_needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 2, 2, 3, 3, 5]"), + *result.make_array()); +} + +TEST(SearchSorted, SlicedRunEndEncodedValues) { + auto values_type = run_end_encoded(int32(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_values, + REEFromJSON(values_type, "[0, 0, 1, 1, 1, 4, 4, 9]")); + auto sliced = ree_values->Slice(2, 5); + auto needles = ArrayFromJSON(int32(), "[0, 1, 2, 4, 9]"); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(sliced), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 3, 3, 5]"), *result.make_array()); +} + +TEST(SearchSorted, BinaryValues) { + auto values = ArrayFromJSON(utf8(), R"(["aa", "bb", "bb", "cc"])"); + auto needles = ArrayFromJSON(utf8(), R"(["a", "bb", "bc", "z"])"); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 1, 3, 4]"), *result.make_array()); +} + +TEST_P(SearchSortedSupportedTypesTest, ArraySmoke) { + const auto& param = GetParam(); + CheckSimpleSearchSorted(param.type, param.values_json, param.needles_json, + param.expected_left_json, param.expected_right_json); +} + +TEST_P(SearchSortedSupportedTypesTest, ScalarSmoke) { + const auto& param = GetParam(); + CheckSimpleScalarSearchSorted(param.type, param.values_json, param.scalar_needle_json, + param.expected_scalar_left, param.expected_scalar_right); +} + +INSTANTIATE_TEST_SUITE_P(SupportedTypes, SearchSortedSupportedTypesTest, + ::testing::ValuesIn(SupportedTypeSmokeCases()), + [](const ::testing::TestParamInfo& info) { + return info.param.name; + }); + +} // namespace +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index 5b9d7f8d608f..d457f6fb95c0 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -50,6 +50,7 @@ void RegisterVectorHash(FunctionRegistry* registry); void RegisterVectorNested(FunctionRegistry* registry); void RegisterVectorRank(FunctionRegistry* registry); void RegisterVectorReplace(FunctionRegistry* registry); +void RegisterVectorSearchSorted(FunctionRegistry* registry); void RegisterVectorSelectK(FunctionRegistry* registry); void RegisterVectorSelection(FunctionRegistry* registry); void RegisterVectorSort(FunctionRegistry* registry); diff --git a/cpp/src/arrow/meson.build b/cpp/src/arrow/meson.build index fe7f11af6ff8..3f6c59fe7d97 100644 --- a/cpp/src/arrow/meson.build +++ b/cpp/src/arrow/meson.build @@ -590,6 +590,7 @@ if needs_compute 'compute/kernels/vector_rank.cc', 'compute/kernels/vector_replace.cc', 'compute/kernels/vector_run_end_encode.cc', + 'compute/kernels/vector_search_sorted.cc', 'compute/kernels/vector_select_k.cc', 'compute/kernels/vector_sort.cc', 'compute/kernels/vector_statistics.cc', diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index e4092af70cde..d16a123b8028 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1873,6 +1873,8 @@ in the respective option classes. +-----------------------+------------+---------------------------------------------------------+-------------------+-------------------------------+----------------+ | sort_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SortOptions` | \(1) \(6) | +-----------------------+------------+---------------------------------------------------------+-------------------+-------------------------------+----------------+ +| search_sorted | Binary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SearchSortedOptions` | \(8) | ++-----------------------+------------+---------------------------------------------------------+-------------------+-------------------------------+----------------+ * \(1) The output is an array of indices into the input, that define a @@ -1901,6 +1903,13 @@ in the respective option classes. * \(7) The output is an array of indices into the input, that define a non-stable sort of the input. +* \(8) The first argument must be sorted in ascending order. If it contains + nulls, they must be clustered entirely at the start or the end, and non-null + needles are matched only against the non-null portion. The second argument + may be a scalar, array, or run-end encoded array. Null needles yield null + outputs. Both arguments must have the same logical type. A scalar needle + yields a UInt64 scalar; otherwise the result is a UInt64 array. + .. _cpp-compute-vector-structural-transforms: Structural transforms diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 137b034d6ffc..bf3d7ae1cb59 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2074,6 +2074,14 @@ cdef CNullPlacement unwrap_null_placement(null_placement) except *: _raise_invalid_function_option(null_placement, "null placement") +cdef CSearchSortedSide unwrap_search_sorted_side(side) except *: + if side == "left": + return CSearchSortedSide_Left + elif side == "right": + return CSearchSortedSide_Right + _raise_invalid_function_option(side, "search sorted side") + + cdef class _PartitionNthOptions(FunctionOptions): def _set_options(self, pivot, null_placement): self.wrapped.reset(new CPartitionNthOptions( @@ -2243,6 +2251,27 @@ class ArraySortOptions(_ArraySortOptions): self._set_options(order, null_placement) +cdef class _SearchSortedOptions(FunctionOptions): + def _set_options(self, side): + self.wrapped.reset(new CSearchSortedOptions( + unwrap_search_sorted_side(side))) + + +class SearchSortedOptions(_SearchSortedOptions): + """ + Options for the `search_sorted` function. + + Parameters + ---------- + side : str, default "left" + Whether to return the leftmost or rightmost insertion point. + Accepted values are "left", "right". + """ + + def __init__(self, side="left"): + self._set_options(side) + + cdef class _SortOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): self.wrapped.reset(new CSortOptions( diff --git a/python/pyarrow/_compute_docstrings.py b/python/pyarrow/_compute_docstrings.py index 079f00db7d92..e5189f063b7f 100644 --- a/python/pyarrow/_compute_docstrings.py +++ b/python/pyarrow/_compute_docstrings.py @@ -42,6 +42,41 @@ ] """ +function_doc_additions["search_sorted"] = """ + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> values = pa.array([1, 1, 3, 5, 8]) + >>> pc.search_sorted(values, pa.array([0, 1, 4, 9])) + + [ + 0, + 0, + 3, + 5 + ] + >>> with_nulls = pa.array([None, 200, 300, 300], type=pa.int64()) + >>> pc.search_sorted( + ... with_nulls, pa.array([50, 200, None, 400], type=pa.int64()) + ... ) + + [ + 1, + 1, + null, + 4 + ] + >>> pc.search_sorted(values, pa.array([0, 1, 4, 9]), side="right") + + [ + 0, + 2, + 3, + 5 + ] + """ + function_doc_additions["mode"] = """ Examples -------- diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 8177948aaebc..0a2e231d189e 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -68,6 +68,7 @@ RoundToMultipleOptions, ScalarAggregateOptions, ScatterOptions, + SearchSortedOptions, SelectKOptions, SetLookupOptions, SkewOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e96a7d84696d..0b67cf0c382c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2781,6 +2781,16 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CSortOrder order CNullPlacement null_placement + cdef enum CSearchSortedSide \ + "arrow::compute::SearchSortedOptions::Side": + CSearchSortedSide_Left "arrow::compute::SearchSortedOptions::Left" + CSearchSortedSide_Right "arrow::compute::SearchSortedOptions::Right" + + cdef cppclass CSearchSortedOptions \ + "arrow::compute::SearchSortedOptions"(CFunctionOptions): + CSearchSortedOptions(CSearchSortedSide side) + CSearchSortedSide side + cdef cppclass CSortKey" arrow::compute::SortKey": CSortKey(CFieldRef target, CSortOrder order) CFieldRef target diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 8c3b09f612cc..35e6813a8a9b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -3021,6 +3021,99 @@ def test_array_sort_indices(): pc.array_sort_indices(arr, order="nonscending") +def test_search_sorted(): + values = pa.array([1, 1, 3, 5, 8]) + needles = pa.array([0, 1, 3, 4, 5, 8, 9]) + + expected_left = pa.array([0, 0, 2, 3, 3, 4, 5], type=pa.uint64()) + expected_right = pa.array([0, 2, 3, 3, 4, 5, 5], type=pa.uint64()) + + assert pc.search_sorted(values, needles).equals(expected_left) + assert pc.search_sorted(values, needles, side="left").equals(expected_left) + assert pc.search_sorted(values, needles, "right").equals(expected_right) + assert pc.search_sorted( + values, needles, options=pc.SearchSortedOptions(side="right") + ).equals(expected_right) + + assert pc.search_sorted(values, pa.scalar(5, type=pa.int64())).as_py() == 3 + assert pc.search_sorted( + values, pa.scalar(5, type=pa.int64()), side="right" + ).as_py() == 4 + + +def test_search_sorted_null_values(): + needles = pa.array([50, 200, 250, 400], type=pa.int64()) + + values = pa.array([None, 200, 300, 300], type=pa.int64()) + expected_left = pa.array([1, 1, 2, 4], type=pa.uint64()) + expected_right = pa.array([1, 2, 2, 4], type=pa.uint64()) + assert pc.search_sorted(values, needles, side="left").equals(expected_left) + assert pc.search_sorted(values, needles, side="right").equals(expected_right) + + values = pa.array([200, 300, 300, None, None], type=pa.int64()) + expected_left = pa.array([0, 0, 1, 3], type=pa.uint64()) + expected_right = pa.array([0, 1, 1, 3], type=pa.uint64()) + assert pc.search_sorted(values, needles, side="left").equals(expected_left) + assert pc.search_sorted(values, needles, side="right").equals(expected_right) + + +def test_search_sorted_null_needles_emit_null(): + values = pa.array([None, 200, 300, 300], type=pa.int64()) + needles = pa.array([None, 50, 200, None, 400], type=pa.int64()) + + expected_left = pa.array([None, 1, 1, None, 4], type=pa.uint64()) + expected_right = pa.array([None, 1, 2, None, 4], type=pa.uint64()) + + assert pc.search_sorted(values, needles, side="left").equals(expected_left) + assert pc.search_sorted(values, needles, side="right").equals(expected_right) + + scalar_result = pc.search_sorted(values, pa.scalar(None, type=pa.int64())) + assert scalar_result.as_py() is None + + +def test_search_sorted_run_end_encoded(): + run_ends = pa.array([2, 3, 4, 5], type=pa.int16()) + encoded_values = pa.array([1, 3, 5, 8], type=pa.int64()) + values = pa.RunEndEncodedArray.from_arrays(run_ends, encoded_values) + needles = pa.array([0, 1, 3, 4, 5, 8, 9], type=pa.int64()) + + expected_left = pa.array([0, 0, 2, 3, 3, 4, 5], type=pa.uint64()) + assert pc.search_sorted(values, needles).equals(expected_left) + + ree_needles = pa.RunEndEncodedArray.from_arrays( + pa.array([2, 4, 6], type=pa.int16()), + pa.array([1, 4, 9], type=pa.int64()) + ) + expected_right = pa.array([2, 2, 3, 3, 5, 5], type=pa.uint64()) + assert pc.search_sorted(values, ree_needles, side="right").equals( + expected_right + ) + + +def test_search_sorted_run_end_encoded_nulls(): + values = pa.RunEndEncodedArray.from_arrays( + pa.array([2, 3, 5], type=pa.int16()), + pa.array([None, 2, 4], type=pa.int64()) + ) + needles = pa.RunEndEncodedArray.from_arrays( + pa.array([2, 3, 5, 6], type=pa.int16()), + pa.array([None, 1, 4, None], type=pa.int64()) + ) + + expected = pa.array([None, None, 2, 3, 3, None], type=pa.uint64()) + assert pc.search_sorted(values, needles, side="left").equals(expected) + + +def test_search_sorted_errors(): + values = pa.array([1, 1, 3, 5, 8]) + + with pytest.raises(ValueError, match='"middle" is not a valid search sorted side'): + pc.search_sorted(values, pa.array([1]), side="middle") + + with pytest.raises(pa.ArrowInvalid, match="clustered at the start or end"): + pc.search_sorted(pa.array([None, 1, None], type=pa.int64()), pa.array([1])) + + def test_sort_indices_array(): arr = pa.array([1, 2, None, 0]) result = pc.sort_indices(arr)