simd_sketch/
lib.rs

1//! # SimdSketch
2//!
3//! This library provides two types of sequence sketches:
4//! - the classic bottom-`s` sketch;
5//! - the newer bucket sketch, returning the smallest hash in each of `s` buckets.
6//!
7//! See the corresponding [blogpost](https://curiouscoding.nl/posts/simd-sketch/) for more background and an evaluation.
8//!
9//! ## Hash function
10//! All internal hashes are 32 bits. Either a forward-only hash or
11//! reverse-complement-aware (canonical) hash can be used.
12//!
13//! **TODO:** Current we use (canonical) ntHash. This causes some hash-collisions
14//! for `k <= 16`, [which could be avoided](https://curiouscoding.nl/posts/nthash/#is-nthash-injective-on-kmers).
15//!
16//! ## BucketSketch
17//! For classic bottom-sketch, evaluating the similarity is slow because a
18//! merge-sort must be done between the two lists.
19//!
20//! The bucket sketch solves this by partitioning the hashes into `s` partitions.
21//! Previous methods partition into ranges of size `u32::MAX/s`, but here we
22//! partition by remainder mod `s` instead.
23//!
24//! We find the smallest hash for each remainder as the sketch.
25//! To compute the similarity, we can simply use the hamming distance between
26//! two sketches, which is significantly faster.
27//!
28//! The bucket sketch similarity has a very strong one-to-one correlation with the classic bottom-sketch.
29//!
30//! **TODO:** A drawback of this method is that some buckets may remain empty
31//! when the input sequences are not long enough.  In that case, _densification_
32//! could be applied, but this is not currently implemented. If you need this, please reach out.
33//!
34//! ## Jaccard similarity
35//! For the bottom sketch, we conceptually estimate similarity as follows:
36//! 1. Find the smallest `s` distinct k-mer hashes in the union of two sketches.
37//! 2. Return the fraction of these k-mers that occurs in both sketches.
38//!
39//! For the bucket sketch, we simply return the fraction of partitions that have
40//! the same k-mer for both sequences.
41//!
42//! ## b-bit sketches
43//!
44//! Instead of storing the full 32-bit hashes, it is sufficient to only store the low bits of each hash.
45//! In practice, `b=8` is usually fine.
46//! When extra fast comparisons are needed, use `b=1` in combination with a 3 to 4x larger `s`.
47//!
48//! ## Usage
49//!
50//! The main entrypoint of this library is the [`Sketcher`] object.
51//! Construct it in either the forward or canonical variant, and give `k` and `s`.
52//! Then call either [`Sketcher::bottom_sketch`] or [`Sketcher::sketch`] on it, and use the
53//! `similarity` functions on the returned [`BottomSketch`] and [`BucketSketch`] objects.
54//!
55//! ```
56//! use packed_seq::SeqVec;
57//!
58//! let k = 31;   // Hash all k-mers.
59//! let s = 8192; // Sample 8192 hashes
60//! let b = 8;    // Store the bottom 8 bits of each hash.
61//!
62//! // Use `new_rc` for a canonical (reverse-complement aware) hash.
63//! // `new_fwd` uses a plain forward hash instead.
64//! let sketcher = simd_sketch::Sketcher::new_rc(k, s, b);
65//!
66//! // Generate two random sequences of 2M characters.
67//! let n = 2_000_000;
68//! let seq1 = packed_seq::PackedSeqVec::random(n);
69//! let seq2 = packed_seq::PackedSeqVec::random(n);
70//!
71//! // Bottom-sketch variant
72//!
73//! let sketch1: simd_sketch::BottomSketch = sketcher.bottom_sketch(seq1.as_slice());
74//! let sketch2: simd_sketch::BottomSketch = sketcher.bottom_sketch(seq2.as_slice());
75//!
76//! // Value between 0 and 1, estimating the fraction of shared k-mers.
77//! let similarity = sketch1.similarity(&sketch2);
78//!
79//! // Bucket sketch variant
80//!
81//! let sketch1: simd_sketch::BucketSketch = sketcher.sketch(seq1.as_slice());
82//! let sketch2: simd_sketch::BucketSketch = sketcher.sketch(seq2.as_slice());
83//!
84//! // Value between 0 and 1, estimating the fraction of shared k-mers.
85//! let similarity: f32 = sketch1.similarity(&sketch2);
86//! ```
87//!
88//! **TODO:** Currently there is no support yet for merging sketches, or for
89//! sketching multiple sequences into one sketch. It's not hard, I just need to find a good API.
90//! Please reach out if you're interested in this.
91//!
92//! **TODO:** If you would like a binary instead of a library, again, please reach out :)
93//!
94//! ## Implementation notes
95//!
96//! This library works by partitioning the input sequence into 8 chunks,
97//! and processing those in parallel using SIMD.
98//! This is based on the [`packed-seq`](../packed_seq/index.html) and [`simd-minimizers`](../simd_minimizers/index.html) crates.
99//!
100//! For bottom sketch, the largest hash should be around `target = u32::MAX * s / n` (ignoring duplicates).
101//! To ensure a branch-free algorithm, we first collect all hashes up to `bound = 1.5 * target`.
102//! Then we sort the collected hashes. If there are at least `s` left after deduplicating, we return the bottom `s`.
103//! Otherwise, we double the `1.5` multiplication factor and retry. This
104//! factor is cached to make the sketching of multiple genomes more efficient.
105//!
106//! For bucket sketch, we use the same approach, and increase the factor until we find a k-mer hash in every bucket.
107//! In expectation, this needs to collect a fraction around `log(n) * s / n` of hashes, rather than `s / n`.
108//! In practice this doesn't matter much, as the hashing of all input k-mers is the bottleneck,
109//! and the sorting of the small sample of k-mers is relatively fast.
110//!
111//! For bucket sketch we assign each element to its bucket via its remainder modulo `s`.
112//! We compute this efficiently using [fast-mod](https://github.com/lemire/fastmod/blob/master/include/fastmod.h).
113//!
114//! ## Performance
115//!
116//! The sketching throughput of this library is around 2 seconds for a 3GB human genome
117//! (once the scaling factor is large enough to avoid a second pass).
118//! That's typically a few times faster than parsing a Fasta file.
119//!
120//! [BinDash](https://github.com/zhaoxiaofei/bindash) instead takes 180s (90x
121//! more), when running on a single thread.
122//!
123//! Comparing sketches is relatively fast, but can become a bottleneck when there are many input sequences,
124//! since the number of comparisons grows quadratically. In this case, prefer bucket sketch.
125//! As an example, when sketching 5MB bacterial genomes using `s=10000`, each sketch takes 4ms.
126//! Comparing two sketches takes 1.6us.
127//! This starts to be the dominant factor when the number of input sequences is more than 5000.
128
129mod intrinsics;
130
131use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
132
133use packed_seq::{u32x8, Seq};
134use simd_minimizers::private::nthash::NtHasher;
135use tracing::debug;
136
137enum BitSketch {
138    B32(Vec<u32>),
139    B16(Vec<u16>),
140    B8(Vec<u8>),
141    B1(Vec<u64>),
142}
143
144impl BitSketch {
145    fn new(b: usize, vals: Vec<u32>) -> Self {
146        match b {
147            32 => BitSketch::B32(vals),
148            16 => BitSketch::B16(vals.into_iter().map(|x| x as u16).collect()),
149            8 => BitSketch::B8(vals.into_iter().map(|x| x as u8).collect()),
150            1 => BitSketch::B1({
151                assert_eq!(vals.len() % 64, 0);
152                vals.chunks_exact(64)
153                    .map(|xs| {
154                        xs.iter()
155                            .enumerate()
156                            .fold(0u64, |bits, (i, x)| bits | (((x & 1) as u64) << i))
157                    })
158                    .collect()
159            }),
160            _ => panic!("Unsupported bit width. Must be 1 or 8 or 16 or 32."),
161        }
162    }
163}
164
165/// A sketch containing the `s` smallest k-mer hashes.
166pub struct BottomSketch {
167    rc: bool,
168    k: usize,
169    b: usize,
170    bottom: Vec<u32>,
171}
172
173impl BottomSketch {
174    /// Compute the similarity between two `BottomSketch`es.
175    pub fn similarity(&self, other: &Self) -> f32 {
176        assert_eq!(self.rc, other.rc);
177        assert_eq!(self.k, other.k);
178        assert_eq!(self.b, other.b);
179        let a = &self.bottom;
180        let b = &other.bottom;
181        assert_eq!(a.len(), b.len());
182        let mut intersection_size = 0;
183        let mut union_size = 0;
184        let mut i = 0;
185        let mut j = 0;
186        while union_size < a.len() {
187            intersection_size += (a[i] == b[j]) as usize;
188            let di = (a[i] <= b[j]) as usize;
189            let dj = (a[i] >= b[j]) as usize;
190            i += di;
191            j += dj;
192            union_size += 1;
193        }
194
195        return intersection_size as f32 / a.len() as f32;
196    }
197}
198
199/// A sketch containing the smallest k-mer hash for each remainder mod `s`.
200pub struct BucketSketch {
201    rc: bool,
202    k: usize,
203    b: usize,
204    buckets: BitSketch,
205}
206
207impl BucketSketch {
208    /// Compute the similarity between two `BucketSketch`es.
209    pub fn similarity(&self, other: &Self) -> f32 {
210        assert_eq!(self.rc, other.rc);
211        assert_eq!(self.k, other.k);
212        assert_eq!(self.b, other.b);
213        match (&self.buckets, &other.buckets) {
214            (BitSketch::B32(a), BitSketch::B32(b)) => Self::inner_similarity(a, b),
215            (BitSketch::B16(a), BitSketch::B16(b)) => Self::inner_similarity(a, b),
216            (BitSketch::B8(a), BitSketch::B8(b)) => Self::inner_similarity(a, b),
217            (BitSketch::B1(a), BitSketch::B1(b)) => Self::b1_similarity(a, b),
218            _ => panic!("Bit width mismatch"),
219        }
220    }
221    fn inner_similarity<T: Eq>(a: &Vec<T>, b: &Vec<T>) -> f32 {
222        assert_eq!(a.len(), b.len());
223        std::iter::zip(a, b)
224            .map(|(a, b)| (a == b) as u32)
225            .sum::<u32>() as f32
226            / a.len() as f32
227    }
228
229    fn b1_similarity(a: &Vec<u64>, b: &Vec<u64>) -> f32 {
230        assert_eq!(a.len(), b.len());
231        let f = std::iter::zip(a, b)
232            .map(|(a, b)| (*a ^ *b).count_zeros())
233            .sum::<u32>() as f32
234            / (64 * a.len()) as f32;
235        2. * f - 1.
236    }
237}
238
239/// An object containing the sketch parameters.
240///
241/// Contains internal state to optimize the implementation when sketching multiple similar sequences.
242pub struct Sketcher {
243    rc: bool,
244    k: usize,
245    s: usize,
246    b: usize,
247
248    factor: AtomicUsize,
249}
250
251impl Sketcher {
252    /// Default sketcher that very fast at comparisons, but 20% slower at sketching.
253    /// Use for >= 50000 seqs, and safe default when input sequences are > 500'000 characters.
254    ///
255    /// When sequences are < 100'000 characters, inaccuracies may occur due to empty buckets.
256    pub fn default(k: usize) -> Self {
257        Sketcher {
258            rc: true,
259            k,
260            s: 32768,
261            b: 1,
262            factor: 2.into(),
263        }
264    }
265
266    /// Default sketcher that is fast at sketching, but somewhat slower at comparisons.
267    /// Use for <= 5000 seqs, or when input sequences are < 100'000 characters.
268    pub fn default_fast_sketching(k: usize) -> Self {
269        Sketcher {
270            rc: true,
271            k,
272            s: 8192,
273            b: 8,
274            factor: 2.into(),
275        }
276    }
277
278    /// Construct a new forward-only `Sketcher` object.
279    pub fn new_fwd(k: usize, s: usize, b: usize) -> Self {
280        Sketcher {
281            rc: false,
282            k,
283            s,
284            b,
285            factor: 2.into(),
286        }
287    }
288
289    /// Construct a new reverse-complement-aware `Sketcher` object.
290    pub fn new_rc(k: usize, s: usize, b: usize) -> Self {
291        Sketcher {
292            rc: true,
293            k,
294            s,
295            b,
296            factor: 2.into(),
297        }
298    }
299}
300
301impl Sketcher {
302    /// Return the `s` smallest `u32` k-mer hashes.
303    /// Prefer [`Sketcher::sketch`] instead, which is much faster and just as
304    /// accurate when input sequences are not too short.
305    pub fn bottom_sketch<'s, S: Seq<'s>>(&self, seq: S) -> BottomSketch {
306        // Iterate all kmers and compute 32bit nthashes.
307        let n = seq.len();
308        let mut out = vec![];
309        loop {
310            let target = u32::MAX as usize / n * self.s;
311            let bound =
312                (target.saturating_mul(self.factor.load(SeqCst))).min(u32::MAX as usize) as u32;
313
314            self.collect_up_to_bound(seq, bound, &mut out);
315
316            if bound == u32::MAX || out.len() >= self.s {
317                out.sort_unstable();
318                out.dedup();
319                if bound == u32::MAX || out.len() >= self.s {
320                    out.resize(self.s, u32::MAX);
321
322                    break BottomSketch {
323                        rc: self.rc,
324                        k: self.k,
325                        b: self.b,
326                        bottom: out,
327                    };
328                }
329            }
330            self.factor
331                .fetch_add((self.factor.load(SeqCst) + 1) / 2, SeqCst);
332            debug!("Increase factor to {}", self.factor.load(SeqCst));
333        }
334    }
335
336    /// s-buckets sketch. Splits the hashes into `s` buckets and returns the smallest hash per bucket.
337    /// Buckets are determined via the remainder mod `s`.
338    pub fn sketch<'s, S: Seq<'s>>(&self, seq: S) -> BucketSketch {
339        // Iterate all kmers and compute 32bit nthashes.
340        let n = seq.len();
341        let mut out = vec![];
342        let mut buckets = vec![u32::MAX; self.s];
343        loop {
344            let target = u32::MAX as usize / n * self.s;
345            let bound =
346                (target.saturating_mul(self.factor.load(SeqCst))).min(u32::MAX as usize) as u32;
347
348            self.collect_up_to_bound(seq, bound, &mut out);
349
350            if bound == u32::MAX || out.len() >= self.s {
351                let m = FM32::new(self.s as u32);
352                for &hash in &out {
353                    let bucket = m.fastmod(hash);
354                    buckets[bucket] = buckets[bucket].min(hash);
355                }
356                let mut empty = 0;
357                for &x in &buckets {
358                    if x == u32::MAX {
359                        empty += 1;
360                    }
361                }
362                if bound == u32::MAX || empty == 0 {
363                    break BucketSketch {
364                        rc: self.rc,
365                        k: self.k,
366                        b: self.b,
367                        buckets: BitSketch::new(
368                            self.b,
369                            buckets.into_iter().map(|x| m.fastdiv(x) as u32).collect(),
370                        ),
371                    };
372                }
373            }
374            self.factor
375                .fetch_add((self.factor.load(SeqCst) + 1) / 2, SeqCst);
376            debug!("Increase factor to {}", self.factor.load(SeqCst));
377        }
378    }
379    fn collect_up_to_bound<'s, S: Seq<'s>>(&self, seq: S, bound: u32, out: &mut Vec<u32>) {
380        if self.rc {
381            collect_up_to_bound_generic::<true, S>(seq, self.k, bound, out);
382        } else {
383            collect_up_to_bound_generic::<false, S>(seq, self.k, bound, out);
384        }
385    }
386}
387
388fn collect_up_to_bound_generic<'s, const RC: bool, S: Seq<'s>>(
389    seq: S,
390    k: usize,
391    bound: u32,
392    out: &mut Vec<u32>,
393) {
394    let simd_bound = u32x8::splat(bound);
395
396    let (hashes_head, hashes_tail) =
397        simd_minimizers::private::nthash::nthash_seq_simd::<RC, S, NtHasher>(seq, k, 1);
398
399    out.clear();
400    let mut write_idx = 0;
401    for hashes in hashes_head {
402        let mask = hashes.cmp_lt(simd_bound);
403        if write_idx + 8 >= out.len() {
404            out.resize(write_idx * 3 / 2 + 8, 0);
405        }
406        unsafe { intrinsics::append_from_mask(hashes, mask, out, &mut write_idx) };
407    }
408
409    out.resize(write_idx, 0);
410
411    for hash in hashes_tail {
412        if hash <= bound {
413            out.push(hash);
414        }
415    }
416}
417
418/// FastMod32, using the low 32 bits of the hash.
419/// Taken from https://github.com/lemire/fastmod/blob/master/include/fastmod.h
420#[derive(Copy, Clone, Debug)]
421struct FM32 {
422    d: u64,
423    m: u64,
424}
425impl FM32 {
426    fn new(d: u32) -> Self {
427        Self {
428            d: d as u64,
429            m: u64::MAX / d as u64 + 1,
430        }
431    }
432    fn fastmod(self, h: u32) -> usize {
433        let lowbits = self.m.wrapping_mul(h as u64);
434        ((lowbits as u128 * self.d as u128) >> 64) as usize
435    }
436    fn fastdiv(self, h: u32) -> usize {
437        ((self.m as u128 * h as u128) >> 64) as u32 as usize
438    }
439}
440
441#[cfg(test)]
442#[test]
443fn test() {
444    use packed_seq::SeqVec;
445    let b = 16;
446
447    let k = 31;
448    for n in 31..100 {
449        let s = n - k + 1;
450        let seq = packed_seq::PackedSeqVec::random(n);
451        let sketcher = crate::Sketcher::new_fwd(k, s, b);
452        let bottom = sketcher.bottom_sketch(seq.as_slice()).bottom;
453        assert_eq!(bottom.len(), s);
454        assert!(bottom.is_sorted());
455
456        let s = s.min(10);
457        let seq = packed_seq::PackedSeqVec::random(n);
458        let sketcher = crate::Sketcher::new_fwd(k, s, b);
459        let bottom = sketcher.bottom_sketch(seq.as_slice()).bottom;
460        assert_eq!(bottom.len(), s);
461        assert!(bottom.is_sorted());
462    }
463}
464
465#[cfg(test)]
466#[test]
467fn rc() {
468    use packed_seq::SeqVec;
469
470    let b = 32;
471    for k in (0..10).map(|_| rand::random_range(1..100)) {
472        for n in (0..10).map(|_| rand::random_range(k..1000)) {
473            for s in (0..10).map(|_| rand::random_range(0..n - k + 1)) {
474                let seq = packed_seq::AsciiSeqVec::random(n);
475                let sketcher = crate::Sketcher::new_rc(k, s, b);
476                let bottom = sketcher.bottom_sketch(seq.as_slice()).bottom;
477                assert_eq!(bottom.len(), s);
478                assert!(bottom.is_sorted());
479
480                let seq_rc = packed_seq::AsciiSeqVec::from_ascii(
481                    &seq.seq
482                        .iter()
483                        .rev()
484                        .map(|c| packed_seq::complement_char(*c))
485                        .collect::<Vec<_>>(),
486                );
487
488                let bottom_rc = sketcher.bottom_sketch(seq_rc.as_slice()).bottom;
489                assert_eq!(bottom, bottom_rc);
490            }
491        }
492    }
493}