simd_minimizers/
syncmers.rs

1//! Collect (and dedup) SIMD-iterator values into a flat `Vec<u32>`.
2
3#![allow(clippy::uninit_vec)]
4
5use std::{
6    array::{self, from_fn},
7    cell::RefCell,
8};
9
10use crate::{S, minimizers::SKIPPED};
11use packed_seq::{ChunkIt, L, PaddedIt, intrinsics::transpose};
12use wide::u32x8;
13
14/// Collect positions of all syncmers.
15/// `OPEN`:
16/// - `false`: closed syncmers
17/// - `true`: open syncmers
18pub fn collect_syncmers_scalar<const OPEN: bool>(
19    w: usize,
20    it: impl Iterator<Item = u32>,
21    out_vec: &mut Vec<u32>,
22) {
23    if OPEN {
24        assert!(
25            w % 2 == 1,
26            "Open syncmers require odd window size, so that there is a unique middle element."
27        );
28    }
29    unsafe { out_vec.set_len(out_vec.capacity()) };
30    let mut idx = 0;
31    it.enumerate().for_each(|(i, min_pos)| {
32        let is_syncmer = if OPEN {
33            min_pos as usize == i + w / 2
34        } else {
35            min_pos as usize == i || min_pos as usize == i + w - 1
36        };
37        if is_syncmer {
38            if idx == out_vec.len() {
39                out_vec.reserve(1);
40                unsafe { out_vec.set_len(out_vec.capacity()) };
41            }
42            *unsafe { out_vec.get_unchecked_mut(idx) } = i as u32;
43            idx += 1;
44        }
45    });
46    out_vec.truncate(idx);
47}
48
49pub trait CollectSyncmers: Sized {
50    /// Collect all indices where syncmers start.
51    ///
52    /// Automatically skips `SIMD_SKIPPED` values for ambiguous windows for sequences shorter than 2^32-2 or so.
53    fn collect_syncmers<const OPEN: bool>(self, w: usize) -> Vec<u32> {
54        let mut v = vec![];
55        self.collect_syncmers_into::<OPEN>(w, &mut v);
56        v
57    }
58
59    /// Collect all indices where syncmers start into `out_vec`.
60    ///
61    /// Automatically skips `SIMD_SKIPPED` values for ambiguous windows for sequences shorter than 2^32-2 or so.
62    fn collect_syncmers_into<const OPEN: bool>(self, w: usize, out_vec: &mut Vec<u32>);
63}
64
65thread_local! {
66    static CACHE: RefCell<[Vec<u32>; 8]> = RefCell::new(array::from_fn(|_| Vec::new()));
67}
68
69impl<I: ChunkIt<u32x8>> CollectSyncmers for PaddedIt<I> {
70    // mostly copied from `Collect::collect_minimizers_into`
71    #[inline(always)]
72    fn collect_syncmers_into<const OPEN: bool>(self, w: usize, out_vec: &mut Vec<u32>) {
73        let Self { it, padding } = self;
74        CACHE.with(
75            #[inline(always)]
76            |v| {
77                let mut v = v.borrow_mut();
78
79                let mut write_idx = [0; 8];
80
81                let len = it.len();
82                let mut lane_offsets: u32x8 = u32x8::from(from_fn(|i| (i * len) as u32));
83
84                let mut mask = u32x8::ZERO;
85                let mut padding_i = 0;
86                let mut padding_idx = 0;
87                assert!(padding <= L * len, "padding {padding} <= L {L} * len {len}");
88                let mut remaining_padding = padding;
89                for i in (0..8).rev() {
90                    if remaining_padding >= len {
91                        mask.as_array_mut()[i] = u32::MAX;
92                        remaining_padding -= len;
93                        continue;
94                    }
95                    padding_i = len - remaining_padding;
96                    padding_idx = i;
97                    break;
98                }
99
100                // FIXME: Is this one slow?
101                let mut m = [u32x8::ZERO; 8];
102                let mut i = 0;
103                it.for_each(
104                    #[inline(always)]
105                    |x| {
106                        if i == padding_i {
107                            mask.as_array_mut()[padding_idx] = u32::MAX;
108                        }
109                        let x = x | mask;
110
111                        // Every non-syncmer minimizer pos is masked out.
112                        let is_syncmer = if OPEN {
113                            x.cmp_eq(lane_offsets + S::splat((w / 2) as u32))
114                        } else {
115                            x.cmp_eq(lane_offsets) | x.cmp_eq(lane_offsets + S::splat(w as u32 - 1))
116                        };
117                        // current window position if syncmer, else u32::MAX
118                        let y = is_syncmer.blend(lane_offsets, u32x8::MAX);
119
120                        m[i % 8] = y;
121                        if i % 8 == 7 {
122                            let t = transpose(m);
123                            for j in 0..8 {
124                                let lane = t[j];
125                                if write_idx[j] + 8 > v[j].len() {
126                                    v[j].reserve(8);
127                                    unsafe {
128                                        let new_len = v[j].capacity();
129                                        v[j].set_len(new_len);
130                                    }
131                                }
132                                unsafe {
133                                    crate::intrinsics::append_filtered_vals(
134                                        lane,
135                                        // skip masked out values
136                                        lane.cmp_eq(u32x8::MAX),
137                                        &mut v[j],
138                                        &mut write_idx[j],
139                                    );
140                                }
141                            }
142                        }
143                        i += 1;
144                        lane_offsets += S::ONE;
145                    },
146                );
147
148                for j in 0..8 {
149                    v[j].truncate(write_idx[j]);
150                }
151
152                // Manually write the unfinished parts of length k=i%8.
153                let t = transpose(m);
154                let k = i % 8;
155                for j in 0..8 {
156                    let lane = t[j].as_array_ref();
157                    for &x in lane.iter().take(k) {
158                        if x < SKIPPED {
159                            v[j].push(x);
160                        }
161                    }
162                }
163
164                // Flatten v.
165                for lane in v.iter() {
166                    out_vec.extend_from_slice(lane.as_slice());
167                }
168            },
169        )
170    }
171}