Skip to content

Commit a547e8f

Browse files
authored
Add static_multiset::for_each and its OA impl (#506)
closes #499
1 parent 7803754 commit a547e8f

5 files changed

Lines changed: 432 additions & 1 deletion

File tree

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,164 @@ class open_addressing_ref_impl {
962962
}
963963
}
964964

965+
/**
966+
* @brief Executes a callback on every element in the container with key equivalent to the probe
967+
* key.
968+
*
969+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
970+
* `key` to the callback.
971+
*
972+
* @tparam ProbeKey Input type which is convertible to 'key_type'
973+
* @tparam CallbackOp Unary callback functor or device lambda
974+
*
975+
* @param key The key to search for
976+
* @param callback_op Function to call on every element found
977+
*/
978+
template <class ProbeKey, class CallbackOp>
979+
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
980+
{
981+
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
982+
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.window_extent());
983+
984+
while (true) {
985+
// TODO atomic_ref::load if insert operator is present
986+
auto const window_slots = this->storage_ref_[*probing_iter];
987+
988+
for (int32_t i = 0; i < window_size; ++i) {
989+
switch (
990+
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]))) {
991+
case detail::equal_result::EMPTY: {
992+
return;
993+
}
994+
case detail::equal_result::EQUAL: {
995+
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
996+
continue;
997+
}
998+
default: continue;
999+
}
1000+
}
1001+
++probing_iter;
1002+
}
1003+
}
1004+
1005+
/**
1006+
* @brief Executes a callback on every element in the container with key equivalent to the probe
1007+
* key.
1008+
*
1009+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
1010+
* `key` to the callback.
1011+
*
1012+
* @note This function uses cooperative group semantics, meaning that any thread may call the
1013+
* callback if it finds a matching element. If multiple elements are found within the same group,
1014+
* each thread with a match will call the callback with its associated element.
1015+
*
1016+
* @note Synchronizing `group` within `callback_op` is undefined behavior.
1017+
*
1018+
* @tparam ProbeKey Input type which is convertible to 'key_type'
1019+
* @tparam CallbackOp Unary callback functor or device lambda
1020+
*
1021+
* @param group The Cooperative Group used to perform this operation
1022+
* @param key The key to search for
1023+
* @param callback_op Function to call on every element found
1024+
*/
1025+
template <class ProbeKey, class CallbackOp>
1026+
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
1027+
ProbeKey const& key,
1028+
CallbackOp&& callback_op) const noexcept
1029+
{
1030+
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.window_extent());
1031+
bool empty = false;
1032+
1033+
while (true) {
1034+
// TODO atomic_ref::load if insert operator is present
1035+
auto const window_slots = this->storage_ref_[*probing_iter];
1036+
1037+
for (int32_t i = 0; i < window_size and !empty; ++i) {
1038+
switch (
1039+
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]))) {
1040+
case detail::equal_result::EMPTY: {
1041+
empty = true;
1042+
continue;
1043+
}
1044+
case detail::equal_result::EQUAL: {
1045+
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
1046+
continue;
1047+
}
1048+
default: {
1049+
continue;
1050+
}
1051+
}
1052+
}
1053+
if (group.any(empty)) { return; }
1054+
1055+
++probing_iter;
1056+
}
1057+
}
1058+
1059+
/**
1060+
* @brief Executes a callback on every element in the container with key equivalent to the probe
1061+
* key and can additionally perform work that requires synchronizing the Cooperative Group
1062+
* performing this operation.
1063+
*
1064+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
1065+
* `key` to the callback.
1066+
*
1067+
* @note This function uses cooperative group semantics, meaning that any thread may call the
1068+
* callback if it finds a matching element. If multiple elements are found within the same group,
1069+
* each thread with a match will call the callback with its associated element.
1070+
*
1071+
* @note Synchronizing `group` within `callback_op` is undefined behavior.
1072+
*
1073+
* @note The `sync_op` function can be used to perform work that requires synchronizing threads in
1074+
* `group` inbetween probing steps, where the number of probing steps performed between
1075+
* synchronization points is capped by `window_size * cg_size`. The functor will be called right
1076+
* after the current probing window has been traversed.
1077+
*
1078+
* @tparam ProbeKey Input type which is convertible to 'key_type'
1079+
* @tparam CallbackOp Unary callback functor or device lambda
1080+
* @tparam SyncOp Functor or device lambda which accepts the current `group` object
1081+
*
1082+
* @param group The Cooperative Group used to perform this operation
1083+
* @param key The key to search for
1084+
* @param callback_op Function to call on every element found
1085+
* @param sync_op Function that is allowed to synchronize `group` inbetween probing windows
1086+
*/
1087+
template <class ProbeKey, class CallbackOp, class SyncOp>
1088+
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
1089+
ProbeKey const& key,
1090+
CallbackOp&& callback_op,
1091+
SyncOp&& sync_op) const noexcept
1092+
{
1093+
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.window_extent());
1094+
bool empty = false;
1095+
1096+
while (true) {
1097+
// TODO atomic_ref::load if insert operator is present
1098+
auto const window_slots = this->storage_ref_[*probing_iter];
1099+
1100+
for (int32_t i = 0; i < window_size and !empty; ++i) {
1101+
switch (
1102+
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]))) {
1103+
case detail::equal_result::EMPTY: {
1104+
empty = true;
1105+
continue;
1106+
}
1107+
case detail::equal_result::EQUAL: {
1108+
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
1109+
continue;
1110+
}
1111+
default: {
1112+
continue;
1113+
}
1114+
}
1115+
}
1116+
sync_op(group);
1117+
if (group.any(empty)) { return; }
1118+
1119+
++probing_iter;
1120+
}
1121+
}
1122+
9651123
/**
9661124
* @brief Compares the content of the address `address` (old value) with the `expected` value and,
9671125
* only if they are the same, sets the content of `address` to `desired`.

include/cuco/detail/static_multiset/static_multiset_ref.inl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#include <cooperative_groups.h>
2424

25+
#include <utility>
26+
2527
namespace cuco {
2628

2729
template <typename Key,
@@ -446,6 +448,114 @@ class operator_impl<
446448
}
447449
};
448450

451+
template <typename Key,
452+
cuda::thread_scope Scope,
453+
typename KeyEqual,
454+
typename ProbingScheme,
455+
typename StorageRef,
456+
typename... Operators>
457+
class operator_impl<
458+
op::for_each_tag,
459+
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
460+
using base_type = static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef>;
461+
using ref_type =
462+
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
463+
464+
static constexpr auto cg_size = base_type::cg_size;
465+
466+
public:
467+
/**
468+
* @brief Executes a callback on every element in the container with key equivalent to the probe
469+
* key.
470+
*
471+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
472+
* `key` to the callback.
473+
*
474+
* @tparam ProbeKey Input type which is convertible to 'key_type'
475+
* @tparam CallbackOp Unary callback functor or device lambda
476+
*
477+
* @param key The key to search for
478+
* @param callback_op Function to call on every element found
479+
*/
480+
template <class ProbeKey, class CallbackOp>
481+
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
482+
{
483+
// CRTP: cast `this` to the actual ref type
484+
auto const& ref_ = static_cast<ref_type const&>(*this);
485+
ref_.impl_.for_each(key, std::forward<CallbackOp>(callback_op));
486+
}
487+
488+
/**
489+
* @brief Executes a callback on every element in the container with key equivalent to the probe
490+
* key.
491+
*
492+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
493+
* `key` to the callback.
494+
*
495+
* @note This function uses cooperative group semantics, meaning that any thread may call the
496+
* callback if it finds a matching element. If multiple elements are found within the same group,
497+
* each thread with a match will call the callback with its associated element.
498+
*
499+
* @note Synchronizing `group` within `callback_op` is undefined behavior.
500+
*
501+
* @tparam ProbeKey Input type which is convertible to 'key_type'
502+
* @tparam CallbackOp Unary callback functor or device lambda
503+
*
504+
* @param group The Cooperative Group used to perform this operation
505+
* @param key The key to search for
506+
* @param callback_op Function to call on every element found
507+
*/
508+
template <class ProbeKey, class CallbackOp>
509+
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
510+
ProbeKey const& key,
511+
CallbackOp&& callback_op) const noexcept
512+
{
513+
// CRTP: cast `this` to the actual ref type
514+
auto const& ref_ = static_cast<ref_type const&>(*this);
515+
ref_.impl_.for_each(group, key, std::forward<CallbackOp>(callback_op));
516+
}
517+
518+
/**
519+
* @brief Executes a callback on every element in the container with key equivalent to the probe
520+
* key and can additionally perform work that requires synchronizing the Cooperative Group
521+
* performing this operation.
522+
*
523+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
524+
* `key` to the callback.
525+
*
526+
* @note This function uses cooperative group semantics, meaning that any thread may call the
527+
* callback if it finds a matching element. If multiple elements are found within the same group,
528+
* each thread with a match will call the callback with its associated element.
529+
*
530+
* @note Synchronizing `group` within `callback_op` is undefined behavior.
531+
*
532+
* @note The `sync_op` function can be used to perform work that requires synchronizing threads in
533+
* `group` inbetween probing steps, where the number of probing steps performed between
534+
* synchronization points is capped by `window_size * cg_size`. The functor will be called right
535+
* after the current probing window has been traversed.
536+
*
537+
* @tparam ProbeKey Input type which is convertible to 'key_type'
538+
* @tparam CallbackOp Unary callback functor or device lambda
539+
* @tparam SyncOp Functor or device lambda which accepts the current `group` object
540+
*
541+
* @param group The Cooperative Group used to perform this operation
542+
* @param key The key to search for
543+
* @param callback_op Function to call on every element found
544+
* @param sync_op Function that is allowed to synchronize `group` inbetween probing windows
545+
*/
546+
template <class ProbeKey, class CallbackOp, class SyncOp>
547+
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
548+
ProbeKey const& key,
549+
CallbackOp&& callback_op,
550+
SyncOp&& sync_op) const noexcept
551+
{
552+
// CRTP: cast `this` to the actual ref type
553+
auto const& ref_ = static_cast<ref_type const&>(*this);
554+
ref_.impl_.for_each(
555+
group, key, std::forward<CallbackOp>(callback_op), std::forward<SyncOp>(sync_op));
556+
}
557+
};
558+
449559
template <typename Key,
450560
cuda::thread_scope Scope,
451561
typename KeyEqual,

include/cuco/operator.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ struct count_tag {
6262
struct find_tag {
6363
} inline constexpr find; ///< `cuco::find` operator
6464

65+
/**
66+
* @brief `for_each` operator tag
67+
*/
68+
struct for_each_tag {
69+
} inline constexpr for_each; ///< `cuco::for_each` operator
70+
6571
} // namespace op
6672
} // namespace cuco
6773

tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ ConfigureTest(STATIC_MULTISET_TEST
100100
static_multiset/count_test.cu
101101
static_multiset/custom_count_test.cu
102102
static_multiset/find_test.cu
103-
static_multiset/insert_test.cu)
103+
static_multiset/insert_test.cu
104+
static_multiset/for_each_test.cu)
104105

105106
###################################################################################################
106107
# - static_multimap tests -------------------------------------------------------------------------

0 commit comments

Comments
 (0)