simd_minimizers/
nthash.rs

1//! NtHash the kmers in a sequence.
2use std::array::from_fn;
3
4use super::intrinsics;
5use crate::S;
6use packed_seq::complement_base;
7use packed_seq::Seq;
8use wide::u32x8;
9
10pub trait Captures<U> {}
11impl<T: ?Sized, U> Captures<U> for T {}
12
13/// Original ntHash seed values.
14// TODO: Update to guarantee unique hash values for k<=16?
15const HASHES_F: [u32; 4] = [
16    0x3c8b_fbb3_95c6_0474u64 as u32,
17    0x3193_c185_62a0_2b4cu64 as u32,
18    0x2032_3ed0_8257_2324u64 as u32,
19    0x2955_49f5_4be2_4456u64 as u32,
20];
21
22pub trait CharHasher: Clone {
23    fn new_from_val<'s, SEQ: Seq<'s>>(k: usize, _: SEQ) -> Self {
24        Self::new::<SEQ>(k)
25    }
26    fn new<'s, SEQ: Seq<'s>>(k: usize) -> Self;
27    fn f(&self, b: u8) -> u32;
28    fn c(&self, b: u8) -> u32;
29    fn f_rot(&self, b: u8) -> u32;
30    fn c_rot(&self, b: u8) -> u32;
31    fn simd_f(&self, b: u32x8) -> u32x8;
32    fn simd_c(&self, b: u32x8) -> u32x8;
33    fn simd_f_rot(&self, b: u32x8) -> u32x8;
34    fn simd_c_rot(&self, b: u32x8) -> u32x8;
35}
36
37#[derive(Clone)]
38pub struct NtHasher {
39    f: [u32; 4],
40    c: [u32; 4],
41    f_rot: [u32; 4],
42    c_rot: [u32; 4],
43    simd_f: u32x8,
44    simd_c: u32x8,
45    simd_f_rot: u32x8,
46    simd_c_rot: u32x8,
47}
48
49impl CharHasher for NtHasher {
50    fn new<'s, SEQ: Seq<'s>>(k: usize) -> Self {
51        assert_eq!(SEQ::BITS_PER_CHAR, 2);
52
53        let rot = k as u32 - 1;
54        let f = HASHES_F;
55        let c = from_fn(|i| HASHES_F[complement_base(i as u8) as usize]);
56        let f_rot = f.map(|h| h.rotate_left(rot));
57        let c_rot = c.map(|h| h.rotate_left(rot));
58        let idx = [0, 1, 2, 3, 0, 1, 2, 3];
59        let simd_f = idx.map(|i| f[i]).into();
60        let simd_c = idx.map(|i| c[i]).into();
61        let simd_f_rot = idx.map(|i| f_rot[i]).into();
62        let simd_c_rot = idx.map(|i| c_rot[i]).into();
63
64        Self {
65            f,
66            c,
67            f_rot,
68            c_rot,
69            simd_f,
70            simd_c,
71            simd_f_rot,
72            simd_c_rot,
73        }
74    }
75
76    fn f(&self, b: u8) -> u32 {
77        unsafe { *self.f.get_unchecked(b as usize) }
78    }
79    fn c(&self, b: u8) -> u32 {
80        unsafe { *self.c.get_unchecked(b as usize) }
81    }
82    fn f_rot(&self, b: u8) -> u32 {
83        unsafe { *self.f_rot.get_unchecked(b as usize) }
84    }
85    fn c_rot(&self, b: u8) -> u32 {
86        unsafe { *self.c_rot.get_unchecked(b as usize) }
87    }
88
89    fn simd_f(&self, b: u32x8) -> u32x8 {
90        intrinsics::table_lookup(self.simd_f, b)
91    }
92    fn simd_c(&self, b: u32x8) -> u32x8 {
93        intrinsics::table_lookup(self.simd_c, b)
94    }
95    fn simd_f_rot(&self, b: u32x8) -> u32x8 {
96        intrinsics::table_lookup(self.simd_f_rot, b)
97    }
98    fn simd_c_rot(&self, b: u32x8) -> u32x8 {
99        intrinsics::table_lookup(self.simd_c_rot, b)
100    }
101}
102
103#[derive(Clone)]
104pub struct MulHasher {
105    rot: u32,
106}
107
108// Mixing constant.
109const C: u32 = 0x517cc1b727220a95u64 as u32;
110
111impl CharHasher for MulHasher {
112    fn new<'s, SEQ: Seq<'s>>(k: usize) -> Self {
113        MulHasher {
114            rot: (k as u32 - 1) % 32,
115        }
116    }
117
118    fn f(&self, b: u8) -> u32 {
119        (b as u32).wrapping_mul(C)
120    }
121    fn c(&self, b: u8) -> u32 {
122        (complement_base(b) as u32).wrapping_mul(C)
123    }
124    fn f_rot(&self, b: u8) -> u32 {
125        (b as u32).wrapping_mul(C).rotate_left(self.rot)
126    }
127    fn c_rot(&self, b: u8) -> u32 {
128        (complement_base(b) as u32)
129            .wrapping_mul(C)
130            .rotate_left(self.rot)
131    }
132
133    fn simd_f(&self, b: u32x8) -> u32x8 {
134        b * C.into()
135    }
136    fn simd_c(&self, b: u32x8) -> u32x8 {
137        packed_seq::complement_base_simd(b) * C.into()
138    }
139    fn simd_f_rot(&self, b: u32x8) -> u32x8 {
140        let r = b * C.into();
141        (r << self.rot) | (r >> (32 - self.rot))
142    }
143    fn simd_c_rot(&self, b: u32x8) -> u32x8 {
144        let r = packed_seq::complement_base_simd(b) * C.into();
145        (r << self.rot) | (r >> (32 - self.rot))
146    }
147}
148
149/// Naively compute the 32-bit NT hash of a single k-mer.
150/// When `RC` is false, compute a forward hash.
151/// When `RC` is true, compute a canonical hash.
152/// TODO: Investigate if we can use CLMUL instruction for speedup.
153pub fn nthash_kmer<'s, const RC: bool, H: CharHasher>(seq: impl Seq<'s>) -> u32 {
154    let hasher = H::new_from_val(seq.len(), seq);
155
156    let k = seq.len();
157    let mut hfw: u32 = 0;
158    let mut hrc: u32 = 0;
159    seq.iter_bp().for_each(|a| {
160        hfw = hfw.rotate_left(1) ^ hasher.f(a);
161        if RC {
162            hrc = hrc.rotate_right(1) ^ hasher.c(a);
163        }
164    });
165    hfw.wrapping_add(hrc.rotate_left(k as u32 - 1))
166}
167
168/// Returns a scalar iterator over the 32-bit NT hashes of all k-mers in the sequence.
169/// Prefer `hash_seq_simd`.
170///
171/// Set `RC` to true for canonical ntHash.
172pub fn nthash_seq_scalar<'s, const RC: bool, H: CharHasher>(
173    seq: impl Seq<'s>,
174    k: usize,
175) -> impl ExactSizeIterator<Item = u32> + Captures<&'s ()> + Clone {
176    assert!(k > 0);
177    let hasher = H::new_from_val(k, seq);
178
179    let mut hfw: u32 = 0;
180    let mut hrc: u32 = 0;
181    let mut add = seq.iter_bp();
182    let remove = seq.iter_bp();
183    add.by_ref().take(k - 1).for_each(|a| {
184        hfw = hfw.rotate_left(1) ^ hasher.f(a);
185        if RC {
186            hrc = hrc.rotate_right(1) ^ hasher.c_rot(a);
187        }
188    });
189    add.zip(remove).map(move |(a, r)| {
190        let hfw_out = hfw.rotate_left(1) ^ hasher.f(a);
191        hfw = hfw_out ^ hasher.f_rot(r);
192        if RC {
193            let hrc_out = hrc.rotate_right(1) ^ hasher.c_rot(a);
194            hrc = hrc_out ^ hasher.c(r);
195            hfw_out.wrapping_add(hrc_out)
196        } else {
197            hfw_out
198        }
199    })
200}
201
202/// Returns a simd-iterator over the 8 chunks 32-bit ntHashes of all k-mers in the sequence.
203/// The tail is returned separately.
204/// Returned chunks overlap by w-1 hashes. Set w=1 for non-overlapping chunks.
205///
206/// Set `RC` to true for canonical ntHash.
207pub fn nthash_seq_simd<'s, const RC: bool, SEQ: Seq<'s>, H: CharHasher>(
208    seq: impl Seq<'s>,
209    k: usize,
210    w: usize,
211) -> (
212    impl ExactSizeIterator<Item = S> + Captures<&'s ()> + Clone,
213    usize,
214) {
215    let (add_remove, padding) = seq.par_iter_bp_delayed(k + w - 1, k - 1);
216
217    let mut it = add_remove.map(nthash_mapper::<RC, SEQ, H>(k, w));
218    it.by_ref().take(k - 1).for_each(drop);
219
220    (it, padding)
221}
222
223/// A function that 'eats' added and removed bases, and returns the updated hash.
224/// The distance between them must be k-1, and the first k-1 removed bases must be 0.
225/// The first k-1 returned values will be useless.
226///
227/// Set `RC` to true for canonical ntHash.
228pub fn nthash_mapper<'s, const RC: bool, SEQ: Seq<'s>, H: CharHasher>(
229    k: usize,
230    w: usize,
231) -> impl FnMut((S, S)) -> S + Clone {
232    let hasher = H::new::<SEQ>(k);
233
234    assert!(k > 0);
235    assert!(w > 0);
236    // Each 128-bit half has a copy of the 4 32-bit hashes.
237
238    let mut fw = 0u32;
239    let mut rc = 0u32;
240    for _ in 0..k - 1 {
241        fw = fw.rotate_left(1) ^ hasher.f(0);
242        rc = rc.rotate_right(1) ^ hasher.c_rot(0);
243    }
244
245    let mut h_fw = S::splat(fw);
246    let mut h_rc = S::splat(rc);
247
248    move |(a, r)| {
249        let hfw_out = ((h_fw << 1) | (h_fw >> 31)) ^ hasher.simd_f(a);
250        h_fw = hfw_out ^ hasher.simd_f_rot(r);
251        if RC {
252            let hrc_out = ((h_rc >> 1) | (h_rc << 31)) ^ hasher.simd_c_rot(a);
253            h_rc = hrc_out ^ hasher.simd_c(r);
254            // Wrapping SIMD add
255            hfw_out + hrc_out
256        } else {
257            hfw_out
258        }
259    }
260}