From c9d53555f04dabf6b151961dc9780411eb795b11 Mon Sep 17 00:00:00 2001 From: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:59:45 +0200 Subject: [PATCH 1/8] Implement benchmarks and tests for SearchSorted functionality - Added a new benchmark file `vector_search_sorted_benchmark.cc` to evaluate the performance of the SearchSorted function for various data types including Int64, String, and Binary. - Created a comprehensive test suite in `vector_search_sorted_test.cc` to validate the correctness of SearchSorted across different scenarios, including handling of null values, scalar needles, and run-end encoded arrays. - Ensured that the benchmarks cover both left and right search options, as well as edge cases like empty arrays and arrays with leading/trailing nulls. --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/api_vector.cc | 28 + cpp/src/arrow/compute/api_vector.h | 36 + cpp/src/arrow/compute/initialize.cc | 1 + cpp/src/arrow/compute/kernels/CMakeLists.txt | 8 + cpp/src/arrow/compute/kernels/meson.build | 1 + .../compute/kernels/vector_search_sorted.cc | 645 ++++++++++++++++++ .../kernels/vector_search_sorted_benchmark.cc | 297 ++++++++ .../kernels/vector_search_sorted_test.cc | 282 ++++++++ cpp/src/arrow/compute/registry_internal.h | 1 + cpp/src/arrow/meson.build | 1 + docs/source/cpp/compute.rst | 9 + python/pyarrow/_compute.pyx | 29 + python/pyarrow/_compute_docstrings.py | 33 + python/pyarrow/compute.py | 1 + python/pyarrow/includes/libarrow.pxd | 10 + python/pyarrow/tests/test_compute.py | 93 +++ 17 files changed, 1476 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/vector_search_sorted.cc create mode 100644 cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc create mode 100644 cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index eee63b11ca1c..8585b6aaec7f 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -795,6 +795,7 @@ if(ARROW_COMPUTE) compute/kernels/vector_rank.cc compute/kernels/vector_replace.cc compute/kernels/vector_run_end_encode.cc + compute/kernels/vector_search_sorted.cc compute/kernels/vector_select_k.cc compute/kernels/vector_sort.cc compute/kernels/vector_statistics.cc diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 1bf4de93520c..689788c3847d 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -50,6 +50,7 @@ using compute::FilterOptions; using compute::NullPlacement; using compute::RankOptions; using compute::RankQuantileOptions; +using compute::SearchSortedOptions; template <> struct EnumTraits @@ -96,6 +97,21 @@ struct EnumTraits } }; template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "SearchSortedOptions::Side"; } + static std::string value_name(SearchSortedOptions::Side value) { + switch (value) { + case SearchSortedOptions::Left: + return "Left"; + case SearchSortedOptions::Right: + return "Right"; + } + return ""; + } +}; +template <> struct EnumTraits : BasicEnumTraits { @@ -137,6 +153,8 @@ static auto kRunEndEncodeOptionsType = GetFunctionOptionsType( DataMember("order", &ArraySortOptions::order), DataMember("null_placement", &ArraySortOptions::null_placement)); +static auto kSearchSortedOptionsType = GetFunctionOptionsType( + DataMember("side", &SearchSortedOptions::side)); static auto kSortOptionsType = GetFunctionOptionsType( DataMember("sort_keys", &SortOptions::sort_keys), DataMember("null_placement", &SortOptions::null_placement)); @@ -196,6 +214,10 @@ ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; +SearchSortedOptions::SearchSortedOptions(SearchSortedOptions::Side side) + : FunctionOptions(internal::kSearchSortedOptionsType), side(side) {} +constexpr char SearchSortedOptions::kTypeName[]; + SortOptions::SortOptions(std::vector sort_keys, NullPlacement null_placement) : FunctionOptions(internal::kSortOptionsType), sort_keys(std::move(sort_keys)), @@ -274,6 +296,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kDictionaryEncodeOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kRunEndEncodeOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kArraySortOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSearchSortedOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType)); @@ -315,6 +338,11 @@ Result> SelectKUnstable(const Datum& datum, return result.make_array(); } +Result SearchSorted(const Datum& values, const Datum& needles, + const SearchSortedOptions& options, ExecContext* ctx) { + return CallFunction("search_sorted", {values, needles}, &options, ctx); +} + Result ReplaceWithMask(const Datum& values, const Datum& mask, const Datum& replacements, ExecContext* ctx) { return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 159a787641ee..2d003726804f 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -102,6 +102,21 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { NullPlacement null_placement; }; +class ARROW_EXPORT SearchSortedOptions : public FunctionOptions { + public: + enum Side { + Left, + Right, + }; + + explicit SearchSortedOptions(Side side = Side::Left); + static constexpr const char kTypeName[] = "SearchSortedOptions"; + static SearchSortedOptions Defaults() { return SearchSortedOptions(); } + + /// Whether to return the leftmost or rightmost insertion point. + Side side; +}; + class ARROW_EXPORT SortOptions : public FunctionOptions { public: explicit SortOptions(std::vector sort_keys = {}, @@ -515,6 +530,27 @@ Result> SelectKUnstable(const Datum& datum, const SelectKOptions& options, ExecContext* ctx = NULLPTR); +/// \brief Find insertion indices that preserve sorted order. +/// +/// The `values` datum must be a plain array or run-end encoded array sorted in +/// ascending order. `needles` may be a scalar, plain array, or run-end encoded +/// array whose logical value type matches `values`. +/// +/// Nulls in `values` are supported when clustered entirely at the start or the +/// end of the sorted array. Non-null needles are matched only against the +/// non-null portion of `values`. Null needles yield null outputs. +/// +/// \param[in] values sorted array to search within +/// \param[in] needles scalar or array-like values to search for +/// \param[in] options selects left or right insertion semantics +/// \param[in] ctx the function execution context, optional +/// \return insertion indices as uint64 scalar or array +ARROW_EXPORT +Result SearchSorted( + const Datum& values, const Datum& needles, + const SearchSortedOptions& options = SearchSortedOptions::Defaults(), + ExecContext* ctx = NULLPTR); + /// \brief Return the indices that would sort an array. /// /// Perform an indirect sort of array. The output array will contain diff --git a/cpp/src/arrow/compute/initialize.cc b/cpp/src/arrow/compute/initialize.cc index d88835da04ac..ec531e8d490f 100644 --- a/cpp/src/arrow/compute/initialize.cc +++ b/cpp/src/arrow/compute/initialize.cc @@ -48,6 +48,7 @@ Status RegisterComputeKernels() { internal::RegisterVectorNested(registry); internal::RegisterVectorRank(registry); internal::RegisterVectorReplace(registry); + internal::RegisterVectorSearchSorted(registry); internal::RegisterVectorSelectK(registry); internal::RegisterVectorSort(registry); internal::RegisterVectorRunEndEncode(registry); diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 15955b5ef883..d07356aa6300 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -121,6 +121,13 @@ add_arrow_compute_test(vector_sort_test arrow_compute_kernels_testing arrow_compute_testing) +add_arrow_compute_test(vector_search_sorted_test + SOURCES + vector_search_sorted_test.cc + EXTRA_LINK_LIBS + arrow_compute_kernels_testing + arrow_compute_testing) + add_arrow_compute_test(vector_selection_test SOURCES vector_selection_test.cc @@ -141,6 +148,7 @@ add_arrow_compute_benchmark(vector_sort_benchmark) add_arrow_compute_benchmark(vector_partition_benchmark) add_arrow_compute_benchmark(vector_topk_benchmark) add_arrow_compute_benchmark(vector_replace_benchmark) +add_arrow_compute_benchmark(vector_search_sorted_benchmark) add_arrow_compute_benchmark(vector_selection_benchmark) # ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/meson.build b/cpp/src/arrow/compute/kernels/meson.build index fb682443783c..0a8c16f9766b 100644 --- a/cpp/src/arrow/compute/kernels/meson.build +++ b/cpp/src/arrow/compute/kernels/meson.build @@ -132,6 +132,7 @@ vector_kernel_benchmarks = [ 'vector_partition_benchmark', 'vector_topk_benchmark', 'vector_replace_benchmark', + 'vector_search_sorted_benchmark', 'vector_selection_benchmark', ] diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc new file mode 100644 index 000000000000..6df52f931fda --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc @@ -0,0 +1,645 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_vector.h" + +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_primitive.h" +#include "arrow/array/array_run_end.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/array/util.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/kernels/vector_sort_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/registry_internal.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging_internal.h" +#include "arrow/util/ree_util.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute::internal { +namespace { + +const SearchSortedOptions* GetDefaultSearchSortedOptions() { + static const auto kDefaultSearchSortedOptions = SearchSortedOptions::Defaults(); + return &kDefaultSearchSortedOptions; +} + +const FunctionDoc search_sorted_doc( + "Find insertion indices for sorted input", + ("Return the index where each needle should be inserted in a sorted input array\n" + "to maintain ascending order.\n" + "\n" + "With side='left', returns the first suitable index (lower bound).\n" + "With side='right', returns the last suitable index (upper bound).\n" + "\n" + "The input array must already be sorted in ascending order. Null values in\n" + "the searched array are supported when clustered entirely at the start or\n" + "entirely at the end. Non-null needles are matched only against the non-null\n" + "portion of the searched array. Null needles emit nulls in the output."), + {"values", "needles"}, "SearchSortedOptions"); + +#define VISIT_SEARCH_SORTED_TYPES(VISIT) \ + VISIT(BooleanType) \ + VISIT(Int8Type) \ + VISIT(Int16Type) \ + VISIT(Int32Type) \ + VISIT(Int64Type) \ + VISIT(UInt8Type) \ + VISIT(UInt16Type) \ + VISIT(UInt32Type) \ + VISIT(UInt64Type) \ + VISIT(FloatType) \ + VISIT(DoubleType) \ + VISIT(Date32Type) \ + VISIT(Date64Type) \ + VISIT(Time32Type) \ + VISIT(Time64Type) \ + VISIT(TimestampType) \ + VISIT(DurationType) \ + VISIT(BinaryType) \ + VISIT(StringType) \ + VISIT(LargeBinaryType) \ + VISIT(LargeStringType) \ + VISIT(BinaryViewType) \ + VISIT(StringViewType) + +template +using SearchValue = typename GetViewType::T; + +/// Comparator implementing Arrow's ascending-order semantics for supported types. +template +struct SearchSortedCompare { + using ValueType = SearchValue; + + int operator()(const ValueType& left, const ValueType& right) const { + return CompareTypeValues(left, right, SortOrder::Ascending, + NullPlacement::AtEnd); + } +}; + +/// Access logical values from a plain Arrow array. +template +class PlainArrayAccessor { + public: + using ArrayType = typename TypeTraits::ArrayType; + using ValueType = SearchValue; + + /// Build a typed accessor over a plain array payload. + explicit PlainArrayAccessor(const Array& array) : array_(array.data()) {} + + /// Return the logical length of the searched values. + int64_t length() const { return array_.length(); } + + /// Return the logical value at the given logical position. + ValueType Value(int64_t index) const { + return GetViewType::LogicalValue(array_.GetView(index)); + } + + private: + ArrayType array_; +}; + +/// Access logical values from a run-end encoded Arrow array. +template +class RunEndEncodedValuesAccessor { + public: + using ArrayType = typename TypeTraits::ArrayType; + using ValueType = SearchValue; + + /// Build a typed accessor over a run-end encoded payload. + explicit RunEndEncodedValuesAccessor(const RunEndEncodedArray& array) + : values_(array.values()->data()), array_span_(*array.data()), span_(array_span_) {} + + /// Return the logical length of the searched values. + int64_t length() const { return span_.length(); } + + /// Return the logical value at the given logical position. + ValueType Value(int64_t index) const { + const auto physical_index = span_.PhysicalIndex(index); + return GetViewType::LogicalValue(values_.GetView(physical_index)); + } + + private: + ArrayType values_; + ArraySpan array_span_; + ::arrow::ree_util::RunEndEncodedArraySpan span_; +}; + +struct NonNullValuesRange { + int64_t offset = 0; + int64_t length = 0; + + /// Return whether the range spans the full searched values input. + bool is_identity(int64_t full_length) const { + return offset == 0 && length == full_length; + } +}; + +/// Present a contiguous non-null slice of the searched values through the same +/// accessor interface as the original values container. +template +class NonNullValuesAccessor { + public: + /// Wrap the original accessor with the discovered non-null subrange. + explicit NonNullValuesAccessor(const ValuesAccessor& values, + const NonNullValuesRange& non_null_values_range) + : values_(values), + offset_(non_null_values_range.offset), + length_(non_null_values_range.length) {} + + /// Return the number of accessible non-null values. + int64_t length() const { return length_; } + + /// Return the value at the given index within the non-null subrange. + auto Value(int64_t index) const { return values_.Value(offset_ + index); } + + private: + const ValuesAccessor& values_; + int64_t offset_; + int64_t length_; +}; + +/// Return the logical type of an array, unwrapping run-end encoding when present. +inline const DataType& LogicalType(const Array& array) { + const auto& type = *array.type(); + if (type.id() == Type::RUN_END_ENCODED) { + return *checked_cast(type).value_type(); + } + return type; +} + +/// Return the logical type of a datum, unwrapping run-end encoding when present. +inline const DataType& LogicalType(const Datum& datum) { + if (datum.is_scalar()) { + return *datum.scalar()->type; + } + return LogicalType(*datum.make_array()); +} + +/// Return whether a scalar or array needle input contains any logical nulls. +inline bool DatumHasNulls(const Datum& datum) { + if (datum.is_scalar()) { + return !datum.scalar()->is_valid; + } + + auto array = datum.make_array(); + const bool has_nulls = array->null_count() > 0; + if (array->type_id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray run_end_encoded(array->data()); + return run_end_encoded.values()->null_count() != 0 || has_nulls; + } + return has_nulls; +} + +/// Reject nested run-end encoded values. TODO: Support this case in the future if there +/// is demand for it. +inline Status ValidateRunEndEncodedLogicalValueType(const DataType& type, + const char* name) { + const auto& ree_type = checked_cast(type); + if (ree_type.value_type()->id() == Type::RUN_END_ENCODED) { + return Status::TypeError("Nested run-end encoded ", name, " are not supported"); + } + return Status::OK(); +} + +/// Compute the contiguous non-null window of the searched values. +/// +inline Result FindNonNullValuesRange(const Array& values) { + NonNullValuesRange non_null_values_range{.offset = 0, .length = values.length()}; + + const auto null_count = values.null_count(); + if (null_count == 0) { + return non_null_values_range; + } + + int64_t leading_null_count = 0; + while (leading_null_count < values.length() && values.IsNull(leading_null_count)) { + ++leading_null_count; + } + + if (leading_null_count == values.length()) { + non_null_values_range.offset = values.length(); + non_null_values_range.length = 0; + return non_null_values_range; + } + + if (leading_null_count > 0) { + if (leading_null_count != null_count) { + return Status::Invalid( + "search_sorted values with nulls must be clustered at the start or end"); + } + non_null_values_range.offset = leading_null_count; + non_null_values_range.length = values.length() - leading_null_count; + return non_null_values_range; + } + + int64_t trailing_null_count = 0; + while (trailing_null_count < values.length() && + values.IsNull(values.length() - 1 - trailing_null_count)) { + ++trailing_null_count; + } + + if (trailing_null_count == 0 || trailing_null_count != null_count) { + return Status::Invalid( + "search_sorted values with nulls must be clustered at the start or end"); + } + + non_null_values_range.length = values.length() - trailing_null_count; + return non_null_values_range; +} + +/// Validate the searched values input shape and supported encoding. +inline Status ValidateSortedValuesInput(const Datum& datum) { + if (!datum.is_array()) { + return Status::TypeError("search_sorted values must be an array"); + } + + const auto& type = *datum.type(); + if (type.id() == Type::RUN_END_ENCODED) { + return ValidateRunEndEncodedLogicalValueType(type, "values"); + } + + return Status::OK(); +} + +/// Validate the needles input shape and supported encoding. +inline Status ValidateNeedleInput(const Datum& datum) { + if (!(datum.is_array() || datum.is_scalar())) { + return Status::TypeError("search_sorted needles must be a scalar or array"); + } + + if (datum.is_array() && datum.type()->id() == Type::RUN_END_ENCODED) { + return ValidateRunEndEncodedLogicalValueType(*datum.type(), "needles"); + } + return Status::OK(); +} + +/// Perform a lower- or upper-bound binary search over already sorted values. +template +uint64_t FindInsertionPoint(const Accessor& sorted_values, + const SearchValue& needle, + SearchSortedOptions::Side side) { + SearchSortedCompare compare; + int64_t first = 0; + int64_t count = sorted_values.length(); + + // TODO(search_sorted): For fixed-width primitive haystacks, investigate a SIMD-friendly + // batched search path . + while (count > 0) { + const int64_t step = count / 2; + const int64_t it = first + step; + const bool advance = side == SearchSortedOptions::Left + ? compare(sorted_values.Value(it), needle) < 0 + : compare(needle, sorted_values.Value(it)) >= 0; + if (advance) { + first = it + 1; + count -= step + 1; + } else { + count = step; + } + } + return static_cast(first); +} + +/// Read a scalar needle without materializing a one-element array. +template +SearchValue ExtractScalarValue(const Scalar& scalar) { + using ScalarType = typename TypeTraits::ScalarType; + const auto& typed_scalar = checked_cast(scalar); + + if constexpr (std::is_base_of_v) { + return GetViewType::LogicalValue(typed_scalar.view()); + } else { + return GetViewType::LogicalValue(typed_scalar.value); + } +} + +/// Append the same insertion index repeatedly for a logical run of needles. +inline void AppendInsertionIndex(UInt64Builder& builder, uint64_t insertion_index, + int64_t count) { + if (count == 0) { + return; + } + DCHECK_LE(builder.length() + count, builder.capacity()); + std::fill_n(builder.GetMutableValue(builder.length()), count, insertion_index); + builder.UnsafeAdvance(count); +} + +/// Dispatch a run-end encoded array to the matching run-end physical type. +template +ReturnType DispatchRunEndEncodedByRunEndType(const RunEndEncodedArray& array, + const char* argument_name, + Visitor&& visitor) { + const auto& ree_type = checked_cast(*array.type()); + switch (ree_type.run_end_type()->id()) { + case Type::INT16: + return std::forward(visitor).template operator()(array); + case Type::INT32: + return std::forward(visitor).template operator()(array); + case Type::INT64: + return std::forward(visitor).template operator()(array); + default: + return ReturnType(Status::TypeError("Unsupported run-end type for search_sorted ", + argument_name, ": ", array.type()->ToString())); + } +} + +template +using VisitedNeedle = std::conditional_t>, + SearchValue>; + +/// Normalize a non-null logical needle into the visitor payload type. +template +VisitedNeedle MakeVisitedNeedle( + const SearchValue& needle) { + if constexpr (EmitNulls) { + return std::optional>(needle); + } else { + return needle; + } +} + +/// Read one logical needle value from a physical array position. +template +VisitedNeedle ReadVisitedNeedle(const ArrayType& array, + int64_t physical_index) { + if constexpr (EmitNulls) { + if (array.IsNull(physical_index)) { + return std::nullopt; + } + } + const auto needle = GetViewType::LogicalValue(array.GetView(physical_index)); + return MakeVisitedNeedle(needle); +} + +/// Visit each plain-array needle as a single-value logical span. +template +Status VisitArrayNeedles(const Array& needles, Visitor&& visitor) { + using ArrayType = typename TypeTraits::ArrayType; + + ArrayType array(needles.data()); + for (int64_t index = 0; index < array.length(); ++index) { + RETURN_NOT_OK( + visitor(ReadVisitedNeedle(array, index), index, index + 1)); + } + return Status::OK(); +} + +/// Visit each run of a run-end encoded needle array as one logical span. +template +Status VisitRunEndEncodedNeedleRuns(const RunEndEncodedArray& needles, + Visitor&& visitor) { + using ArrayType = typename TypeTraits::ArrayType; + + ArrayType values(needles.values()->data()); + ArraySpan array_span(*needles.data()); + ::arrow::ree_util::RunEndEncodedArraySpan span(array_span); + + for (auto it = span.begin(); !it.is_end(span); ++it) { + const auto physical_index = it.index_into_array(); + RETURN_NOT_OK(visitor(ReadVisitedNeedle(values, physical_index), + it.logical_position(), it.run_end())); + } + return Status::OK(); +} + +/// Visit scalar, plain-array, or run-end encoded needles through a uniform +/// callback interface of [begin, end) logical spans. +template +Status VisitNeedles(const Datum& needles, Visitor&& visitor) { + if (needles.is_scalar()) { + if constexpr (EmitNulls) { + if (!needles.scalar()->is_valid) { + return visitor(std::optional>{}, 0, 1); + } + } + return visitor(MakeVisitedNeedle( + ExtractScalarValue(*needles.scalar())), + 0, 1); + } + + auto needle_array = needles.make_array(); + if (needle_array->type_id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray ree(needle_array->data()); + return DispatchRunEndEncodedByRunEndType( + ree, "needles", + [&](const RunEndEncodedArray& run_end_encoded_needles) { + return VisitRunEndEncodedNeedleRuns( + run_end_encoded_needles, visitor); + }); + } + + return VisitArrayNeedles(*needle_array, visitor); +} + +/// Search all needle values and write insertion indices into the preallocated output. +template +Status SearchNeedleValues(const ValuesAccessor& sorted_values, const Datum& needles, + SearchSortedOptions::Side side, uint64_t insertion_offset, + uint64_t* out) { + auto emit_search_result = [&](const SearchValue& needle, int64_t begin, + int64_t end) -> Status { + const auto insertion_index = + FindInsertionPoint(sorted_values, needle, side) + insertion_offset; + std::ranges::fill(std::span(out + begin, static_cast(end - begin)), + insertion_index); + return Status::OK(); + }; + + return VisitNeedles(needles, emit_search_result); +} + +/// Search needle values while emitting nulls for null needles. +template +Status AppendInsertionIndicesWithNulls(const ValuesAccessor& sorted_values, + const Datum& needles, + SearchSortedOptions::Side side, + uint64_t insertion_offset, + UInt64Builder& builder) { + auto emit_search_result = [&](const std::optional>& needle, + int64_t begin, int64_t end) -> Status { + const auto span_length = end - begin; + if (!needle.has_value()) { + return builder.AppendNulls(span_length); + } + const auto insertion_index = + FindInsertionPoint(sorted_values, *needle, side) + insertion_offset; + AppendInsertionIndex(builder, insertion_index, span_length); + return Status::OK(); + }; + + return VisitNeedles(needles, emit_search_result); +} + +/// Materialize output for scalar or array needles. +template +Result ComputeInsertionIndices(const ValuesAccessor& sorted_values, + const Datum& needles, + SearchSortedOptions::Side side, + uint64_t insertion_offset, ExecContext* ctx) { + if (needles.is_scalar() && !needles.scalar()->is_valid) { + return Datum(std::make_shared()); + } + + if (needles.is_scalar()) { + const auto insertion_index = + FindInsertionPoint( + sorted_values, ExtractScalarValue(*needles.scalar()), side) + + insertion_offset; + return Datum(std::make_shared(insertion_index)); + } + + if (DatumHasNulls(needles)) { + UInt64Builder builder(ctx->memory_pool()); + ARROW_RETURN_NOT_OK(builder.Reserve(needles.length())); + RETURN_NOT_OK(AppendInsertionIndicesWithNulls(sorted_values, needles, side, + insertion_offset, builder)); + ARROW_ASSIGN_OR_RAISE(auto out, builder.Finish()); + return Datum(std::move(out)); + } + + ARROW_ASSIGN_OR_RAISE(auto out, + MakeMutableUInt64Array(needles.length(), ctx->memory_pool())); + auto* out_values = out->GetMutableValues(1); + + RETURN_NOT_OK(SearchNeedleValues(sorted_values, needles, side, + insertion_offset, out_values)); + return Datum(MakeArray(std::move(out))); +} + +// Main entry point for search_sorted over a single array of sorted values and scalar or +// array needles. Handles null presence in the needles and dispatches to the appropriate +// search implementation. +template +Result SearchWithAccessor(const ValuesAccessor& values_accessor, + const NonNullValuesRange& non_null_values_range, + const Datum& needles, SearchSortedOptions::Side side, + ExecContext* ctx) { + if (non_null_values_range.is_identity(values_accessor.length())) { + return ComputeInsertionIndices(values_accessor, needles, side, + /*insertion_offset=*/0, ctx); + } + + NonNullValuesAccessor non_null_values(values_accessor, non_null_values_range); + return ComputeInsertionIndices( + non_null_values, needles, side, static_cast(non_null_values_range.offset), + ctx); +} + +// Meta-function implementation for the search_sorted public compute entrypoint. +template +Result VisitValuesAccessor(const Array& values, Visitor&& visitor) { + if (values.type_id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray ree(values.data()); + return DispatchRunEndEncodedByRunEndType>( + ree, "values", + [&](const RunEndEncodedArray& run_end_encoded_values) { + RunEndEncodedValuesAccessor values_accessor( + run_end_encoded_values); + return visitor(values_accessor); + }); + } + + PlainArrayAccessor values_accessor(values); + return visitor(values_accessor); +} + +/// Meta-function implementation for the search_sorted public compute entrypoint. +class SearchSortedMetaFunction : public MetaFunction { + public: + /// Construct the registry entry with default options and documentation. + SearchSortedMetaFunction() + : MetaFunction("search_sorted", Arity::Binary(), search_sorted_doc, + GetDefaultSearchSortedOptions()) {} + + /// Validate inputs, normalize options, and dispatch to the typed search implementation. + Result ExecuteImpl(const std::vector& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + RETURN_NOT_OK(ValidateSortedValuesInput(args[0])); + RETURN_NOT_OK(ValidateNeedleInput(args[1])); + + const auto& values_type = LogicalType(args[0]); + const auto& needles_type = LogicalType(args[1]); + if (!values_type.Equals(needles_type)) { + return Status::TypeError( + "search_sorted arguments must have matching logical types, got ", + values_type.ToString(), " and ", needles_type.ToString()); + } + + auto values_array = args[0].make_array(); + ARROW_ASSIGN_OR_RAISE(auto non_null_values_range, + FindNonNullValuesRange(*values_array)); + auto result = DispatchByType(*values_array, non_null_values_range, args[1], + static_cast(*options), ctx); + return result; + } + + private: + /// Dispatch the logical value type to the matching template specialization. + Result DispatchByType(const Array& values, + const NonNullValuesRange& non_null_values_range, + const Datum& needles, const SearchSortedOptions& options, + ExecContext* ctx) const { + switch (LogicalType(values).id()) { +#define VISIT(TYPE) \ + case TYPE::type_id: \ + return DispatchHaystack(values, non_null_values_range, needles, options.side, \ + ctx); + VISIT_SEARCH_SORTED_TYPES(VISIT) +#undef VISIT + default: + break; + } + return Status::NotImplemented("search_sorted is not implemented for type ", + LogicalType(values).ToString()); + } + + /// Dispatch the physical representation of the searched values. + template + Result DispatchHaystack(const Array& values, + const NonNullValuesRange& non_null_values_range, + const Datum& needles, SearchSortedOptions::Side side, + ExecContext* ctx) const { + return VisitValuesAccessor(values, [&](const auto& values_accessor) { + return SearchWithAccessor(values_accessor, non_null_values_range, + needles, side, ctx); + }); + } +}; + +} // namespace + +/// Register the search_sorted vector kernel in the global compute registry. +void RegisterVectorSearchSorted(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunction(std::make_shared())); +} + +} // namespace compute::internal +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc new file mode 100644 index 000000000000..7898fe7dd59d --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/builder.h" +#include "arrow/compute/api_vector.h" +#include "arrow/datum.h" +#include "arrow/scalar.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" + +namespace arrow { +namespace compute { + +constexpr auto kSeed = 0x5EA4C42; +constexpr int64_t kNeedleToValueRatio = 4; +constexpr int64_t kValuesRunLength = 16; +constexpr int64_t kNeedlesRunLength = 8; +constexpr int32_t kStringMinLength = 8; +constexpr int32_t kStringMaxLength = 24; + +void SetSearchSortedQuickArgs(benchmark::internal::Benchmark* bench) { + bench->Unit(benchmark::kMicrosecond); + for (const auto size : std::vector{kL1Size, kL2Size}) { + bench->Arg(size); + } +} + +void SetSearchSortedArgs(benchmark::internal::Benchmark* bench) { + bench->Unit(benchmark::kMicrosecond); + for (const auto size : kMemorySizes) { + bench->Arg(size); + } +} + +int64_t Int64LengthFromBytes(int64_t size_bytes) { + return std::max(1, size_bytes / static_cast(sizeof(int64_t))); +} + +int64_t NeedleLengthFromBytes(int64_t size_bytes) { + return std::max(1, Int64LengthFromBytes(size_bytes) / kNeedleToValueRatio); +} + +int64_t StringLengthFromBytes(int64_t size_bytes) { + const int64_t average_length = (kStringMinLength + kStringMaxLength) / 2; + return std::max(1, size_bytes / average_length); +} + +std::shared_ptr BuildSortedInt64Values(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed); + const auto length = Int64LengthFromBytes(size_bytes); + const auto max_value = std::max(length / 8, 1); + + auto values = std::static_pointer_cast(rand.Int64(length, 0, max_value, 0.0)); + std::vector data(values->raw_values(), values->raw_values() + values->length()); + std::ranges::sort(data); + + Int64Builder builder; + ABORT_NOT_OK(builder.AppendValues(data)); + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildInt64Needles(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 1); + const auto length = NeedleLengthFromBytes(size_bytes); + const auto max_value = std::max(Int64LengthFromBytes(size_bytes) / 8, 1); + return std::static_pointer_cast(rand.Int64(length, 0, max_value, 0.0)); +} + +std::shared_ptr BuildSortedStringValues(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 2); + const auto length = StringLengthFromBytes(size_bytes); + auto values = std::static_pointer_cast( + rand.String(length, kStringMinLength, kStringMaxLength, 0.0)); + + std::vector data; + data.reserve(static_cast(values->length())); + for (int64_t index = 0; index < values->length(); ++index) { + data.push_back(values->GetString(index)); + } + std::ranges::sort(data); + + StringBuilder builder; + ABORT_NOT_OK(builder.AppendValues(data)); + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildStringNeedles(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 3); + const auto length = std::max(1, StringLengthFromBytes(size_bytes) / 4); + return std::static_pointer_cast( + rand.String(length, kStringMinLength, kStringMaxLength, 0.0)); +} + +std::shared_ptr BuildSortedBinaryValues(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 4); + const auto length = StringLengthFromBytes(size_bytes); + const auto unique = std::max(1, length / 8); + auto values = std::static_pointer_cast( + rand.BinaryWithRepeats(length, unique, kStringMinLength, kStringMaxLength, 0.0)); + + std::vector data; + data.reserve(static_cast(values->length())); + for (int64_t index = 0; index < values->length(); ++index) { + data.emplace_back(values->GetView(index)); + } + std::ranges::sort(data); + + BinaryBuilder builder; + ABORT_NOT_OK(builder.AppendValues(data)); + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildBinaryNeedles(int64_t size_bytes) { + random::RandomArrayGenerator rand(kSeed + 5); + const auto length = std::max(1, StringLengthFromBytes(size_bytes) / 4); + const auto unique = std::max(1, length / 2); + return std::static_pointer_cast( + rand.BinaryWithRepeats(length, unique, kStringMinLength, kStringMaxLength, 0.0)); +} + +std::shared_ptr BuildRunHeavyInt64Values(int64_t logical_length, + int64_t run_length) { + Int64Builder builder; + ABORT_NOT_OK(builder.Reserve(logical_length)); + for (int64_t index = 0; index < logical_length; ++index) { + builder.UnsafeAppend(index / run_length); + } + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildRunEndEncodedInt64Values(int64_t size_bytes, int64_t run_length) { + auto values = BuildRunHeavyInt64Values(Int64LengthFromBytes(size_bytes), run_length); + return RunEndEncode(Datum(values), RunEndEncodeOptions{int32()}).ValueOrDie().make_array(); +} + +std::shared_ptr BuildRunEndEncodedInt64Needles(int64_t size_bytes, int64_t run_length) { + auto needles = BuildRunHeavyInt64Values(NeedleLengthFromBytes(size_bytes), run_length); + return RunEndEncode(Datum(needles), RunEndEncodeOptions{int32()}) + .ValueOrDie() + .make_array(); +} + +void SetBenchmarkCounters(benchmark::State& state, const Datum& values, const Datum& needles) { + const auto values_length = values.length(); + const auto needles_length = needles.length(); + state.counters["values_length"] = static_cast(values_length); + state.counters["needles_length"] = static_cast(needles_length); + state.SetItemsProcessed(state.iterations() * needles_length); +} + +void RunSearchSortedBenchmark(benchmark::State& state, const Datum& values, + const Datum& needles, SearchSortedOptions::Side side) { + const SearchSortedOptions options(side); + for (auto _ : state) { + auto result = SearchSorted(values, needles, options); + ABORT_NOT_OK(result.status()); + benchmark::DoNotOptimize(result.ValueUnsafe()); + } + SetBenchmarkCounters(state, values, needles); +} + +static void BM_SearchSortedInt64ArrayNeedles(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildSortedInt64Values(state.range(0))); + const Datum needles(BuildInt64Needles(state.range(0))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedInt64ScalarNeedle(benchmark::State& state, + SearchSortedOptions::Side side) { + const auto values_array = BuildSortedInt64Values(state.range(0)); + const auto scalar_index = values_array->length() / 2; + const Datum values(values_array); + const Datum needles(std::make_shared(values_array->Value(scalar_index))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedRunEndEncodedValues(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildRunEndEncodedInt64Values(state.range(0), kValuesRunLength)); + const Datum needles(BuildInt64Needles(state.range(0))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedRunEndEncodedValuesAndNeedles( + benchmark::State& state, SearchSortedOptions::Side side) { + const Datum values(BuildRunEndEncodedInt64Values(state.range(0), kValuesRunLength)); + const Datum needles(BuildRunEndEncodedInt64Needles(state.range(0), kNeedlesRunLength)); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedStringArrayNeedles(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildSortedStringValues(state.range(0))); + const Datum needles(BuildStringNeedles(state.range(0))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedStringScalarNeedle(benchmark::State& state, + SearchSortedOptions::Side side) { + const auto values_array = BuildSortedStringValues(state.range(0)); + const auto scalar_index = values_array->length() / 2; + const Datum values(values_array); + const Datum needles(std::make_shared(values_array->GetString(scalar_index))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedBinaryScalarNeedle(benchmark::State& state, + SearchSortedOptions::Side side) { + const auto values_array = BuildSortedBinaryValues(state.range(0)); + const auto scalar_index = values_array->length() / 2; + const Datum values(values_array); + const Datum needles(std::make_shared( + std::string(values_array->GetView(scalar_index)))); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedInt64ArrayNeedlesQuick(benchmark::State& state, + SearchSortedOptions::Side side) { + BM_SearchSortedInt64ArrayNeedles(state, side); +} + +static void BM_SearchSortedRunEndEncodedValuesAndNeedlesQuick( + benchmark::State& state, SearchSortedOptions::Side side) { + BM_SearchSortedRunEndEncodedValuesAndNeedles(state, side); +} + +// Primitive-array and REE cases are the main baselines for the kernel TODOs around +// SIMD batched search, vectorized REE writeback, and future parallel needle traversal. + +BENCHMARK_CAPTURE(BM_SearchSortedInt64ArrayNeedles, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64ArrayNeedles, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64ScalarNeedle, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64ScalarNeedle, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValues, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValues, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedles, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedles, right, + SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); + +// String and binary scalar cases specifically exercise the direct scalar fast path that +// avoids boxing a scalar needle into a temporary one-element array. +BENCHMARK_CAPTURE(BM_SearchSortedStringScalarNeedle, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedStringScalarNeedle, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedBinaryScalarNeedle, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedBinaryScalarNeedle, right, SearchSortedOptions::Right) + ->Apply(SetSearchSortedQuickArgs); + +// Lightweight L1/L2 regressions keep a fast local loop for future optimization work. +BENCHMARK_CAPTURE(BM_SearchSortedInt64ArrayNeedlesQuick, left, SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedlesQuick, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc new file mode 100644 index 000000000000..709cbda76994 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include + +#include "arrow/compute/api.h" +#include "arrow/compute/kernels/test_util_internal.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace { + +Result> REEFromJSON(const std::shared_ptr& ree_type, + const std::string& json) { + auto ree_type_ptr = checked_cast(ree_type.get()); + auto array = ArrayFromJSON(ree_type_ptr->value_type(), json); + ARROW_ASSIGN_OR_RAISE( + auto datum, RunEndEncode(array, RunEndEncodeOptions{ree_type_ptr->run_end_type()})); + return datum.make_array(); +} + +TEST(SearchSorted, BasicLeftRight) { + auto values = ArrayFromJSON(int64(), "[100, 200, 200, 300, 300]"); + auto needles = ArrayFromJSON(int64(), "[50, 200, 250, 400]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 1, 3, 5]"), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 3, 3, 5]"), *right.make_array()); +} + +TEST(SearchSorted, ScalarNeedle) { + auto values = ArrayFromJSON(int32(), "[1, 3, 5, 7]"); + + ASSERT_OK_AND_ASSIGN( + auto result, + SearchSorted(Datum(values), Datum(std::make_shared(5)), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(result.is_scalar()); + ASSERT_EQ(checked_cast(*result.scalar()).value, 3); +} + +TEST(SearchSorted, ScalarStringNeedle) { + auto values = ArrayFromJSON(utf8(), R"(["aa", "bb", "bb", "cc"])"); + + ASSERT_OK_AND_ASSIGN( + auto result, + SearchSorted(Datum(values), Datum(std::make_shared("bb")), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(result.is_scalar()); + ASSERT_EQ(checked_cast(*result.scalar()).value, 3); +} + +TEST(SearchSorted, EmptyHaystack) { + auto values = ArrayFromJSON(int16(), "[]"); + auto needles = ArrayFromJSON(int16(), "[1, 2, 3]"); + + ASSERT_OK_AND_ASSIGN(auto result, SearchSorted(Datum(values), Datum(needles))); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 0]"), *result.make_array()); +} + +TEST(SearchSorted, ValuesWithLeadingNulls) { + auto values = ArrayFromJSON(int32(), "[null, 200, 300, 300]"); + auto needles = ArrayFromJSON(int32(), "[50, 200, 250, 400]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[1, 1, 2, 4]"), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[1, 2, 2, 4]"), *right.make_array()); +} + +TEST(SearchSorted, ValuesAllNull) { + auto values = ArrayFromJSON(int32(), "[null, null, null]"); + auto needles = ArrayFromJSON(int32(), "[50, 200, null]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[3, 3, null]"), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[3, 3, null]"), *right.make_array()); +} + +TEST(SearchSorted, ValuesWithTrailingNulls) { + auto values = ArrayFromJSON(int32(), "[200, 300, 300, null, null]"); + auto needles = ArrayFromJSON(int32(), "[50, 200, 250, 400]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 1, 3]"), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 1, 1, 3]"), *right.make_array()); +} + +TEST(SearchSorted, NullNeedlesEmitNull) { + auto values = ArrayFromJSON(int32(), "[null, 200, 300, 300]"); + auto needles = ArrayFromJSON(int32(), "[null, 50, 200, null, 400]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[null, 1, 1, null, 4]"), + *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[null, 1, 2, null, 4]"), + *right.make_array()); + + ASSERT_OK_AND_ASSIGN(auto scalar_result, + SearchSorted(Datum(values), Datum(std::make_shared()), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_TRUE(scalar_result.is_scalar()); + ASSERT_FALSE(scalar_result.scalar()->is_valid); + ASSERT_TRUE(scalar_result.scalar()->type->Equals(uint64())); +} + +TEST(SearchSorted, RejectUnclusteredNullValues) { + auto values = ArrayFromJSON(int32(), "[null, 1, null, 3]"); + auto needles = ArrayFromJSON(int32(), "[2]"); + + ASSERT_RAISES(Invalid, SearchSorted(Datum(values), Datum(needles))); +} + +TEST(SearchSorted, RunEndEncodedNulls) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_values, REEFromJSON(values_type, "[null, null, 2, 4, 4]")); + auto needles_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_needles, + REEFromJSON(needles_type, "[null, null, 1, 4, 4, null, 8]")); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(ree_values), Datum(ree_needles), + SearchSortedOptions(SearchSortedOptions::Left))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[null, null, 2, 3, 3, null, 5]"), + *result.make_array()); +} + +TEST(SearchSorted, RunEndEncodedNeedlesWithNullRuns) { + auto values = ArrayFromJSON(int32(), "[1, 1, 3, 5, 8]"); + auto needles_type = run_end_encoded(int32(), int32()); + ASSERT_OK_AND_ASSIGN( + auto ree_needles, + REEFromJSON(needles_type, "[null, null, 0, 0, 0, 1, 1, 4, 4, 4, null, 9, 9]")); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(ree_needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(ree_needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual( + *ArrayFromJSON(uint64(), "[null, null, 0, 0, 0, 0, 0, 3, 3, 3, null, 5, 5]"), + *left.make_array()); + AssertArraysEqual( + *ArrayFromJSON(uint64(), "[null, null, 0, 0, 0, 2, 2, 3, 3, 3, null, 5, 5]"), + *right.make_array()); +} + +TEST(SearchSorted, RunEndEncodedAllNullValues) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_values, + REEFromJSON(values_type, "[null, null, null, null]")); + auto needles = ArrayFromJSON(int32(), "[null, 1, 8]"); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(ree_values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[null, 4, 4]"), *result.make_array()); +} + +TEST(SearchSorted, RejectMismatchedTypes) { + auto values = ArrayFromJSON(int32(), "[1, 2, 3]"); + auto needles = ArrayFromJSON(int64(), "[2]"); + + ASSERT_RAISES(TypeError, SearchSorted(Datum(values), Datum(needles))); +} + +TEST(SearchSorted, RunEndEncodedValues) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_values, REEFromJSON(values_type, "[1, 1, 1, 3, 3, 5]")); + auto needles = ArrayFromJSON(int32(), "[0, 1, 2, 3, 4, 5, 6]"); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(ree_values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(ree_values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 3, 3, 5, 5, 6]"), + *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 3, 3, 5, 5, 6, 6]"), + *right.make_array()); +} + +TEST(SearchSorted, RunEndEncodedNeedles) { + auto values = ArrayFromJSON(int32(), "[1, 1, 3, 5, 8]"); + auto needles_type = run_end_encoded(int32(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_needles, + REEFromJSON(needles_type, "[0, 0, 1, 1, 4, 4, 9]")); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(values), Datum(ree_needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 2, 2, 3, 3, 5]"), + *result.make_array()); +} + +TEST(SearchSorted, SlicedRunEndEncodedValues) { + auto values_type = run_end_encoded(int32(), int32()); + ASSERT_OK_AND_ASSIGN(auto ree_values, + REEFromJSON(values_type, "[0, 0, 1, 1, 1, 4, 4, 9]")); + auto sliced = ree_values->Slice(2, 5); + auto needles = ArrayFromJSON(int32(), "[0, 1, 2, 4, 9]"); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(sliced), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 0, 3, 3, 5]"), *result.make_array()); +} + +TEST(SearchSorted, BinaryValues) { + auto values = ArrayFromJSON(utf8(), R"(["aa", "bb", "bb", "cc"])"); + auto needles = ArrayFromJSON(utf8(), R"(["a", "bb", "bc", "z"])"); + + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 1, 3, 4]"), *result.make_array()); +} + +} // namespace +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index 5b9d7f8d608f..d457f6fb95c0 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -50,6 +50,7 @@ void RegisterVectorHash(FunctionRegistry* registry); void RegisterVectorNested(FunctionRegistry* registry); void RegisterVectorRank(FunctionRegistry* registry); void RegisterVectorReplace(FunctionRegistry* registry); +void RegisterVectorSearchSorted(FunctionRegistry* registry); void RegisterVectorSelectK(FunctionRegistry* registry); void RegisterVectorSelection(FunctionRegistry* registry); void RegisterVectorSort(FunctionRegistry* registry); diff --git a/cpp/src/arrow/meson.build b/cpp/src/arrow/meson.build index fe7f11af6ff8..3f6c59fe7d97 100644 --- a/cpp/src/arrow/meson.build +++ b/cpp/src/arrow/meson.build @@ -590,6 +590,7 @@ if needs_compute 'compute/kernels/vector_rank.cc', 'compute/kernels/vector_replace.cc', 'compute/kernels/vector_run_end_encode.cc', + 'compute/kernels/vector_search_sorted.cc', 'compute/kernels/vector_select_k.cc', 'compute/kernels/vector_sort.cc', 'compute/kernels/vector_statistics.cc', diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index e4092af70cde..d16a123b8028 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1873,6 +1873,8 @@ in the respective option classes. +-----------------------+------------+---------------------------------------------------------+-------------------+-------------------------------+----------------+ | sort_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SortOptions` | \(1) \(6) | +-----------------------+------------+---------------------------------------------------------+-------------------+-------------------------------+----------------+ +| search_sorted | Binary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SearchSortedOptions` | \(8) | ++-----------------------+------------+---------------------------------------------------------+-------------------+-------------------------------+----------------+ * \(1) The output is an array of indices into the input, that define a @@ -1901,6 +1903,13 @@ in the respective option classes. * \(7) The output is an array of indices into the input, that define a non-stable sort of the input. +* \(8) The first argument must be sorted in ascending order. If it contains + nulls, they must be clustered entirely at the start or the end, and non-null + needles are matched only against the non-null portion. The second argument + may be a scalar, array, or run-end encoded array. Null needles yield null + outputs. Both arguments must have the same logical type. A scalar needle + yields a UInt64 scalar; otherwise the result is a UInt64 array. + .. _cpp-compute-vector-structural-transforms: Structural transforms diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 137b034d6ffc..bf3d7ae1cb59 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2074,6 +2074,14 @@ cdef CNullPlacement unwrap_null_placement(null_placement) except *: _raise_invalid_function_option(null_placement, "null placement") +cdef CSearchSortedSide unwrap_search_sorted_side(side) except *: + if side == "left": + return CSearchSortedSide_Left + elif side == "right": + return CSearchSortedSide_Right + _raise_invalid_function_option(side, "search sorted side") + + cdef class _PartitionNthOptions(FunctionOptions): def _set_options(self, pivot, null_placement): self.wrapped.reset(new CPartitionNthOptions( @@ -2243,6 +2251,27 @@ class ArraySortOptions(_ArraySortOptions): self._set_options(order, null_placement) +cdef class _SearchSortedOptions(FunctionOptions): + def _set_options(self, side): + self.wrapped.reset(new CSearchSortedOptions( + unwrap_search_sorted_side(side))) + + +class SearchSortedOptions(_SearchSortedOptions): + """ + Options for the `search_sorted` function. + + Parameters + ---------- + side : str, default "left" + Whether to return the leftmost or rightmost insertion point. + Accepted values are "left", "right". + """ + + def __init__(self, side="left"): + self._set_options(side) + + cdef class _SortOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): self.wrapped.reset(new CSortOptions( diff --git a/python/pyarrow/_compute_docstrings.py b/python/pyarrow/_compute_docstrings.py index 079f00db7d92..39cd39875c10 100644 --- a/python/pyarrow/_compute_docstrings.py +++ b/python/pyarrow/_compute_docstrings.py @@ -42,6 +42,39 @@ ] """ +function_doc_additions["search_sorted"] = """ + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> values = pa.array([1, 1, 3, 5, 8]) + >>> pc.search_sorted(values, pa.array([0, 1, 4, 9])) + + [ + 0, + 0, + 3, + 5 + ] + >>> with_nulls = pa.array([None, 200, 300, 300], type=pa.int64()) + >>> pc.search_sorted(with_nulls, pa.array([50, 200, None, 400], type=pa.int64())) + + [ + 1, + 1, + null, + 4 + ] + >>> pc.search_sorted(values, pa.array([0, 1, 4, 9]), side="right") + + [ + 0, + 2, + 3, + 5 + ] + """ + function_doc_additions["mode"] = """ Examples -------- diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 8177948aaebc..0a2e231d189e 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -68,6 +68,7 @@ RoundToMultipleOptions, ScalarAggregateOptions, ScatterOptions, + SearchSortedOptions, SelectKOptions, SetLookupOptions, SkewOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e96a7d84696d..0b67cf0c382c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2781,6 +2781,16 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CSortOrder order CNullPlacement null_placement + cdef enum CSearchSortedSide \ + "arrow::compute::SearchSortedOptions::Side": + CSearchSortedSide_Left "arrow::compute::SearchSortedOptions::Left" + CSearchSortedSide_Right "arrow::compute::SearchSortedOptions::Right" + + cdef cppclass CSearchSortedOptions \ + "arrow::compute::SearchSortedOptions"(CFunctionOptions): + CSearchSortedOptions(CSearchSortedSide side) + CSearchSortedSide side + cdef cppclass CSortKey" arrow::compute::SortKey": CSortKey(CFieldRef target, CSortOrder order) CFieldRef target diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 8c3b09f612cc..35e6813a8a9b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -3021,6 +3021,99 @@ def test_array_sort_indices(): pc.array_sort_indices(arr, order="nonscending") +def test_search_sorted(): + values = pa.array([1, 1, 3, 5, 8]) + needles = pa.array([0, 1, 3, 4, 5, 8, 9]) + + expected_left = pa.array([0, 0, 2, 3, 3, 4, 5], type=pa.uint64()) + expected_right = pa.array([0, 2, 3, 3, 4, 5, 5], type=pa.uint64()) + + assert pc.search_sorted(values, needles).equals(expected_left) + assert pc.search_sorted(values, needles, side="left").equals(expected_left) + assert pc.search_sorted(values, needles, "right").equals(expected_right) + assert pc.search_sorted( + values, needles, options=pc.SearchSortedOptions(side="right") + ).equals(expected_right) + + assert pc.search_sorted(values, pa.scalar(5, type=pa.int64())).as_py() == 3 + assert pc.search_sorted( + values, pa.scalar(5, type=pa.int64()), side="right" + ).as_py() == 4 + + +def test_search_sorted_null_values(): + needles = pa.array([50, 200, 250, 400], type=pa.int64()) + + values = pa.array([None, 200, 300, 300], type=pa.int64()) + expected_left = pa.array([1, 1, 2, 4], type=pa.uint64()) + expected_right = pa.array([1, 2, 2, 4], type=pa.uint64()) + assert pc.search_sorted(values, needles, side="left").equals(expected_left) + assert pc.search_sorted(values, needles, side="right").equals(expected_right) + + values = pa.array([200, 300, 300, None, None], type=pa.int64()) + expected_left = pa.array([0, 0, 1, 3], type=pa.uint64()) + expected_right = pa.array([0, 1, 1, 3], type=pa.uint64()) + assert pc.search_sorted(values, needles, side="left").equals(expected_left) + assert pc.search_sorted(values, needles, side="right").equals(expected_right) + + +def test_search_sorted_null_needles_emit_null(): + values = pa.array([None, 200, 300, 300], type=pa.int64()) + needles = pa.array([None, 50, 200, None, 400], type=pa.int64()) + + expected_left = pa.array([None, 1, 1, None, 4], type=pa.uint64()) + expected_right = pa.array([None, 1, 2, None, 4], type=pa.uint64()) + + assert pc.search_sorted(values, needles, side="left").equals(expected_left) + assert pc.search_sorted(values, needles, side="right").equals(expected_right) + + scalar_result = pc.search_sorted(values, pa.scalar(None, type=pa.int64())) + assert scalar_result.as_py() is None + + +def test_search_sorted_run_end_encoded(): + run_ends = pa.array([2, 3, 4, 5], type=pa.int16()) + encoded_values = pa.array([1, 3, 5, 8], type=pa.int64()) + values = pa.RunEndEncodedArray.from_arrays(run_ends, encoded_values) + needles = pa.array([0, 1, 3, 4, 5, 8, 9], type=pa.int64()) + + expected_left = pa.array([0, 0, 2, 3, 3, 4, 5], type=pa.uint64()) + assert pc.search_sorted(values, needles).equals(expected_left) + + ree_needles = pa.RunEndEncodedArray.from_arrays( + pa.array([2, 4, 6], type=pa.int16()), + pa.array([1, 4, 9], type=pa.int64()) + ) + expected_right = pa.array([2, 2, 3, 3, 5, 5], type=pa.uint64()) + assert pc.search_sorted(values, ree_needles, side="right").equals( + expected_right + ) + + +def test_search_sorted_run_end_encoded_nulls(): + values = pa.RunEndEncodedArray.from_arrays( + pa.array([2, 3, 5], type=pa.int16()), + pa.array([None, 2, 4], type=pa.int64()) + ) + needles = pa.RunEndEncodedArray.from_arrays( + pa.array([2, 3, 5, 6], type=pa.int16()), + pa.array([None, 1, 4, None], type=pa.int64()) + ) + + expected = pa.array([None, None, 2, 3, 3, None], type=pa.uint64()) + assert pc.search_sorted(values, needles, side="left").equals(expected) + + +def test_search_sorted_errors(): + values = pa.array([1, 1, 3, 5, 8]) + + with pytest.raises(ValueError, match='"middle" is not a valid search sorted side'): + pc.search_sorted(values, pa.array([1]), side="middle") + + with pytest.raises(pa.ArrowInvalid, match="clustered at the start or end"): + pc.search_sorted(pa.array([None, 1, None], type=pa.int64()), pa.array([1])) + + def test_sort_indices_array(): arr = pa.array([1, 2, None, 0]) result = pc.sort_indices(arr) From a47c6e53b6fe7ec176341216aba193b8052f8664 Mon Sep 17 00:00:00 2001 From: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Date: Thu, 2 Apr 2026 13:52:51 +0200 Subject: [PATCH 2/8] Refactor vector_search_sorted kernel to use ArrayData and add benchmarks for needles with null runs --- .../compute/kernels/vector_search_sorted.cc | 69 +++++++++-------- .../kernels/vector_search_sorted_benchmark.cc | 77 +++++++++++++++++++ 2 files changed, 113 insertions(+), 33 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc index 6df52f931fda..2b59958aec6e 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc @@ -111,7 +111,8 @@ class PlainArrayAccessor { using ValueType = SearchValue; /// Build a typed accessor over a plain array payload. - explicit PlainArrayAccessor(const Array& array) : array_(array.data()) {} + explicit PlainArrayAccessor(const std::shared_ptr& array_data) + : array_(array_data) {} /// Return the logical length of the searched values. int64_t length() const { return array_.length(); } @@ -186,8 +187,8 @@ class NonNullValuesAccessor { }; /// Return the logical type of an array, unwrapping run-end encoding when present. -inline const DataType& LogicalType(const Array& array) { - const auto& type = *array.type(); +inline const DataType& LogicalType(const ArrayData& array) { + const auto& type = *array.type; if (type.id() == Type::RUN_END_ENCODED) { return *checked_cast(type).value_type(); } @@ -199,7 +200,7 @@ inline const DataType& LogicalType(const Datum& datum) { if (datum.is_scalar()) { return *datum.scalar()->type; } - return LogicalType(*datum.make_array()); + return LogicalType(*datum.array()); } /// Return whether a scalar or array needle input contains any logical nulls. @@ -208,10 +209,10 @@ inline bool DatumHasNulls(const Datum& datum) { return !datum.scalar()->is_valid; } - auto array = datum.make_array(); - const bool has_nulls = array->null_count() > 0; - if (array->type_id() == Type::RUN_END_ENCODED) { - RunEndEncodedArray run_end_encoded(array->data()); + const auto& array_data = datum.array(); + const bool has_nulls = array_data->GetNullCount() > 0; + if (array_data->type->id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray run_end_encoded(array_data); return run_end_encoded.values()->null_count() != 0 || has_nulls; } return has_nulls; @@ -230,21 +231,21 @@ inline Status ValidateRunEndEncodedLogicalValueType(const DataType& type, /// Compute the contiguous non-null window of the searched values. /// -inline Result FindNonNullValuesRange(const Array& values) { - NonNullValuesRange non_null_values_range{.offset = 0, .length = values.length()}; +inline Result FindNonNullValuesRange(const ArrayData& values) { + NonNullValuesRange non_null_values_range{.offset = 0, .length = values.length}; - const auto null_count = values.null_count(); + const auto null_count = values.GetNullCount(); if (null_count == 0) { return non_null_values_range; } int64_t leading_null_count = 0; - while (leading_null_count < values.length() && values.IsNull(leading_null_count)) { + while (leading_null_count < values.length && values.IsNull(leading_null_count)) { ++leading_null_count; } - if (leading_null_count == values.length()) { - non_null_values_range.offset = values.length(); + if (leading_null_count == values.length) { + non_null_values_range.offset = values.length; non_null_values_range.length = 0; return non_null_values_range; } @@ -255,13 +256,13 @@ inline Result FindNonNullValuesRange(const Array& values) { "search_sorted values with nulls must be clustered at the start or end"); } non_null_values_range.offset = leading_null_count; - non_null_values_range.length = values.length() - leading_null_count; + non_null_values_range.length = values.length - leading_null_count; return non_null_values_range; } int64_t trailing_null_count = 0; - while (trailing_null_count < values.length() && - values.IsNull(values.length() - 1 - trailing_null_count)) { + while (trailing_null_count < values.length && + values.IsNull(values.length - 1 - trailing_null_count)) { ++trailing_null_count; } @@ -270,7 +271,7 @@ inline Result FindNonNullValuesRange(const Array& values) { "search_sorted values with nulls must be clustered at the start or end"); } - non_null_values_range.length = values.length() - trailing_null_count; + non_null_values_range.length = values.length - trailing_null_count; return non_null_values_range; } @@ -400,10 +401,11 @@ VisitedNeedle ReadVisitedNeedle(const ArrayType& array, /// Visit each plain-array needle as a single-value logical span. template -Status VisitArrayNeedles(const Array& needles, Visitor&& visitor) { +Status VisitArrayNeedles(const std::shared_ptr& needles_data, + Visitor&& visitor) { using ArrayType = typename TypeTraits::ArrayType; - ArrayType array(needles.data()); + ArrayType array(needles_data); for (int64_t index = 0; index < array.length(); ++index) { RETURN_NOT_OK( visitor(ReadVisitedNeedle(array, index), index, index + 1)); @@ -444,9 +446,9 @@ Status VisitNeedles(const Datum& needles, Visitor&& visitor) { 0, 1); } - auto needle_array = needles.make_array(); - if (needle_array->type_id() == Type::RUN_END_ENCODED) { - RunEndEncodedArray ree(needle_array->data()); + const auto& needle_data = needles.array(); + if (needle_data->type->id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray ree(needle_data); return DispatchRunEndEncodedByRunEndType( ree, "needles", [&](const RunEndEncodedArray& run_end_encoded_needles) { @@ -455,7 +457,7 @@ Status VisitNeedles(const Datum& needles, Visitor&& visitor) { }); } - return VisitArrayNeedles(*needle_array, visitor); + return VisitArrayNeedles(needle_data, visitor); } /// Search all needle values and write insertion indices into the preallocated output. @@ -554,9 +556,10 @@ Result SearchWithAccessor(const ValuesAccessor& values_accessor, // Meta-function implementation for the search_sorted public compute entrypoint. template -Result VisitValuesAccessor(const Array& values, Visitor&& visitor) { - if (values.type_id() == Type::RUN_END_ENCODED) { - RunEndEncodedArray ree(values.data()); +Result VisitValuesAccessor(const std::shared_ptr& values_data, + Visitor&& visitor) { + if (values_data->type->id() == Type::RUN_END_ENCODED) { + RunEndEncodedArray ree(values_data); return DispatchRunEndEncodedByRunEndType>( ree, "values", [&](const RunEndEncodedArray& run_end_encoded_values) { @@ -566,7 +569,7 @@ Result VisitValuesAccessor(const Array& values, Visitor&& visitor) { }); } - PlainArrayAccessor values_accessor(values); + PlainArrayAccessor values_accessor(values_data); return visitor(values_accessor); } @@ -593,17 +596,17 @@ class SearchSortedMetaFunction : public MetaFunction { values_type.ToString(), " and ", needles_type.ToString()); } - auto values_array = args[0].make_array(); + const auto& values_array = args[0].array(); ARROW_ASSIGN_OR_RAISE(auto non_null_values_range, - FindNonNullValuesRange(*values_array)); - auto result = DispatchByType(*values_array, non_null_values_range, args[1], + FindNonNullValuesRange(*values_array)); + auto result = DispatchByType(values_array, non_null_values_range, args[1], static_cast(*options), ctx); return result; } private: /// Dispatch the logical value type to the matching template specialization. - Result DispatchByType(const Array& values, + Result DispatchByType(const std::shared_ptr& values, const NonNullValuesRange& non_null_values_range, const Datum& needles, const SearchSortedOptions& options, ExecContext* ctx) const { @@ -623,7 +626,7 @@ class SearchSortedMetaFunction : public MetaFunction { /// Dispatch the physical representation of the searched values. template - Result DispatchHaystack(const Array& values, + Result DispatchHaystack(const std::shared_ptr& values, const NonNullValuesRange& non_null_values_range, const Datum& needles, SearchSortedOptions::Side side, ExecContext* ctx) const { diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc index 7898fe7dd59d..c40a06d058f6 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc @@ -152,6 +152,31 @@ std::shared_ptr BuildRunHeavyInt64Values(int64_t logical_length, return std::static_pointer_cast(builder.Finish().ValueOrDie()); } +std::shared_ptr BuildRunHeavyInt64NeedlesWithNullRuns(int64_t logical_length, + int64_t run_length) { + std::vector data(static_cast(logical_length), 0); + std::vector is_valid(static_cast(logical_length), true); + + for (int64_t index = 0; index < logical_length; ++index) { + const int64_t run_index = index / run_length; + if (run_index % 4 == 0) { + is_valid[static_cast(index)] = false; + continue; + } + data[static_cast(index)] = run_index / 2; + } + + Int64Builder builder; + ABORT_NOT_OK(builder.AppendValues(data, is_valid)); + return std::static_pointer_cast(builder.Finish().ValueOrDie()); +} + +std::shared_ptr BuildInt64NeedlesWithNullRuns(int64_t size_bytes, + int64_t run_length) { + return BuildRunHeavyInt64NeedlesWithNullRuns(NeedleLengthFromBytes(size_bytes), + run_length); +} + std::shared_ptr BuildRunEndEncodedInt64Values(int64_t size_bytes, int64_t run_length) { auto values = BuildRunHeavyInt64Values(Int64LengthFromBytes(size_bytes), run_length); return RunEndEncode(Datum(values), RunEndEncodeOptions{int32()}).ValueOrDie().make_array(); @@ -164,6 +189,15 @@ std::shared_ptr BuildRunEndEncodedInt64Needles(int64_t size_bytes, int64_ .make_array(); } +std::shared_ptr BuildRunEndEncodedInt64NeedlesWithNullRuns(int64_t size_bytes, + int64_t run_length) { + auto needles = + BuildRunHeavyInt64NeedlesWithNullRuns(NeedleLengthFromBytes(size_bytes), run_length); + return RunEndEncode(Datum(needles), RunEndEncodeOptions{int32()}) + .ValueOrDie() + .make_array(); +} + void SetBenchmarkCounters(benchmark::State& state, const Datum& values, const Datum& needles) { const auto values_length = values.length(); const auto needles_length = needles.length(); @@ -213,6 +247,21 @@ static void BM_SearchSortedRunEndEncodedValuesAndNeedles( RunSearchSortedBenchmark(state, values, needles, side); } +static void BM_SearchSortedInt64NeedlesWithNullRuns(benchmark::State& state, + SearchSortedOptions::Side side) { + const Datum values(BuildSortedInt64Values(state.range(0))); + const Datum needles(BuildInt64NeedlesWithNullRuns(state.range(0), kNeedlesRunLength)); + RunSearchSortedBenchmark(state, values, needles, side); +} + +static void BM_SearchSortedRunEndEncodedNeedlesWithNullRuns( + benchmark::State& state, SearchSortedOptions::Side side) { + const Datum values(BuildRunEndEncodedInt64Values(state.range(0), kValuesRunLength)); + const Datum needles( + BuildRunEndEncodedInt64NeedlesWithNullRuns(state.range(0), kNeedlesRunLength)); + RunSearchSortedBenchmark(state, values, needles, side); +} + static void BM_SearchSortedStringArrayNeedles(benchmark::State& state, SearchSortedOptions::Side side) { const Datum values(BuildSortedStringValues(state.range(0))); @@ -249,6 +298,16 @@ static void BM_SearchSortedRunEndEncodedValuesAndNeedlesQuick( BM_SearchSortedRunEndEncodedValuesAndNeedles(state, side); } +static void BM_SearchSortedInt64NeedlesWithNullRunsQuick( + benchmark::State& state, SearchSortedOptions::Side side) { + BM_SearchSortedInt64NeedlesWithNullRuns(state, side); +} + +static void BM_SearchSortedRunEndEncodedNeedlesWithNullRunsQuick( + benchmark::State& state, SearchSortedOptions::Side side) { + BM_SearchSortedRunEndEncodedNeedlesWithNullRuns(state, side); +} + // Primitive-array and REE cases are the main baselines for the kernel TODOs around // SIMD batched search, vectorized REE writeback, and future parallel needle traversal. @@ -270,6 +329,18 @@ BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedles, left, BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedles, right, SearchSortedOptions::Right) ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRuns, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRuns, right, + SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRuns, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRuns, right, + SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, left, SearchSortedOptions::Left) ->Apply(SetSearchSortedArgs); BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, right, SearchSortedOptions::Right) @@ -289,9 +360,15 @@ BENCHMARK_CAPTURE(BM_SearchSortedBinaryScalarNeedle, right, SearchSortedOptions: // Lightweight L1/L2 regressions keep a fast local loop for future optimization work. BENCHMARK_CAPTURE(BM_SearchSortedInt64ArrayNeedlesQuick, left, SearchSortedOptions::Left) ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRunsQuick, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedlesQuick, left, SearchSortedOptions::Left) ->Apply(SetSearchSortedQuickArgs); +BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRunsQuick, left, + SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); } // namespace compute } // namespace arrow \ No newline at end of file From f62b93d5764c765f2dc4de8fae60886d6807249f Mon Sep 17 00:00:00 2001 From: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:33:34 +0200 Subject: [PATCH 3/8] Enhance documentation for search_sorted kernel with detailed implementation overview and flow --- .../compute/kernels/vector_search_sorted.cc | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc index 2b59958aec6e..a05086a0049c 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc @@ -64,6 +64,88 @@ const FunctionDoc search_sorted_doc( "portion of the searched array. Null needles emit nulls in the output."), {"values", "needles"}, "SearchSortedOptions"); +// This file implements search_sorted as a small pipeline that first normalizes +// Arrow input shapes and then runs one typed binary-search core on logical +// values. +// +// Plain arrays, run-end encoded arrays, and scalar needles are all +// adapted into the same accessor and visitor model so the search logic does +// not care about physical layout. +// +// After validation, the kernel isolates the contiguous non-null window of the searched +// values, because nulls are only supported when clustered at one end. +// Needles are then visited either as single values or as logical runs, and each non-null +// needle is resolved with a lower-bound or upper-bound binary search over the sorted +// non-null range. +// +// Output materialization is split by null handling: non-null-only needles write directly +// into a preallocated uint64 buffer, while nullable needles append null and non-null +// spans through a UInt64Builder. That builder path is optimized for repeated runs by +// bulk-filling reserved memory instead of appending one insertion index at a time. +// +// High-level flow: +// +// values datum +// | +// +--> ValidateSortedValuesInput +// | +// +--> LogicalType / FindNonNullValuesRange +// | +// +--> VisitValuesAccessor +// | +// +--> PlainArrayAccessor +// | +// `--> RunEndEncodedValuesAccessor +// +// needles datum +// | +// +--> ValidateNeedleInput +// | +// +--> DatumHasNulls +// | +// `--> VisitNeedles +// | +// +--> scalar needle -> one logical span +// | +// +--> plain array -> one span per element +// | +// `--> REE array -> one span per logical run +// +// normalized values accessor + normalized needle spans +// | +// `--> FindInsertionPoint +// | +// +--> side = left -> lower_bound semantics +// | +// `--> side = right -> upper_bound semantics +// +// result materialization +// | +// +--> no needle nulls +// | `--> MakeMutableUInt64Array +// | `--> fill output buffer directly +// | +// `--> nullable needles +// `--> UInt64Builder +// +--> AppendNulls for null runs +// `--> bulk fill + UnsafeAdvance for repeated indices +// +// A rough map of the file: +// +// [validation + type helpers] +// | +// [value accessors] +// | +// [needle visitors] +// | +// [typed search + output helpers] +// | +// [meta-function dispatch] +// +// The file follows that layout: validation and type helpers first, then value +// accessors, then needle visitors, then typed search and output helpers, and +// finally the meta-function dispatch that selects the Arrow type and accessor. + #define VISIT_SEARCH_SORTED_TYPES(VISIT) \ VISIT(BooleanType) \ VISIT(Int8Type) \ From b826e499effa326bd8877854390d5e17ded0ef24 Mon Sep 17 00:00:00 2001 From: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Date: Fri, 3 Apr 2026 09:57:49 +0200 Subject: [PATCH 4/8] Refactor vector_search_sorted kernel to improve null handling and utilize ranges for leading/trailing null counts --- .../compute/kernels/vector_search_sorted.cc | 49 +++++++++++-------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc index a05086a0049c..e6dbd817ef24 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -295,7 +296,7 @@ inline bool DatumHasNulls(const Datum& datum) { const bool has_nulls = array_data->GetNullCount() > 0; if (array_data->type->id() == Type::RUN_END_ENCODED) { RunEndEncodedArray run_end_encoded(array_data); - return run_end_encoded.values()->null_count() != 0 || has_nulls; + return has_nulls || (run_end_encoded.values()->null_count() != 0); } return has_nulls; } @@ -321,10 +322,12 @@ inline Result FindNonNullValuesRange(const ArrayData& values return non_null_values_range; } - int64_t leading_null_count = 0; - while (leading_null_count < values.length && values.IsNull(leading_null_count)) { - ++leading_null_count; - } + const int64_t leading_null_count = [&] { + auto indices = std::ranges::views::iota(int64_t{0}, values.length); + auto it = + std::ranges::find_if_not(indices, [&](int64_t i) { return values.IsNull(i); }); + return it == indices.end() ? values.length : *it; + }(); if (leading_null_count == values.length) { non_null_values_range.offset = values.length; @@ -342,13 +345,14 @@ inline Result FindNonNullValuesRange(const ArrayData& values return non_null_values_range; } - int64_t trailing_null_count = 0; - while (trailing_null_count < values.length && - values.IsNull(values.length - 1 - trailing_null_count)) { - ++trailing_null_count; - } + const int64_t trailing_null_count = [&] { + auto indices = std::ranges::views::iota(int64_t{0}, values.length); + auto it = std::ranges::find_if_not( + indices, [&](int64_t i) { return values.IsNull(values.length - 1 - i); }); + return it == indices.end() ? values.length : *it; + }(); - if (trailing_null_count == 0 || trailing_null_count != null_count) { + if (trailing_null_count == 0 || (trailing_null_count != null_count)) { return Status::Invalid( "search_sorted values with nulls must be clustered at the start or end"); } @@ -372,6 +376,8 @@ inline Status ValidateSortedValuesInput(const Datum& datum) { } /// Validate the needles input shape and supported encoding. +/// Needles can be either a scalar or an array, but if an array is provided it must not +/// have nested run-end encoding since that is not currently supported. inline Status ValidateNeedleInput(const Datum& datum) { if (!(datum.is_array() || datum.is_scalar())) { return Status::TypeError("search_sorted needles must be a scalar or array"); @@ -506,9 +512,9 @@ Status VisitRunEndEncodedNeedleRuns(const RunEndEncodedArray& needles, ::arrow::ree_util::RunEndEncodedArraySpan span(array_span); for (auto it = span.begin(); !it.is_end(span); ++it) { - const auto physical_index = it.index_into_array(); - RETURN_NOT_OK(visitor(ReadVisitedNeedle(values, physical_index), - it.logical_position(), it.run_end())); + RETURN_NOT_OK( + visitor(ReadVisitedNeedle(values, it.index_into_array()), + it.logical_position(), it.run_end())); } return Status::OK(); } @@ -587,14 +593,15 @@ Result ComputeInsertionIndices(const ValuesAccessor& sorted_values, const Datum& needles, SearchSortedOptions::Side side, uint64_t insertion_offset, ExecContext* ctx) { - if (needles.is_scalar() && !needles.scalar()->is_valid) { - return Datum(std::make_shared()); - } - if (needles.is_scalar()) { + auto scalar = needles.scalar(); + if (!scalar->is_valid) { + return Datum(std::make_shared()); + } + const auto insertion_index = - FindInsertionPoint( - sorted_values, ExtractScalarValue(*needles.scalar()), side) + + FindInsertionPoint(sorted_values, + ExtractScalarValue(*scalar), side) + insertion_offset; return Datum(std::make_shared(insertion_index)); } @@ -656,6 +663,8 @@ Result VisitValuesAccessor(const std::shared_ptr& values_data, } /// Meta-function implementation for the search_sorted public compute entrypoint. +/// Validates input shapes and types, normalizes to logical value accessors, and +/// dispatches to the typed search implementation. class SearchSortedMetaFunction : public MetaFunction { public: /// Construct the registry entry with default options and documentation. From 4460a03b927de3cd59236f6ef56cb7fde9604212 Mon Sep 17 00:00:00 2001 From: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:25:47 +0200 Subject: [PATCH 5/8] Refactor search_sorted kernel: improve error messages and add comprehensive tests for supported types --- .../compute/kernels/vector_search_sorted.cc | 9 +- .../kernels/vector_search_sorted_test.cc | 136 ++++++++++++++++++ 2 files changed, 139 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc index e6dbd817ef24..2624d900527a 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc @@ -143,9 +143,6 @@ const FunctionDoc search_sorted_doc( // | // [meta-function dispatch] // -// The file follows that layout: validation and type helpers first, then value -// accessors, then needle visitors, then typed search and output helpers, and -// finally the meta-function dispatch that selects the Arrow type and accessor. #define VISIT_SEARCH_SORTED_TYPES(VISIT) \ VISIT(BooleanType) \ @@ -338,7 +335,7 @@ inline Result FindNonNullValuesRange(const ArrayData& values if (leading_null_count > 0) { if (leading_null_count != null_count) { return Status::Invalid( - "search_sorted values with nulls must be clustered at the start or end"); + "search_sorted values with nulls must be clustered at the start or end."); } non_null_values_range.offset = leading_null_count; non_null_values_range.length = values.length - leading_null_count; @@ -354,7 +351,7 @@ inline Result FindNonNullValuesRange(const ArrayData& values if (trailing_null_count == 0 || (trailing_null_count != null_count)) { return Status::Invalid( - "search_sorted values with nulls must be clustered at the start or end"); + "search_sorted values with nulls must be clustered at the start or end."); } non_null_values_range.length = values.length - trailing_null_count; @@ -403,7 +400,7 @@ uint64_t FindInsertionPoint(const Accessor& sorted_values, while (count > 0) { const int64_t step = count / 2; const int64_t it = first + step; - const bool advance = side == SearchSortedOptions::Left + const bool advance = (side == SearchSortedOptions::Left) ? compare(sorted_values.Value(it), needle) < 0 : compare(needle, sorted_values.Value(it)) >= 0; if (advance) { diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc index 709cbda76994..8752ce018701 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc @@ -17,6 +17,7 @@ #include #include +#include #include @@ -40,6 +41,121 @@ Result> REEFromJSON(const std::shared_ptr& ree_ return datum.make_array(); } +void CheckSimpleSearchSorted(const std::shared_ptr& type, + const std::string& values_json, + const std::string& needles_json, + const std::string& expected_left_json, + const std::string& expected_right_json) { + auto values = ArrayFromJSON(type, values_json); + auto needles = ArrayFromJSON(type, needles_json); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needles), + SearchSortedOptions(SearchSortedOptions::Right))); + + AssertArraysEqual(*ArrayFromJSON(uint64(), expected_left_json), *left.make_array()); + AssertArraysEqual(*ArrayFromJSON(uint64(), expected_right_json), *right.make_array()); +} + +void CheckSimpleScalarSearchSorted(const std::shared_ptr& type, + const std::string& values_json, + const std::string& needle_json, + uint64_t expected_left, + uint64_t expected_right) { + auto values = ArrayFromJSON(type, values_json); + auto needle = ScalarFromJSON(type, needle_json); + + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(Datum(values), Datum(needle), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(Datum(values), Datum(needle), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(left.is_scalar()); + ASSERT_TRUE(right.is_scalar()); + ASSERT_EQ(checked_cast(*left.scalar()).value, expected_left); + ASSERT_EQ(checked_cast(*right.scalar()).value, expected_right); +} + +struct SearchSortedSmokeCase { + std::string name; + std::shared_ptr type; + std::string values_json; + std::string needles_json; + std::string expected_left_json; + std::string expected_right_json; + std::string scalar_needle_json; + uint64_t expected_scalar_left; + uint64_t expected_scalar_right; +}; + +std::vector SupportedTypeSmokeCases() { + return { + {"Boolean", boolean(), "[false, false, true, true]", "[false, true]", + "[0, 2]", "[2, 4]", "true", 2, 4}, + {"Int8", int8(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"Int16", int16(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"Int32", int32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"Int64", int64(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"UInt8", uint8(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"UInt16", uint16(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"UInt32", uint32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"UInt64", uint64(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"Float32", float32(), "[1.0, 3.0, 3.0, 5.0]", "[0.0, 3.0, 4.0, 6.0]", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3.0", 1, 3}, + {"Float64", float64(), "[1.0, 3.0, 3.0, 5.0]", "[0.0, 3.0, 4.0, 6.0]", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3.0", 1, 3}, + {"Date32", date32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"Date64", date64(), "[86400000, 259200000, 259200000, 432000000]", + "[0, 259200000, 345600000, 518400000]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "259200000", 1, 3}, + {"Time32", time32(TimeUnit::SECOND), "[1, 3, 3, 5]", "[0, 3, 4, 6]", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3", 1, 3}, + {"Time64", time64(TimeUnit::NANO), "[1, 3, 3, 5]", "[0, 3, 4, 6]", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3", 1, 3}, + {"Timestamp", timestamp(TimeUnit::SECOND), + R"(["1970-01-02", "1970-01-04", "1970-01-04", "1970-01-06"])", + R"(["1970-01-01", "1970-01-04", "1970-01-05", "1970-01-07"])", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("1970-01-04")", 1, 3}, + {"Duration", duration(TimeUnit::NANO), "[1, 3, 3, 5]", "[0, 3, 4, 6]", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3", 1, 3}, + {"Binary", binary(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + R"("bb")", 1, 3}, + {"String", utf8(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + R"("bb")", 1, 3}, + {"LargeBinary", large_binary(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + R"("bb")", 1, 3}, + {"LargeString", large_utf8(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + R"("bb")", 1, 3}, + {"BinaryView", binary_view(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + R"("bb")", 1, 3}, + {"StringView", utf8_view(), R"(["aa", "bb", "bb", "dd"])", + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + R"("bb")", 1, 3}, + }; +} + +class SearchSortedSupportedTypesTest + : public ::testing::TestWithParam {}; + TEST(SearchSorted, BasicLeftRight) { auto values = ArrayFromJSON(int64(), "[100, 200, 200, 300, 300]"); auto needles = ArrayFromJSON(int64(), "[50, 200, 250, 400]"); @@ -277,6 +393,26 @@ TEST(SearchSorted, BinaryValues) { AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 1, 3, 4]"), *result.make_array()); } +TEST_P(SearchSortedSupportedTypesTest, ArraySmoke) { + const auto& param = GetParam(); + CheckSimpleSearchSorted(param.type, param.values_json, param.needles_json, + param.expected_left_json, param.expected_right_json); +} + +TEST_P(SearchSortedSupportedTypesTest, ScalarSmoke) { + const auto& param = GetParam(); + CheckSimpleScalarSearchSorted(param.type, param.values_json, param.scalar_needle_json, + param.expected_scalar_left, + param.expected_scalar_right); +} + +INSTANTIATE_TEST_SUITE_P( + SupportedTypes, SearchSortedSupportedTypesTest, + ::testing::ValuesIn(SupportedTypeSmokeCases()), + [](const ::testing::TestParamInfo& info) { + return info.param.name; + }); + } // namespace } // namespace compute } // namespace arrow \ No newline at end of file From 4ec630ec557cd6972c54c4bb1f3a8f7cfffb79dd Mon Sep 17 00:00:00 2001 From: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Date: Wed, 8 Apr 2026 10:13:48 +0200 Subject: [PATCH 6/8] Formatting --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/compute/api_vector.cc | 2 +- cpp/src/arrow/compute/api_vector.h | 6 +- .../compute/kernels/vector_search_sorted.cc | 2 +- .../kernels/vector_search_sorted_benchmark.cc | 80 +++++++------ .../kernels/vector_search_sorted_test.cc | 105 ++++++++---------- python/pyarrow/_compute_docstrings.py | 4 +- 7 files changed, 101 insertions(+), 100 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 8585b6aaec7f..87d858727b51 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -795,7 +795,7 @@ if(ARROW_COMPUTE) compute/kernels/vector_rank.cc compute/kernels/vector_replace.cc compute/kernels/vector_run_end_encode.cc - compute/kernels/vector_search_sorted.cc + compute/kernels/vector_search_sorted.cc compute/kernels/vector_select_k.cc compute/kernels/vector_sort.cc compute/kernels/vector_statistics.cc diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 689788c3847d..676b5b89d5a4 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -154,7 +154,7 @@ static auto kArraySortOptionsType = GetFunctionOptionsType( DataMember("order", &ArraySortOptions::order), DataMember("null_placement", &ArraySortOptions::null_placement)); static auto kSearchSortedOptionsType = GetFunctionOptionsType( - DataMember("side", &SearchSortedOptions::side)); + DataMember("side", &SearchSortedOptions::side)); static auto kSortOptionsType = GetFunctionOptionsType( DataMember("sort_keys", &SortOptions::sort_keys), DataMember("null_placement", &SortOptions::null_placement)); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 2d003726804f..a662cad94ba4 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -547,9 +547,9 @@ Result> SelectKUnstable(const Datum& datum, /// \return insertion indices as uint64 scalar or array ARROW_EXPORT Result SearchSorted( - const Datum& values, const Datum& needles, - const SearchSortedOptions& options = SearchSortedOptions::Defaults(), - ExecContext* ctx = NULLPTR); + const Datum& values, const Datum& needles, + const SearchSortedOptions& options = SearchSortedOptions::Defaults(), + ExecContext* ctx = NULLPTR); /// \brief Return the indices that would sort an array. /// diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc index 2624d900527a..32ee35fdfdda 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc @@ -733,4 +733,4 @@ void RegisterVectorSearchSorted(FunctionRegistry* registry) { } } // namespace compute::internal -} // namespace arrow \ No newline at end of file +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc index c40a06d058f6..71c78beb338b 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted_benchmark.cc @@ -74,8 +74,10 @@ std::shared_ptr BuildSortedInt64Values(int64_t size_bytes) { const auto length = Int64LengthFromBytes(size_bytes); const auto max_value = std::max(length / 8, 1); - auto values = std::static_pointer_cast(rand.Int64(length, 0, max_value, 0.0)); - std::vector data(values->raw_values(), values->raw_values() + values->length()); + auto values = + std::static_pointer_cast(rand.Int64(length, 0, max_value, 0.0)); + std::vector data(values->raw_values(), + values->raw_values() + values->length()); std::ranges::sort(data); Int64Builder builder; @@ -177,12 +179,16 @@ std::shared_ptr BuildInt64NeedlesWithNullRuns(int64_t size_bytes, run_length); } -std::shared_ptr BuildRunEndEncodedInt64Values(int64_t size_bytes, int64_t run_length) { +std::shared_ptr BuildRunEndEncodedInt64Values(int64_t size_bytes, + int64_t run_length) { auto values = BuildRunHeavyInt64Values(Int64LengthFromBytes(size_bytes), run_length); - return RunEndEncode(Datum(values), RunEndEncodeOptions{int32()}).ValueOrDie().make_array(); + return RunEndEncode(Datum(values), RunEndEncodeOptions{int32()}) + .ValueOrDie() + .make_array(); } -std::shared_ptr BuildRunEndEncodedInt64Needles(int64_t size_bytes, int64_t run_length) { +std::shared_ptr BuildRunEndEncodedInt64Needles(int64_t size_bytes, + int64_t run_length) { auto needles = BuildRunHeavyInt64Values(NeedleLengthFromBytes(size_bytes), run_length); return RunEndEncode(Datum(needles), RunEndEncodeOptions{int32()}) .ValueOrDie() @@ -190,15 +196,16 @@ std::shared_ptr BuildRunEndEncodedInt64Needles(int64_t size_bytes, int64_ } std::shared_ptr BuildRunEndEncodedInt64NeedlesWithNullRuns(int64_t size_bytes, - int64_t run_length) { - auto needles = - BuildRunHeavyInt64NeedlesWithNullRuns(NeedleLengthFromBytes(size_bytes), run_length); + int64_t run_length) { + auto needles = BuildRunHeavyInt64NeedlesWithNullRuns(NeedleLengthFromBytes(size_bytes), + run_length); return RunEndEncode(Datum(needles), RunEndEncodeOptions{int32()}) - .ValueOrDie() - .make_array(); + .ValueOrDie() + .make_array(); } -void SetBenchmarkCounters(benchmark::State& state, const Datum& values, const Datum& needles) { +void SetBenchmarkCounters(benchmark::State& state, const Datum& values, + const Datum& needles) { const auto values_length = values.length(); const auto needles_length = needles.length(); state.counters["values_length"] = static_cast(values_length); @@ -240,8 +247,8 @@ static void BM_SearchSortedRunEndEncodedValues(benchmark::State& state, RunSearchSortedBenchmark(state, values, needles, side); } -static void BM_SearchSortedRunEndEncodedValuesAndNeedles( - benchmark::State& state, SearchSortedOptions::Side side) { +static void BM_SearchSortedRunEndEncodedValuesAndNeedles(benchmark::State& state, + SearchSortedOptions::Side side) { const Datum values(BuildRunEndEncodedInt64Values(state.range(0), kValuesRunLength)); const Datum needles(BuildRunEndEncodedInt64Needles(state.range(0), kNeedlesRunLength)); RunSearchSortedBenchmark(state, values, needles, side); @@ -274,7 +281,8 @@ static void BM_SearchSortedStringScalarNeedle(benchmark::State& state, const auto values_array = BuildSortedStringValues(state.range(0)); const auto scalar_index = values_array->length() / 2; const Datum values(values_array); - const Datum needles(std::make_shared(values_array->GetString(scalar_index))); + const Datum needles( + std::make_shared(values_array->GetString(scalar_index))); RunSearchSortedBenchmark(state, values, needles, side); } @@ -283,8 +291,8 @@ static void BM_SearchSortedBinaryScalarNeedle(benchmark::State& state, const auto values_array = BuildSortedBinaryValues(state.range(0)); const auto scalar_index = values_array->length() / 2; const Datum values(values_array); - const Datum needles(std::make_shared( - std::string(values_array->GetView(scalar_index)))); + const Datum needles( + std::make_shared(std::string(values_array->GetView(scalar_index)))); RunSearchSortedBenchmark(state, values, needles, side); } @@ -298,8 +306,8 @@ static void BM_SearchSortedRunEndEncodedValuesAndNeedlesQuick( BM_SearchSortedRunEndEncodedValuesAndNeedles(state, side); } -static void BM_SearchSortedInt64NeedlesWithNullRunsQuick( - benchmark::State& state, SearchSortedOptions::Side side) { +static void BM_SearchSortedInt64NeedlesWithNullRunsQuick(benchmark::State& state, + SearchSortedOptions::Side side) { BM_SearchSortedInt64NeedlesWithNullRuns(state, side); } @@ -330,17 +338,17 @@ BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedles, right, SearchSortedOptions::Right) ->Apply(SetSearchSortedArgs); BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRuns, left, - SearchSortedOptions::Left) - ->Apply(SetSearchSortedArgs); + SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRuns, right, - SearchSortedOptions::Right) - ->Apply(SetSearchSortedArgs); + SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRuns, left, - SearchSortedOptions::Left) - ->Apply(SetSearchSortedArgs); + SearchSortedOptions::Left) + ->Apply(SetSearchSortedArgs); BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRuns, right, - SearchSortedOptions::Right) - ->Apply(SetSearchSortedArgs); + SearchSortedOptions::Right) + ->Apply(SetSearchSortedArgs); BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, left, SearchSortedOptions::Left) ->Apply(SetSearchSortedArgs); BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, right, SearchSortedOptions::Right) @@ -349,26 +357,26 @@ BENCHMARK_CAPTURE(BM_SearchSortedStringArrayNeedles, right, SearchSortedOptions: // String and binary scalar cases specifically exercise the direct scalar fast path that // avoids boxing a scalar needle into a temporary one-element array. BENCHMARK_CAPTURE(BM_SearchSortedStringScalarNeedle, left, SearchSortedOptions::Left) - ->Apply(SetSearchSortedQuickArgs); + ->Apply(SetSearchSortedQuickArgs); BENCHMARK_CAPTURE(BM_SearchSortedStringScalarNeedle, right, SearchSortedOptions::Right) - ->Apply(SetSearchSortedQuickArgs); + ->Apply(SetSearchSortedQuickArgs); BENCHMARK_CAPTURE(BM_SearchSortedBinaryScalarNeedle, left, SearchSortedOptions::Left) - ->Apply(SetSearchSortedQuickArgs); + ->Apply(SetSearchSortedQuickArgs); BENCHMARK_CAPTURE(BM_SearchSortedBinaryScalarNeedle, right, SearchSortedOptions::Right) - ->Apply(SetSearchSortedQuickArgs); + ->Apply(SetSearchSortedQuickArgs); // Lightweight L1/L2 regressions keep a fast local loop for future optimization work. BENCHMARK_CAPTURE(BM_SearchSortedInt64ArrayNeedlesQuick, left, SearchSortedOptions::Left) - ->Apply(SetSearchSortedQuickArgs); + ->Apply(SetSearchSortedQuickArgs); BENCHMARK_CAPTURE(BM_SearchSortedInt64NeedlesWithNullRunsQuick, left, SearchSortedOptions::Left) - ->Apply(SetSearchSortedQuickArgs); + ->Apply(SetSearchSortedQuickArgs); BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedValuesAndNeedlesQuick, left, - SearchSortedOptions::Left) - ->Apply(SetSearchSortedQuickArgs); + SearchSortedOptions::Left) + ->Apply(SetSearchSortedQuickArgs); BENCHMARK_CAPTURE(BM_SearchSortedRunEndEncodedNeedlesWithNullRunsQuick, left, SearchSortedOptions::Left) - ->Apply(SetSearchSortedQuickArgs); + ->Apply(SetSearchSortedQuickArgs); } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc index 8752ce018701..cc41286dc8e1 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc @@ -62,8 +62,7 @@ void CheckSimpleSearchSorted(const std::shared_ptr& type, void CheckSimpleScalarSearchSorted(const std::shared_ptr& type, const std::string& values_json, - const std::string& needle_json, - uint64_t expected_left, + const std::string& needle_json, uint64_t expected_left, uint64_t expected_right) { auto values = ArrayFromJSON(type, values_json); auto needle = ScalarFromJSON(type, needle_json); @@ -95,61 +94,55 @@ struct SearchSortedSmokeCase { std::vector SupportedTypeSmokeCases() { return { - {"Boolean", boolean(), "[false, false, true, true]", "[false, true]", - "[0, 2]", "[2, 4]", "true", 2, 4}, - {"Int8", int8(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, - {"Int16", int16(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, - {"Int32", int32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, - {"Int64", int64(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, - {"UInt8", uint8(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, - {"UInt16", uint16(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, - {"UInt32", uint32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, - {"UInt64", uint64(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, + {"Boolean", boolean(), "[false, false, true, true]", "[false, true]", "[0, 2]", + "[2, 4]", "true", 2, 4}, + {"Int8", int8(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"Int16", int16(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"Int32", int32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"Int64", int64(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"UInt8", uint8(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"UInt16", uint16(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"UInt32", uint32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, + {"UInt64", uint64(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, {"Float32", float32(), "[1.0, 3.0, 3.0, 5.0]", "[0.0, 3.0, 4.0, 6.0]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3.0", 1, 3}, {"Float64", float64(), "[1.0, 3.0, 3.0, 5.0]", "[0.0, 3.0, 4.0, 6.0]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3.0", 1, 3}, - {"Date32", date32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "3", 1, 3}, + {"Date32", date32(), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "3", 1, 3}, {"Date64", date64(), "[86400000, 259200000, 259200000, 432000000]", - "[0, 259200000, 345600000, 518400000]", "[0, 1, 3, 4]", - "[0, 3, 3, 4]", "259200000", 1, 3}, - {"Time32", time32(TimeUnit::SECOND), "[1, 3, 3, 5]", "[0, 3, 4, 6]", - "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3", 1, 3}, - {"Time64", time64(TimeUnit::NANO), "[1, 3, 3, 5]", "[0, 3, 4, 6]", - "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3", 1, 3}, + "[0, 259200000, 345600000, 518400000]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", + "259200000", 1, 3}, + {"Time32", time32(TimeUnit::SECOND), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, + {"Time64", time64(TimeUnit::NANO), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", "3", 1, 3}, {"Timestamp", timestamp(TimeUnit::SECOND), R"(["1970-01-02", "1970-01-04", "1970-01-04", "1970-01-06"])", - R"(["1970-01-01", "1970-01-04", "1970-01-05", "1970-01-07"])", - "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("1970-01-04")", 1, 3}, + R"(["1970-01-01", "1970-01-04", "1970-01-05", "1970-01-07"])", "[0, 1, 3, 4]", + "[0, 3, 3, 4]", R"("1970-01-04")", 1, 3}, {"Duration", duration(TimeUnit::NANO), "[1, 3, 3, 5]", "[0, 3, 4, 6]", "[0, 1, 3, 4]", "[0, 3, 3, 4]", "3", 1, 3}, - {"Binary", binary(), R"(["aa", "bb", "bb", "dd"])", - R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", - R"("bb")", 1, 3}, - {"String", utf8(), R"(["aa", "bb", "bb", "dd"])", - R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", - R"("bb")", 1, 3}, + {"Binary", binary(), R"(["aa", "bb", "bb", "dd"])", R"(["a", "bb", "bc", "z"])", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, + {"String", utf8(), R"(["aa", "bb", "bb", "dd"])", R"(["a", "bb", "bc", "z"])", + "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, {"LargeBinary", large_binary(), R"(["aa", "bb", "bb", "dd"])", - R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", - R"("bb")", 1, 3}, + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, {"LargeString", large_utf8(), R"(["aa", "bb", "bb", "dd"])", - R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", - R"("bb")", 1, 3}, + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, {"BinaryView", binary_view(), R"(["aa", "bb", "bb", "dd"])", - R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", - R"("bb")", 1, 3}, + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, {"StringView", utf8_view(), R"(["aa", "bb", "bb", "dd"])", - R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", - R"("bb")", 1, 3}, + R"(["a", "bb", "bc", "z"])", "[0, 1, 3, 4]", "[0, 3, 3, 4]", R"("bb")", 1, 3}, }; } @@ -175,9 +168,8 @@ TEST(SearchSorted, ScalarNeedle) { auto values = ArrayFromJSON(int32(), "[1, 3, 5, 7]"); ASSERT_OK_AND_ASSIGN( - auto result, - SearchSorted(Datum(values), Datum(std::make_shared(5)), - SearchSortedOptions(SearchSortedOptions::Right))); + auto result, SearchSorted(Datum(values), Datum(std::make_shared(5)), + SearchSortedOptions(SearchSortedOptions::Right))); ASSERT_TRUE(result.is_scalar()); ASSERT_EQ(checked_cast(*result.scalar()).value, 3); @@ -281,7 +273,8 @@ TEST(SearchSorted, RejectUnclusteredNullValues) { TEST(SearchSorted, RunEndEncodedNulls) { auto values_type = run_end_encoded(int16(), int32()); - ASSERT_OK_AND_ASSIGN(auto ree_values, REEFromJSON(values_type, "[null, null, 2, 4, 4]")); + ASSERT_OK_AND_ASSIGN(auto ree_values, + REEFromJSON(values_type, "[null, null, 2, 4, 4]")); auto needles_type = run_end_encoded(int16(), int32()); ASSERT_OK_AND_ASSIGN(auto ree_needles, REEFromJSON(needles_type, "[null, null, 1, 4, 4, null, 8]")); @@ -402,17 +395,15 @@ TEST_P(SearchSortedSupportedTypesTest, ArraySmoke) { TEST_P(SearchSortedSupportedTypesTest, ScalarSmoke) { const auto& param = GetParam(); CheckSimpleScalarSearchSorted(param.type, param.values_json, param.scalar_needle_json, - param.expected_scalar_left, - param.expected_scalar_right); + param.expected_scalar_left, param.expected_scalar_right); } -INSTANTIATE_TEST_SUITE_P( - SupportedTypes, SearchSortedSupportedTypesTest, - ::testing::ValuesIn(SupportedTypeSmokeCases()), - [](const ::testing::TestParamInfo& info) { - return info.param.name; - }); +INSTANTIATE_TEST_SUITE_P(SupportedTypes, SearchSortedSupportedTypesTest, + ::testing::ValuesIn(SupportedTypeSmokeCases()), + [](const ::testing::TestParamInfo& info) { + return info.param.name; + }); } // namespace } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow diff --git a/python/pyarrow/_compute_docstrings.py b/python/pyarrow/_compute_docstrings.py index 39cd39875c10..42a51558696d 100644 --- a/python/pyarrow/_compute_docstrings.py +++ b/python/pyarrow/_compute_docstrings.py @@ -57,7 +57,9 @@ 5 ] >>> with_nulls = pa.array([None, 200, 300, 300], type=pa.int64()) - >>> pc.search_sorted(with_nulls, pa.array([50, 200, None, 400], type=pa.int64())) + >>> pc.search_sorted( + ... with_nulls, pa.array([50, 200, None, 400], type=pa.int64()) + ... ) [ 1, From 56146c21ce8e6f804d3bd557a5f318a6a06e7c41 Mon Sep 17 00:00:00 2001 From: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Date: Wed, 8 Apr 2026 13:46:27 +0200 Subject: [PATCH 7/8] Refactor vector_search_sorted kernel: enhance readability and add noexcept to length method --- .../compute/kernels/vector_search_sorted.cc | 4 +- python/pyarrow/_compute_docstrings.py | 66 +++++++++---------- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc index 32ee35fdfdda..86c8bdce0c46 100644 --- a/cpp/src/arrow/compute/kernels/vector_search_sorted.cc +++ b/cpp/src/arrow/compute/kernels/vector_search_sorted.cc @@ -238,7 +238,7 @@ struct NonNullValuesRange { /// Return whether the range spans the full searched values input. bool is_identity(int64_t full_length) const { - return offset == 0 && length == full_length; + return (offset == 0) && (length == full_length); } }; @@ -255,7 +255,7 @@ class NonNullValuesAccessor { length_(non_null_values_range.length) {} /// Return the number of accessible non-null values. - int64_t length() const { return length_; } + int64_t length() const noexcept { return length_; } /// Return the value at the given index within the non-null subrange. auto Value(int64_t index) const { return values_.Value(offset_ + index); } diff --git a/python/pyarrow/_compute_docstrings.py b/python/pyarrow/_compute_docstrings.py index 42a51558696d..9fb05c7a6442 100644 --- a/python/pyarrow/_compute_docstrings.py +++ b/python/pyarrow/_compute_docstrings.py @@ -43,39 +43,39 @@ """ function_doc_additions["search_sorted"] = """ - Examples - -------- - >>> import pyarrow as pa - >>> import pyarrow.compute as pc - >>> values = pa.array([1, 1, 3, 5, 8]) - >>> pc.search_sorted(values, pa.array([0, 1, 4, 9])) - - [ - 0, - 0, - 3, - 5 - ] - >>> with_nulls = pa.array([None, 200, 300, 300], type=pa.int64()) - >>> pc.search_sorted( - ... with_nulls, pa.array([50, 200, None, 400], type=pa.int64()) - ... ) - - [ - 1, - 1, - null, - 4 - ] - >>> pc.search_sorted(values, pa.array([0, 1, 4, 9]), side="right") - - [ - 0, - 2, - 3, - 5 - ] - """ + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> values = pa.array([1, 1, 3, 5, 8]) + >>> pc.search_sorted(values, pa.array([0, 1, 4, 9])) + + [ + 0, + 0, + 3, + 5 + ] + >>> with_nulls = pa.array([None, 200, 300, 300], type=pa.int64()) + >>> pc.search_sorted( + ... with_nulls, pa.array([50, 200, None, 400], type=pa.int64()) + ... ) + + [ + 1, + 1, + null, + 4 + ] + >>> pc.search_sorted(values, pa.array([0, 1, 4, 9]), side="right") + + [ + 0, + 2, + 3, + 5 + ] + """ function_doc_additions["mode"] = """ Examples From 9def02cf5cba920ca2c5dd2493b915120112a662 Mon Sep 17 00:00:00 2001 From: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:24:17 +0200 Subject: [PATCH 8/8] Refactor search_sorted documentation: adjust indentation for clarity --- python/pyarrow/_compute_docstrings.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/pyarrow/_compute_docstrings.py b/python/pyarrow/_compute_docstrings.py index 9fb05c7a6442..e5189f063b7f 100644 --- a/python/pyarrow/_compute_docstrings.py +++ b/python/pyarrow/_compute_docstrings.py @@ -51,10 +51,10 @@ >>> pc.search_sorted(values, pa.array([0, 1, 4, 9])) [ - 0, - 0, - 3, - 5 + 0, + 0, + 3, + 5 ] >>> with_nulls = pa.array([None, 200, 300, 300], type=pa.int64()) >>> pc.search_sorted( @@ -62,18 +62,18 @@ ... ) [ - 1, - 1, - null, - 4 + 1, + 1, + null, + 4 ] >>> pc.search_sorted(values, pa.array([0, 1, 4, 9]), side="right") [ - 0, - 2, - 3, - 5 + 0, + 2, + 3, + 5 ] """