diff --git a/Project.toml b/Project.toml index 655dfb1..f95ccc8 100644 --- a/Project.toml +++ b/Project.toml @@ -5,10 +5,12 @@ authors = ["Chris Elrod and contributors"] [deps] PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" [compat] Pkg = "1.10" PrecompileTools = "1" +SIMD = "3" SafeTestsets = "0.1" StableRNGs = "1" Test = "1.10" diff --git a/bench/simd_binary_sweep.jl b/bench/simd_binary_sweep.jl new file mode 100644 index 0000000..3368d38 --- /dev/null +++ b/bench/simd_binary_sweep.jl @@ -0,0 +1,206 @@ +# Bench sweep: `SIMDBinarySearch` vs `BinaryBracket` (`Base.searchsortedlast`) +# for single-query workloads across n × eltype × cache state. +# +# Strategy is opt-in only and ignores any hint. To make the comparison fair +# we drive each strategy with the same per-query loop — no batched API, no +# hint chaining — and measure ns per query. +# +# Two cache regimes: +# - hot: v stays resident; consecutive queries see warm cache lines. +# - cold: cycle through a working set of independent v vectors whose +# combined footprint is larger than LLC, so each v's first probe +# sees a cold cache line. Larger n covers more of the working set +# per single query and the steady-state cache pressure is what we +# measure. + +using FindFirstFunctions, StableRNGs, BenchmarkTools, Printf, Statistics + +const F = FindFirstFunctions + +# Strategies under test. +const STRATS = [ + ("SIMDBinary", F.SIMDBinarySearch()), + ("BinaryBracket", F.BinaryBracket()), +] + +# ---- query-loop kernels (no batched API, no hint) ------------------------- + +@inline function loop_last!(strat, v, qs, out) + @inbounds for i in eachindex(qs) + out[i] = searchsortedlast(strat, v, qs[i]) + end + return out +end + +@inline function loop_first!(strat, v, qs, out) + @inbounds for i in eachindex(qs) + out[i] = searchsortedfirst(strat, v, qs[i]) + end + return out +end + +# Cold-cache driver: cycle through a working set of independent v vectors +# whose combined bytes exceed LLC (default 256 MiB). Each query consumes one +# v from the working set; we sweep through often enough that the v we're +# about to use was last touched many MiB ago. +const COLD_WORKING_SET_BYTES = 256 * 1024 * 1024 + +@inline function loop_last_cold!(strat, vs, qs, out) + nv = length(vs) + @inbounds for i in eachindex(qs) + v = vs[mod1(i, nv)] + out[i] = searchsortedlast(strat, v, qs[i]) + end + return out +end + +# ---- timing helpers -------------------------------------------------------- + +function time_hot(strat, v, qs, out, reps = 7) + loop_last!(strat, v, qs, out) # warmup + best = typemax(Float64) + for _ in 1:reps + t = @elapsed loop_last!(strat, v, qs, out) + best = min(best, t) + end + return best * 1.0e9 / length(qs) +end + +function time_cold(strat, vs, qs, out, reps = 5) + # Cold-cache: each query reads from a different v in a working-set-larger- + # than-LLC pool, so probes that hit "v's data" almost always miss cache. + loop_last_cold!(strat, vs, qs, out) # warmup + best = typemax(Float64) + for _ in 1:reps + t = @elapsed loop_last_cold!(strat, vs, qs, out) + best = min(best, t) + end + return best * 1.0e9 / length(qs) +end + +# ---- workload generation --------------------------------------------------- + +function build_v(::Type{Float64}, n, seed) + return collect(range(1.0, Float64(n); length = n)) +end +function build_v(::Type{Int64}, n, seed) + return collect(Int64(1):Int64(n)) +end + +function build_queries(::Type{Float64}, v, m, seed) + rng = StableRNG(seed) + return rand(rng, m) .* (last(v) - first(v)) .+ first(v) +end +function build_queries(::Type{Int64}, v, m, seed) + rng = StableRNG(seed) + return rand(rng, Int64(first(v)):Int64(last(v)), m) +end + +# ---- sweep ----------------------------------------------------------------- + +function build_cold_working_set(::Type{T}, n, seed) where {T} + # Build enough independent vectors so the total bytes exceed LLC. + bytes_per_v = n * sizeof(T) + nv = max(2, cld(COLD_WORKING_SET_BYTES, bytes_per_v)) + rng = StableRNG(seed) + vs = Vector{Vector{T}}(undef, nv) + for i in 1:nv + # Each v is the same shape (1..n) but offset so queries hit + # well-defined positions regardless of which v is picked. + vs[i] = build_v(T, n, seed + i) + end + return vs +end + +function build_cold_queries(::Type{T}, vs, m, seed) where {T} + rng = StableRNG(seed) + # Choose a query uniformly across the common range of each v. + return [ + T == Float64 ? + T(rand(rng) * (length(vs[1]) - 1) + 1) : + T(rand(rng, 1:length(vs[1]))) + for _ in 1:m + ] +end + +function run_sweep() + ns = (256, 1024, 4096, 16_384, 65_536, 262_144, 1_048_576) + eltypes = (Float64, Int64) + # Number of queries per timing rep — large enough that per-query timing + # noise is small, small enough that the whole sweep finishes in minutes. + m_hot = 65_536 + m_cold = 4096 + + println("SIMDBinarySearch vs BinaryBracket — single-query sweep") + println("="^78) + println( + "Hot cache: $(m_hot) queries / rep; cold cache: $(m_cold) queries / rep" + ) + @printf( + "Cold cycles through a working set ≥ %d MiB so each v is cold.\n", + COLD_WORKING_SET_BYTES ÷ (1024 * 1024) + ) + println() + + rows = [] + + for T in eltypes + println("=== eltype = $T ===") + @printf( + "%9s | %28s | %28s\n", + "n", + "hot (ns/q, SIMD vs Base)", + "cold (ns/q, SIMD vs Base)" + ) + println("-"^78) + for n in ns + v = build_v(T, n, 1) + qs_hot = build_queries(T, v, m_hot, 2) + out_hot = Vector{Int}(undef, m_hot) + + # Cold: cycle through independent v's. + vs_cold = build_cold_working_set(T, n, 1000) + qs_cold = build_cold_queries(T, vs_cold, m_cold, 3) + out_cold = Vector{Int}(undef, m_cold) + + simd_hot = time_hot(F.SIMDBinarySearch(), v, qs_hot, out_hot) + base_hot = time_hot(F.BinaryBracket(), v, qs_hot, out_hot) + simd_cold = time_cold(F.SIMDBinarySearch(), vs_cold, qs_cold, out_cold) + base_cold = time_cold(F.BinaryBracket(), vs_cold, qs_cold, out_cold) + + push!( + rows, + (T, n, simd_hot, base_hot, simd_cold, base_cold) + ) + @printf( + "%9d | SIMD=%9.1f Base=%9.1f | SIMD=%9.1f Base=%9.1f\n", + n, simd_hot, base_hot, simd_cold, base_cold + ) + end + println() + end + + println() + println("Winner table (lower ns/q wins; tied = within 5%):") + println("="^78) + @printf( + "%-9s %-9s | %-12s %-12s | %-12s %-12s\n", + "eltype", "n", "hot winner", "ratio S/B", "cold winner", "ratio S/B" + ) + println("-"^78) + for (T, n, sh, bh, sc, bc) in rows + rh = sh / bh + rc = sc / bc + wh = rh < 0.95 ? "SIMD" : (rh > 1.05 ? "Base" : "tie") + wc = rc < 0.95 ? "SIMD" : (rc > 1.05 ? "Base" : "tie") + @printf( + "%-9s %-9d | %-12s %-12.2f | %-12s %-12.2f\n", + T, n, wh, rh, wc, rc + ) + end + return rows +end + +if !isinteractive() && abspath(PROGRAM_FILE) == @__FILE__ + rows = @time run_sweep() +end diff --git a/src/FindFirstFunctions.jl b/src/FindFirstFunctions.jl index 2c962c6..97b54c5 100644 --- a/src/FindFirstFunctions.jl +++ b/src/FindFirstFunctions.jl @@ -1,5 +1,7 @@ module FindFirstFunctions +using SIMD: SIMD + # Public API surface for `using FindFirstFunctions`. The strategy types are # zero-field singletons (except `GuesserHint` and `Auto`, which carry small # isbits payloads), so exporting them only adds names to the caller's @@ -11,7 +13,7 @@ export SearchStrategy, LinearScan, SIMDLinearScan, BracketGallop, ExpFromLeft, InterpolationSearch, BitInterpolationSearch, - BinaryBracket, BisectThenSIMD, + BinaryBracket, SIMDBinarySearch, BisectThenSIMD, GuesserHint, Auto, SearchProperties, Guesser, looks_linear, @@ -30,6 +32,7 @@ include("equality.jl") # findfirstequal + findfirstsortedequal include("strategies.jl") # SearchStrategy + concrete strategy types + SearchProperties + Auto include("search_properties.jl") # Linearity / NaN probes + populated SearchProperties constructor include("dispatch.jl") # Per-strategy searchsortedfirst/last methods + their internal helpers +include("simd_binary_search.jl") # SIMDBinarySearch (8-way SIMD-gather binary search) include("auto.jl") # Auto crossover constants + per-query Auto + Auto's batched helpers include("batched.jl") # Batched API + searchsortedrange + _batched! (incl Auto specialization) include("guesser.jl") # looks_linear + Guesser + GuesserHint dispatch diff --git a/src/simd_binary_search.jl b/src/simd_binary_search.jl new file mode 100644 index 0000000..96dc4c3 --- /dev/null +++ b/src/simd_binary_search.jl @@ -0,0 +1,229 @@ +# 8-way SIMD binary search for `DenseVector{Int64}` and `DenseVector{Float64}`. +# Each iteration loads 8 strided probes via `SIMD.vgather`, compares against +# the query, and reduces the result mask via `SIMD.bitmask` + `trailing_zeros` +# to pick the next subrange. The bracket shrinks by ~8× per step, giving an +# asymptotic ~log₈(n) iterations instead of log₂(n). +# +# Probe layout in `[lo, hi]` of length n = hi - lo + 1: +# p_k = lo + ((k - 1) * (n - 1)) ÷ 7 for k ∈ 1..8 +# so p_1 = lo, p_8 = hi, and the seven interior probes are evenly spaced. +# When n < 8 the probes can coincide; we fall back to a scalar bounded +# binary search on the inner loop for that case. + +# Per-iteration SIMD shrink step. Returns (new_lo, new_hi, done, answer) where +# `done` is true once the bracket has either collapsed or we've isolated the +# answer at a boundary. The polarity of `pred_gt` controls whether we're +# implementing searchsortedlast (predicate `v > x`) or searchsortedfirst +# (predicate `v >= x`). +@inline function _simd_bsearch_step_last( + v::DenseVector{T}, x::T, lo::Int, hi::Int, + ) where {T <: Union{Int64, Float64}} + # Compute the 8 probe indices. Using Int multiplication then division to + # keep the indices integer; the (n-1) factor + integer division means + # `p_1 == lo` and `p_8 == hi` exactly. + n = hi - lo + 1 + # 8 quantile probes spanning [lo, hi]. Stored 0-based offsets here so + # the SIMD.Vec{8,Int} ctor is one shuffle, not eight scalar adds. + o0 = 0 + o1 = ((n - 1) * 1) ÷ 7 + o2 = ((n - 1) * 2) ÷ 7 + o3 = ((n - 1) * 3) ÷ 7 + o4 = ((n - 1) * 4) ÷ 7 + o5 = ((n - 1) * 5) ÷ 7 + o6 = ((n - 1) * 6) ÷ 7 + o7 = n - 1 + # SIMD.jl's vgather wants 1-based indices. + idx = SIMD.Vec{8, Int}( + ( + lo + o0, lo + o1, lo + o2, lo + o3, + lo + o4, lo + o5, lo + o6, lo + o7, + ) + ) + vals = SIMD.vgather(v, idx) + mask = vals > x # Vec{8, Bool}: lane k is true iff v[p_k] > x + bm = SIMD.bitmask(mask) + if bm == 0x00 + # All 8 probes have v[p] <= x. The answer is >= p_8 = hi. Since hi is + # the current upper bracket bound, the answer is exactly hi (we know + # v[hi+1] > x from the bracket invariant, or hi is lastindex(v)). + return (lo, hi, true, hi) + end + tz = Int(trailing_zeros(bm)) # 0..7, index of first lane with v[p] > x + if tz == 0 + # v[lo] > x already → answer is lo - 1. + return (lo, hi, true, lo - 1) + end + # Lane tz is the first probe where v > x. Lane tz-1 had v <= x. The + # answer lives in [p_{tz-1}, p_tz - 1]. + new_lo = lo + (((n - 1) * (tz - 1)) ÷ 7) + new_hi = lo + (((n - 1) * tz) ÷ 7) - 1 + if new_lo > new_hi + # Adjacent probes — answer is new_lo (since v[new_lo] <= x and + # v[new_lo + 1] = v[p_tz] > x). + return (lo, hi, true, new_lo) + end + return (new_lo, new_hi, false, 0) +end + +# searchsortedfirst counterpart. Predicate: `v >= x`. +@inline function _simd_bsearch_step_first( + v::DenseVector{T}, x::T, lo::Int, hi::Int, + ) where {T <: Union{Int64, Float64}} + n = hi - lo + 1 + o0 = 0 + o1 = ((n - 1) * 1) ÷ 7 + o2 = ((n - 1) * 2) ÷ 7 + o3 = ((n - 1) * 3) ÷ 7 + o4 = ((n - 1) * 4) ÷ 7 + o5 = ((n - 1) * 5) ÷ 7 + o6 = ((n - 1) * 6) ÷ 7 + o7 = n - 1 + idx = SIMD.Vec{8, Int}( + ( + lo + o0, lo + o1, lo + o2, lo + o3, + lo + o4, lo + o5, lo + o6, lo + o7, + ) + ) + vals = SIMD.vgather(v, idx) + mask = vals >= x + bm = SIMD.bitmask(mask) + if bm == 0x00 + # All probes v < x; answer is > p_8 = hi → hi + 1. + return (lo, hi, true, hi + 1) + end + tz = Int(trailing_zeros(bm)) + if tz == 0 + # v[lo] >= x already → answer is lo. + return (lo, hi, true, lo) + end + # Lane tz first with v >= x; lane tz-1 had v < x. Answer lives in + # [p_{tz-1} + 1, p_tz]. + new_lo = lo + (((n - 1) * (tz - 1)) ÷ 7) + 1 + new_hi = lo + (((n - 1) * tz) ÷ 7) + if new_lo > new_hi + return (lo, hi, true, new_hi) + end + return (new_lo, new_hi, false, 0) +end + +# Threshold: below this length, do a scalar bounded binary search instead of +# the SIMD step. The probe-position math needs n >= 8 to keep all 8 lanes at +# distinct indices; below that the gather either reloads the same index +# (correct but wasteful) or risks a zero-stride boundary case in some Julia +# versions. Picking the threshold at 16 gives some headroom and matches the +# n where scalar binary search is still very fast (4 compares). +const SIMD_BSEARCH_BASECASE = 16 + +@inline function _simd_bsearch_last(v::DenseVector{T}, x::T) where {T <: Union{Int64, Float64}} + lo = firstindex(v) + hi = lastindex(v) + hi < lo && return lo - 1 + # Outer-bound checks short-circuit the common out-of-range queries. + @inbounds if x < v[lo] + return lo - 1 + end + @inbounds if x >= v[hi] + return hi + end + # Now v[lo] <= x < v[hi]; the answer is in [lo, hi - 1]. + hi -= 1 + while (hi - lo + 1) >= SIMD_BSEARCH_BASECASE + lo, hi, done, ans = _simd_bsearch_step_last(v, x, lo, hi) + done && return ans + end + # Basecase: scalar bounded binary search. Base.searchsortedlast accepts + # (v, x, lo, hi, order) overloads. + return searchsortedlast(v, x, lo, hi, Base.Order.Forward) +end + +@inline function _simd_bsearch_first(v::DenseVector{T}, x::T) where {T <: Union{Int64, Float64}} + lo = firstindex(v) + hi = lastindex(v) + hi < lo && return lo + @inbounds if x <= v[lo] + return lo + end + @inbounds if x > v[hi] + return hi + 1 + end + # Now v[lo] < x <= v[hi]; the answer is in [lo + 1, hi]. + lo += 1 + while (hi - lo + 1) >= SIMD_BSEARCH_BASECASE + lo, hi, done, ans = _simd_bsearch_step_first(v, x, lo, hi) + done && return ans + end + return searchsortedfirst(v, x, lo, hi, Base.Order.Forward) +end + +# =========================================================================== +# Dispatch — Int64 and Float64 specialisations +# =========================================================================== + +function Base.searchsortedlast( + ::SIMDBinarySearch, v::DenseVector{Int64}, x::Int64; + order::Base.Order.Ordering = Base.Order.Forward, + ) + order === Base.Order.Forward || + return searchsortedlast(v, x, order) + return _simd_bsearch_last(v, x) +end +function Base.searchsortedlast( + ::SIMDBinarySearch, v::DenseVector{Float64}, x::Float64; + order::Base.Order.Ordering = Base.Order.Forward, + ) + order === Base.Order.Forward || + return searchsortedlast(v, x, order) + return _simd_bsearch_last(v, x) +end +function Base.searchsortedfirst( + ::SIMDBinarySearch, v::DenseVector{Int64}, x::Int64; + order::Base.Order.Ordering = Base.Order.Forward, + ) + order === Base.Order.Forward || + return searchsortedfirst(v, x, order) + return _simd_bsearch_first(v, x) +end +function Base.searchsortedfirst( + ::SIMDBinarySearch, v::DenseVector{Float64}, x::Float64; + order::Base.Order.Ordering = Base.Order.Forward, + ) + order === Base.Order.Forward || + return searchsortedfirst(v, x, order) + return _simd_bsearch_first(v, x) +end + +# Strategy ignores any hint that is supplied. +Base.searchsortedlast( + s::SIMDBinarySearch, v::DenseVector{Int64}, x::Int64, ::Integer; + order::Base.Order.Ordering = Base.Order.Forward, +) = searchsortedlast(s, v, x; order = order) +Base.searchsortedlast( + s::SIMDBinarySearch, v::DenseVector{Float64}, x::Float64, ::Integer; + order::Base.Order.Ordering = Base.Order.Forward, +) = searchsortedlast(s, v, x; order = order) +Base.searchsortedfirst( + s::SIMDBinarySearch, v::DenseVector{Int64}, x::Int64, ::Integer; + order::Base.Order.Ordering = Base.Order.Forward, +) = searchsortedfirst(s, v, x; order = order) +Base.searchsortedfirst( + s::SIMDBinarySearch, v::DenseVector{Float64}, x::Float64, ::Integer; + order::Base.Order.Ordering = Base.Order.Forward, +) = searchsortedfirst(s, v, x; order = order) + +# Other eltypes / non-dense storage: fall back to BinaryBracket. +Base.searchsortedlast( + ::SIMDBinarySearch, v::AbstractVector, x; + order::Base.Order.Ordering = Base.Order.Forward, +) = searchsortedlast(BinaryBracket(), v, x; order = order) +Base.searchsortedfirst( + ::SIMDBinarySearch, v::AbstractVector, x; + order::Base.Order.Ordering = Base.Order.Forward, +) = searchsortedfirst(BinaryBracket(), v, x; order = order) +Base.searchsortedlast( + s::SIMDBinarySearch, v::AbstractVector, x, ::Integer; + order::Base.Order.Ordering = Base.Order.Forward, +) = searchsortedlast(BinaryBracket(), v, x; order = order) +Base.searchsortedfirst( + s::SIMDBinarySearch, v::AbstractVector, x, ::Integer; + order::Base.Order.Ordering = Base.Order.Forward, +) = searchsortedfirst(BinaryBracket(), v, x; order = order) diff --git a/src/strategies.jl b/src/strategies.jl index e70263c..5f73649 100644 --- a/src/strategies.jl +++ b/src/strategies.jl @@ -165,6 +165,29 @@ that is supplied. """ struct BinaryBracket <: SearchStrategy end +""" + SIMDBinarySearch <: SearchStrategy + +Single-query binary search that evaluates 8 candidate positions per iteration +via SIMD.jl `Vec{8, T}` gather + lane-wise compare. The bracket `[lo, hi]` is +divided into 8 segments by 7 internal probe positions plus 1 segment-boundary +probe; each step narrows the bracket by a factor of ~8 instead of the +factor-of-2 of standard binary search. Asymptotic cost is ~log₈(n) gather + +compare + bitmask operations, which for large `n` is about a third of the +log₂(n) compares of `BinaryBracket`. + +Specialised for `DenseVector{Int64}` and `DenseVector{Float64}`. Other +eltypes fall back to [`BinaryBracket`](@ref). The strategy ignores any hint +that is supplied — the bracket starts at `[firstindex(v), lastindex(v)]` on +every call. For batched sorted-query workloads use a hinted strategy +(`ExpFromLeft`, `BracketGallop`, `SIMDLinearScan`) instead. + +This strategy is opt-in only; `Auto` does not pick it. Whether it actually +beats scalar `BinaryBracket` depends on hardware (gather latency, vector +unit width) and `n`. See `bench/simd_binary_sweep.jl`. +""" +struct SIMDBinarySearch <: SearchStrategy end + """ BisectThenSIMD <: SearchStrategy diff --git a/test/runtests.jl b/test/runtests.jl index 45e04a7..830677a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -633,6 +633,162 @@ end end end + @safetestset "SIMDBinarySearch correctness" begin + using FindFirstFunctions, StableRNGs + F = FindFirstFunctions + + @testset "Int64 fuzz vs Base" begin + rng = StableRNG(4001) + for _ in 1:5_000 + n = rand(rng, 0:512) + v = sort!(rand(rng, Int64(-1000):Int64(1000), n)) + x = rand(rng, Int64(-1100):Int64(1100)) + @test searchsortedlast(F.SIMDBinarySearch(), v, x) == + searchsortedlast(v, x) + @test searchsortedfirst(F.SIMDBinarySearch(), v, x) == + searchsortedfirst(v, x) + end + end + + @testset "Float64 fuzz vs Base" begin + rng = StableRNG(4002) + for _ in 1:5_000 + n = rand(rng, 0:512) + v = sort!(randn(rng, n)) + x = (rand(rng) - 0.5) * 6 + @test searchsortedlast(F.SIMDBinarySearch(), v, x) == + searchsortedlast(v, x) + @test searchsortedfirst(F.SIMDBinarySearch(), v, x) == + searchsortedfirst(v, x) + end + end + + @testset "Multiple n covering basecase boundary" begin + rng = StableRNG(4003) + for n in ( + 0, 1, 2, 7, 8, 15, 16, 17, 31, 32, 64, 127, 128, + 256, 1023, 1024, 4095, 4096, + ) + v = sort!(randn(rng, n)) + isempty(v) && ( + @test searchsortedlast(F.SIMDBinarySearch(), v, 0.0) == 0; + @test searchsortedfirst(F.SIMDBinarySearch(), v, 0.0) == 1; + continue + ) + for x in ( + v[1] - 1, v[1], v[end], v[end] + 1, + (v[1] + v[end]) / 2, + ) + @test searchsortedlast(F.SIMDBinarySearch(), v, x) == + searchsortedlast(v, x) + @test searchsortedfirst(F.SIMDBinarySearch(), v, x) == + searchsortedfirst(v, x) + end + # Random fuzz at this n. + for _ in 1:20 + x = (rand(rng) - 0.5) * 6 + @test searchsortedlast(F.SIMDBinarySearch(), v, x) == + searchsortedlast(v, x) + @test searchsortedfirst(F.SIMDBinarySearch(), v, x) == + searchsortedfirst(v, x) + end + end + end + + @testset "Edge cases" begin + # Empty + @test searchsortedlast(F.SIMDBinarySearch(), Int64[], Int64(5)) == 0 + @test searchsortedfirst(F.SIMDBinarySearch(), Int64[], Int64(5)) == 1 + @test searchsortedlast(F.SIMDBinarySearch(), Float64[], 5.0) == 0 + @test searchsortedfirst(F.SIMDBinarySearch(), Float64[], 5.0) == 1 + # Single element + v1 = [42.0] + @test searchsortedlast(F.SIMDBinarySearch(), v1, 42.0) == 1 + @test searchsortedlast(F.SIMDBinarySearch(), v1, 41.0) == 0 + @test searchsortedlast(F.SIMDBinarySearch(), v1, 43.0) == 1 + @test searchsortedfirst(F.SIMDBinarySearch(), v1, 42.0) == 1 + @test searchsortedfirst(F.SIMDBinarySearch(), v1, 41.0) == 1 + @test searchsortedfirst(F.SIMDBinarySearch(), v1, 43.0) == 2 + # x outside range + v = collect(1.0:100.0) + @test searchsortedlast(F.SIMDBinarySearch(), v, -100.0) == 0 + @test searchsortedlast(F.SIMDBinarySearch(), v, 200.0) == 100 + @test searchsortedfirst(F.SIMDBinarySearch(), v, -100.0) == 1 + @test searchsortedfirst(F.SIMDBinarySearch(), v, 200.0) == 101 + # x at exact match + @test searchsortedlast(F.SIMDBinarySearch(), v, 50.0) == + searchsortedlast(v, 50.0) + @test searchsortedfirst(F.SIMDBinarySearch(), v, 50.0) == + searchsortedfirst(v, 50.0) + # Duplicates - small + vd = Float64[1.0, 2.0, 2.0, 2.0, 5.0] + @test searchsortedlast(F.SIMDBinarySearch(), vd, 2.0) == 4 + @test searchsortedfirst(F.SIMDBinarySearch(), vd, 2.0) == 2 + # Duplicates - large (exercises both base case and SIMD step) + vd_big = vcat(fill(1.0, 50), fill(2.0, 100), fill(5.0, 50)) + @test searchsortedlast(F.SIMDBinarySearch(), vd_big, 2.0) == 150 + @test searchsortedfirst(F.SIMDBinarySearch(), vd_big, 2.0) == 51 + @test searchsortedlast(F.SIMDBinarySearch(), vd_big, 3.0) == 150 + @test searchsortedfirst(F.SIMDBinarySearch(), vd_big, 3.0) == 151 + # Constant vector + vc = fill(3.0, 32) + @test searchsortedlast(F.SIMDBinarySearch(), vc, 3.0) == 32 + @test searchsortedlast(F.SIMDBinarySearch(), vc, 2.0) == 0 + @test searchsortedlast(F.SIMDBinarySearch(), vc, 4.0) == 32 + @test searchsortedfirst(F.SIMDBinarySearch(), vc, 3.0) == 1 + @test searchsortedfirst(F.SIMDBinarySearch(), vc, 2.0) == 1 + @test searchsortedfirst(F.SIMDBinarySearch(), vc, 4.0) == 33 + end + + @testset "Hint is ignored" begin + v = collect(1.0:100.0) + # Same answer regardless of hint + expected_last = searchsortedlast(v, 50.5) + expected_first = searchsortedfirst(v, 50.5) + for h in (1, 10, 50, 99, 100, -5, 1000) + @test searchsortedlast(F.SIMDBinarySearch(), v, 50.5, h) == + expected_last + @test searchsortedfirst(F.SIMDBinarySearch(), v, 50.5, h) == + expected_first + end + end + + @testset "Fallback: non-Int64/Float64 eltypes" begin + # Int32 falls back to BinaryBracket + v32 = Int32[1, 5, 10, 20, 50, 100, 200] + for x in (Int32(0), Int32(7), Int32(20), Int32(300)) + @test searchsortedlast(F.SIMDBinarySearch(), v32, x) == + searchsortedlast(v32, x) + @test searchsortedfirst(F.SIMDBinarySearch(), v32, x) == + searchsortedfirst(v32, x) + end + # Float32 same + v32f = Float32[1.0, 5.0, 10.0, 20.0, 50.0] + for x in (Float32(0.0), Float32(7.0), Float32(20.0), Float32(100.0)) + @test searchsortedlast(F.SIMDBinarySearch(), v32f, x) == + searchsortedlast(v32f, x) + end + # Non-numeric + vs = sort!(["alpha", "beta", "gamma", "delta", "epsilon"]) + @test searchsortedlast(F.SIMDBinarySearch(), vs, "gamma") == + searchsortedlast(vs, "gamma") + end + + @testset "Reverse order falls back" begin + v_rev = collect(Int64, 100:-1:1) + @test searchsortedlast( + F.SIMDBinarySearch(), v_rev, Int64(50); order = Base.Order.Reverse, + ) == searchsortedlast(v_rev, Int64(50), Base.Order.Reverse) + @test searchsortedfirst( + F.SIMDBinarySearch(), v_rev, Int64(50); order = Base.Order.Reverse, + ) == searchsortedfirst(v_rev, Int64(50), Base.Order.Reverse) + end + + @testset "Strategy hierarchy" begin + @test F.SIMDBinarySearch <: F.SearchStrategy + end + end + @safetestset "findequal + BisectThenSIMD" begin using FindFirstFunctions, StableRNGs F = FindFirstFunctions