diff --git a/Cargo.toml b/Cargo.toml index 312f46d..ad992f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [ "crates/*", "crates/bpe/benchmarks", - "crates/bpe/tests", + "crates/bpe/tests" ] resolver = "2" @@ -11,4 +11,4 @@ resolver = "2" debug = true [profile.release] -debug = true \ No newline at end of file +debug = true diff --git a/crates/hriblt/Cargo.toml b/crates/hriblt/Cargo.toml new file mode 100644 index 0000000..91f8585 --- /dev/null +++ b/crates/hriblt/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "hriblt" +version = "0.1.0" +edition = "2024" +description = "Algorithm for rateless set reconciliation" +repository = "https://github.com/github/rust-gems" +license = "MIT" +keywords = ["set-reconciliation", "sync", "algorithm", "probabilistic"] +categories = ["algorithms", "data-structures", "mathematics", "science"] + +[dependencies] +thiserror = "2" diff --git a/crates/hriblt/README.md b/crates/hriblt/README.md new file mode 100644 index 0000000..6b93d8b --- /dev/null +++ b/crates/hriblt/README.md @@ -0,0 +1,55 @@ +# Hierarchical Rateless Bloom Lookup Tables + +A novel algorithm for computing the symmetric difference between sets where the amount of data shared is proportional to the size of the difference in the sets rather than proportional to the overall size. + +## Usage + +Add the library to your `Cargo.toml` file. + +```toml +[dependencies] +hriblt = "0.1" +``` + +Create two encoding sessions, one containing your data, and another containing the counter-parties data. This counterparty data might have been sent to you over a network for example. + +The following example attempts to reconcile the differences between two sets of `u64` integers, and is done from the perspective of "Bob", who has recieved some symbols from "Alice". + +```rust +use hriblt::{DecodingSession, EncodingSession, DefaultHashFunctions}; +// On Alice's computer + +// Alice creates an encoding session... +let mut alice_encoding_session = EncodingSession::::new(DefaultHashFunctions, 0..128); + +// And adds her data to that session, in this case the numbers from 0 to 10. +for i in 0..=10 { + alice_encoding_session.insert(i); +} + +// On Bob's computer + +// Bob creates his encoding session, note that the range **must** be the same as Alice's +let mut bob_encoding_session = EncodingSession::::new(DefaultHashFunctions, 0..128); + +// Bob adds his data, the numbers from 5 to 15. +for i in 5..=15 { + bob_encoding_session.insert(i); +} + +// "Subtract" Bob's coded symbols from Alice's, the remaining symbols will be the symmetric +// difference between the two sets, iff we can decode them. This is a commutative function so you +// could also subtract Alice's symbols from Bob's and it would still work. +let merged_sessions = alice_encoding_session.merge(bob_encoding_session, true); + +let decoding_session = DecodingSession::from_encoding(merged_sessions); + +assert!(decoding_session.is_done()); + +let mut diff = decoding_session.into_decoded_iter().map(|v| v.into_value()).collect::>(); + +diff.sort(); + +assert_eq!(diff, [0, 1, 2, 3, 4, 11, 12, 13, 14, 15]); + +``` diff --git a/crates/hriblt/docs/assets/coded-symbol-multiplier.png b/crates/hriblt/docs/assets/coded-symbol-multiplier.png new file mode 100644 index 0000000..84d3b69 Binary files /dev/null and b/crates/hriblt/docs/assets/coded-symbol-multiplier.png differ diff --git a/crates/hriblt/docs/hashing_functions.md b/crates/hriblt/docs/hashing_functions.md new file mode 100644 index 0000000..8bede64 --- /dev/null +++ b/crates/hriblt/docs/hashing_functions.md @@ -0,0 +1,21 @@ +# Hash Functions + +This library has a trait, `HashFunctions` which is used to create the hashes required to place your symbol into the range of coded symbols. + +The following documentation provides more details on this trait in particular. How and why this is done is explained in the `overview.md` documentation. + +## Hash stability + +When using HRIBLT in production systems it is important to consider the stability of your hash functions. + +We provide a `DefaultHashFunctions` type which is a wrapper around the `DefaultHasher` type provided by the Rust standard library. Though the seed for this function is fixed, it should be noted that the hashes produces by this type are *not* guarenteed to be stable across different versions of the Rust standard library. As such, you should not use this type for any situation where clients might potentially be running on a binary built with an unspecified version of Rust. + +We recommend you implement your own `HashFunctions` implementation with a stable hash function. + +## Hash value hashing trick + +If the value you're inserting into the encoding session is a high entropy random value, such as a cryptographic hash digest, you can recycle the bytes in that value to produce the coded symbol indexing hashes, instead of hashing that value again. This results in a constant-factor speed up. + +For example if you were trying to find the difference between two sets of documents, instead of each coded symbol being the whole document it could instead just be a SHA1 hash of the document content. Since each SHA1 digest has 20 bytes of high entropy bits, instead of hashing this value five times again to produce the five coded symbol indices we can simply slice out five `u32` values from the digest itself. + +This is a useful trick because hash values are often used as IDs for documents during set reconciliation since they are a fixed size, making serialization easy. diff --git a/crates/hriblt/docs/sizing.md b/crates/hriblt/docs/sizing.md new file mode 100644 index 0000000..912f803 --- /dev/null +++ b/crates/hriblt/docs/sizing.md @@ -0,0 +1,17 @@ +# Sizing your HRIBLT + +Because the HRIBLT is rateless, it is possible to append additional data in order to make it decoding possible. That is, it does not need to be sized in advance like a standard invertible bloom lookup table. + +Regardless, there are some advantages to getting the size of your decoding session correct the first time. An example might be if you're performing set reconciliation over some RPC and you want to minimise the number of round trips it takes to perform a decode. + +## Coded Symbol Multiplier + +The number of coded symbols required to find the difference between two sets is proportional to the difference between the two sets. The following chart shows the relationship between the number of coded symbols required to decode HRIBLT and the size of the diff. Note that the size of the base set (before diffs were added) was fixed. + +`y = len(coded_symbols) / diff_size` + +![Coded symbol multiplier](./assets/coded-symbol-multiplier.png) + +For small diffs, the number of coded symbols required per value is larger, after a difference of approximately 100 values the coefficient settles on around 1.3 to 1.4. + +You can use this chart, combined with an estimate of the diff size (perhaps from a `geo_filter`) to increase the probability that you will have a successful decode after a single round-trip while also minimising the amount of data sent. diff --git a/crates/hriblt/src/coded_symbol.rs b/crates/hriblt/src/coded_symbol.rs new file mode 100644 index 0000000..3aab3ca --- /dev/null +++ b/crates/hriblt/src/coded_symbol.rs @@ -0,0 +1,127 @@ +use crate::{Encodable, HashFunctions, index_for_seed, indices}; + +/// Represents a coded symbol in the invertible bloom filter table. +/// In some of the literature this is referred to as a "cell" or "bucket". +/// It includes a checksum to verify whether the instance represents a pure value. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +pub struct CodedSymbol { + /// Values aggregated by XOR operation. + pub value: T, + /// We repurpose the two least significant bits of the checksum: + /// - The least significant bit is a one bit counter which is incremented for each entity. + /// This bit must be set when there is a single entity represented by this hash. + /// - The second least significant bit indicates whether the entity is a deletion or insertion. + pub checksum: u64, +} + +impl Default for CodedSymbol { + fn default() -> Self { + CodedSymbol { + value: T::zero(), + checksum: 0, + } + } +} + +impl From<(T, u64)> for CodedSymbol { + fn from(tuple: (T, u64)) -> Self { + Self { + value: tuple.0, + checksum: tuple.1, + } + } +} + +impl CodedSymbol { + /// Creates a new coded symbol with the given hash and deletion flag. + pub(crate) fn new>(state: &S, hash: T, deletion: bool) -> Self { + let mut checksum = state.check_sum(&hash); + checksum |= 1; // Add a single bit counter + if deletion { + checksum = checksum.wrapping_neg(); + } + CodedSymbol { + value: hash, + checksum, + } + } + + /// Merges another coded symbol into this one. + pub(crate) fn add(&mut self, other: &CodedSymbol, negate: bool) { + self.value.xor(other.value); + if negate { + self.checksum = self.checksum.wrapping_sub(other.checksum); + } else { + self.checksum = self.checksum.wrapping_add(other.checksum); + } + } + + /// Checks whether this coded symbol is pure, i.e., whether it represents a single entity + /// A pure coded symbol must satisfy the following conditions: + /// - The 1-bit counter must be 1 or -1 (which are both represented by the bit being set) + /// - The checksum must match the checksum of the value. + /// - The indices of the value must match the index of this coded symbol. + pub(crate) fn is_pure>( + &self, + state: &S, + i: usize, + len: usize, + ) -> (bool, usize) { + if self.checksum & 1 == 0 { + return (false, 0); + } + let multiplicity = indices_contains(state, &self.value, len, i); + if multiplicity != 1 { + return (false, 0); + } + let checksum = state.check_sum(&self.value) | 1; + if checksum == self.checksum || checksum.wrapping_neg() == self.checksum { + (true, 0) + } else { + let required_bits = self + .checksum + .wrapping_sub(checksum) + .leading_zeros() + .max(self.checksum.wrapping_add(checksum).leading_zeros()) + as usize; + (false, required_bits) + } + } + + /// Checks whether this coded symbol is zero, i.e., whether it represents no entity. + pub(crate) fn is_zero(&self) -> bool { + self.checksum == 0 && self.value == T::zero() + } + + /// Checks whether this coded symbol represents a deletion. + pub(crate) fn is_deletion>(&self, state: &S) -> bool { + let checksum = state.check_sum(&self.value) | 1; + checksum != self.checksum + } +} + +/// This function checks efficiently whether the given index is contained in the indices. +/// +/// Note: we have constructed the indices such that we can determine from the last 5 bits +/// which hash function would map to this index. Therefore, we only need to check against +/// a single hash function and not all 5! +/// The only exception is for very small indices (0..32) or if the index is a multiple of 32. +/// +/// The function returns the multiplicity, i.e. how many indices hit this particular index. +/// Thereby, it takes into account whether the value is stored negated or not. +fn indices_contains( + state: &impl HashFunctions, + value: &T, + stream_len: usize, + i: usize, +) -> i32 { + if stream_len > 32 && i % 32 != 0 { + let seed = i % 4; + let j = index_for_seed(state, value, stream_len, seed as u32); + if i == j { 1 } else { 0 } + } else { + indices(state, value, stream_len) + .map(|j| if j == i { 1 } else { 0 }) + .sum() + } +} diff --git a/crates/hriblt/src/decoded_value.rs b/crates/hriblt/src/decoded_value.rs new file mode 100644 index 0000000..a7aac07 --- /dev/null +++ b/crates/hriblt/src/decoded_value.rs @@ -0,0 +1,31 @@ +/// A value that has been found by the set reconciliation algorithm. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, PartialOrd, Ord)] +pub enum DecodedValue { + /// A value that has been added + Addition(T), + /// A value that has been removed + Deletion(T), +} + +impl DecodedValue { + /// Consume this `DecodedValue` to return the value + pub fn into_value(self) -> T { + match self { + DecodedValue::Addition(v) => v, + DecodedValue::Deletion(v) => v, + } + } + + /// Borrow the value within this decoded value. + pub fn value(&self) -> &T { + match self { + DecodedValue::Addition(v) => v, + DecodedValue::Deletion(v) => v, + } + } + + /// Returns true if this decoded value is a deletion + pub fn is_deletion(&self) -> bool { + matches!(self, DecodedValue::Deletion(_)) + } +} diff --git a/crates/hriblt/src/decoding_session.rs b/crates/hriblt/src/decoding_session.rs new file mode 100644 index 0000000..37bfe6e --- /dev/null +++ b/crates/hriblt/src/decoding_session.rs @@ -0,0 +1,221 @@ +use std::collections::HashSet; + +use crate::{ + Encodable, HashFunctions, coded_symbol::CodedSymbol, decoded_value::DecodedValue, + encoding_session::EncodingSession, error::SetReconciliationError, parent, +}; + +/// A session for decoding a stream of hashes. +#[derive(Clone)] +pub struct DecodingSession> { + /// The encoded stream of hashes. + /// All recovered coded symbols have been removed from this stream. + /// If decoded failed, then one can simply append more data and continue decoding. + encoded: EncodingSession, + /// All the recovered coded symbols. + recovered: Vec>, + /// Don't decode the same coded symbol multiple times (it might occur up to 5 times in the stream!). + visited: HashSet, + /// Tracks the number of non-zero coded symbol up to the split point. + pub(crate) non_zero: isize, + + /// For statistical purposes: this number informs how many bits of the + /// checksum were required to identify pure coded symbols. + pub(crate) required_bits: usize, +} + +impl> DecodingSession { + /// Create a new decoding session with a given seed. + pub fn new(state: H) -> Self { + DecodingSession { + recovered: vec![], + encoded: EncodingSession::new(state, 0..0), + visited: HashSet::default(), + non_zero: 0, + required_bits: 0, + } + } + + fn from_encoding_unchecked(merged: EncodingSession) -> Self { + let mut me = DecodingSession { + recovered: vec![], + encoded: merged, + visited: HashSet::default(), + non_zero: 0, + required_bits: 0, + }; + // We work here with a non-hierarchical stream. + me.encoded.move_split_point(me.encoded.range.end); + me.non_zero = me + .encoded + .coded_symbols + .iter() + .filter(|e| !e.is_zero()) + .count() as isize; + let mut j = me.recovered.len(); + let len = me.encoded.coded_symbols.len(); + for i in (0..len).rev() { + if me.non_zero == 0 { + break; + } + let (is_pure, required_bits) = me.encoded.is_pure(i); + if is_pure && !me.visited.contains(&me.encoded.coded_symbols[i].value) { + me.visited.insert(me.encoded.coded_symbols[i].value); + me.recovered.push(me.encoded.coded_symbols[i]); + } + if !is_pure && required_bits > me.required_bits { + me.required_bits = required_bits; + } + while j < me.recovered.len() { + let entity = me.recovered[j]; + let (changes, required_bits) = + me.encoded.add_to_stream(&entity, true, i + 1..len, |e, k| { + assert!(k > i); + if !me.visited.contains(&e.value) { + me.visited.insert(e.value); + me.recovered.push(e); + } + }); + me.non_zero += changes; + me.required_bits = me.required_bits.max(required_bits); + j += 1; + } + } + me + } + + /// This is a faster version for decoding the initial stream. + /// It processes this stream from back to front without going through the hierarchical representation. + /// The other procedure needs to execute roughly one additional `is_pure` test when unrolling the hierarchy + /// which this procedure avoids. + /// Additionally, this procedure can save on average another 50% of is_pure tests, since it won't waste time + /// on the highly packed hierarchy levels where we don't expect to find any pure values. + /// + /// Panics if the encoding session is not the beginning of a stream (e.g. the range is `0..n`) + pub fn from_encoding(merged: EncodingSession) -> Self { + assert_eq!(merged.range.start, 0); + Self::from_encoding_unchecked(merged) + } + + /// This is a faster version for decoding the initial stream. + /// It processes this stream from back to front without going through the hierarchical representation. + /// The other procedure needs to execute roughly one additional `is_pure` test when unrolling the hierarchy + /// which this procedure avoids. + /// Additionally, this procedure can save on average another 50% of is_pure tests, since it won't waste time + /// on the highly packed hierarchy levels where we don't expect to find any pure values. + /// + /// Returns an error if the encoding session is not the beginning of a stream (e.g. the range is `0..n`) + pub fn try_from_encoding( + merged: EncodingSession, + ) -> Result { + if merged.range.start != 0 { + return Err(SetReconciliationError::NotInitialRange); + } + Ok(Self::from_encoding_unchecked(merged)) + } + + fn append_unchecked(&mut self, mut merged: EncodingSession) { + // Apply all the reconstructed entities to the new part of the stream. + for entity in &self.recovered { + merged.add_entity_inner(entity, true); + } + assert_eq!(self.encoded.split, self.encoded.range.end); + self.encoded.append(merged); + // Now continue decoding starting with the newly arrived data. + let mut j = self.recovered.len(); + for i in self.encoded.split..self.encoded.range.end { + // Undo hierarchy manually here, since we also need to count non-zero entries along the way. + let ii = parent(i); + if i > 0 { + let tmp = self.encoded.coded_symbols[i]; + if !self.encoded.is_zero(ii) { + self.non_zero -= 1; + } + self.encoded.coded_symbols[ii].add(&tmp, true); + if !self.encoded.is_zero(ii) { + self.non_zero += 1; + } + } + self.encoded.split = i + 1; + if !self.encoded.is_zero(i) { + self.non_zero += 1; + } + for l in [i, ii] { + let (is_pure, required_bits) = self.encoded.is_pure(l); + if is_pure && !self.visited.contains(&self.encoded.coded_symbols[l].value) { + self.visited.insert(self.encoded.coded_symbols[l].value); + self.recovered.push(self.encoded.coded_symbols[l]); + } + if !is_pure && required_bits > self.required_bits { + self.required_bits = required_bits; + } + } + while j < self.recovered.len() { + let entity = self.recovered[j]; + let (changes, required_bits) = + self.encoded.add_to_stream(&entity, true, 0..i, |e, k| { + assert!(k < i); + if !self.visited.contains(&e.value) { + self.visited.insert(e.value); + self.recovered.push(e); + } + }); + self.non_zero += changes; + self.required_bits = self.required_bits.max(required_bits); + j += 1; + } + if self.non_zero == 0 { + // At this point everything should be decoded... + // We could in theory check that all remaining coded symbols are zero. + break; + } + } + } + + /// Appends the next chunk of coded symbols to the decoding session. + /// This should only be called if decoding was not yet completed. + /// Panics if the encoding session is not a contiguous range with `self`. + pub fn append(&mut self, merged: EncodingSession) { + assert_eq!(self.encoded.range.end, merged.range.start); + self.append_unchecked(merged); + } + + /// Appends the next chunk of coded symbols to the decoding session. + /// This should only be called if decoding was not yet completed. + /// Returns an error if the encoding session is not a contiguous range with `self`. + pub fn try_append( + &mut self, + merged: EncodingSession, + ) -> Result<(), SetReconciliationError> { + if self.encoded.range.end != merged.range.start { + return Err(SetReconciliationError::NonContiguousRanges); + } + self.append_unchecked(merged); + Ok(()) + } + + /// Returns whether decoding has successfully finished. + pub fn is_done(&self) -> bool { + !self.encoded.coded_symbols.is_empty() && self.non_zero == 0 + } + + /// Returns the number of coded symbols that were consumed during the decoding process. + pub fn consumed_coded_symbols(&self) -> usize { + self.encoded.split + } + + /// Extract the decoded entities from the session. + /// Only call when `is_done()` returns true. + pub fn into_decoded_iter(self) -> impl Iterator> { + // We have decoded the stream successfully. + // Now we can return the decoded entities. + let hasher = self.encoded.hasher; + self.recovered.into_iter().map(move |e| { + if e.is_deletion(&hasher) { + DecodedValue::Deletion(e.value) + } else { + DecodedValue::Addition(e.value) + } + }) + } +} diff --git a/crates/hriblt/src/encoding_session.rs b/crates/hriblt/src/encoding_session.rs new file mode 100644 index 0000000..69538f3 --- /dev/null +++ b/crates/hriblt/src/encoding_session.rs @@ -0,0 +1,284 @@ +use std::ops::Range; + +use crate::{ + Encodable, HashFunctions, coded_symbol::CodedSymbol, error::SetReconciliationError, indices, + parent, +}; + +/// A session for encoding a stream of values. +/// This session can be used to merge multiple streams together, append more coded symbol from another session, +/// or extract coded symbols from the stream for decoding. +#[derive(Clone, PartialEq, Eq)] +pub struct EncodingSession> { + /// The hashing functions used for mapping values to indices. + pub(crate) hasher: H, + /// The range of the rateless stream which are encoded by this session. + /// This way it is possible to just encode a subset if needed. + pub(crate) range: Range, + /// The coded symbols for the range + pub(crate) coded_symbols: Vec>, + /// Starting at this point, the stream is represented by a hierarchy! + /// We use a somewhat unique representation where coded symbols from 0..split are represented + /// as a "normal" invertible bloom filter table and subsequent coded symbols are represented + /// in a hierarchical way. + /// The nice property of this representation is that we can switch back and forth between the two. + /// This is useful, since certain operations are faster in one representation than the other. + /// + /// The hierarchical representation can be thought of as a rateless set reconciliation stream. + pub(crate) split: usize, +} + +impl> EncodingSession { + /// Create a new encoding session with a given seed and range. + pub fn new(state: H, range: Range) -> Self { + EncodingSession { + hasher: state, + coded_symbols: vec![CodedSymbol::default(); range.len()], + split: range.end, // We start with a non-hierarchical stream. + range, + } + } + + /// Create a EncodingSession from a vector of coded symbols. + /// + /// Panics if the split is out of range or if the length of te vector + /// and the length of the range differ. + pub fn from_coded_symbols( + state: H, + coded_symbols: Vec>, + range: Range, + split: usize, + ) -> Self { + assert!(split >= range.start && split <= range.end); + assert_eq!(coded_symbols.len(), range.len()); + EncodingSession { + hasher: state, + coded_symbols, + split, + range, + } + } + + /// Create a EncodingSession from a vector of coded symbols. + /// + /// Returns an error if the split is out of range or if the length of te vector + /// and the length of the range differ. + pub fn try_from_coded_symbols( + state: H, + coded_symbols: Vec>, + range: Range, + split: usize, + ) -> Result { + if split < range.start || split > range.end { + return Err(SetReconciliationError::SplitOutOfRange); + } + if coded_symbols.len() > range.len() { + return Err(SetReconciliationError::RangeLengthMismatch); + } + Ok(EncodingSession { + hasher: state, + coded_symbols, + split, + range, + }) + } + + /// Adds an entity to the encoding session. + pub fn insert(&mut self, entity: T) { + let check_hash = CodedSymbol::new(&self.hasher, entity, false); + self.add_entity_inner(&check_hash, false); + } + + /// Adds multiple entities to the encoding session. + pub fn extend(&mut self, entities: impl Iterator) { + for entity in entities { + self.insert(entity); + } + } + + /// Returns the encoded rateless stream. + /// Don't forget to either move the split point to the desired place or communicate it + /// with the receiver of this data! Otherwise, the stream cannot be processed correctly! + pub fn into_coded_symbols(self) -> impl Iterator> { + self.coded_symbols.into_iter() + } + + fn append_unchecked(&mut self, mut other: EncodingSession) { + other.move_split_point(self.range.end); + self.coded_symbols.append(&mut other.coded_symbols); + self.range.end = other.range.end; + } + + /// Appends another encoded stream to this session. + /// The function will automatically adapt the split point of the second stream, so that + /// it is compatible with the first one. + pub fn append(&mut self, other: EncodingSession) { + assert_eq!(self.hasher, other.hasher); + assert_eq!(self.range.end, other.range.start); + self.append_unchecked(other); + } + + /// Attempt to append another encoding session onto this one returning an error + /// if the hashers do not match or if the ranges are not contiguous. + pub fn try_append( + &mut self, + other: EncodingSession, + ) -> Result<(), SetReconciliationError> { + if self.hasher != other.hasher { + return Err(SetReconciliationError::MismatchedHasher); + } + if self.range.end != other.range.start { + return Err(SetReconciliationError::NonContiguousRanges); + } + self.append_unchecked(other); + Ok(()) + } + + /// Call this function to extract the next `n` many coded symbols from the current session. + /// Note: this will move the split point, such that the next extraction will also be fast. + /// Inserting elements will however become slower if the split point is not moved to the end again! + pub fn split_off(&mut self, n: usize) -> EncodingSession { + let split = (self.range.start + n).min(self.range.end); + assert!( + split > self.range.start, + "split {split} must be greater than start {}", + self.range.start + ); + if split < self.split { + self.move_split_point(split); + } + let mut rest = EncodingSession { + hasher: self.hasher, + range: split..self.range.end, + coded_symbols: self.coded_symbols.split_off(split - self.range.start), + split, + }; + self.range.end = split; + std::mem::swap(self, &mut rest); + rest + } + + fn merge_unchecked(mut self, other: EncodingSession, negated: bool) -> Self { + self.coded_symbols + .iter_mut() + .zip(other.coded_symbols) + .for_each(|(a, b)| a.add(&b, negated)); + + self + } + + /// Merge another encoding session representing the same range into this one. + /// This is needed in case of sharding or parallel processing. + /// It should also be used to combine the data from two parties in order to determine + /// the symmetric difference between their sets. + /// + /// The `negated` parameter indicates whether the values in the other session should be negated. + pub fn merge(self, other: EncodingSession, negated: bool) -> Self { + assert_eq!(self.range, other.range); + assert_eq!(self.hasher, other.hasher); + self.merge_unchecked(other, negated) + } + + /// Attempt to merge an encoding session into this one. If the hashers or ranges differ then + /// this will fail with an error result. + pub fn try_merge( + self, + other: EncodingSession, + negated: bool, + ) -> Result { + if self.hasher != other.hasher { + return Err(SetReconciliationError::MismatchedHasher); + } + if self.range != other.range { + return Err(SetReconciliationError::MismatchedRanges); + } + Ok(self.merge_unchecked(other, negated)) + } + + /// Helper function which adds an entity to all the indices determined by the hash functions. + /// It also works, when the stream is represented partially or fully hierarchically! + /// Calls a functor for each changed position which might now represent a single entity. + /// Counts how many elements changed to zero/non-zero. + pub(crate) fn add_to_stream( + &mut self, + entity: &CodedSymbol, + negated: bool, + range: Range, + mut f: impl FnMut(CodedSymbol, usize), + ) -> (isize, usize) { + let mut changes = 0; + let mut required_bits = 0; + for mut i in indices(&self.hasher, &entity.value, self.range.end) { + while i >= self.split && i >= self.range.start && i > 0 { + self.coded_symbols[i - self.range.start].add(entity, negated); + i = parent(i); + } + if i >= self.range.start { + if self.is_zero(i) { + changes += 1; + } + self.coded_symbols[i - self.range.start].add(entity, negated); + if self.is_zero(i) { + changes -= 1; + } else if range.contains(&i) { + let (is_pure, bits) = self.is_pure(i); + if is_pure { + let tmp = self.coded_symbols[i - self.range.start]; + f(tmp, i); + } else { + required_bits = required_bits.max(bits); + } + } + } + } + (changes, required_bits) + } + + /// Returns true if this is a pure coded symbol. + pub(crate) fn is_pure(&self, i: usize) -> (bool, usize) { + self.coded_symbols[i - self.range.start].is_pure(&self.hasher, i, self.split) + } + + /// Returns true if no value is present in this coded symbol. + pub(crate) fn is_zero(&self, i: usize) -> bool { + self.coded_symbols[i - self.range.start].is_zero() + } + + /// The caller has to ensure that the same seed was used to construct the entity! + pub(crate) fn add_entity_inner(&mut self, entity: &CodedSymbol, negated: bool) { + self.add_to_stream(entity, negated, 0..0, |_, _| {}); + } + + /// Helper function to move the split point to a desired position. + /// For fast insertion operations, the split point should be at the end of the represented range. + /// For fast extraction operations, the split point should be at the beginning. + pub(crate) fn move_split_point(&mut self, new_split: usize) { + assert!(new_split <= self.range.end && new_split >= self.range.start); + assert!( + self.split <= self.range.end && self.split >= self.range.start, + "{} {:?}", + self.split, + self.range + ); + while self.split < new_split { + if self.split > 0 { + let i = parent(self.split); + if i >= self.range.start { + let tmp = self.coded_symbols[self.split - self.range.start]; + self.coded_symbols[i - self.range.start].add(&tmp, true); + } + } + self.split += 1; + } + while self.split > new_split { + self.split -= 1; + if self.split > 0 { + let i = parent(self.split); + if i >= self.range.start { + let tmp = self.coded_symbols[self.split - self.range.start]; + self.coded_symbols[i - self.range.start].add(&tmp, false); + } + } + } + } +} diff --git a/crates/hriblt/src/error.rs b/crates/hriblt/src/error.rs new file mode 100644 index 0000000..a36570d --- /dev/null +++ b/crates/hriblt/src/error.rs @@ -0,0 +1,22 @@ +/// Errors raised during set reconciliation +#[derive(thiserror::Error, Debug)] +pub enum SetReconciliationError { + /// Expected hashers to match. + #[error("coded symbol hasher mismatched")] + MismatchedHasher, + /// The provided split does not lie within the provided range + #[error("split does not lie within the provided range")] + SplitOutOfRange, + /// Expected a provided range to follow on from the current range. + #[error("provided coded symbol range did not follow on from previous range")] + NonContiguousRanges, + /// Expected ranges to be the same. + #[error("provided coded symbol range did not match the previous range")] + MismatchedRanges, + /// The range of the provided coded symbols did not begin at zero. + #[error("provided coded symbol range did not start from zero")] + NotInitialRange, + /// The length of the range and the length of the coded symbols do not match. + #[error("number of coded symbols did not match the length of the provided range")] + RangeLengthMismatch, +} diff --git a/crates/hriblt/src/lib.rs b/crates/hriblt/src/lib.rs new file mode 100644 index 0000000..0616a74 --- /dev/null +++ b/crates/hriblt/src/lib.rs @@ -0,0 +1,400 @@ +//! This module implements a set reconciliation algorithm using XOR-based hashes. +//! +//! The core algorithm is based on the idea of set similarity sketching where pure hashes are identified +//! by testing whether one of the 5 hash functions would map the candidate value back to the index +//! of the value. Taking advantage of this necessary condition reduces the number of necessary checksum +//! bits by a bit. +//! +//! To make this algorithm rateless, we simply observe that any given IBLT becomes another IBLT if +//! we XOR the hashes of the upper half of the hashes with the lower half. +//! This process can also be realized by mapping one code symbol at a time into the lower half, such +//! that we can generate ANY number of coded symbols. Since this procedure is revertable when the next coded symbol +//! is provided by the client, one ends up with a rateless set reconciliation algorithm. +//! +//! The decoding algorithm essentially reconstructs in each iteration an IBLT with one more coded symbol +//! and then tries to decode the two coded symbols which have been modified by the expansion operation. +//! +//! This algorithm has similar properties to the "Practical Rateless Set Reconciliation" algorithm. +//! Main differences are: +//! * We only use 5 hash functions instead of log(n). +//! * As a result we only require 5 independent hash functions instead of log(n) many. +//! * The amount of data being transferred is comparable to the one in the paper. +//! * The chance that two documents collide on all hash functions is higher, but still very low, +//! since we utilize 5 hash functions instead of just 3. +//! * There is no complicated math involved. In the paper it is important that the exact same computation +//! is performed by both sides or the scheme will fall apart. (I.e. any kind of math optimizations must be disabled +//! and a stable math library must be used!) +//! * Encoding/decoding is faster due to the fixed number of hash functions and the simpler operations. +//! * Since we have a fixed number of hash functions, we can utilize the coded symbol index as +//! an additional condition. In fact, we need to compute just a single hash function (on average). +mod coded_symbol; +mod decoded_value; +mod decoding_session; +mod encoding_session; +mod error; + +use std::{ + fmt::Debug, + hash::{DefaultHasher, Hash, Hasher}, + ops::BitXorAssign, +}; + +pub use coded_symbol::CodedSymbol; +pub use decoded_value::DecodedValue; +pub use decoding_session::DecodingSession; +pub use encoding_session::EncodingSession; +pub use error::SetReconciliationError; + +/// Computes independent hash functions. +pub trait HashFunctions: Eq + Copy + Debug { + /// Hashes the given value with the n-th hash function. + /// This trait should provide 5 independent hash functions! + fn hash(&self, value: &T, n: u32) -> u32; + /// Computes a checksum. Note this checksum must become invalid + /// under xor operations! + fn check_sum(&self, value: &T) -> u64; +} + +/// Hasher builder implementing equality operation of seed value. +#[derive(Default, Debug, Clone, Copy, Eq, PartialEq)] +pub struct DefaultHashFunctions; + +impl HashFunctions for DefaultHashFunctions { + fn hash(&self, value: &T, n: u32) -> u32 { + let mut hasher = DefaultHasher::new(); + n.hash(&mut hasher); + value.hash(&mut hasher); + hasher.finish() as u32 + } + + fn check_sum(&self, value: &T) -> u64 { + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + hasher.finish() + } +} + +/// Trait for value types that can be used in the set reconciliation algorithm. +pub trait Encodable: Copy + Eq + PartialEq + Hash { + /// Returns a zero value for this type which is usually the default value. + fn zero() -> Self; + /// Xor the current value with another value of the same type. + /// This must not strictly be an XOR operation. We only require that applying the operation twice + /// returns the original value. + fn xor(&mut self, other: Self); +} + +impl Encodable for T { + fn zero() -> Self { + Self::default() + } + + fn xor(&mut self, other: Self) { + *self ^= other; + } +} + +fn indices( + builder: &impl HashFunctions, + value: &T, + stream_len: usize, +) -> impl Iterator { + (0..5).map(move |seed| index_for_seed(builder, value, stream_len, seed)) +} + +/// This function computes with 5 distinct hash functions 5 indices to which a value maps. +/// Essentially, we are "fighting" here two contradicting requirements: +/// - More hash functions with larger partitions reduce the probability that two values maps +/// to exactly the same indices for all hash functions! In this situation, we can decode the stream. +/// - Fewer (ideally 3) hash functions with larger partitions lead to a higher chance to +/// find a pure value in the stream, i.e. the stream can be decoded with fewer coded symbols. +/// +/// After testing various schemes, I settled for this one which uses 4 equally sized partitions, +/// one for each of the first 4 hash functions. This ensures that the first 4 hash functions +/// map to distinct indices. +/// The last hash function is used to reduce the probability of hash collisions further without +/// reducing the chance to find pure values in the stream. +/// +/// Note: we want an odd number of hash functions, so that collapsing the stream to a single coded symbol +/// (or small number of coded symbols) won't erase the value information. +/// +/// The second return value indicates whether the entry should be stored negated. +fn index_for_seed( + builder: &impl HashFunctions, + value: &T, + stream_len: usize, + seed: u32, +) -> usize { + let mut hash = builder.hash(value, seed); + let mask = 31; + let mut lsb = hash & mask; + hash -= lsb; + if seed == 4 { + lsb = 0; + } else { + lsb &= !3; + lsb += seed; + } + hash += lsb; + hash_to_index(hash, stream_len) +} + +/// Determines the parent index in the rateless/hierarchical representation. +fn parent(i: usize) -> usize { + i - ((i + 1).next_power_of_two() / 2) +} + +/// Function to map a hash value into a correct index for a given number of coded symbols. +/// This function is compatible with the above `parent` function in the sense that +/// repeatedly applying `parent` until the index is less than `n` will yield the same result! +fn hash_to_index(hash: u32, n: usize) -> usize { + let power_of_two = (n as u32).next_power_of_two(); + let hash = hash % power_of_two; + let res = if hash >= n as u32 { + hash - power_of_two / 2 + } else { + hash + }; + res as usize +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, fmt::Debug, time::Instant}; + + use crate::{ + DefaultHashFunctions, EncodingSession, decoded_value::DecodedValue, + decoding_session::DecodingSession, hash_to_index, parent, + }; + + /// This test ensures that the parent and hash_to_index functions are consistent to each other! + #[test] + fn test_parent() { + for i in 0..1000 { + let mut ii = i; + while ii > 0 { + let jj = parent(ii); + for k in jj + 1..=ii { + assert_eq!(hash_to_index(i as u32, k), jj, "{i} {ii} {k}"); + } + assert_eq!(hash_to_index(i as u32, jj + 1), jj, "{i} {ii} {jj}"); + ii = jj; + } + } + } + + /// Test that moving the split point works correctly. + /// This test simply goes through all possible size and split point combinations and compares + /// the results when the split point is changed before or after inserting elements. + /// In both cases, the result should be the same! + #[test] + fn test_move_split() { + let state = DefaultHashFunctions; + for n in 1..64 { + for s in 0..=n { + let mut encoding1 = EncodingSession::new(state, 0..n); + let mut encoding2 = EncodingSession::new(state, 0..n); + encoding2.move_split_point(s); + for k in 1..100 { + encoding1.insert(k); + encoding2.insert(k); + } + encoding1.move_split_point(s); + assert_eq!( + encoding1.coded_symbols, encoding2.coded_symbols, + "n: {n}, s: {s}" + ); + } + } + } + + /// Test encoding and decoding of a large stream. + #[test] + fn test_single_stream() { + let mut stats = Stats::default(); + let mut bits = Stats::default(); + let mut encoding_time = Stats::default(); + let mut decoding_time = Stats::default(); + let mut deocding_time_fast = Stats::default(); + for i in 0..10 { + let items = 100000; + let entries: Vec<_> = (1u64..=items).collect(); + let state = DefaultHashFunctions; + let start = Instant::now(); + let mut stream1 = EncodingSession::new(state, 0..items as usize * 15 / 10); + stream1.extend(entries.iter().cloned()); + encoding_time.add(start.elapsed().as_secs_f32()); + + let start = Instant::now(); + let mut decoding_session = DecodingSession::new(state); + decoding_session.append(stream1.clone()); + assert!( + decoding_session.is_done(), + "{} {i}", + decoding_session.non_zero, + ); + stats.add(decoding_session.consumed_coded_symbols() as f32); + bits.add(decoding_session.required_bits as f32); + decoding_time.add(start.elapsed().as_secs_f32()); + let mut decoded: Vec<_> = decoding_session + .into_decoded_iter() + .map(|decoded_value| { + let DecodedValue::Addition(e) = decoded_value else { + panic!("Value was deleted, but expected added"); + }; + e + }) + .collect(); + decoded.sort(); + assert_eq!(decoded.len(), items as usize); + assert_eq!(decoded, entries); + + let start = Instant::now(); + let decoding_session = DecodingSession::from_encoding(stream1); + assert!(decoding_session.is_done()); + deocding_time_fast.add(start.elapsed().as_secs_f32()); + let decoded2: Vec<_> = decoding_session.into_decoded_iter().collect(); + assert_eq!(decoded2.len(), items as usize); + } + println!("stream size: {stats:?}"); + println!("required bits: {bits:?}"); + println!("encoding time: {encoding_time:?}"); + println!("decoding time: {decoding_time:?}"); + println!("decoding time fast: {deocding_time_fast:?}"); + } + + /// Test that splitting an encoding session and reassembling it into a decoding session works. + /// The comparison must be done in the hierarchical representation which is enforced by + /// calling move_split_point. + /// Note: we abort the loop, once the stream was successfully decoded, since otherwise some + /// assertion would trigger :) + #[test] + fn test_splitting_of_decoding() { + let state = DefaultHashFunctions; + let mut stream1 = EncodingSession::new(state, 0..200); + let items = 100; + stream1.extend(1..=items); + + // Test that decoding would actually work... + let mut decoding_session = DecodingSession::new(state); + decoding_session.append(stream1.clone()); + assert!(decoding_session.is_done()); + + let mut expected = stream1.clone(); + expected.move_split_point(0); + let expected: Vec<_> = expected.into_coded_symbols().collect(); + let mut decoding_session = DecodingSession::new(state); + for i in 0.. { + let mut got = stream1.split_off(10); + decoding_session.append(got.clone()); + assert_eq!(got.range.start, i * 10); + got.move_split_point(i * 10); + let got: Vec<_> = got.into_coded_symbols().collect(); + assert_eq!(got, expected[i * 10..(i + 1) * 10]); + if decoding_session.is_done() { + break; + } + } + assert_eq!(decoding_session.into_decoded_iter().count(), items); + } + + #[derive(Default)] + struct Stats { + sum: f32, + sum2: f32, + cnt: f32, + max: f32, + } + + impl Stats { + fn add(&mut self, v: f32) { + self.sum += v; + self.sum2 += v * v; + self.cnt += 1.0; + if v > self.max { + self.max = v; + } + } + + fn finish(&self) -> (f32, f32, f32) { + let mean = self.sum / self.cnt; + let var = self.sum2 / self.cnt - mean * mean; + (mean, var.max(0.0).sqrt() / mean, self.max) + } + } + + impl Debug for Stats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (mean, stddev, max) = self.finish(); + write!(f, "{mean},{stddev},{max}") + } + } + + #[test] + fn test_merging() { + let state = DefaultHashFunctions; + let mut stream1 = EncodingSession::new(state, 0..200); + stream1.extend(0..20); + let mut stream2 = EncodingSession::new(state, 0..200); + stream2.extend(10..30); + let merged = stream1.merge(stream2, true); + let decoding_session = DecodingSession::from_encoding(merged); + assert!(decoding_session.is_done()); + let mut decoded: Vec<_> = decoding_session.into_decoded_iter().collect(); + decoded.sort(); + assert_eq!( + &decoded[0..10], + (0..10).map(DecodedValue::Addition).collect::>() + ); + assert_eq!( + &decoded[10..20], + (20..30).map(DecodedValue::Deletion).collect::>() + ); + } + + #[test] + fn test_statistics() { + let state = DefaultHashFunctions; + let mut stats = HashMap::new(); + + for start in 0..100 { + let mut stream = EncodingSession::new(state, 0..1500); + stream.move_split_point(0); + for value in 1u64..1000 { + stream.insert(value); + if ((value as f32).log2() * 8.0).floor() + != ((value as f32 + 1.0).log2() * 8.0).floor() + { + let mut decoding_session = DecodingSession::new(state); + decoding_session.append(stream.clone()); + assert!(decoding_session.is_done(), "start: {start}, value: {value}"); + { + // It is in theory possible that the fast decoding requires a longer stream than the + // incremental decoding. This can happen when collisions cancel out in the incremental decoding + // case. This test shows that this is extremely unlikely to happen. + let decoding_session = DecodingSession::from_encoding( + stream + .clone() + .split_off(decoding_session.consumed_coded_symbols()), + ); + assert!(decoding_session.is_done(), "start: {start}, value: {value}"); + } + let s = stats + .entry(value) + .or_insert((Stats::default(), Stats::default())); + s.0.add(decoding_session.consumed_coded_symbols() as f32); + s.1.add(decoding_session.required_bits as f32); + } + } + } + let mut stats: Vec<_> = stats.into_iter().collect(); + stats.sort_by_key(|(value, _)| *value); + for (value, (stat, bits)) in stats { + println!("{value},{stat:?},{bits:?}"); + } + } +} + +#[doc = include_str!("../README.md")] +#[cfg(doctest)] +pub struct ReadmeDocTests;