diff --git a/include/rabitqlib/index/estimator.hpp b/include/rabitqlib/index/estimator.hpp index b3f87ec..05bc47b 100644 --- a/include/rabitqlib/index/estimator.hpp +++ b/include/rabitqlib/index/estimator.hpp @@ -31,24 +31,60 @@ inline void split_batch_estdist( float* ip_x0_qr, bool use_hacc ) { + constexpr size_t kSafeChunkDim = 1024; ConstBatchDataMap cur_batch(batch_data, padded_dim); RowMajorArray accu_arr(1, fastscan::kBatchSize); if (use_hacc) { std::array accu_res; - fastscan::accumulate_hacc( - cur_batch.bin_code(), q_obj.lut(), accu_res.data(), padded_dim - ); for (size_t i = 0; i < fastscan::kBatchSize; ++i) { - accu_arr.data()[i] = accu_res[i]; + accu_arr.data()[i] = 0; + } + + const uint8_t* codes_ptr = + reinterpret_cast(cur_batch.bin_code()); + const uint8_t* lut_ptr = reinterpret_cast(q_obj.lut()); + size_t remaining_dim = padded_dim; + + while (remaining_dim > kSafeChunkDim) { + fastscan::accumulate_hacc( + codes_ptr, lut_ptr, accu_res.data(), kSafeChunkDim + ); + codes_ptr += kSafeChunkDim << 2; + lut_ptr += kSafeChunkDim << 3; + for (size_t i = 0; i < fastscan::kBatchSize; ++i) { + accu_arr.data()[i] += accu_res[i]; + } + remaining_dim -= kSafeChunkDim; + } + + fastscan::accumulate_hacc(codes_ptr, lut_ptr, accu_res.data(), remaining_dim); + for (size_t i = 0; i < fastscan::kBatchSize; ++i) { + accu_arr.data()[i] += accu_res[i]; } } else { std::array accu_res; - fastscan::accumulate( - cur_batch.bin_code(), q_obj.lut(), accu_res.data(), padded_dim - ); + const uint8_t* codes_ptr = + reinterpret_cast(cur_batch.bin_code()); + const uint8_t* lut_ptr = reinterpret_cast(q_obj.lut()); for (size_t i = 0; i < fastscan::kBatchSize; ++i) { - accu_arr.data()[i] = accu_res[i]; + accu_arr.data()[i] = 0; + } + + size_t remaining_dim = padded_dim; + while (remaining_dim > kSafeChunkDim) { + fastscan::accumulate(codes_ptr, lut_ptr, accu_res.data(), kSafeChunkDim); + codes_ptr += kSafeChunkDim << 2; + lut_ptr += kSafeChunkDim << 2; + for (size_t i = 0; i < fastscan::kBatchSize; ++i) { + accu_arr.data()[i] += accu_res[i]; + } + remaining_dim -= kSafeChunkDim; + } + + fastscan::accumulate(codes_ptr, lut_ptr, accu_res.data(), remaining_dim); + for (size_t i = 0; i < fastscan::kBatchSize; ++i) { + accu_arr.data()[i] += accu_res[i]; } } @@ -70,6 +106,7 @@ inline void split_batch_estdist( low_dist_arr = est_dist_arr - f_error_arr * q_obj.g_error(); } + /** * @brief Use ex-data bits to get more accurate distance * @@ -188,4 +225,4 @@ inline void split_single_estdist( low_dist = est_dist - (cur_bin.f_error() * g_error); }; -} // namespace rabitqlib \ No newline at end of file +} // namespace rabitqlib