simd_sketch/lib.rs
1//! # SimdSketch
2//!
3//! This library provides two types of sequence sketches:
4//! - the classic bottom-`s` sketch;
5//! - the newer bucket sketch, returning the smallest hash in each of `s` buckets.
6//!
7//! See the corresponding [blogpost](https://curiouscoding.nl/posts/simd-sketch/) for more background and an evaluation.
8//!
9//! ## Hash function
10//! All internal hashes are 32 bits. Either a forward-only hash or
11//! reverse-complement-aware (canonical) hash can be used.
12//!
13//! **TODO:** Current we use (canonical) ntHash. This causes some hash-collisions
14//! for `k <= 16`, [which could be avoided](https://curiouscoding.nl/posts/nthash/#is-nthash-injective-on-kmers).
15//!
16//! ## BucketSketch
17//! For classic bottom-sketch, evaluating the similarity is slow because a
18//! merge-sort must be done between the two lists.
19//!
20//! The bucket sketch solves this by partitioning the hashes into `s` partitions.
21//! Previous methods partition into ranges of size `u32::MAX/s`, but here we
22//! partition by remainder mod `s` instead.
23//!
24//! We find the smallest hash for each remainder as the sketch.
25//! To compute the similarity, we can simply use the hamming distance between
26//! two sketches, which is significantly faster.
27//!
28//! The bucket sketch similarity has a very strong one-to-one correlation with the classic bottom-sketch.
29//!
30//! **TODO:** A drawback of this method is that some buckets may remain empty
31//! when the input sequences are not long enough. In that case, _densification_
32//! could be applied, but this is not currently implemented. If you need this, please reach out.
33//!
34//! ## Jaccard similarity
35//! For the bottom sketch, we conceptually estimate similarity as follows:
36//! 1. Find the smallest `s` distinct k-mer hashes in the union of two sketches.
37//! 2. Return the fraction of these k-mers that occurs in both sketches.
38//!
39//! For the bucket sketch, we simply return the fraction of partitions that have
40//! the same k-mer for both sequences.
41//!
42//! ## b-bit sketches
43//!
44//! Instead of storing the full 32-bit hashes, it is sufficient to only store the low bits of each hash.
45//! In practice, `b=8` is usually fine.
46//! When extra fast comparisons are needed, use `b=1` in combination with a 3 to 4x larger `s`.
47//!
48//! ## Usage
49//!
50//! The main entrypoint of this library is the [`Sketcher`] object.
51//! Construct it in either the forward or canonical variant, and give `k` and `s`.
52//! Then call either [`Sketcher::bottom_sketch`] or [`Sketcher::sketch`] on it, and use the
53//! `similarity` functions on the returned [`BottomSketch`] and [`BucketSketch`] objects.
54//!
55//! ```
56//! use packed_seq::SeqVec;
57//!
58//! let k = 31; // Hash all k-mers.
59//! let s = 8192; // Sample 8192 hashes
60//! let b = 8; // Store the bottom 8 bits of each hash.
61//!
62//! // Use `new_rc` for a canonical (reverse-complement aware) hash.
63//! // `new_fwd` uses a plain forward hash instead.
64//! let sketcher = simd_sketch::Sketcher::new_rc(k, s, b);
65//!
66//! // Generate two random sequences of 2M characters.
67//! let n = 2_000_000;
68//! let seq1 = packed_seq::PackedSeqVec::random(n);
69//! let seq2 = packed_seq::PackedSeqVec::random(n);
70//!
71//! // Bottom-sketch variant
72//!
73//! let sketch1: simd_sketch::BottomSketch = sketcher.bottom_sketch(seq1.as_slice());
74//! let sketch2: simd_sketch::BottomSketch = sketcher.bottom_sketch(seq2.as_slice());
75//!
76//! // Value between 0 and 1, estimating the fraction of shared k-mers.
77//! let similarity = sketch1.similarity(&sketch2);
78//!
79//! // Bucket sketch variant
80//!
81//! let sketch1: simd_sketch::BucketSketch = sketcher.sketch(seq1.as_slice());
82//! let sketch2: simd_sketch::BucketSketch = sketcher.sketch(seq2.as_slice());
83//!
84//! // Value between 0 and 1, estimating the fraction of shared k-mers.
85//! let similarity: f32 = sketch1.similarity(&sketch2);
86//! ```
87//!
88//! **TODO:** Currently there is no support yet for merging sketches, or for
89//! sketching multiple sequences into one sketch. It's not hard, I just need to find a good API.
90//! Please reach out if you're interested in this.
91//!
92//! **TODO:** If you would like a binary instead of a library, again, please reach out :)
93//!
94//! ## Implementation notes
95//!
96//! This library works by partitioning the input sequence into 8 chunks,
97//! and processing those in parallel using SIMD.
98//! This is based on the [`packed-seq`](../packed_seq/index.html) and [`simd-minimizers`](../simd_minimizers/index.html) crates.
99//!
100//! For bottom sketch, the largest hash should be around `target = u32::MAX * s / n` (ignoring duplicates).
101//! To ensure a branch-free algorithm, we first collect all hashes up to `bound = 1.5 * target`.
102//! Then we sort the collected hashes. If there are at least `s` left after deduplicating, we return the bottom `s`.
103//! Otherwise, we double the `1.5` multiplication factor and retry. This
104//! factor is cached to make the sketching of multiple genomes more efficient.
105//!
106//! For bucket sketch, we use the same approach, and increase the factor until we find a k-mer hash in every bucket.
107//! In expectation, this needs to collect a fraction around `log(n) * s / n` of hashes, rather than `s / n`.
108//! In practice this doesn't matter much, as the hashing of all input k-mers is the bottleneck,
109//! and the sorting of the small sample of k-mers is relatively fast.
110//!
111//! For bucket sketch we assign each element to its bucket via its remainder modulo `s`.
112//! We compute this efficiently using [fast-mod](https://github.com/lemire/fastmod/blob/master/include/fastmod.h).
113//!
114//! ## Performance
115//!
116//! The sketching throughput of this library is around 2 seconds for a 3GB human genome
117//! (once the scaling factor is large enough to avoid a second pass).
118//! That's typically a few times faster than parsing a Fasta file.
119//!
120//! [BinDash](https://github.com/zhaoxiaofei/bindash) instead takes 180s (90x
121//! more), when running on a single thread.
122//!
123//! Comparing sketches is relatively fast, but can become a bottleneck when there are many input sequences,
124//! since the number of comparisons grows quadratically. In this case, prefer bucket sketch.
125//! As an example, when sketching 5MB bacterial genomes using `s=10000`, each sketch takes 4ms.
126//! Comparing two sketches takes 1.6us.
127//! This starts to be the dominant factor when the number of input sequences is more than 5000.
128
129mod intrinsics;
130
131use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
132
133use packed_seq::{u32x8, Seq};
134use simd_minimizers::private::nthash::NtHasher;
135use tracing::debug;
136
137enum BitSketch {
138 B32(Vec<u32>),
139 B16(Vec<u16>),
140 B8(Vec<u8>),
141 B1(Vec<u64>),
142}
143
144impl BitSketch {
145 fn new(b: usize, vals: Vec<u32>) -> Self {
146 match b {
147 32 => BitSketch::B32(vals),
148 16 => BitSketch::B16(vals.into_iter().map(|x| x as u16).collect()),
149 8 => BitSketch::B8(vals.into_iter().map(|x| x as u8).collect()),
150 1 => BitSketch::B1({
151 assert_eq!(vals.len() % 64, 0);
152 vals.chunks_exact(64)
153 .map(|xs| {
154 xs.iter()
155 .enumerate()
156 .fold(0u64, |bits, (i, x)| bits | (((x & 1) as u64) << i))
157 })
158 .collect()
159 }),
160 _ => panic!("Unsupported bit width. Must be 1 or 8 or 16 or 32."),
161 }
162 }
163}
164
165/// A sketch containing the `s` smallest k-mer hashes.
166pub struct BottomSketch {
167 rc: bool,
168 k: usize,
169 b: usize,
170 bottom: Vec<u32>,
171}
172
173impl BottomSketch {
174 /// Compute the similarity between two `BottomSketch`es.
175 pub fn similarity(&self, other: &Self) -> f32 {
176 assert_eq!(self.rc, other.rc);
177 assert_eq!(self.k, other.k);
178 assert_eq!(self.b, other.b);
179 let a = &self.bottom;
180 let b = &other.bottom;
181 assert_eq!(a.len(), b.len());
182 let mut intersection_size = 0;
183 let mut union_size = 0;
184 let mut i = 0;
185 let mut j = 0;
186 while union_size < a.len() {
187 intersection_size += (a[i] == b[j]) as usize;
188 let di = (a[i] <= b[j]) as usize;
189 let dj = (a[i] >= b[j]) as usize;
190 i += di;
191 j += dj;
192 union_size += 1;
193 }
194
195 return intersection_size as f32 / a.len() as f32;
196 }
197}
198
199/// A sketch containing the smallest k-mer hash for each remainder mod `s`.
200pub struct BucketSketch {
201 rc: bool,
202 k: usize,
203 b: usize,
204 buckets: BitSketch,
205}
206
207impl BucketSketch {
208 /// Compute the similarity between two `BucketSketch`es.
209 pub fn similarity(&self, other: &Self) -> f32 {
210 assert_eq!(self.rc, other.rc);
211 assert_eq!(self.k, other.k);
212 assert_eq!(self.b, other.b);
213 match (&self.buckets, &other.buckets) {
214 (BitSketch::B32(a), BitSketch::B32(b)) => Self::inner_similarity(a, b),
215 (BitSketch::B16(a), BitSketch::B16(b)) => Self::inner_similarity(a, b),
216 (BitSketch::B8(a), BitSketch::B8(b)) => Self::inner_similarity(a, b),
217 (BitSketch::B1(a), BitSketch::B1(b)) => Self::b1_similarity(a, b),
218 _ => panic!("Bit width mismatch"),
219 }
220 }
221 fn inner_similarity<T: Eq>(a: &Vec<T>, b: &Vec<T>) -> f32 {
222 assert_eq!(a.len(), b.len());
223 std::iter::zip(a, b)
224 .map(|(a, b)| (a == b) as u32)
225 .sum::<u32>() as f32
226 / a.len() as f32
227 }
228
229 fn b1_similarity(a: &Vec<u64>, b: &Vec<u64>) -> f32 {
230 assert_eq!(a.len(), b.len());
231 let f = std::iter::zip(a, b)
232 .map(|(a, b)| (*a ^ *b).count_zeros())
233 .sum::<u32>() as f32
234 / (64 * a.len()) as f32;
235 2. * f - 1.
236 }
237}
238
239/// An object containing the sketch parameters.
240///
241/// Contains internal state to optimize the implementation when sketching multiple similar sequences.
242pub struct Sketcher {
243 rc: bool,
244 k: usize,
245 s: usize,
246 b: usize,
247
248 factor: AtomicUsize,
249}
250
251impl Sketcher {
252 /// Default sketcher that very fast at comparisons, but 20% slower at sketching.
253 /// Use for >= 50000 seqs, and safe default when input sequences are > 500'000 characters.
254 ///
255 /// When sequences are < 100'000 characters, inaccuracies may occur due to empty buckets.
256 pub fn default(k: usize) -> Self {
257 Sketcher {
258 rc: true,
259 k,
260 s: 32768,
261 b: 1,
262 factor: 2.into(),
263 }
264 }
265
266 /// Default sketcher that is fast at sketching, but somewhat slower at comparisons.
267 /// Use for <= 5000 seqs, or when input sequences are < 100'000 characters.
268 pub fn default_fast_sketching(k: usize) -> Self {
269 Sketcher {
270 rc: true,
271 k,
272 s: 8192,
273 b: 8,
274 factor: 2.into(),
275 }
276 }
277
278 /// Construct a new forward-only `Sketcher` object.
279 pub fn new_fwd(k: usize, s: usize, b: usize) -> Self {
280 Sketcher {
281 rc: false,
282 k,
283 s,
284 b,
285 factor: 2.into(),
286 }
287 }
288
289 /// Construct a new reverse-complement-aware `Sketcher` object.
290 pub fn new_rc(k: usize, s: usize, b: usize) -> Self {
291 Sketcher {
292 rc: true,
293 k,
294 s,
295 b,
296 factor: 2.into(),
297 }
298 }
299}
300
301impl Sketcher {
302 /// Return the `s` smallest `u32` k-mer hashes.
303 /// Prefer [`Sketcher::sketch`] instead, which is much faster and just as
304 /// accurate when input sequences are not too short.
305 pub fn bottom_sketch<'s, S: Seq<'s>>(&self, seq: S) -> BottomSketch {
306 // Iterate all kmers and compute 32bit nthashes.
307 let n = seq.len();
308 let mut out = vec![];
309 loop {
310 let target = u32::MAX as usize / n * self.s;
311 let bound =
312 (target.saturating_mul(self.factor.load(SeqCst))).min(u32::MAX as usize) as u32;
313
314 self.collect_up_to_bound(seq, bound, &mut out);
315
316 if bound == u32::MAX || out.len() >= self.s {
317 out.sort_unstable();
318 out.dedup();
319 if bound == u32::MAX || out.len() >= self.s {
320 out.resize(self.s, u32::MAX);
321
322 break BottomSketch {
323 rc: self.rc,
324 k: self.k,
325 b: self.b,
326 bottom: out,
327 };
328 }
329 }
330 self.factor
331 .fetch_add((self.factor.load(SeqCst) + 1) / 2, SeqCst);
332 debug!("Increase factor to {}", self.factor.load(SeqCst));
333 }
334 }
335
336 /// s-buckets sketch. Splits the hashes into `s` buckets and returns the smallest hash per bucket.
337 /// Buckets are determined via the remainder mod `s`.
338 pub fn sketch<'s, S: Seq<'s>>(&self, seq: S) -> BucketSketch {
339 // Iterate all kmers and compute 32bit nthashes.
340 let n = seq.len();
341 let mut out = vec![];
342 let mut buckets = vec![u32::MAX; self.s];
343 loop {
344 let target = u32::MAX as usize / n * self.s;
345 let bound =
346 (target.saturating_mul(self.factor.load(SeqCst))).min(u32::MAX as usize) as u32;
347
348 self.collect_up_to_bound(seq, bound, &mut out);
349
350 if bound == u32::MAX || out.len() >= self.s {
351 let m = FM32::new(self.s as u32);
352 for &hash in &out {
353 let bucket = m.fastmod(hash);
354 buckets[bucket] = buckets[bucket].min(hash);
355 }
356 let mut empty = 0;
357 for &x in &buckets {
358 if x == u32::MAX {
359 empty += 1;
360 }
361 }
362 if bound == u32::MAX || empty == 0 {
363 break BucketSketch {
364 rc: self.rc,
365 k: self.k,
366 b: self.b,
367 buckets: BitSketch::new(
368 self.b,
369 buckets.into_iter().map(|x| m.fastdiv(x) as u32).collect(),
370 ),
371 };
372 }
373 }
374 self.factor
375 .fetch_add((self.factor.load(SeqCst) + 1) / 2, SeqCst);
376 debug!("Increase factor to {}", self.factor.load(SeqCst));
377 }
378 }
379 fn collect_up_to_bound<'s, S: Seq<'s>>(&self, seq: S, bound: u32, out: &mut Vec<u32>) {
380 if self.rc {
381 collect_up_to_bound_generic::<true, S>(seq, self.k, bound, out);
382 } else {
383 collect_up_to_bound_generic::<false, S>(seq, self.k, bound, out);
384 }
385 }
386}
387
388fn collect_up_to_bound_generic<'s, const RC: bool, S: Seq<'s>>(
389 seq: S,
390 k: usize,
391 bound: u32,
392 out: &mut Vec<u32>,
393) {
394 let simd_bound = u32x8::splat(bound);
395
396 let (hashes_head, hashes_tail) =
397 simd_minimizers::private::nthash::nthash_seq_simd::<RC, S, NtHasher>(seq, k, 1);
398
399 out.clear();
400 let mut write_idx = 0;
401 for hashes in hashes_head {
402 let mask = hashes.cmp_lt(simd_bound);
403 if write_idx + 8 >= out.len() {
404 out.resize(write_idx * 3 / 2 + 8, 0);
405 }
406 unsafe { intrinsics::append_from_mask(hashes, mask, out, &mut write_idx) };
407 }
408
409 out.resize(write_idx, 0);
410
411 for hash in hashes_tail {
412 if hash <= bound {
413 out.push(hash);
414 }
415 }
416}
417
418/// FastMod32, using the low 32 bits of the hash.
419/// Taken from https://github.com/lemire/fastmod/blob/master/include/fastmod.h
420#[derive(Copy, Clone, Debug)]
421struct FM32 {
422 d: u64,
423 m: u64,
424}
425impl FM32 {
426 fn new(d: u32) -> Self {
427 Self {
428 d: d as u64,
429 m: u64::MAX / d as u64 + 1,
430 }
431 }
432 fn fastmod(self, h: u32) -> usize {
433 let lowbits = self.m.wrapping_mul(h as u64);
434 ((lowbits as u128 * self.d as u128) >> 64) as usize
435 }
436 fn fastdiv(self, h: u32) -> usize {
437 ((self.m as u128 * h as u128) >> 64) as u32 as usize
438 }
439}
440
441#[cfg(test)]
442#[test]
443fn test() {
444 use packed_seq::SeqVec;
445 let b = 16;
446
447 let k = 31;
448 for n in 31..100 {
449 let s = n - k + 1;
450 let seq = packed_seq::PackedSeqVec::random(n);
451 let sketcher = crate::Sketcher::new_fwd(k, s, b);
452 let bottom = sketcher.bottom_sketch(seq.as_slice()).bottom;
453 assert_eq!(bottom.len(), s);
454 assert!(bottom.is_sorted());
455
456 let s = s.min(10);
457 let seq = packed_seq::PackedSeqVec::random(n);
458 let sketcher = crate::Sketcher::new_fwd(k, s, b);
459 let bottom = sketcher.bottom_sketch(seq.as_slice()).bottom;
460 assert_eq!(bottom.len(), s);
461 assert!(bottom.is_sorted());
462 }
463}
464
465#[cfg(test)]
466#[test]
467fn rc() {
468 use packed_seq::SeqVec;
469
470 let b = 32;
471 for k in (0..10).map(|_| rand::random_range(1..100)) {
472 for n in (0..10).map(|_| rand::random_range(k..1000)) {
473 for s in (0..10).map(|_| rand::random_range(0..n - k + 1)) {
474 let seq = packed_seq::AsciiSeqVec::random(n);
475 let sketcher = crate::Sketcher::new_rc(k, s, b);
476 let bottom = sketcher.bottom_sketch(seq.as_slice()).bottom;
477 assert_eq!(bottom.len(), s);
478 assert!(bottom.is_sorted());
479
480 let seq_rc = packed_seq::AsciiSeqVec::from_ascii(
481 &seq.seq
482 .iter()
483 .rev()
484 .map(|c| packed_seq::complement_char(*c))
485 .collect::<Vec<_>>(),
486 );
487
488 let bottom_rc = sketcher.bottom_sketch(seq_rc.as_slice()).bottom;
489 assert_eq!(bottom, bottom_rc);
490 }
491 }
492 }
493}