diff --git a/CHANGELOG.md b/CHANGELOG.md index cb9623f..6157fe7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 4.0.0 +- **Breaking**: Make `delay` passed into `par_iter_bp_delayed` a strong type `Delay(pub usize)` to + reduce potential for bugs. +- **Breaking**: Encapsulate parallel iterators in new `PaddedIt { it, padding: usize }` type with `.map`, `.advance`, and `.collect_into` functions. +- Make `intrinsics::transpose` public for use in `collect_and_dedup` in `simd_minimizers`. + ## 3.2.1 - Add `Seq::read_{revcomp}_kmer_u128` with more tests - Fix bug in `revcomp_u128` diff --git a/Cargo.toml b/Cargo.toml index 6f29827..07c64b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,6 @@ pyo3 = { version = "0.25", features = ["extension-module"], optional = true } [features] # Also needed for tests. default = ["rand"] + +# Hides the `simd` warnings when neither AVX2 nor NEON is detected. +scalar = [] diff --git a/README.md b/README.md index 1548ab1..f473d6f 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,12 @@ crate was developed: This library supports AVX2 and NEON instruction sets. Make sure to set `RUSTFLAGS="-C target-cpu=native"` when compiling to use the instruction sets available on your architecture. - RUSTFLAGS="-C target-cpu=native" cargo run --release +``` sh +RUSTFLAGS="-C target-cpu=native" cargo run --release +``` +Enable the `-F scalar` feature flag to fall back to a scalar implementation with +reduced performance. ## Usage example diff --git a/src/ascii.rs b/src/ascii.rs index bed7132..d0f4f4a 100644 --- a/src/ascii.rs +++ b/src/ascii.rs @@ -1,4 +1,4 @@ -use crate::{intrinsics::transpose, packed_seq::read_slice}; +use crate::{intrinsics::transpose, packed_seq::read_slice, padded_it::ChunkIt}; use super::*; @@ -71,13 +71,13 @@ impl Seq<'_> for &[u8] { /// Iter the ASCII characters. #[inline(always)] - fn iter_bp(self) -> impl ExactSizeIterator + Clone { + fn iter_bp(self) -> impl ExactSizeIterator { self.iter().copied() } /// Iter the ASCII characters in parallel. #[inline(always)] - fn par_iter_bp(self, context: usize) -> (impl ExactSizeIterator + Clone, usize) { + fn par_iter_bp(self, context: usize) -> PaddedIt> { let num_kmers = self.len().saturating_sub(context - 1); let n = num_kmers.div_ceil(L); let padding = L * n - num_kmers; @@ -112,15 +112,15 @@ impl Seq<'_> for &[u8] { }, ); - (it, padding) + PaddedIt { it, padding } } #[inline(always)] fn par_iter_bp_delayed( self, context: usize, - delay: usize, - ) -> (impl ExactSizeIterator + Clone, usize) { + Delay(delay): Delay, + ) -> PaddedIt> { assert!( delay < usize::MAX / 2, "Delay={} should be >=0.", @@ -185,16 +185,16 @@ impl Seq<'_> for &[u8] { }, ); - (it, padding) + PaddedIt { it, padding } } #[inline(always)] fn par_iter_bp_delayed_2( self, context: usize, - delay1: usize, - delay2: usize, - ) -> (impl ExactSizeIterator + Clone, usize) { + Delay(delay1): Delay, + Delay(delay2): Delay, + ) -> PaddedIt> { assert!(delay1 <= delay2, "Delay1 must be at most delay2."); let num_kmers = self.len().saturating_sub(context - 1); @@ -266,7 +266,7 @@ impl Seq<'_> for &[u8] { }, ); - (it, padding) + PaddedIt { it, padding } } // TODO: This is not very optimized. diff --git a/src/ascii_seq.rs b/src/ascii_seq.rs index 0164e6c..1183676 100644 --- a/src/ascii_seq.rs +++ b/src/ascii_seq.rs @@ -1,4 +1,4 @@ -use crate::{intrinsics::transpose, packed_seq::read_slice}; +use crate::{intrinsics::transpose, packed_seq::read_slice, padded_it::ChunkIt}; use super::*; @@ -240,7 +240,7 @@ impl<'s> Seq<'s> for AsciiSeq<'s> { /// /// NOTE: This is only efficient on x86_64 with `BMI2` support for `pext`. #[inline(always)] - fn iter_bp(self) -> impl ExactSizeIterator + Clone { + fn iter_bp(self) -> impl ExactSizeIterator { #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] { let mut cache = 0; @@ -274,7 +274,7 @@ impl<'s> Seq<'s> for AsciiSeq<'s> { /// Iterate the basepairs in the sequence in 8 parallel streams, assuming values in `0..4`. #[inline(always)] - fn par_iter_bp(self, context: usize) -> (impl ExactSizeIterator + Clone, usize) { + fn par_iter_bp(self, context: usize) -> PaddedIt> { let num_kmers = self.len().saturating_sub(context - 1); let n = num_kmers.div_ceil(L); let padding = L * n - num_kmers; @@ -312,15 +312,15 @@ impl<'s> Seq<'s> for AsciiSeq<'s> { }, ); - (it, padding) + PaddedIt { it, padding } } #[inline(always)] fn par_iter_bp_delayed( self, context: usize, - delay: usize, - ) -> (impl ExactSizeIterator + Clone, usize) { + Delay(delay): Delay, + ) -> PaddedIt> { assert!( delay < usize::MAX / 2, "Delay={} should be >=0.", @@ -388,16 +388,16 @@ impl<'s> Seq<'s> for AsciiSeq<'s> { }, ); - (it, padding) + PaddedIt { it, padding } } #[inline(always)] fn par_iter_bp_delayed_2( self, context: usize, - delay1: usize, - delay2: usize, - ) -> (impl ExactSizeIterator + Clone, usize) { + Delay(delay1): Delay, + Delay(delay2): Delay, + ) -> PaddedIt> { assert!(delay1 <= delay2, "Delay1 must be at most delay2."); let num_kmers = self.len().saturating_sub(context - 1); @@ -472,7 +472,7 @@ impl<'s> Seq<'s> for AsciiSeq<'s> { }, ); - (it, padding) + PaddedIt { it, padding } } // TODO: This is not very optimized. diff --git a/src/intrinsics/transpose.rs b/src/intrinsics/transpose.rs index 32ceab6..400d178 100644 --- a/src/intrinsics/transpose.rs +++ b/src/intrinsics/transpose.rs @@ -3,7 +3,7 @@ use wide::u32x4; use wide::u32x8 as S; -/// Transpose a matrix of 8 SIMD vectors. +/// Transpose an 8x8 matrix of 8 `u32x8` SIMD elements. /// // TODO: Investigate other transpose functions mentioned there? pub fn transpose(m: [S; 8]) -> [S; 8] { diff --git a/src/lib.rs b/src/lib.rs index 0820117..d98783d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,10 @@ //! But instead of just returning a single character, they also return a second (and third) character, that is `delay` positions _behind_ the new character (at index `idx - delay`). //! This way, k-mers can be enumerated by setting `delay=k` and then mapping e.g. `|(add, remove)| kmer = (kmer<<2) ^ add ^ (remove << (2*k))`. //! +//! #### Collect +//! +//! Use [`PaddedIt::collect`] and [`PaddedIt::collect_into`] to collect the values returned by a parallel iterator over `u32x8` into a flat `Vec`. +//! //! ## Example //! //! ``` @@ -78,19 +82,42 @@ //! // Iterate over 8 chunks at the same time. //! let seq = b"AAAACCTTGGTTACTG"; // plain ASCII sequence //! // chunks: ^ ^ ^ ^ ^ ^ ^ ^ -//! let (par_iter, padding) = seq.as_slice().par_iter_bp(1); -//! let mut par_iter_u8 = par_iter.map(|x| x.as_array_ref().map(|c| c as u8)); +//! // the `1` argument indicates a 'context' length of 1, +//! // since we're just iterating single characters. +//! let par_iter = seq.as_slice().par_iter_bp(1); +//! let mut par_iter_u8 = par_iter.it.map(|x| x.as_array_ref().map(|c| c as u8)); //! assert_eq!(par_iter_u8.next(), Some(*b"AACTGTAT")); //! assert_eq!(par_iter_u8.next(), Some(*b"AACTGTCG")); //! assert_eq!(par_iter_u8.next(), None); +//! +//! let bases: Vec = seq.as_slice().par_iter_bp(1).collect(); +//! let bases: Vec = bases.into_iter().map(|x| x as u8).collect(); +//! assert_eq!(bases, seq); +//! +//! // With context=3, the chunks overlap by 2 characters, +//! // which can be skipped using `advance`. +//! let bases: Vec = seq.as_slice().par_iter_bp(3).advance(2).collect(); +//! let bases: Vec = bases.into_iter().map(|x| x as u8).collect(); +//! assert_eq!(bases, &seq[2..]); //! ``` //! //! ## Feature flags //! - `epserde` enables `derive(epserde::Epserde)` for `PackedSeqVec` and `AsciiSeqVec`, and adds its `SerializeInner` and `DeserializeInner` traits to `SeqVec`. //! - `pyo3` enables `derive(pyo3::pyclass)` for `PackedSeqVec` and `AsciiSeqVec`. +#[cfg(not(any( + doc, + debug_assertions, + target_feature = "avx2", + target_feature = "neon", + feature = "scalar" +)))] +compile_error!( + "Packed-seq uses AVX2 or NEON SIMD instructions. Compile using `-C target-cpu=native` to get the expected performance. Silence this error using the `scalar` feature." +); + /// Functions with architecture-specific implementations. -mod intrinsics { +pub mod intrinsics { mod transpose; pub use transpose::transpose; } @@ -100,6 +127,7 @@ mod traits; mod ascii; mod ascii_seq; mod packed_seq; +mod padded_it; #[cfg(test)] mod test; @@ -114,7 +142,8 @@ pub use packed_seq::{ complement_base, complement_base_simd, complement_char, pack_char, unpack_base, }; pub use packed_seq::{PackedSeq, PackedSeqVec}; -pub use traits::{Seq, SeqVec}; +pub use padded_it::{Advance, ChunkIt, PaddedIt}; +pub use traits::{Delay, Seq, SeqVec}; // For internal use only. use core::array::from_fn; diff --git a/src/packed_seq.rs b/src/packed_seq.rs index 7af7767..a9b3e42 100644 --- a/src/packed_seq.rs +++ b/src/packed_seq.rs @@ -1,6 +1,6 @@ use traits::Seq; -use crate::intrinsics::transpose; +use crate::{intrinsics::transpose, padded_it::ChunkIt}; use super::*; @@ -212,7 +212,10 @@ impl<'s> Seq<'s> for PackedSeq<'s> { /// Panics if `self` is longer than 64 characters. #[inline(always)] fn as_u128(&self) -> u128 { - assert!(self.len() <= 61, "Sequences >61 long cannot be read with a single unaligned u128 read."); + assert!( + self.len() <= 61, + "Sequences >61 long cannot be read with a single unaligned u128 read." + ); debug_assert!(self.seq.len() <= 17); let mask = u128::MAX >> (128 - 2 * self.len()); @@ -300,29 +303,29 @@ impl<'s> Seq<'s> for PackedSeq<'s> { } #[inline(always)] - fn iter_bp(self) -> impl ExactSizeIterator + Clone { + fn iter_bp(self) -> impl ExactSizeIterator { assert!(self.len <= self.seq.len() * 4); let this = self.normalize(); // read u64 at a time? let mut byte = 0; - let mut it = (0..this.len + this.offset).map( - #[inline(always)] - move |i| { - if i % 4 == 0 { - byte = this.seq[i / 4]; - } - // Shift byte instead of i? - (byte >> (2 * (i % 4))) & 0b11 - }, - ); - it.by_ref().take(this.offset).for_each(drop); - it + (0..this.len + this.offset) + .map( + #[inline(always)] + move |i| { + if i % 4 == 0 { + byte = this.seq[i / 4]; + } + // Shift byte instead of i? + (byte >> (2 * (i % 4))) & 0b11 + }, + ) + .advance(this.offset) } #[inline(always)] - fn par_iter_bp(self, context: usize) -> (impl ExactSizeIterator + Clone, usize) { + fn par_iter_bp(self, context: usize) -> PaddedIt> { #[cfg(target_endian = "big")] panic!("Big endian architectures are not supported."); @@ -354,31 +357,31 @@ impl<'s> Seq<'s> for PackedSeq<'s> { } else { n + context + o - 1 }; - let mut it = (0..par_len).map( - #[inline(always)] - move |i| { - if i % 16 == 0 { - if i % 128 == 0 { - // Read a u256 for each lane containing the next 128 characters. - let data: [u32x8; 8] = from_fn( - #[inline(always)] - |lane| read_slice(this.seq, offsets[lane] + (i / 4)), - ); - *buf = transpose(data); + let it = (0..par_len) + .map( + #[inline(always)] + move |i| { + if i % 16 == 0 { + if i % 128 == 0 { + // Read a u256 for each lane containing the next 128 characters. + let data: [u32x8; 8] = from_fn( + #[inline(always)] + |lane| read_slice(this.seq, offsets[lane] + (i / 4)), + ); + *buf = transpose(data); + } + cur = buf[(i % 128) / 16]; } - cur = buf[(i % 128) / 16]; - } - // Extract the last 2 bits of each character. - let chars = cur & S::splat(0x03); - // Shift remaining characters to the right. - cur = cur >> S::splat(2); - chars - }, - ); - // Drop the first few chars. - it.by_ref().take(o).for_each(drop); + // Extract the last 2 bits of each character. + let chars = cur & S::splat(0x03); + // Shift remaining characters to the right. + cur = cur >> S::splat(2); + chars + }, + ) + .advance(o); - (it, padding) + PaddedIt { it, padding } } /// NOTE: When `self` starts does not start at a byte boundary, the @@ -387,8 +390,8 @@ impl<'s> Seq<'s> for PackedSeq<'s> { fn par_iter_bp_delayed( self, context: usize, - delay: usize, - ) -> (impl ExactSizeIterator + Clone, usize) { + Delay(delay): Delay, + ) -> PaddedIt> { #[cfg(target_endian = "big")] panic!("Big endian architectures are not supported."); @@ -433,51 +436,52 @@ impl<'s> Seq<'s> for PackedSeq<'s> { } else { n + context + o - 1 }; - let mut it = (0..par_len).map( - #[inline(always)] - move |i| { - if i % 16 == 0 { - if i % 128 == 0 { - // Read a u256 for each lane containing the next 128 characters. - let data: [u32x8; 8] = from_fn( - #[inline(always)] - |lane| read_slice(this.seq, offsets[lane] + (i / 4)), - ); - unsafe { - *TryInto::<&mut [u32x8; 8]>::try_into( - buf.get_unchecked_mut(write_idx..write_idx + 8), - ) - .unwrap_unchecked() = transpose(data); - } - if i == 0 { - // Mask out chars before the offset. - let elem = !((1u32 << (2 * o)) - 1); - let mask = S::splat(elem); - buf[write_idx] &= mask; + let it = (0..par_len) + .map( + #[inline(always)] + move |i| { + if i % 16 == 0 { + if i % 128 == 0 { + // Read a u256 for each lane containing the next 128 characters. + let data: [u32x8; 8] = from_fn( + #[inline(always)] + |lane| read_slice(this.seq, offsets[lane] + (i / 4)), + ); + unsafe { + *TryInto::<&mut [u32x8; 8]>::try_into( + buf.get_unchecked_mut(write_idx..write_idx + 8), + ) + .unwrap_unchecked() = transpose(data); + } + if i == 0 { + // Mask out chars before the offset. + let elem = !((1u32 << (2 * o)) - 1); + let mask = S::splat(elem); + buf[write_idx] &= mask; + } } + upcoming = buf[write_idx]; + write_idx += 1; + write_idx &= buf_mask; } - upcoming = buf[write_idx]; - write_idx += 1; - write_idx &= buf_mask; - } - if i % 16 == delay % 16 { - unsafe { assert_unchecked(read_idx < buf.len()) }; - upcoming_d = buf[read_idx]; - read_idx += 1; - read_idx &= buf_mask; - } - // Extract the last 2 bits of each character. - let chars = upcoming & S::splat(0x03); - let chars_d = upcoming_d & S::splat(0x03); - // Shift remaining characters to the right. - upcoming = upcoming >> S::splat(2); - upcoming_d = upcoming_d >> S::splat(2); - (chars, chars_d) - }, - ); - it.by_ref().take(o).for_each(drop); + if i % 16 == delay % 16 { + unsafe { assert_unchecked(read_idx < buf.len()) }; + upcoming_d = buf[read_idx]; + read_idx += 1; + read_idx &= buf_mask; + } + // Extract the last 2 bits of each character. + let chars = upcoming & S::splat(0x03); + let chars_d = upcoming_d & S::splat(0x03); + // Shift remaining characters to the right. + upcoming = upcoming >> S::splat(2); + upcoming_d = upcoming_d >> S::splat(2); + (chars, chars_d) + }, + ) + .advance(o); - (it, padding) + PaddedIt { it, padding } } /// NOTE: When `self` starts does not start at a byte boundary, the @@ -486,9 +490,9 @@ impl<'s> Seq<'s> for PackedSeq<'s> { fn par_iter_bp_delayed_2( self, context: usize, - delay1: usize, - delay2: usize, - ) -> (impl ExactSizeIterator + Clone, usize) { + Delay(delay1): Delay, + Delay(delay2): Delay, + ) -> PaddedIt> { #[cfg(target_endian = "big")] panic!("Big endian architectures are not supported."); @@ -528,59 +532,60 @@ impl<'s> Seq<'s> for PackedSeq<'s> { } else { n + context + o - 1 }; - let mut it = (0..par_len).map( - #[inline(always)] - move |i| { - if i % 16 == 0 { - if i % 128 == 0 { - // Read a u256 for each lane containing the next 128 characters. - let data: [u32x8; 8] = from_fn( - #[inline(always)] - |lane| read_slice(this.seq, offsets[lane] + (i / 4)), - ); - unsafe { - *TryInto::<&mut [u32x8; 8]>::try_into( - buf.get_unchecked_mut(write_idx..write_idx + 8), - ) - .unwrap_unchecked() = transpose(data); - } - if i == 0 { - // Mask out chars before the offset. - let elem = !((1u32 << (2 * o)) - 1); - let mask = S::splat(elem); - buf[write_idx] &= mask; + let it = (0..par_len) + .map( + #[inline(always)] + move |i| { + if i % 16 == 0 { + if i % 128 == 0 { + // Read a u256 for each lane containing the next 128 characters. + let data: [u32x8; 8] = from_fn( + #[inline(always)] + |lane| read_slice(this.seq, offsets[lane] + (i / 4)), + ); + unsafe { + *TryInto::<&mut [u32x8; 8]>::try_into( + buf.get_unchecked_mut(write_idx..write_idx + 8), + ) + .unwrap_unchecked() = transpose(data); + } + if i == 0 { + // Mask out chars before the offset. + let elem = !((1u32 << (2 * o)) - 1); + let mask = S::splat(elem); + buf[write_idx] &= mask; + } } + upcoming = buf[write_idx]; + write_idx += 1; + write_idx &= buf_mask; } - upcoming = buf[write_idx]; - write_idx += 1; - write_idx &= buf_mask; - } - if i % 16 == delay1 % 16 { - unsafe { assert_unchecked(read_idx1 < buf.len()) }; - upcoming_d1 = buf[read_idx1]; - read_idx1 += 1; - read_idx1 &= buf_mask; - } - if i % 16 == delay2 % 16 { - unsafe { assert_unchecked(read_idx2 < buf.len()) }; - upcoming_d2 = buf[read_idx2]; - read_idx2 += 1; - read_idx2 &= buf_mask; - } - // Extract the last 2 bits of each character. - let chars = upcoming & S::splat(0x03); - let chars_d1 = upcoming_d1 & S::splat(0x03); - let chars_d2 = upcoming_d2 & S::splat(0x03); - // Shift remaining characters to the right. - upcoming = upcoming >> S::splat(2); - upcoming_d1 = upcoming_d1 >> S::splat(2); - upcoming_d2 = upcoming_d2 >> S::splat(2); - (chars, chars_d1, chars_d2) - }, - ); - it.by_ref().take(o).for_each(drop); - - (it, padding) + if i % 16 == delay1 % 16 { + unsafe { assert_unchecked(read_idx1 < buf.len()) }; + upcoming_d1 = buf[read_idx1]; + read_idx1 += 1; + read_idx1 &= buf_mask; + } + if i % 16 == delay2 % 16 { + unsafe { assert_unchecked(read_idx2 < buf.len()) }; + upcoming_d2 = buf[read_idx2]; + read_idx2 += 1; + read_idx2 &= buf_mask; + } + // Extract the last 2 bits of each character. + let chars = upcoming & S::splat(0x03); + let chars_d1 = upcoming_d1 & S::splat(0x03); + let chars_d2 = upcoming_d2 & S::splat(0x03); + // Shift remaining characters to the right. + upcoming = upcoming >> S::splat(2); + upcoming_d1 = upcoming_d1 >> S::splat(2); + upcoming_d2 = upcoming_d2 >> S::splat(2); + (chars, chars_d1, chars_d2) + }, + ) + .advance(o); + + PaddedIt { it, padding } } /// Compares 29 characters at a time. diff --git a/src/padded_it.rs b/src/padded_it.rs new file mode 100644 index 0000000..feadc69 --- /dev/null +++ b/src/padded_it.rs @@ -0,0 +1,106 @@ +use crate::intrinsics::transpose; +use std::mem::transmute; +use wide::u32x8; + +/// Trait alias for iterators over multiple chunks in parallel, typically over `u32x8`. +pub trait ChunkIt: ExactSizeIterator {} +impl> ChunkIt for I {} + +/// An iterator over values in multiple SIMD lanes, with a certain amount of `padding` at the end. +/// +/// This type is returned by functions like [`crate::Seq::par_iter_bp`]. +/// It usally contains an iterator over e.g. `u32x8` values or `(u32x8, u32x8)` tuples, +pub struct PaddedIt { + pub it: I, + pub padding: usize, +} + +/// Extension trait to advance an iterator by `n` steps. +/// Used to skip e.g. the first `k-1` values of an iterator over k-mer hasher. +pub trait Advance { + fn advance(self, n: usize) -> Self; +} +impl Advance for I { + /// Advance the iterator by `n` steps, consuming the first `n` values. + #[inline(always)] + fn advance(mut self, n: usize) -> Self { + self.by_ref().take(n).for_each(drop); + self + } +} + +impl PaddedIt { + /// Apply `f` to each element. + #[inline(always)] + pub fn map(self, f: impl FnMut(T) -> T2) -> PaddedIt> + where + I: ChunkIt, + { + PaddedIt { + it: self.it.map(f), + padding: self.padding, + } + } + + /// Advance the iterator by `n` steps, consuming the first `n` values (of each lane). + #[inline(always)] + pub fn advance(mut self, n: usize) -> PaddedIt> + where + I: ChunkIt, + { + self.it = self.it.advance(n); + self + } +} + +impl> PaddedIt { + /// Collect all values of a padded `u32x8`-iterator into a flat vector. + /// Prefer `collect_into` to avoid repeated allocations. + pub fn collect(self) -> Vec { + let mut v = vec![]; + self.collect_into(&mut v); + v + } + + /// Collect all values of a padded `u32x8`-iterator into a flat vector. + /// + /// Implemented by taking 8 elements from each stream, and transposing this SIMD-matrix before writing out the results. + /// The `tail` is appended at the end. + #[inline(always)] + pub fn collect_into(self, out_vec: &mut Vec) { + let PaddedIt { it, padding } = self; + let len = it.len(); + out_vec.resize(len * 8, 0); + + let mut m = [unsafe { transmute([0; 8]) }; 8]; + let mut i = 0; + it.for_each(|x| { + m[i % 8] = x; + if i % 8 == 7 { + let t = transpose(m); + for j in 0..8 { + unsafe { + *out_vec + .get_unchecked_mut(j * len + 8 * (i / 8)..) + .split_first_chunk_mut::<8>() + .unwrap() + .0 = transmute(t[j]); + } + } + } + i += 1; + }); + + // Manually write the unfinished parts of length k=i%8. + let t = transpose(m); + let k = i % 8; + for j in 0..8 { + unsafe { + out_vec[j * len + 8 * (i / 8)..j * len + 8 * (i / 8) + k] + .copy_from_slice(&transmute::<_, [u32; 8]>(t[j])[..k]); + } + } + + out_vec.resize(out_vec.len() - padding, 0); + } +} diff --git a/src/test.rs b/src/test.rs index bd53df2..f556cd5 100644 --- a/src/test.rs +++ b/src/test.rs @@ -167,7 +167,9 @@ fn pack_word() { #[test] fn pack_u128() { - let packed = PackedSeqVec::from_ascii(b"ACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT"); + let packed = PackedSeqVec::from_ascii( + b"ACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT", + ); let slice = packed.slice(0..1); assert_eq!(slice.as_u128(), 0b0000000000000000); let slice = packed.slice(1..2); @@ -336,15 +338,15 @@ fn iter_bp() { #[test] fn par_iter_bp() { let s = PackedSeqVec::from_ascii(b"ACGTAACCGGTTAAACCCGGGTTTAAAAAAAAACGT"); - let (head, padding) = s.as_slice().par_iter_bp(1); - let head = head.collect::>(); + let PaddedIt { it, padding } = s.as_slice().par_iter_bp(1); + let it = it.collect::>(); fn f(x: &[u8; 8]) -> u32x8 { let x = x.map(|x| pack_char(x) as u32); u32x8::from(x) } assert_eq!(padding, 8 * 8 - s.len()); assert_eq!( - head, + it, vec![ f(b"AGCAAAAA"), f(b"CGCACAAA"), @@ -361,8 +363,8 @@ fn par_iter_bp() { #[test] fn par_iter_bp_delayed0() { let s = PackedSeqVec::from_ascii(b"ACGTAACCGGTTAAACCCGGGTTTAAAAAAAAACGT"); - let (head, padding) = s.as_slice().par_iter_bp_delayed(1, 0); - let head = head.collect::>(); + let PaddedIt { it, padding } = s.as_slice().par_iter_bp_delayed(1, Delay(0)); + let it = it.collect::>(); fn f(x: &[u8; 8], y: &[u8; 8]) -> (u32x8, u32x8) { let x = x.map(|x| pack_char(x) as u32); let y = y.map(|x| pack_char(x) as u32); @@ -370,7 +372,7 @@ fn par_iter_bp_delayed0() { } assert_eq!(padding, 8 * 8 - s.len()); assert_eq!( - head, + it, vec![ f(b"AGCAAAAA", b"AGCAAAAA"), f(b"CGCACAAA", b"CGCACAAA"), @@ -387,8 +389,8 @@ fn par_iter_bp_delayed0() { #[test] fn par_iter_bp_delayed1() { let s = PackedSeqVec::from_ascii(b"ACGTAACCGGTTAAACCCGGGTTTAAAAAAAAACGT"); - let (head, padding) = s.as_slice().par_iter_bp_delayed(1, 1); - let head = head.collect::>(); + let PaddedIt { it, padding } = s.as_slice().par_iter_bp_delayed(1, Delay(1)); + let it = it.collect::>(); fn f(x: &[u8; 8], y: &[u8; 8]) -> (u32x8, u32x8) { let x = x.map(|x| pack_char(x) as u32); let y = y.map(|x| pack_char(x) as u32); @@ -396,7 +398,7 @@ fn par_iter_bp_delayed1() { } assert_eq!(padding, 8 * 8 - s.len()); assert_eq!( - head, + it, vec![ f(b"AGCAAAAA", b"AAAAAAAA"), f(b"CGCACAAA", b"AGCAAAAA"), @@ -439,38 +441,38 @@ fn par_iter_bp_fuzz() { let context = random_range(1..=512.min(len).max(1)); eprintln!("CONTEXT: {context:?}"); - let (head, padding) = s.par_iter_bp(context); - let head = head.collect::>(); + let PaddedIt { it, padding } = s.par_iter_bp(context); + let it = it.collect::>(); fn f(x: &[u8; 8]) -> u32x8 { let x = x.map(|x| pack_char(x) as u32); u32x8::from(x) } - let head_len = head.len(); - eprintln!("par it len {head_len}"); + let it_len = it.len(); + eprintln!("par it len {it_len}"); eprintln!("padding: {padding}"); // Test padding len. - assert_eq!(8 * head_len, len + 7 * (context - 1) + padding); + assert_eq!(8 * it_len, len + 7 * (context - 1) + padding); assert!(padding < 32); // Test context overlap. for i in 0..7 { for j in 0..context - 1 { assert_eq!( - head[head_len - (context - 1) + j].as_array_ref()[i], - head[j].as_array_ref()[i + 1], + it[it_len - (context - 1) + j].as_array_ref()[i], + it[j].as_array_ref()[i + 1], "Context check failed at {i} {j}" ); } } - let stride = head_len - (context - 1); + let stride = it_len - (context - 1); eprintln!("stride {stride}"); assert_eq!( - head, - (0..head_len) + it, + (0..it_len) .map(|i| { f(&from_fn(|j| get(seq.0, i + stride * j))) }) .collect::>() ); @@ -499,61 +501,60 @@ fn par_iter_bp_delayed_fuzz() { // let context = random_range(1..=512.min(len).max(1)); let context = 1; - let delay = random_range(0..512); - eprintln!("LEN {len} CONTEXT {context} DELAY {delay}"); - let (head, padding) = s.par_iter_bp_delayed(context, delay); + let delay = Delay(random_range(0..512)); + let PaddedIt { it, padding } = s.par_iter_bp_delayed(context, delay); eprintln!("padding: {padding}"); - let head = head.collect::>(); + let it = it.collect::>(); fn f(x: &[u8; 8], y: &[u8; 8]) -> (u32x8, u32x8) { let x = x.map(|x| pack_char(x) as u32); let y = y.map(|x| pack_char(x) as u32); (u32x8::from(x), u32x8::from(y)) } - let head_len = head.len(); - eprintln!("par it len {head_len}"); + let it_len = it.len(); + eprintln!("par it len {it_len}"); eprintln!("padding: {padding}"); // Test padding len. - assert_eq!(8 * head_len, len + 7 * (context - 1) + padding); + assert_eq!(8 * it_len, len + 7 * (context - 1) + padding); assert!(padding < 32); // Test context overlap. for i in 0..7 { for j in 0..context - 1 { assert_eq!( - head[head_len - (context - 1) + j].0.as_array_ref()[i], - head[j].0.as_array_ref()[i + 1], + it[it_len - (context - 1) + j].0.as_array_ref()[i], + it[j].0.as_array_ref()[i + 1], "Context check failed at {i} {j}" ); } } - let stride = head_len - (context - 1); + let stride = it_len - (context - 1); eprintln!("stride {stride}"); - let ans = (0..head_len) + let ans = (0..it_len) .map(|i| { f( &from_fn(|j| get(seq.0, i + stride * j)), &from_fn(|j| { - if i < delay { + if i < delay.0 { b'A' } else { - get(seq.0, (i + stride * j).wrapping_sub(delay)) + get(seq.0, (i + stride * j).wrapping_sub(delay.0)) } }), ) }) .collect::>(); - if head != ans { - for (i, (x, y)) in head.iter().zip(ans.iter()).enumerate() { + if it != ans { + for (i, (x, y)) in it.iter().zip(ans.iter()).enumerate() { if x != y { - eprintln!("head {i} {x:?} != {y:?}"); + eprintln!("it {i} {x:?} != {y:?}"); } } } - assert!(head == ans); + assert!(it == ans); } } @@ -579,9 +580,10 @@ fn par_iter_bp_delayed2_fuzz() { let delay = random_range(0..512); let delay2 = random_range(delay..=512); eprintln!("LEN {len} CONTEXT {context} DELAY {delay}"); - let (head, padding) = s.par_iter_bp_delayed_2(context, delay, delay2); + let PaddedIt { it, padding } = + s.par_iter_bp_delayed_2(context, Delay(delay), Delay(delay2)); eprintln!("padding: {padding}"); - let head = head.collect::>(); + let it = it.collect::>(); fn f(x: &[u8; 8], y: &[u8; 8], z: &[u8; 8]) -> (u32x8, u32x8, u32x8) { let x = x.map(|x| pack_char(x) as u32); let y = y.map(|x| pack_char(x) as u32); @@ -589,29 +591,29 @@ fn par_iter_bp_delayed2_fuzz() { (u32x8::from(x), u32x8::from(y), u32x8::from(z)) } - let head_len = head.len(); - eprintln!("par it len {head_len}"); + let it_len = it.len(); + eprintln!("par it len {it_len}"); eprintln!("padding: {padding}"); // Test padding len. - assert_eq!(8 * head_len, len + 7 * (context - 1) + padding); + assert_eq!(8 * it_len, len + 7 * (context - 1) + padding); assert!(padding < 32); // Test context overlap. for i in 0..7 { for j in 0..context - 1 { assert_eq!( - head[head_len - (context - 1) + j].0.as_array_ref()[i], - head[j].0.as_array_ref()[i + 1], + it[it_len - (context - 1) + j].0.as_array_ref()[i], + it[j].0.as_array_ref()[i + 1], "Context check failed at {i} {j}" ); } } - let stride = head_len - (context - 1); + let stride = it_len - (context - 1); eprintln!("stride {stride}"); - let ans = (0..head_len) + let ans = (0..it_len) .map(|i| { f( &from_fn(|j| get(seq.0, i + stride * j)), @@ -632,22 +634,22 @@ fn par_iter_bp_delayed2_fuzz() { ) }) .collect::>(); - if head != ans { - for (i, (x, y)) in head.iter().zip(ans.iter()).enumerate() { + if it != ans { + for (i, (x, y)) in it.iter().zip(ans.iter()).enumerate() { if x != y { - eprintln!("head {i} {x:?} != {y:?}"); + eprintln!("it {i} {x:?} != {y:?}"); } } } - assert!(head == ans); + assert!(it == ans); } } #[test] fn par_iter_bp_delayed01() { let s = PackedSeqVec::from_ascii(b"ACGTAACCGGTTAAACCCGGGTTTAAAAAAAAACGT"); - let (head, padding) = s.as_slice().par_iter_bp_delayed_2(1, 0, 1); - let head = head.collect::>(); + let PaddedIt { it, padding } = s.as_slice().par_iter_bp_delayed_2(1, Delay(0), Delay(1)); + let it = it.collect::>(); fn f(x: &[u8; 8], y: &[u8; 8], z: &[u8; 8]) -> (u32x8, u32x8, u32x8) { let x = x.map(|x| pack_char(x) as u32); let y = y.map(|x| pack_char(x) as u32); @@ -656,7 +658,7 @@ fn par_iter_bp_delayed01() { } assert_eq!(padding, 8 * 8 - s.len()); assert_eq!( - head, + it, vec![ f(b"AGCAAAAA", b"AGCAAAAA", b"AAAAAAAA"), f(b"CGCACAAA", b"CGCACAAA", b"AGCAAAAA"), diff --git a/src/traits.rs b/src/traits.rs index 9651b2b..d947f8a 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,7 +1,13 @@ +use crate::{ChunkIt, PaddedIt}; + use super::u32x8; use mem_dbg::{MemDbg, MemSize}; use std::ops::Range; +/// Strong type indicating the delay passed to [`Seq::par_iter_bp_delayed`] and [`Seq::par_iter_bp_delayed_2`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Delay(pub usize); + /// A non-owned slice of characters. /// /// The represented character values are expected to be in `[0, 2^b)`, @@ -99,7 +105,7 @@ pub trait Seq<'s>: Copy + Eq + Ord { } /// Iterate over the `b`-bit characters of the sequence. - fn iter_bp(self) -> impl ExactSizeIterator + Clone; + fn iter_bp(self) -> impl ExactSizeIterator; /// Iterate over 8 chunks of `b`-bit characters of the sequence in parallel. /// @@ -110,7 +116,7 @@ pub trait Seq<'s>: Copy + Eq + Ord { /// When `context>1`, consecutive chunks overlap by `context-1` bases. /// /// Expected to be implemented using SIMD instructions. - fn par_iter_bp(self, context: usize) -> (impl ExactSizeIterator + Clone, usize); + fn par_iter_bp(self, context: usize) -> PaddedIt>; /// Iterate over 8 chunks of the sequence in parallel, returning two characters offset by `delay` positions. /// @@ -126,8 +132,8 @@ pub trait Seq<'s>: Copy + Eq + Ord { fn par_iter_bp_delayed( self, context: usize, - delay: usize, - ) -> (impl ExactSizeIterator + Clone, usize); + delay: Delay, + ) -> PaddedIt>; /// Iterate over 8 chunks of the sequence in parallel, returning three characters: /// the char added, the one `delay` positions before, and the one `delay2` positions before. @@ -146,12 +152,9 @@ pub trait Seq<'s>: Copy + Eq + Ord { fn par_iter_bp_delayed_2( self, context: usize, - delay1: usize, - delay2: usize, - ) -> ( - impl ExactSizeIterator + Clone, - usize, - ); + delay1: Delay, + delay2: Delay, + ) -> PaddedIt>; /// Compare and return the LCP of the two sequences. fn cmp_lcp(&self, other: &Self) -> (std::cmp::Ordering, usize);