Skip to main content

umi_core/
dedup.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fs::File;
3use std::io::{self, Write as IoWrite};
4
5use rust_htslib::bam::{self, Read as BamRead, Record};
6
7/// Trait for RNG used in reservoir-sampling tie-breaks.
8///
9/// Currently implemented by `PythonRandom` (MT19937 matching `CPython`) to get
10/// identical output for compat tests. Can be swapped for any fast RNG once
11/// exact-match testing is no longer needed.
12pub(crate) trait TieBreakRng {
13    /// Return a float in `[0, 1)`.
14    fn random(&mut self) -> f64;
15}
16
17/// Mersenne Twister 19937 PRNG, matching `CPython`'s random module exactly.
18///
19/// Python `umi_tools` uses seeded `random.random()` for reservoir-sampling
20/// tie-breaks in read selection. We replicate the identical float sequence.
21pub(crate) struct PythonRandom {
22    mt: [u32; 624],
23    index: usize,
24}
25
26impl PythonRandom {
27    const N: usize = 624;
28    const M: usize = 397;
29    const MATRIX_A: u32 = 0x9908_b0df;
30    const UPPER_MASK: u32 = 0x8000_0000;
31    const LOWER_MASK: u32 = 0x7fff_ffff;
32
33    /// Seed the same way `CPython` `random.seed(int)` does:
34    /// `init_genrand(19_650_218)` then `init_by_array(&[seed])`.
35    pub(crate) fn new(seed: u32) -> Self {
36        let mut rng = Self::init_genrand(19_650_218);
37        rng.init_by_array(&[seed]);
38        rng
39    }
40
41    #[allow(clippy::cast_possible_truncation)]
42    fn init_genrand(seed: u32) -> Self {
43        let mut mt = [0u32; Self::N];
44        mt[0] = seed;
45        for i in 1..Self::N {
46            mt[i] = 1_812_433_253u32
47                .wrapping_mul(mt[i - 1] ^ (mt[i - 1] >> 30))
48                .wrapping_add(i as u32); // i < 624, fits u32
49        }
50        Self { mt, index: Self::N }
51    }
52
53    #[allow(clippy::cast_possible_truncation)]
54    fn init_by_array(&mut self, key: &[u32]) {
55        let mut i: usize = 1;
56        let mut j: usize = 0;
57        let k = Self::N.max(key.len());
58        for _ in 0..k {
59            self.mt[i] = (self.mt[i]
60                ^ ((self.mt[i - 1] ^ (self.mt[i - 1] >> 30)).wrapping_mul(1_664_525)))
61            .wrapping_add(key[j])
62            .wrapping_add(j as u32); // j < key.len(), fits u32
63            i += 1;
64            j += 1;
65            if i >= Self::N {
66                self.mt[0] = self.mt[Self::N - 1];
67                i = 1;
68            }
69            if j >= key.len() {
70                j = 0;
71            }
72        }
73        for _ in 0..Self::N - 1 {
74            self.mt[i] = (self.mt[i]
75                ^ ((self.mt[i - 1] ^ (self.mt[i - 1] >> 30)).wrapping_mul(1_566_083_941)))
76            .wrapping_sub(i as u32); // i < 624, fits u32
77            i += 1;
78            if i >= Self::N {
79                self.mt[0] = self.mt[Self::N - 1];
80                i = 1;
81            }
82        }
83        self.mt[0] = Self::UPPER_MASK;
84    }
85
86    fn generate(&mut self) {
87        static MAG01: [u32; 2] = [0, PythonRandom::MATRIX_A];
88        for kk in 0..Self::N - Self::M {
89            let y = (self.mt[kk] & Self::UPPER_MASK) | (self.mt[kk + 1] & Self::LOWER_MASK);
90            self.mt[kk] = self.mt[kk + Self::M] ^ (y >> 1) ^ MAG01[(y & 1) as usize];
91        }
92        for kk in Self::N - Self::M..Self::N - 1 {
93            let y = (self.mt[kk] & Self::UPPER_MASK) | (self.mt[kk + 1] & Self::LOWER_MASK);
94            self.mt[kk] = self.mt[kk + Self::M - Self::N] ^ (y >> 1) ^ MAG01[(y & 1) as usize];
95        }
96        let y = (self.mt[Self::N - 1] & Self::UPPER_MASK) | (self.mt[0] & Self::LOWER_MASK);
97        self.mt[Self::N - 1] = self.mt[Self::M - 1] ^ (y >> 1) ^ MAG01[(y & 1) as usize];
98        self.index = 0;
99    }
100
101    fn next_u32(&mut self) -> u32 {
102        if self.index >= Self::N {
103            self.generate();
104        }
105        let mut y = self.mt[self.index];
106        self.index += 1;
107        y ^= y >> 11;
108        y ^= (y << 7) & 0x9d2c_5680;
109        y ^= (y << 15) & 0xefc6_0000;
110        y ^= y >> 18;
111        y
112    }
113}
114
115impl TieBreakRng for PythonRandom {
116    /// `CPython` `genrand_res53`: 53-bit precision float in `[0, 1)`.
117    fn random(&mut self) -> f64 {
118        let a = self.next_u32() >> 5;
119        let b = self.next_u32() >> 6;
120        (f64::from(a) * 67_108_864.0 + f64::from(b)) * (1.0 / 9_007_199_254_740_992.0)
121    }
122}
123
124/// MT19937 PRNG matching `NumPy`'s `np.random.seed(int)` + `np.random.random()`.
125///
126/// `NumPy` seeds with `init_genrand(seed)` directly (unlike `CPython` which uses
127/// `init_by_array`). Output generation (`genrand_res53`) is identical.
128struct NumpyRandom {
129    mt: [u32; 624],
130    index: usize,
131}
132
133impl NumpyRandom {
134    const N: usize = 624;
135
136    fn new(seed: u32) -> Self {
137        PythonRandom::init_genrand(seed).into()
138    }
139
140    fn random(&mut self) -> f64 {
141        let a = self.next_u32() >> 5;
142        let b = self.next_u32() >> 6;
143        (f64::from(a) * 67_108_864.0 + f64::from(b)) * (1.0 / 9_007_199_254_740_992.0)
144    }
145
146    fn next_u32(&mut self) -> u32 {
147        if self.index >= Self::N {
148            self.generate();
149        }
150        let mut y = self.mt[self.index];
151        self.index += 1;
152        y ^= y >> 11;
153        y ^= (y << 7) & 0x9d2c_5680;
154        y ^= (y << 15) & 0xefc6_0000;
155        y ^= y >> 18;
156        y
157    }
158
159    fn generate(&mut self) {
160        static MAG01: [u32; 2] = [0, PythonRandom::MATRIX_A];
161        for kk in 0..PythonRandom::N - PythonRandom::M {
162            let y = (self.mt[kk] & PythonRandom::UPPER_MASK)
163                | (self.mt[kk + 1] & PythonRandom::LOWER_MASK);
164            self.mt[kk] = self.mt[kk + PythonRandom::M] ^ (y >> 1) ^ MAG01[(y & 1) as usize];
165        }
166        for kk in PythonRandom::N - PythonRandom::M..PythonRandom::N - 1 {
167            let y = (self.mt[kk] & PythonRandom::UPPER_MASK)
168                | (self.mt[kk + 1] & PythonRandom::LOWER_MASK);
169            self.mt[kk] = self.mt[kk + PythonRandom::M - PythonRandom::N]
170                ^ (y >> 1)
171                ^ MAG01[(y & 1) as usize];
172        }
173        let y = (self.mt[PythonRandom::N - 1] & PythonRandom::UPPER_MASK)
174            | (self.mt[0] & PythonRandom::LOWER_MASK);
175        self.mt[PythonRandom::N - 1] =
176            self.mt[PythonRandom::M - 1] ^ (y >> 1) ^ MAG01[(y & 1) as usize];
177        self.index = 0;
178    }
179}
180
181impl From<PythonRandom> for NumpyRandom {
182    fn from(pr: PythonRandom) -> Self {
183        Self {
184            mt: pr.mt,
185            index: pr.index,
186        }
187    }
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq)]
191pub enum DedupMethod {
192    Unique,
193    Percentile,
194    Cluster,
195    Adjacency,
196    Directional,
197}
198
199#[allow(clippy::struct_excessive_bools)]
200pub struct DedupConfig {
201    pub method: DedupMethod,
202    pub ignore_umi: bool,
203    pub umi_separator: u8,
204    pub random_seed: u64,
205    pub out_sam: bool,
206    pub chrom: Option<String>,
207    pub edit_distance_threshold: u32,
208    pub subset: Option<f32>,
209    pub extract_umi_method: String,
210    pub umi_tag: Option<String>,
211    pub per_gene: bool,
212    pub gene_tag: Option<String>,
213    pub skip_tags_regex: Option<String>,
214    pub output_stats: Option<String>,
215    pub paired: bool,
216    pub ignore_tlen: bool,
217    pub umi_whitelist: Option<HashSet<Vec<u8>>>,
218}
219
220pub struct DedupStats {
221    pub input_reads: u64,
222    pub output_reads: u64,
223    pub positions: u64,
224}
225
226/// Length of a trailing/leading soft-clip, or 0 if the CIGAR op isn't `S`.
227pub(crate) fn soft_clip_len(op: Option<&rust_htslib::bam::record::Cigar>) -> i64 {
228    match op {
229        Some(c) if c.char() == 'S' => i64::from(c.len()),
230        _ => 0,
231    }
232}
233
234/// Returns `(start, pos)` for a read.
235///
236/// - `start`: leftmost aligned position (used for buffer-flush decisions)
237/// - `pos`: 5′ coordinate accounting for soft-clipping (used for grouping)
238///
239/// Matches Python `get_read_position()` with default `soft_clip_threshold=4`.
240pub(crate) fn get_read_position(record: &Record) -> (i64, i64) {
241    let cigar = record.cigar();
242    if record.is_reverse() {
243        let start = record.pos();
244        let pos = cigar.end_pos() + soft_clip_len(cigar.last());
245        (start, pos)
246    } else {
247        let pos = record.pos() - soft_clip_len(cigar.first());
248        (pos, pos)
249    }
250}
251
252/// Sub-key within a position group: `(is_reverse, is_spliced, tlen, read_length)`.
253/// With default options, this collapses to `(is_reverse, false, 0, 0)`.
254pub(crate) type GroupKey = (bool, bool, i64, usize);
255
256/// Holds per-UMI read selection state: best record + reservoir-sampling counter.
257pub(crate) struct UmiSlot {
258    pub(crate) record: Record,
259    pub(crate) mapq: u8,
260    pub(crate) tie_count: u32,
261    pub(crate) count: u32,
262    /// Insertion order within the (pos, key) group — used for deterministic
263    /// tiebreaking to match Python dict insertion order.
264    pub(crate) insertion_order: u32,
265}
266
267/// Buffered read collector that mirrors Python `umi_tools`' `reads_dict`.
268///
269/// Structure: `pos → key → umi → UmiSlot`
270///
271/// `pos` is the 5′ coordinate; `key` is `(is_reverse, …)`.
272/// When flushing, positions are emitted in sorted order and keys within
273/// each position are emitted in sorted order (matching Python's
274/// `sorted(reads_dict[p].keys())`).
275struct ReadBuffer {
276    groups: BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, UmiSlot>>>,
277    /// Per-(pos, key) insertion counters for deterministic ordering.
278    insertion_counters: BTreeMap<i64, BTreeMap<GroupKey, u32>>,
279}
280
281impl ReadBuffer {
282    const fn new() -> Self {
283        Self {
284            groups: BTreeMap::new(),
285            insertion_counters: BTreeMap::new(),
286        }
287    }
288
289    /// Add a record to the buffer, performing reservoir-sampling read selection.
290    fn add(
291        &mut self,
292        record: Record,
293        pos: i64,
294        key: GroupKey,
295        umi: Vec<u8>,
296        rng: &mut impl TieBreakRng,
297    ) {
298        let umi_map = self.groups.entry(pos).or_default().entry(key).or_default();
299
300        let Some(slot) = umi_map.get_mut(&umi) else {
301            let counter = self
302                .insertion_counters
303                .entry(pos)
304                .or_default()
305                .entry(key)
306                .or_default();
307            let order = *counter;
308            *counter += 1;
309            let mapq = record.mapq();
310            umi_map.insert(
311                umi,
312                UmiSlot {
313                    record,
314                    mapq,
315                    tie_count: 0,
316                    count: 1,
317                    insertion_order: order,
318                },
319            );
320            return;
321        };
322
323        slot.count += 1;
324
325        let record_mapq = record.mapq();
326        match slot.mapq.cmp(&record_mapq) {
327            std::cmp::Ordering::Greater => {}
328            std::cmp::Ordering::Less => {
329                slot.record = record;
330                slot.mapq = record_mapq;
331                slot.tie_count = 0;
332            }
333            std::cmp::Ordering::Equal => {
334                slot.tie_count += 1;
335                if rng.random() < 1.0 / f64::from(slot.tie_count) {
336                    slot.record = record;
337                }
338            }
339        }
340    }
341
342    /// Drain all position groups with `pos <= threshold`, applying UMI dedup selection.
343    fn drain_up_to(
344        &mut self,
345        threshold: i64,
346        method: DedupMethod,
347        edit_threshold: u32,
348        stats_ctx: &mut Option<StatsContext>,
349        umi_whitelist: Option<&HashSet<Vec<u8>>>,
350    ) -> Vec<Record> {
351        let rest = self.groups.split_off(&(threshold + 1));
352        let drained = std::mem::replace(&mut self.groups, rest);
353        // Clean up insertion counters for drained positions
354        let rest_counters = self.insertion_counters.split_off(&(threshold + 1));
355        let _ = std::mem::replace(&mut self.insertion_counters, rest_counters);
356        Self::apply_selection(drained, method, edit_threshold, stats_ctx, umi_whitelist)
357    }
358
359    /// Drain all remaining position groups, applying UMI dedup selection.
360    fn drain_all(
361        &mut self,
362        method: DedupMethod,
363        edit_threshold: u32,
364        stats_ctx: &mut Option<StatsContext>,
365        umi_whitelist: Option<&HashSet<Vec<u8>>>,
366    ) -> Vec<Record> {
367        let drained = std::mem::take(&mut self.groups);
368        self.insertion_counters.clear();
369        Self::apply_selection(drained, method, edit_threshold, stats_ctx, umi_whitelist)
370    }
371
372    /// Apply method-specific UMI selection to drained position groups.
373    fn apply_selection(
374        groups: BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, UmiSlot>>>,
375        method: DedupMethod,
376        edit_threshold: u32,
377        stats_ctx: &mut Option<StatsContext>,
378        umi_whitelist: Option<&HashSet<Vec<u8>>>,
379    ) -> Vec<Record> {
380        let mut records = Vec::new();
381        for key_map in groups.into_values() {
382            for umi_map in key_map.into_values() {
383                if stats_ctx.is_some() {
384                    let selected_with_counts =
385                        select_umis_with_cluster_counts(method, &umi_map, edit_threshold);
386                    let mut bundle_records: Vec<&Record> = Vec::new();
387                    let mut selected_umis = Vec::new();
388                    let mut cluster_counts = Vec::new();
389                    for (umi, cluster_count) in &selected_with_counts {
390                        if umi_whitelist.is_some_and(|wl| !wl.contains(umi)) {
391                            continue;
392                        }
393                        if let Some(slot) = umi_map.get(umi) {
394                            bundle_records.push(&slot.record);
395                            selected_umis.push(umi.clone());
396                            cluster_counts.push(*cluster_count);
397                        }
398                    }
399                    if let Some(ctx) = stats_ctx.as_mut() {
400                        ctx.collector.record_bundle(
401                            &umi_map,
402                            &selected_umis,
403                            &cluster_counts,
404                            &bundle_records,
405                            ctx.umi_separator,
406                            &mut ctx.read_gen,
407                        );
408                    }
409                    for r in bundle_records {
410                        records.push(r.clone());
411                    }
412                } else {
413                    let selected = select_umis(method, &umi_map, edit_threshold);
414                    for umi in &selected {
415                        if umi_whitelist.is_some_and(|wl| !wl.contains(umi)) {
416                            continue;
417                        }
418                        if let Some(slot) = umi_map.get(umi) {
419                            records.push(slot.record.clone());
420                        }
421                    }
422                }
423            }
424        }
425        records
426    }
427}
428
429/// Bundles the stats collector + read generator for passing through drain calls.
430struct StatsContext {
431    collector: StatsCollector,
432    read_gen: RandomReadGenerator,
433    umi_separator: u8,
434}
435
436/// Hamming distance between two byte slices of equal length.
437/// Returns `u32::MAX` if lengths differ (matching Python's `np.inf` return).
438#[allow(clippy::cast_possible_truncation)]
439pub(crate) fn edit_distance(a: &[u8], b: &[u8]) -> u32 {
440    if a.len() != b.len() {
441        return u32::MAX;
442    }
443    // UMIs are 5-12bp; count always fits u32
444    a.iter().zip(b.iter()).filter(|(x, y)| x != y).count() as u32
445}
446
447/// Build undirected adjacency list (for cluster + adjacency methods).
448/// Edge between A and B iff `edit_distance(A, B) <= threshold`.
449pub(crate) fn build_adjacency_list<'a>(
450    umis: &[&'a [u8]],
451    threshold: u32,
452) -> HashMap<&'a [u8], Vec<&'a [u8]>> {
453    let mut adj: HashMap<&'a [u8], Vec<&'a [u8]>> = HashMap::new();
454    for umi in umis {
455        adj.entry(umi).or_default();
456    }
457    for i in 0..umis.len() {
458        for j in (i + 1)..umis.len() {
459            if edit_distance(umis[i], umis[j]) <= threshold {
460                adj.get_mut(umis[i])
461                    .expect("UMI pre-inserted")
462                    .push(umis[j]);
463                adj.get_mut(umis[j])
464                    .expect("UMI pre-inserted")
465                    .push(umis[i]);
466            }
467        }
468    }
469    adj
470}
471
472/// Build directed adjacency list (for directional method).
473/// Edge A→B iff `edit_distance(A,B) <= threshold AND counts[A] >= 2*counts[B] - 1`.
474pub(crate) fn build_directional_adjacency_list<'a>(
475    umis: &[&'a [u8]],
476    counts: &HashMap<&[u8], u32>,
477    threshold: u32,
478) -> HashMap<&'a [u8], Vec<&'a [u8]>> {
479    let mut adj: HashMap<&'a [u8], Vec<&'a [u8]>> = HashMap::new();
480    for umi in umis {
481        adj.entry(umi).or_default();
482    }
483    for i in 0..umis.len() {
484        for j in (i + 1)..umis.len() {
485            if edit_distance(umis[i], umis[j]) <= threshold {
486                let ca = counts[umis[i]];
487                let cb = counts[umis[j]];
488                if ca >= (2 * cb).saturating_sub(1) {
489                    adj.get_mut(umis[i])
490                        .expect("UMI pre-inserted")
491                        .push(umis[j]);
492                }
493                if cb >= (2 * ca).saturating_sub(1) {
494                    adj.get_mut(umis[j])
495                        .expect("UMI pre-inserted")
496                        .push(umis[i]);
497                }
498            }
499        }
500    }
501    adj
502}
503
504/// BFS from `start`, following edges in `adj_list`. Returns the connected component.
505pub(crate) fn bfs<'a>(
506    start: &'a [u8],
507    adj_list: &HashMap<&'a [u8], Vec<&'a [u8]>>,
508) -> Vec<&'a [u8]> {
509    let mut searched: HashSet<&'a [u8]> = HashSet::new();
510    let mut queue: Vec<&'a [u8]> = Vec::new();
511    searched.insert(start);
512    queue.push(start);
513    while let Some(node) = queue.pop() {
514        if let Some(neighbors) = adj_list.get(node) {
515            for &next_node in neighbors {
516                if searched.insert(next_node) {
517                    queue.push(next_node);
518                }
519            }
520        }
521    }
522    let mut result: Vec<&'a [u8]> = searched.into_iter().collect();
523    result.sort();
524    result
525}
526
527/// Find connected components by iterating UMIs in count-descending order,
528/// running BFS from each unvisited node. Matches Python `_get_connected_components_adjacency`.
529pub(crate) fn connected_components<'a>(
530    umis: &[&'a [u8]],
531    counts: &HashMap<&[u8], u32>,
532    orders: &HashMap<&[u8], u32>,
533    adj_list: &HashMap<&'a [u8], Vec<&'a [u8]>>,
534) -> Vec<Vec<&'a [u8]>> {
535    // Sort UMIs by count descending, then insertion order ascending for ties
536    let mut sorted_umis: Vec<&[u8]> = umis.to_vec();
537    sorted_umis.sort_by(|a, b| {
538        counts[b]
539            .cmp(&counts[a])
540            .then_with(|| orders[a].cmp(&orders[b]))
541    });
542
543    let mut found: HashSet<&[u8]> = HashSet::new();
544    let mut components: Vec<Vec<&'a [u8]>> = Vec::new();
545    for umi in &sorted_umis {
546        if !found.contains(*umi) {
547            let component = bfs(umi, adj_list);
548            for &node in &component {
549                found.insert(node);
550            }
551            components.push(component);
552        }
553    }
554    components
555}
556
557/// Greedy min-set-cover: select fewest UMIs (by descending count) to "cover"
558/// all UMIs in the cluster via adjacency. Matches Python `_get_best_min_account`.
559pub(crate) fn min_set_cover<'a>(
560    cluster: &[&'a [u8]],
561    adj_list: &HashMap<&'a [u8], Vec<&'a [u8]>>,
562    counts: &HashMap<&[u8], u32>,
563) -> Vec<&'a [u8]> {
564    if cluster.len() == 1 {
565        return cluster.to_vec();
566    }
567    let mut sorted_nodes: Vec<&'a [u8]> = cluster.to_vec();
568    // Sort by count desc, lex asc (BFS output is lex-sorted; Python's stable sort preserves that)
569    sorted_nodes.sort_by(|a, b| counts[*b].cmp(&counts[*a]).then_with(|| a.cmp(b)));
570    for i in 0..sorted_nodes.len() - 1 {
571        let selected = &sorted_nodes[..=i];
572        // Compute covered nodes: selected nodes + their neighbors
573        let mut covered: HashSet<&[u8]> = HashSet::new();
574        for &s in selected {
575            covered.insert(s);
576            if let Some(neighbors) = adj_list.get(s) {
577                for &n in neighbors {
578                    covered.insert(n);
579                }
580            }
581        }
582        // Check if all cluster nodes are covered
583        let remaining: usize = cluster.iter().filter(|n| !covered.contains(*n)).count();
584        if remaining == 0 {
585            return selected.to_vec();
586        }
587    }
588    // Fallback: all nodes (shouldn't reach here for valid inputs)
589    sorted_nodes
590}
591
592/// Select UMIs to keep for one (pos, key) group. Returns UMIs whose records to emit.
593#[allow(clippy::too_many_lines)]
594pub(crate) fn select_umis(
595    method: DedupMethod,
596    umi_map: &HashMap<Vec<u8>, UmiSlot>,
597    edit_threshold: u32,
598) -> Vec<Vec<u8>> {
599    // Build count and insertion-order maps for sorting (matches Python dict insertion order)
600    let counts: HashMap<&[u8], u32> = umi_map
601        .iter()
602        .map(|(k, v)| (k.as_slice(), v.count))
603        .collect();
604    let orders: HashMap<&[u8], u32> = umi_map
605        .iter()
606        .map(|(k, v)| (k.as_slice(), v.insertion_order))
607        .collect();
608    // Sort key for within-component representative selection: count desc, lex asc.
609    // BFS produces lex-sorted components; Python's stable sort preserves that.
610    let lex_sort = |a: &[u8], b: &[u8]| -> std::cmp::Ordering {
611        counts[b].cmp(&counts[a]).then_with(|| a.cmp(b))
612    };
613
614    match method {
615        DedupMethod::Unique => {
616            // Python returns UMIs in dict insertion order (no count sorting)
617            let mut umis: Vec<Vec<u8>> = umi_map.keys().cloned().collect();
618            umis.sort_by(|a, b| orders[a.as_slice()].cmp(&orders[b.as_slice()]));
619            umis
620        }
621
622        DedupMethod::Percentile => {
623            if counts.len() <= 1 {
624                return umi_map.keys().cloned().collect();
625            }
626            let all_counts: Vec<u32> = counts.values().copied().collect();
627            let threshold = median(&all_counts) / 100.0;
628            // Python filters then preserves dict insertion order
629            let mut umis: Vec<Vec<u8>> = umi_map
630                .iter()
631                .filter(|(_, slot)| f64::from(slot.count) > threshold)
632                .map(|(umi, _)| umi.clone())
633                .collect();
634            umis.sort_by(|a, b| orders[a.as_slice()].cmp(&orders[b.as_slice()]));
635            umis
636        }
637
638        DedupMethod::Cluster => {
639            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
640            let adj_list = build_adjacency_list(&umis, edit_threshold);
641            let components = connected_components(&umis, &counts, &orders, &adj_list);
642            // Representative per component: highest count, lex tiebreak
643            components
644                .into_iter()
645                .map(|mut comp| {
646                    comp.sort_by(|a, b| lex_sort(a, b));
647                    comp.into_iter()
648                        .next()
649                        .expect("component is non-empty")
650                        .to_vec()
651                })
652                .collect()
653        }
654
655        DedupMethod::Adjacency => {
656            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
657            let adj_list = build_adjacency_list(&umis, edit_threshold);
658            let components = connected_components(&umis, &counts, &orders, &adj_list);
659            let mut result = Vec::new();
660            for component in components {
661                if component.len() == 1 {
662                    result.push(component[0].to_vec());
663                } else {
664                    let lead_umis = min_set_cover(&component, &adj_list, &counts);
665                    result.extend(lead_umis.into_iter().map(<[u8]>::to_vec));
666                }
667            }
668            result
669        }
670
671        DedupMethod::Directional => {
672            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
673            let adj_list = build_directional_adjacency_list(&umis, &counts, edit_threshold);
674            let components = connected_components(&umis, &counts, &orders, &adj_list);
675            let mut observed: HashSet<&[u8]> = HashSet::new();
676            let mut result = Vec::new();
677            for component in components {
678                if component.len() == 1 {
679                    let umi = component[0];
680                    observed.insert(umi);
681                    result.push(umi.to_vec());
682                } else {
683                    // Sort by count desc, lex asc (BFS output is lex-sorted,
684                    // Python's stable sort preserves that for equal counts)
685                    let mut sorted_comp = component;
686                    sorted_comp.sort_by(|a, b| lex_sort(a, b));
687                    let mut group_lead = None;
688                    for node in sorted_comp {
689                        if observed.insert(node) && group_lead.is_none() {
690                            group_lead = Some(node);
691                        }
692                    }
693                    if let Some(lead) = group_lead {
694                        result.push(lead.to_vec());
695                    }
696                }
697            }
698            result
699        }
700    }
701}
702
703/// Count deduplicated UMI groups from raw count/order maps.
704///
705/// Same logic as `select_umis` but takes `HashMap<Vec<u8>, u32>` instead of
706/// `UmiSlot`, and returns only the count of surviving UMI groups.
707#[allow(clippy::implicit_hasher)]
708#[must_use]
709pub fn count_umis(
710    method: DedupMethod,
711    counts: &HashMap<Vec<u8>, u32>,
712    orders: &HashMap<Vec<u8>, u32>,
713    edit_threshold: u32,
714) -> usize {
715    let count_refs: HashMap<&[u8], u32> = counts.iter().map(|(k, v)| (k.as_slice(), *v)).collect();
716    let order_refs: HashMap<&[u8], u32> = orders.iter().map(|(k, v)| (k.as_slice(), *v)).collect();
717    let lex_sort = |a: &[u8], b: &[u8]| -> std::cmp::Ordering {
718        count_refs[b].cmp(&count_refs[a]).then_with(|| a.cmp(b))
719    };
720
721    match method {
722        DedupMethod::Unique => counts.len(),
723
724        DedupMethod::Percentile => {
725            if counts.len() <= 1 {
726                return counts.len();
727            }
728            let all_counts: Vec<u32> = counts.values().copied().collect();
729            let threshold = median(&all_counts) / 100.0;
730            counts
731                .values()
732                .filter(|&&c| f64::from(c) > threshold)
733                .count()
734        }
735
736        DedupMethod::Cluster => {
737            let umis: Vec<&[u8]> = counts.keys().map(Vec::as_slice).collect();
738            let adj_list = build_adjacency_list(&umis, edit_threshold);
739            let components = connected_components(&umis, &count_refs, &order_refs, &adj_list);
740            components.len()
741        }
742
743        DedupMethod::Adjacency => {
744            let umis: Vec<&[u8]> = counts.keys().map(Vec::as_slice).collect();
745            let adj_list = build_adjacency_list(&umis, edit_threshold);
746            let components = connected_components(&umis, &count_refs, &order_refs, &adj_list);
747            let mut total = 0;
748            for component in components {
749                if component.len() == 1 {
750                    total += 1;
751                } else {
752                    total += min_set_cover(&component, &adj_list, &count_refs).len();
753                }
754            }
755            total
756        }
757
758        DedupMethod::Directional => {
759            let umis: Vec<&[u8]> = counts.keys().map(Vec::as_slice).collect();
760            let adj_list = build_directional_adjacency_list(&umis, &count_refs, edit_threshold);
761            let components = connected_components(&umis, &count_refs, &order_refs, &adj_list);
762            let mut observed: HashSet<&[u8]> = HashSet::new();
763            let mut total = 0;
764            for component in components {
765                if component.len() == 1 {
766                    let umi = component[0];
767                    observed.insert(umi);
768                    total += 1;
769                } else {
770                    let mut sorted_comp = component;
771                    sorted_comp.sort_by(|a, b| lex_sort(a, b));
772                    let mut found_lead = false;
773                    for node in sorted_comp {
774                        if observed.insert(node) && !found_lead {
775                            found_lead = true;
776                            total += 1;
777                        }
778                    }
779                }
780            }
781            total
782        }
783    }
784}
785
786/// Extract UMI and optional cell barcode from a read name using the `umis` method.
787///
788/// Splits `qname` by `:` and looks for `UMI_<seq>` and `CELL_<barcode>` prefixed
789/// fields. Returns `(umi, Option<cell>)`.
790#[must_use]
791pub fn extract_umi_umis(qname: &[u8]) -> (Vec<u8>, Option<Vec<u8>>) {
792    let mut umi = None;
793    let mut cell = None;
794    for part in qname.split(|&b| b == b':') {
795        if part.starts_with(b"UMI_") {
796            umi = Some(part[4..].to_vec());
797        } else if part.starts_with(b"CELL_") {
798            cell = Some(part[5..].to_vec());
799        }
800    }
801    (umi.unwrap_or_default(), cell)
802}
803
804/// Compute the median of a slice of u32 values, returned as f64.
805pub(crate) fn median(values: &[u32]) -> f64 {
806    let mut sorted = values.to_vec();
807    sorted.sort_unstable();
808    let n = sorted.len();
809    if n.is_multiple_of(2) {
810        f64::midpoint(f64::from(sorted[n / 2 - 1]), f64::from(sorted[n / 2]))
811    } else {
812        f64::from(sorted[n / 2])
813    }
814}
815
816/// Like `select_umis`, but also returns the total count for each cluster
817/// (sum of all UMI counts in the cluster, not just the representative).
818/// Returns `(selected_umi, cluster_total_count)` pairs.
819#[allow(clippy::too_many_lines)]
820fn select_umis_with_cluster_counts(
821    method: DedupMethod,
822    umi_map: &HashMap<Vec<u8>, UmiSlot>,
823    edit_threshold: u32,
824) -> Vec<(Vec<u8>, u32)> {
825    let counts: HashMap<&[u8], u32> = umi_map
826        .iter()
827        .map(|(k, v)| (k.as_slice(), v.count))
828        .collect();
829    let orders: HashMap<&[u8], u32> = umi_map
830        .iter()
831        .map(|(k, v)| (k.as_slice(), v.insertion_order))
832        .collect();
833    let lex_sort = |a: &[u8], b: &[u8]| -> std::cmp::Ordering {
834        counts[b].cmp(&counts[a]).then_with(|| a.cmp(b))
835    };
836
837    match method {
838        DedupMethod::Unique => {
839            let mut umis: Vec<Vec<u8>> = umi_map.keys().cloned().collect();
840            umis.sort_by(|a, b| orders[a.as_slice()].cmp(&orders[b.as_slice()]));
841            umis.into_iter()
842                .map(|u| {
843                    let c = counts[u.as_slice()];
844                    (u, c)
845                })
846                .collect()
847        }
848
849        DedupMethod::Percentile => {
850            if counts.len() <= 1 {
851                return umi_map.iter().map(|(u, s)| (u.clone(), s.count)).collect();
852            }
853            let all_counts: Vec<u32> = counts.values().copied().collect();
854            let threshold = median(&all_counts) / 100.0;
855            let mut umis: Vec<Vec<u8>> = umi_map
856                .iter()
857                .filter(|(_, slot)| f64::from(slot.count) > threshold)
858                .map(|(umi, _)| umi.clone())
859                .collect();
860            umis.sort_by(|a, b| orders[a.as_slice()].cmp(&orders[b.as_slice()]));
861            umis.into_iter()
862                .map(|u| {
863                    let c = counts[u.as_slice()];
864                    (u, c)
865                })
866                .collect()
867        }
868
869        DedupMethod::Cluster => {
870            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
871            let adj_list = build_adjacency_list(&umis, edit_threshold);
872            let components = connected_components(&umis, &counts, &orders, &adj_list);
873            components
874                .into_iter()
875                .map(|mut comp| {
876                    let cluster_count: u32 = comp.iter().map(|u| counts[*u]).sum();
877                    comp.sort_by(|a, b| lex_sort(a, b));
878                    (
879                        comp.into_iter()
880                            .next()
881                            .expect("component is non-empty")
882                            .to_vec(),
883                        cluster_count,
884                    )
885                })
886                .collect()
887        }
888
889        DedupMethod::Adjacency => {
890            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
891            let adj_list = build_adjacency_list(&umis, edit_threshold);
892            let components = connected_components(&umis, &counts, &orders, &adj_list);
893            let mut result = Vec::new();
894            for component in components {
895                if component.len() == 1 {
896                    let c = counts[component[0]];
897                    result.push((component[0].to_vec(), c));
898                } else {
899                    let lead_umis = min_set_cover(&component, &adj_list, &counts);
900                    // Each lead UMI's cluster: itself + its unobserved neighbors
901                    let mut observed: HashSet<&[u8]> = HashSet::new();
902                    for &lead in &lead_umis {
903                        let mut cluster_count = counts[lead];
904                        observed.insert(lead);
905                        if let Some(neighbors) = adj_list.get(lead) {
906                            for &n in neighbors {
907                                if observed.insert(n) {
908                                    cluster_count += counts[n];
909                                }
910                            }
911                        }
912                        result.push((lead.to_vec(), cluster_count));
913                    }
914                }
915            }
916            result
917        }
918
919        DedupMethod::Directional => {
920            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
921            let adj_list = build_directional_adjacency_list(&umis, &counts, edit_threshold);
922            let components = connected_components(&umis, &counts, &orders, &adj_list);
923            let mut observed: HashSet<&[u8]> = HashSet::new();
924            let mut result = Vec::new();
925            for component in components {
926                if component.len() == 1 {
927                    let umi = component[0];
928                    let c = counts[umi];
929                    observed.insert(umi);
930                    result.push((umi.to_vec(), c));
931                } else {
932                    let mut sorted_comp = component;
933                    sorted_comp.sort_by(|a, b| lex_sort(a, b));
934                    let mut group_lead = None;
935                    let mut cluster_count: u32 = 0;
936                    for node in sorted_comp {
937                        if observed.insert(node) {
938                            cluster_count += counts[node];
939                            if group_lead.is_none() {
940                                group_lead = Some(node);
941                            }
942                        }
943                    }
944                    if let Some(lead) = group_lead {
945                        result.push((lead.to_vec(), cluster_count));
946                    }
947                }
948            }
949            result
950        }
951    }
952}
953
954/// Mean pairwise Hamming distance between UMIs. Returns -1.0 for single UMI.
955#[allow(clippy::cast_precision_loss)]
956fn get_average_umi_distance(umis: &[&[u8]]) -> f64 {
957    if umis.len() <= 1 {
958        return -1.0;
959    }
960    let mut total: u64 = 0;
961    let mut count: u64 = 0;
962    for i in 0..umis.len() {
963        for j in (i + 1)..umis.len() {
964            total += u64::from(edit_distance(umis[i], umis[j]));
965            count += 1;
966        }
967    }
968    total as f64 / count as f64
969}
970
971/// Pre-scans BAM to build UMI frequency distribution for null model sampling.
972struct RandomReadGenerator {
973    keys: Vec<Vec<u8>>,
974    cdf: Vec<f64>,
975    rng: NumpyRandom,
976    random_umis: Vec<Vec<u8>>,
977    random_ix: usize,
978    fill_size: usize,
979}
980
981impl RandomReadGenerator {
982    fn new(
983        bam_path: &str,
984        umi_separator: u8,
985        extract_method: &str,
986        umi_tag: Option<&str>,
987        chrom: Option<&str>,
988        seed: u32,
989    ) -> Result<Self, DedupError> {
990        let mut reader =
991            bam::Reader::from_path(bam_path).map_err(|e| DedupError::BamOpen(e.to_string()))?;
992
993        let chrom_tid: Option<i32> = chrom
994            .map(|c| {
995                let tid = reader
996                    .header()
997                    .tid(c.as_bytes())
998                    .ok_or_else(|| DedupError::UnknownChrom(c.to_string()))?;
999                #[allow(clippy::cast_possible_wrap)]
1000                Ok(tid as i32)
1001            })
1002            .transpose()?;
1003
1004        // Count UMI frequencies, preserving insertion order (order of first appearance).
1005        let mut umi_order: Vec<Vec<u8>> = Vec::new();
1006        let mut umi_counts: HashMap<Vec<u8>, u64> = HashMap::new();
1007
1008        for result in reader.records() {
1009            let record = result.map_err(|e| DedupError::BamRead(e.to_string()))?;
1010            if record.is_unmapped() {
1011                continue;
1012            }
1013            if record.is_last_in_template() {
1014                continue;
1015            }
1016            if let Some(filter_tid) = chrom_tid
1017                && record.tid() != filter_tid
1018            {
1019                continue;
1020            }
1021            let umi = if extract_method == "tag" {
1022                match extract_umi_from_tag(&record, umi_tag.unwrap_or("RX")) {
1023                    Some(u) => u,
1024                    None => continue,
1025                }
1026            } else {
1027                extract_umi_from_name(&record, umi_separator)
1028            };
1029            let entry = umi_counts.entry(umi.clone());
1030            if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
1031                umi_order.push(umi);
1032            }
1033            *entry.or_insert(0) += 1;
1034        }
1035
1036        // Build CDF from frequencies in insertion order.
1037        #[allow(clippy::cast_precision_loss)]
1038        let total: f64 = umi_counts.values().sum::<u64>() as f64;
1039        let mut cdf = Vec::with_capacity(umi_order.len());
1040        let mut cumsum = 0.0;
1041        for key in &umi_order {
1042            #[allow(clippy::cast_precision_loss)]
1043            {
1044                cumsum += umi_counts[key] as f64 / total;
1045            }
1046            cdf.push(cumsum);
1047        }
1048
1049        let mut rng = Self {
1050            keys: umi_order,
1051            cdf,
1052            rng: NumpyRandom::new(seed),
1053            random_umis: Vec::new(),
1054            random_ix: 0,
1055            fill_size: 100_000,
1056        };
1057        rng.refill();
1058        Ok(rng)
1059    }
1060
1061    fn refill(&mut self) {
1062        self.random_umis.clear();
1063        self.random_umis.reserve(self.fill_size);
1064        for _ in 0..self.fill_size {
1065            let r = self.rng.random();
1066            let idx = self
1067                .cdf
1068                .partition_point(|&c| c <= r)
1069                .min(self.keys.len() - 1);
1070            self.random_umis.push(self.keys[idx].clone());
1071        }
1072        self.random_ix = 0;
1073    }
1074
1075    fn get_umis(&mut self, n: usize) -> Vec<Vec<u8>> {
1076        if n >= self.fill_size - self.random_ix {
1077            if n > self.fill_size {
1078                self.fill_size = n * 2;
1079            }
1080            self.refill();
1081        }
1082        let result = self.random_umis[self.random_ix..self.random_ix + n].to_vec();
1083        self.random_ix += n;
1084        result
1085    }
1086}
1087
1088/// Accumulates per-bundle stats during dedup for the 3 stats output files.
1089struct StatsCollector {
1090    // Per-UMI-per-position: (umi, count) tuples
1091    pre_umi_counts: Vec<(Vec<u8>, u32)>,
1092    post_umi_counts: Vec<(Vec<u8>, u32)>,
1093    // Edit distance stats per bundle
1094    pre_cluster_stats: Vec<f64>,
1095    post_cluster_stats: Vec<f64>,
1096    pre_cluster_stats_null: Vec<f64>,
1097    post_cluster_stats_null: Vec<f64>,
1098}
1099
1100impl StatsCollector {
1101    const fn new() -> Self {
1102        Self {
1103            pre_umi_counts: Vec::new(),
1104            post_umi_counts: Vec::new(),
1105            pre_cluster_stats: Vec::new(),
1106            post_cluster_stats: Vec::new(),
1107            pre_cluster_stats_null: Vec::new(),
1108            post_cluster_stats_null: Vec::new(),
1109        }
1110    }
1111
1112    fn record_bundle(
1113        &mut self,
1114        umi_map: &HashMap<Vec<u8>, UmiSlot>,
1115        selected_umis: &[Vec<u8>],
1116        cluster_counts: &[u32],
1117        selected_records: &[&Record],
1118        umi_separator: u8,
1119        read_gen: &mut RandomReadGenerator,
1120    ) {
1121        // Pre-dedup: all UMIs in the bundle
1122        let pre_umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
1123        for (umi, slot) in umi_map {
1124            self.pre_umi_counts.push((umi.clone(), slot.count));
1125        }
1126        let avg_dist = get_average_umi_distance(&pre_umis);
1127        self.pre_cluster_stats.push(avg_dist);
1128
1129        let cluster_size = pre_umis.len();
1130        let random_umis = read_gen.get_umis(cluster_size);
1131        let random_refs: Vec<&[u8]> = random_umis.iter().map(Vec::as_slice).collect();
1132        let avg_null = get_average_umi_distance(&random_refs);
1133        self.pre_cluster_stats_null.push(avg_null);
1134
1135        // Post-dedup: selected UMIs with cluster-aggregated counts
1136        for (umi, &count) in selected_umis.iter().zip(cluster_counts) {
1137            self.post_umi_counts.push((umi.clone(), count));
1138        }
1139
1140        // Post-dedup edit distance from the actual output records' UMIs
1141        let post_umis: Vec<Vec<u8>> = selected_records
1142            .iter()
1143            .map(|r| extract_umi_from_name(r, umi_separator))
1144            .collect();
1145        let post_refs: Vec<&[u8]> = post_umis.iter().map(Vec::as_slice).collect();
1146        let avg_post = get_average_umi_distance(&post_refs);
1147        self.post_cluster_stats.push(avg_post);
1148
1149        let post_size = post_umis.len();
1150        let random_umis_post = read_gen.get_umis(post_size);
1151        let random_post_refs: Vec<&[u8]> = random_umis_post.iter().map(Vec::as_slice).collect();
1152        let avg_null_post = get_average_umi_distance(&random_post_refs);
1153        self.post_cluster_stats_null.push(avg_null_post);
1154    }
1155
1156    fn write_files(&self, prefix: &str, method_name: &str) -> Result<(), DedupError> {
1157        self.write_per_umi_per_position(prefix)?;
1158        self.write_per_umi(prefix)?;
1159        self.write_edit_distance(prefix, method_name)?;
1160        Ok(())
1161    }
1162
1163    fn write_per_umi_per_position(&self, prefix: &str) -> Result<(), DedupError> {
1164        let mut pre_counts: HashMap<u32, u32> = HashMap::new();
1165        let mut post_counts: HashMap<u32, u32> = HashMap::new();
1166        for (_, count) in &self.pre_umi_counts {
1167            *pre_counts.entry(*count).or_default() += 1;
1168        }
1169        for (_, count) in &self.post_umi_counts {
1170            *post_counts.entry(*count).or_default() += 1;
1171        }
1172
1173        let mut all_counts: Vec<u32> = pre_counts
1174            .keys()
1175            .chain(post_counts.keys())
1176            .copied()
1177            .collect::<HashSet<u32>>()
1178            .into_iter()
1179            .collect();
1180        all_counts.sort_unstable();
1181
1182        let path = format!("{prefix}_per_umi_per_position.tsv");
1183        let mut f =
1184            File::create(&path).map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1185        writeln!(f, "counts\tinstances_pre\tinstances_post")
1186            .map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1187        for count in &all_counts {
1188            let pre = pre_counts.get(count).unwrap_or(&0);
1189            let post = post_counts.get(count).unwrap_or(&0);
1190            writeln!(f, "{count}\t{pre}\t{post}")
1191                .map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1192        }
1193        Ok(())
1194    }
1195
1196    fn write_per_umi(&self, prefix: &str) -> Result<(), DedupError> {
1197        // Aggregate per UMI: median_counts, times_observed, total_counts
1198        let pre_agg = Self::aggregate_per_umi(&self.pre_umi_counts);
1199        let post_agg = Self::aggregate_per_umi(&self.post_umi_counts);
1200
1201        // Sorted union of UMI keys
1202        let mut all_umis: Vec<Vec<u8>> = pre_agg
1203            .keys()
1204            .chain(post_agg.keys())
1205            .cloned()
1206            .collect::<HashSet<Vec<u8>>>()
1207            .into_iter()
1208            .collect();
1209        all_umis.sort();
1210
1211        let path = format!("{prefix}_per_umi.tsv");
1212        let mut f =
1213            File::create(&path).map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1214        writeln!(
1215            f,
1216            "UMI\tmedian_counts_pre\ttimes_observed_pre\ttotal_counts_pre\t\
1217             median_counts_post\ttimes_observed_post\ttotal_counts_post"
1218        )
1219        .map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1220
1221        for umi in &all_umis {
1222            let (med_pre, obs_pre, tot_pre) = pre_agg.get(umi).unwrap_or(&(0, 0, 0));
1223            let (med_post, obs_post, tot_post) = post_agg.get(umi).unwrap_or(&(0, 0, 0));
1224            let umi_str = std::str::from_utf8(umi).unwrap_or("?");
1225            writeln!(
1226                f,
1227                "{umi_str}\t{med_pre}\t{obs_pre}\t{tot_pre}\t{med_post}\t{obs_post}\t{tot_post}"
1228            )
1229            .map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1230        }
1231        Ok(())
1232    }
1233
1234    /// Returns map: umi → (`median_counts`, `times_observed`, `total_counts`)
1235    #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
1236    fn aggregate_per_umi(umi_counts: &[(Vec<u8>, u32)]) -> HashMap<Vec<u8>, (i64, i64, i64)> {
1237        let mut grouped: HashMap<Vec<u8>, Vec<u32>> = HashMap::new();
1238        for (umi, count) in umi_counts {
1239            grouped.entry(umi.clone()).or_default().push(*count);
1240        }
1241        grouped
1242            .into_iter()
1243            .map(|(umi, counts)| {
1244                let times_observed = counts.len() as i64;
1245                let total: i64 = counts.iter().map(|&c| i64::from(c)).sum();
1246                let med = median(&counts);
1247                // Python: .fillna(0).astype(int) truncates toward zero (same as floor for positive)
1248                let median_int = med as i64;
1249                (umi, (median_int, times_observed, total))
1250            })
1251            .collect()
1252    }
1253
1254    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
1255    fn write_edit_distance(&self, prefix: &str, method_name: &str) -> Result<(), DedupError> {
1256        // Find max edit distance across all stats
1257        let all_stats = self
1258            .pre_cluster_stats
1259            .iter()
1260            .chain(&self.post_cluster_stats)
1261            .chain(&self.pre_cluster_stats_null)
1262            .chain(&self.post_cluster_stats_null);
1263        let max_ed = all_stats.copied().fold(0.0_f64, f64::max) as i32;
1264
1265        // bins = range(-1, max_ed + 2)  →  [-1, 0, 1, ..., max_ed+1]
1266        let bins: Vec<i32> = (-1..=max_ed + 1).collect();
1267        let nbins = bins.len();
1268
1269        let digitize = |values: &[f64]| -> Vec<usize> {
1270            // np.digitize(values, bins, right=True): returns i such that
1271            // bins[i-1] < v <= bins[i]. Equivalent to searchsorted(side='left').
1272            values
1273                .iter()
1274                .map(|&v| bins.partition_point(|&b| f64::from(b) < v).min(nbins))
1275                .collect()
1276        };
1277
1278        let bincount = |binned: &[usize], minlength: usize| -> Vec<u64> {
1279            let mut counts = vec![0u64; minlength];
1280            for &b in binned {
1281                if b < counts.len() {
1282                    counts[b] += 1;
1283                }
1284            }
1285            counts
1286        };
1287
1288        let minlength = (max_ed + 3) as usize;
1289
1290        let pre_binned = digitize(&self.pre_cluster_stats);
1291        let post_binned = digitize(&self.post_cluster_stats);
1292        let pre_null_binned = digitize(&self.pre_cluster_stats_null);
1293        let post_null_binned = digitize(&self.post_cluster_stats_null);
1294
1295        let pre_counts = bincount(&pre_binned, minlength);
1296        let post_counts = bincount(&post_binned, minlength);
1297        let pre_null_counts = bincount(&pre_null_binned, minlength);
1298        let post_null_counts = bincount(&post_null_binned, minlength);
1299
1300        let path = format!("{prefix}_edit_distance.tsv");
1301        let mut f =
1302            File::create(&path).map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1303        writeln!(
1304            f,
1305            "unique\tunique_null\t{method_name}\t{method_name}_null\tedit_distance"
1306        )
1307        .map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1308
1309        for i in 0..minlength {
1310            let ed_label = if i == 0 {
1311                "Single_UMI".to_string()
1312            } else if i < bins.len() {
1313                bins[i].to_string()
1314            } else {
1315                (i - 1).to_string()
1316            };
1317            let pre = pre_counts.get(i).unwrap_or(&0);
1318            let post = post_counts.get(i).unwrap_or(&0);
1319            let pre_null = pre_null_counts.get(i).unwrap_or(&0);
1320            let post_null = post_null_counts.get(i).unwrap_or(&0);
1321            writeln!(f, "{pre}\t{pre_null}\t{post}\t{post_null}\t{ed_label}")
1322                .map_err(|e| DedupError::StatsWrite(path.clone(), e.to_string()))?;
1323        }
1324        Ok(())
1325    }
1326}
1327
1328/// # Errors
1329///
1330/// Returns `DedupError` on BAM I/O failures or unknown chromosome filter.
1331#[allow(clippy::too_many_lines)]
1332pub fn run_dedup(
1333    config: &DedupConfig,
1334    input_path: &str,
1335    output: &mut dyn io::Write,
1336) -> Result<DedupStats, DedupError> {
1337    let mut reader =
1338        bam::Reader::from_path(input_path).map_err(|e| DedupError::BamOpen(e.to_string()))?;
1339    let header = bam::Header::from_template(reader.header());
1340
1341    let format = if config.out_sam {
1342        bam::Format::Sam
1343    } else {
1344        bam::Format::Bam
1345    };
1346
1347    let mut writer = bam::Writer::from_stdout(&header, format)
1348        .map_err(|e| DedupError::BamWrite(e.to_string()))?;
1349
1350    // Optional chromosome filter
1351    let chrom_filter: Option<i32> = config
1352        .chrom
1353        .as_ref()
1354        .map(|c| {
1355            let tid = reader
1356                .header()
1357                .tid(c.as_bytes())
1358                .ok_or_else(|| DedupError::UnknownChrom(c.clone()))?;
1359            #[allow(clippy::cast_possible_wrap)]
1360            Ok(tid as i32)
1361        })
1362        .transpose()?;
1363
1364    #[allow(clippy::cast_possible_truncation)]
1365    let mut rng = PythonRandom::new(config.random_seed as u32);
1366    let mut buffer = ReadBuffer::new();
1367    let mut stats = DedupStats {
1368        input_reads: 0,
1369        output_reads: 0,
1370        positions: 0,
1371    };
1372
1373    // Collect all selected records, then sort by coordinate before writing.
1374    // Matches Python umi_tools which calls `pysam.sort()` after processing.
1375    let mut output_records: Vec<Record> = Vec::new();
1376
1377    let mut last_start: i64 = 0;
1378    let mut last_chrom: i32 = -1;
1379
1380    // Per-gene mode: gene tag value → sequential i64 ID used as "position".
1381    let skip_regex = config
1382        .skip_tags_regex
1383        .as_ref()
1384        .map(|s| regex::Regex::new(s).map_err(|e| DedupError::InvalidRegex(e.to_string())))
1385        .transpose()?;
1386    let mut gene_ids: HashMap<Vec<u8>, i64> = HashMap::new();
1387    let mut next_gene_id: i64 = 0;
1388
1389    // Stats collection (optional, only when --output-stats is set)
1390    #[allow(clippy::cast_possible_truncation)]
1391    let mut stats_ctx: Option<StatsContext> = config
1392        .output_stats
1393        .as_ref()
1394        .map(|_| {
1395            let read_gen = RandomReadGenerator::new(
1396                input_path,
1397                config.umi_separator,
1398                &config.extract_umi_method,
1399                config.umi_tag.as_deref(),
1400                config.chrom.as_deref(),
1401                config.random_seed as u32,
1402            )?;
1403            Ok(StatsContext {
1404                collector: StatsCollector::new(),
1405                read_gen,
1406                umi_separator: config.umi_separator,
1407            })
1408        })
1409        .transpose()?;
1410
1411    let wl_ref = config.umi_whitelist.as_ref();
1412
1413    for result in reader.records() {
1414        let record = result.map_err(|e| DedupError::BamRead(e.to_string()))?;
1415
1416        if record.is_unmapped() {
1417            continue;
1418        }
1419
1420        // Paired mode: skip R2 reads and R1s with unmapped mates.
1421        if config.paired {
1422            if record.is_last_in_template() {
1423                continue;
1424            }
1425            if record.is_mate_unmapped() {
1426                continue;
1427            }
1428        }
1429
1430        let tid = record.tid();
1431
1432        // Chromosome filter
1433        if chrom_filter.is_some_and(|filter_tid| tid != filter_tid) {
1434            continue;
1435        }
1436
1437        stats.input_reads += 1;
1438
1439        // Subset check consumes one RNG call per mapped read (before buffer.add)
1440        if config.subset.is_some_and(|s| rng.random() >= f64::from(s)) {
1441            continue;
1442        }
1443
1444        let umi = if config.ignore_umi {
1445            Vec::new()
1446        } else if config.extract_umi_method == "tag" {
1447            match extract_umi_from_tag(&record, config.umi_tag.as_deref().unwrap_or("RX")) {
1448                Some(u) => u,
1449                None => continue,
1450            }
1451        } else {
1452            extract_umi_from_name(&record, config.umi_separator)
1453        };
1454
1455        if config.per_gene {
1456            // Per-gene mode: group by gene tag value instead of position.
1457            let gene_tag_name = config.gene_tag.as_deref().unwrap_or("XF");
1458            let Some(gene) = extract_umi_from_tag(&record, gene_tag_name) else {
1459                continue;
1460            };
1461            if skip_regex
1462                .as_ref()
1463                .is_some_and(|re| re.is_match(std::str::from_utf8(&gene).unwrap_or("")))
1464            {
1465                continue;
1466            }
1467            let gene_id = *gene_ids.entry(gene).or_insert_with(|| {
1468                let id = next_gene_id;
1469                next_gene_id += 1;
1470                id
1471            });
1472            buffer.add(record, gene_id, (false, false, 0, 0), umi, &mut rng);
1473        } else {
1474            let (start, pos) = get_read_position(&record);
1475
1476            // Flush buffer when moving far enough or changing chromosome.
1477            if tid != last_chrom {
1478                output_records.extend(buffer.drain_all(
1479                    config.method,
1480                    config.edit_distance_threshold,
1481                    &mut stats_ctx,
1482                    wl_ref,
1483                ));
1484            } else if start > last_start + 1000 {
1485                let threshold = start - 1000;
1486                output_records.extend(buffer.drain_up_to(
1487                    threshold,
1488                    config.method,
1489                    config.edit_distance_threshold,
1490                    &mut stats_ctx,
1491                    wl_ref,
1492                ));
1493            }
1494
1495            last_start = start;
1496            last_chrom = tid;
1497
1498            let tlen = if config.paired && !config.ignore_tlen {
1499                record.insert_size()
1500            } else {
1501                0
1502            };
1503            let key: GroupKey = (record.is_reverse(), false, tlen, 0);
1504            buffer.add(record, pos, key, umi, &mut rng);
1505        }
1506    }
1507
1508    output_records.extend(buffer.drain_all(
1509        config.method,
1510        config.edit_distance_threshold,
1511        &mut stats_ctx,
1512        wl_ref,
1513    ));
1514
1515    // Paired mode: second pass to find R2 mates of surviving R1 reads.
1516    if config.paired {
1517        let mut mate_set: HashSet<(Vec<u8>, i32, i64)> = HashSet::new();
1518        for r1 in &output_records {
1519            mate_set.insert((r1.qname().to_vec(), r1.mtid(), r1.mpos()));
1520        }
1521        let mut reader2 =
1522            bam::Reader::from_path(input_path).map_err(|e| DedupError::BamOpen(e.to_string()))?;
1523        for result in reader2.records() {
1524            let record = result.map_err(|e| DedupError::BamRead(e.to_string()))?;
1525            if record.is_unmapped() || record.is_mate_unmapped() {
1526                continue;
1527            }
1528            if !record.is_last_in_template() {
1529                continue;
1530            }
1531            let key = (record.qname().to_vec(), record.tid(), record.pos());
1532            if mate_set.remove(&key) {
1533                output_records.push(record);
1534            }
1535        }
1536    }
1537
1538    // Sort by coordinate (tid, pos) to match `pysam.sort()` / `samtools sort`.
1539    output_records.sort_by(|a, b| a.tid().cmp(&b.tid()).then_with(|| a.pos().cmp(&b.pos())));
1540
1541    stats.output_reads = output_records.len() as u64;
1542    for r in &output_records {
1543        writer
1544            .write(r)
1545            .map_err(|e| DedupError::BamWrite(e.to_string()))?;
1546    }
1547
1548    // Drop writer to flush SAM/BAM output.
1549    // The output arg is unused for now (Writer writes to stdout directly).
1550    let _ = output;
1551    drop(writer);
1552
1553    // Write stats files if requested
1554    if let (Some(prefix), Some(ctx)) = (&config.output_stats, &stats_ctx) {
1555        let method_name = match config.method {
1556            DedupMethod::Unique => "unique",
1557            DedupMethod::Percentile => "percentile",
1558            DedupMethod::Cluster => "cluster",
1559            DedupMethod::Adjacency => "adjacency",
1560            DedupMethod::Directional => "directional",
1561        };
1562        ctx.collector.write_files(prefix, method_name)?;
1563    }
1564
1565    Ok(stats)
1566}
1567
1568pub(crate) fn extract_umi_from_tag(record: &Record, tag: &str) -> Option<Vec<u8>> {
1569    match record.aux(tag.as_bytes()) {
1570        Ok(rust_htslib::bam::record::Aux::String(s)) => Some(s.as_bytes().to_vec()),
1571        _ => None,
1572    }
1573}
1574
1575pub(crate) fn extract_umi_from_name(record: &Record, separator: u8) -> Vec<u8> {
1576    let name = record.qname();
1577    name.iter()
1578        .rposition(|&b| b == separator)
1579        .map_or_else(|| name.to_vec(), |pos| name[pos + 1..].to_vec())
1580}
1581
1582#[derive(Debug, thiserror::Error)]
1583pub enum DedupError {
1584    #[error("failed to open BAM: {0}")]
1585    BamOpen(String),
1586    #[error("failed to read BAM record: {0}")]
1587    BamRead(String),
1588    #[error("failed to write BAM/SAM: {0}")]
1589    BamWrite(String),
1590    #[error("unknown chromosome: {0}")]
1591    UnknownChrom(String),
1592    #[error("invalid regex: {0}")]
1593    InvalidRegex(String),
1594    #[error("failed to write stats file {0}: {1}")]
1595    StatsWrite(String, String),
1596}
1597
1598#[cfg(test)]
1599mod tests {
1600    use super::*;
1601
1602    #[test]
1603    fn python_random_matches() {
1604        let mut rng = PythonRandom::new(123_456_789);
1605        let expected = [
1606            0.641_400_616_185_872_6,
1607            0.542_189_268_096_949_5,
1608            0.993_175_066_283_272_1,
1609            0.843_252_136_686_916_6,
1610            0.811_733_928_337_940_6,
1611            0.397_173_710_078_000_4,
1612            0.937_095_107_912_042_5,
1613            0.689_102_653_165_816_2,
1614            0.397_110_488_525_983_74,
1615            0.351_025_192_423_044_75,
1616        ];
1617        for &exp in &expected {
1618            let got = rng.random();
1619            assert!(
1620                (got - exp).abs() < 1e-15,
1621                "mismatch: got {got:.20}, expected {exp:.20}"
1622            );
1623        }
1624    }
1625}