Skip to main content

umi_core/
whitelist.rs

1use std::collections::HashMap;
2use std::f64::consts::PI;
3use std::io::{BufWriter, Write};
4
5use needletail::parser::{FastqReader, FastxReader};
6
7use crate::error::ExtractError;
8use crate::pattern::BarcodePattern;
9
10/// Method for detecting the knee point in the barcode frequency distribution.
11#[derive(Debug, Clone, Copy, Default)]
12pub enum KneeMethod {
13    #[default]
14    Distance,
15    Density,
16}
17
18/// How to handle whitelist barcodes whose edit distance to a higher-count
19/// whitelist barcode is within the error-correct threshold.
20#[derive(Debug, Clone, Copy)]
21pub enum EdAboveThreshold {
22    Discard,
23    Correct,
24}
25
26/// Configuration for the whitelist subcommand.
27pub struct WhitelistConfig {
28    pub pattern: BarcodePattern,
29    pub knee_method: KneeMethod,
30    pub cell_number: Option<usize>,
31    pub expect_cells: Option<usize>,
32    pub error_correct_threshold: usize,
33    pub ed_above_threshold: Option<EdAboveThreshold>,
34    pub subset_reads: usize,
35}
36
37/// A single whitelisted barcode with its count and error-correction mappings.
38pub struct WhitelistEntry {
39    pub barcode: String,
40    pub count: u64,
41    pub corrections: Vec<(String, u64)>,
42}
43
44/// Statistics from a whitelist run.
45pub struct WhitelistStats {
46    pub input_reads: u64,
47    pub no_match: u64,
48}
49
50/// Run the whitelist pipeline: count barcodes, find knee, build error correction, write TSV.
51///
52/// # Errors
53/// Returns error on I/O or pattern-matching failures.
54pub fn run_whitelist<R: std::io::Read + Send, W: Write, FW: Write>(
55    config: &WhitelistConfig,
56    input: R,
57    output: W,
58    filtered_out: Option<FW>,
59) -> Result<WhitelistStats, ExtractError> {
60    let (all_counts, first_seen, stats) =
61        count_barcodes(&config.pattern, input, config.subset_reads, filtered_out)?;
62
63    let whitelist = determine_whitelist(
64        &all_counts,
65        config.knee_method,
66        config.cell_number,
67        config.expect_cells,
68    );
69
70    let mut corrections =
71        build_error_correction_map(&all_counts, &whitelist, config.error_correct_threshold);
72
73    let whitelist = if let Some(mode) = config.ed_above_threshold {
74        error_detect_above_threshold(
75            &all_counts,
76            &first_seen,
77            whitelist,
78            &mut corrections,
79            config.error_correct_threshold,
80            mode,
81        )
82    } else {
83        whitelist
84    };
85
86    let mut entries: Vec<WhitelistEntry> = whitelist
87        .into_iter()
88        .map(|bc| {
89            let count = all_counts.get(&bc).copied().unwrap_or(0);
90            let corr = corrections.get(&bc).cloned().unwrap_or_default();
91            WhitelistEntry {
92                barcode: bc,
93                count,
94                corrections: corr,
95            }
96        })
97        .collect();
98
99    entries.sort_by(|a, b| a.barcode.cmp(&b.barcode));
100
101    let mut writer = BufWriter::new(output);
102    write_whitelist_tsv(&entries, &mut writer)?;
103    writer.flush()?;
104
105    Ok(stats)
106}
107
108/// Read FASTQ, extract cell barcodes, count frequencies.
109/// Optionally writes non-matching reads to `filtered_out`.
110#[allow(clippy::type_complexity)]
111fn count_barcodes<R: std::io::Read + Send, FW: Write>(
112    pattern: &BarcodePattern,
113    input: R,
114    subset_reads: usize,
115    filtered_out: Option<FW>,
116) -> Result<(HashMap<String, u64>, HashMap<String, usize>, WhitelistStats), ExtractError> {
117    let mut counts: HashMap<String, u64> = HashMap::new();
118    let mut first_seen: HashMap<String, usize> = HashMap::new();
119    let mut seen_order: usize = 0;
120    let mut stats = WhitelistStats {
121        input_reads: 0,
122        no_match: 0,
123    };
124    let mut filt_writer = filtered_out.map(BufWriter::new);
125
126    let mut reader = FastqReader::new(input);
127
128    while let Some(result) = reader.next() {
129        let record = result.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
130        stats.input_reads += 1;
131
132        if stats.input_reads > subset_reads as u64 {
133            break;
134        }
135
136        let seq = record.seq();
137        let qual = record
138            .qual()
139            .ok_or_else(|| ExtractError::FastqParse("missing quality scores".into()))?;
140
141        match pattern.extract(&seq, qual) {
142            Ok(extraction) => {
143                let cell = String::from_utf8_lossy(&extraction.cell_barcode).into_owned();
144                if !cell.is_empty() {
145                    if !counts.contains_key(&cell) {
146                        first_seen.insert(cell.clone(), seen_order);
147                        seen_order += 1;
148                    }
149                    *counts.entry(cell).or_insert(0) += 1;
150                }
151            }
152            Err(ExtractError::ReadTooShort { .. } | ExtractError::RegexNoMatch) => {
153                stats.no_match += 1;
154                if let Some(fw) = filt_writer.as_mut() {
155                    write_fastq_record(fw, record.id(), &seq, qual)?;
156                }
157            }
158            Err(e) => return Err(e),
159        }
160    }
161
162    if let Some(fw) = filt_writer.as_mut() {
163        fw.flush()?;
164    }
165
166    Ok((counts, first_seen, stats))
167}
168
169/// Write a FASTQ record (used for filtered-out output).
170fn write_fastq_record<W: Write>(
171    writer: &mut W,
172    id: &[u8],
173    seq: &[u8],
174    qual: &[u8],
175) -> Result<(), ExtractError> {
176    writer.write_all(b"@")?;
177    writer.write_all(id)?;
178    writer.write_all(b"\n")?;
179    writer.write_all(seq)?;
180    writer.write_all(b"\n+\n")?;
181    writer.write_all(qual)?;
182    writer.write_all(b"\n")?;
183    Ok(())
184}
185
186/// Determine which barcodes to whitelist based on knee detection or explicit cell number.
187fn determine_whitelist(
188    all_counts: &HashMap<String, u64>,
189    knee_method: KneeMethod,
190    cell_number: Option<usize>,
191    expect_cells: Option<usize>,
192) -> Vec<String> {
193    let mut sorted_barcodes: Vec<(&String, &u64)> = all_counts.iter().collect();
194    sorted_barcodes.sort_by(|a, b| b.1.cmp(a.1));
195
196    if let Some(n) = cell_number {
197        if n == 0 || sorted_barcodes.is_empty() {
198            return Vec::new();
199        }
200        let threshold_idx = n.min(sorted_barcodes.len()) - 1;
201        let threshold = *sorted_barcodes[threshold_idx].1;
202        sorted_barcodes
203            .iter()
204            .filter(|(_, count)| **count > threshold)
205            .map(|(bc, _)| (*bc).clone())
206            .collect()
207    } else {
208        match knee_method {
209            KneeMethod::Distance => {
210                let counts: Vec<u64> = sorted_barcodes.iter().map(|(_, c)| **c).collect();
211                if counts.is_empty() {
212                    return Vec::new();
213                }
214                let knee = knee_distance(&counts);
215                sorted_barcodes[..=knee]
216                    .iter()
217                    .map(|(bc, _)| (*bc).clone())
218                    .collect()
219            }
220            KneeMethod::Density => knee_density(&sorted_barcodes, expect_cells),
221        }
222    }
223}
224
225/// Iterative distance-to-diagonal knee detection on cumulative counts.
226fn knee_distance(sorted_desc_counts: &[u64]) -> usize {
227    let values = cumulative_sum(sorted_desc_counts);
228    let mut prev = 0;
229    let mut knee = get_max_distance_index(&values);
230    for _ in 0..100 {
231        if knee == prev {
232            break;
233        }
234        prev = knee;
235        let end = (knee * 3).min(values.len());
236        knee = get_max_distance_index(&values[..end]);
237    }
238    knee
239}
240
241/// Find the index with maximum perpendicular distance to the line from first to last point.
242#[allow(clippy::cast_precision_loss)]
243fn get_max_distance_index(values: &[f64]) -> usize {
244    let n = values.len();
245    if n <= 1 {
246        return 0;
247    }
248
249    let first = (0.0_f64, values[0]);
250    let last = ((n - 1) as f64, values[n - 1]);
251    let line_vec = (last.0 - first.0, last.1 - first.1);
252    let line_len = line_vec.0.hypot(line_vec.1);
253
254    if line_len == 0.0 {
255        return 0;
256    }
257
258    let line_norm = (line_vec.0 / line_len, line_vec.1 / line_len);
259
260    let mut best_dist = 0.0_f64;
261    let mut best_idx = 0;
262    for (i, &val) in values.iter().enumerate() {
263        let v = (i as f64 - first.0, val - first.1);
264        let scalar_proj = v.0.mul_add(line_norm.0, v.1 * line_norm.1);
265        let parallel = (scalar_proj * line_norm.0, scalar_proj * line_norm.1);
266        let perp = (v.0 - parallel.0, v.1 - parallel.1);
267        let dist = perp.0.hypot(perp.1);
268        if dist > best_dist {
269            best_dist = dist;
270            best_idx = i;
271        }
272    }
273    best_idx
274}
275
276/// Compute cumulative sum as f64 values.
277#[allow(clippy::cast_precision_loss)]
278fn cumulative_sum(counts: &[u64]) -> Vec<f64> {
279    let mut result = Vec::with_capacity(counts.len());
280    let mut sum = 0.0_f64;
281    for &c in counts {
282        sum += c as f64;
283        result.push(sum);
284    }
285    result
286}
287
288/// Density-based knee detection using Gaussian KDE on log10-transformed counts.
289/// Matches scipy's `gaussian_kde(data, bw_method=0.1)` behavior.
290#[allow(clippy::cast_precision_loss)]
291fn knee_density(sorted_barcodes: &[(&String, &u64)], expect_cells: Option<usize>) -> Vec<String> {
292    if sorted_barcodes.is_empty() {
293        return Vec::new();
294    }
295
296    let max_count = *sorted_barcodes[0].1 as f64;
297    let abundance_threshold = max_count * 0.001;
298
299    // Log10-transform counts above abundance threshold
300    let log_counts: Vec<f64> = sorted_barcodes
301        .iter()
302        .map(|(_, c)| **c as f64)
303        .filter(|&c| c > abundance_threshold)
304        .map(f64::log10)
305        .collect();
306
307    if log_counts.is_empty() {
308        return Vec::new();
309    }
310
311    let bw = sample_std(&log_counts) * 0.1;
312    if bw <= 0.0 {
313        return Vec::new();
314    }
315
316    let log_min = log_counts.iter().copied().fold(f64::INFINITY, f64::min);
317    let log_max = log_counts.iter().copied().fold(f64::NEG_INFINITY, f64::max);
318
319    let num_points: usize = 10_000;
320    let xx: Vec<f64> = (0..num_points)
321        .map(|i| (log_max - log_min).mul_add(i as f64 / (num_points - 1) as f64, log_min))
322        .collect();
323
324    let density = gaussian_kde(&log_counts, bw, &xx);
325
326    // Find local minima: density[i] < density[i-1] && density[i] < density[i+1]
327    let local_mins: Vec<usize> = (1..density.len() - 1)
328        .filter(|&i| density[i] < density[i - 1] && density[i] < density[i + 1])
329        .collect();
330
331    if local_mins.is_empty() {
332        return Vec::new();
333    }
334
335    // Select the appropriate local minimum by iterating in reverse
336    let mut selected_min: Option<usize> = None;
337    for &min_idx in local_mins.iter().rev() {
338        let threshold = 10.0_f64.powf(xx[min_idx]);
339        let passing_count = sorted_barcodes
340            .iter()
341            .filter(|(_, c)| **c as f64 > threshold)
342            .count();
343
344        if let Some(expected) = expect_cells {
345            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
346            let lo = (expected as f64 * 0.1) as usize;
347            if passing_count > lo && passing_count <= expected {
348                selected_min = Some(min_idx);
349                break;
350            }
351        } else {
352            let xx_values = xx.len();
353            let at_least_20pct = min_idx as f64 >= 0.2 * xx_values as f64;
354            let far_from_max = log_max - xx[min_idx] > 0.5;
355            let below_half_max = xx[min_idx] < log_max / 2.0;
356
357            if at_least_20pct && (far_from_max || below_half_max) {
358                selected_min = Some(min_idx);
359                break;
360            }
361        }
362    }
363
364    let Some(min_idx) = selected_min else {
365        return Vec::new();
366    };
367
368    let threshold = 10.0_f64.powf(xx[min_idx]);
369    sorted_barcodes
370        .iter()
371        .filter(|(_, c)| **c as f64 > threshold)
372        .map(|(bc, _)| (*bc).clone())
373        .collect()
374}
375
376/// Gaussian KDE evaluation matching scipy's `gaussian_kde` behavior.
377#[allow(clippy::cast_precision_loss)]
378fn gaussian_kde(data: &[f64], bw: f64, points: &[f64]) -> Vec<f64> {
379    let n = data.len() as f64;
380    let coeff = 1.0 / (n * bw * (2.0 * PI).sqrt());
381    points
382        .iter()
383        .map(|&x| {
384            coeff
385                * data
386                    .iter()
387                    .map(|&d| {
388                        let z = (x - d) / bw;
389                        (-0.5 * z * z).exp()
390                    })
391                    .sum::<f64>()
392        })
393        .collect()
394}
395
396/// Sample standard deviation (ddof=1, Bessel's correction).
397#[allow(clippy::cast_precision_loss)]
398fn sample_std(data: &[f64]) -> f64 {
399    let n = data.len() as f64;
400    if n <= 1.0 {
401        return 0.0;
402    }
403    let mean = data.iter().sum::<f64>() / n;
404    let var = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
405    var.sqrt()
406}
407
408/// Detect whitelist barcodes that may be errors of higher-count whitelist barcodes.
409/// Removes them from the whitelist, optionally correcting substitution errors.
410fn error_detect_above_threshold(
411    all_counts: &HashMap<String, u64>,
412    first_seen: &HashMap<String, usize>,
413    whitelist: Vec<String>,
414    corrections: &mut HashMap<String, Vec<(String, u64)>>,
415    threshold: usize,
416    mode: EdAboveThreshold,
417) -> Vec<String> {
418    // Sort whitelist by count ascending, with first-seen order as tiebreaker
419    // (matches Python's stable sort on Counter.most_common() insertion order)
420    let mut sorted_wl: Vec<String> = whitelist;
421    sorted_wl.sort_by(|a, b| {
422        let count_a = all_counts.get(a).copied().unwrap_or(0);
423        let count_b = all_counts.get(b).copied().unwrap_or(0);
424        count_a
425            .cmp(&count_b)
426            .then_with(|| first_seen.get(a).cmp(&first_seen.get(b)))
427    });
428
429    let mut discard: std::collections::HashSet<String> = std::collections::HashSet::new();
430
431    for ix in 0..sorted_wl.len() {
432        let cb = &sorted_wl[ix];
433
434        // Find near misses among higher-count barcodes
435        let mut near_misses: Vec<String> = Vec::new();
436        for higher_bc in &sorted_wl[ix + 1..] {
437            let cb_len = cb.len();
438            let h_len = higher_bc.len();
439            if cb_len.max(h_len) > cb_len.min(h_len) + threshold {
440                continue;
441            }
442            if prefix_edit_distance(cb.as_bytes(), higher_bc.as_bytes()) <= threshold {
443                near_misses.push(higher_bc.clone());
444                if near_misses.len() > 1 {
445                    break;
446                }
447            }
448        }
449
450        if near_misses.is_empty() {
451            continue;
452        }
453
454        match mode {
455            EdAboveThreshold::Discard => {
456                discard.insert(cb.clone());
457            }
458            EdAboveThreshold::Correct => {
459                if near_misses.len() == 1
460                    && cb.len() == near_misses[0].len()
461                    && hamming_distance(cb.as_bytes(), near_misses[0].as_bytes()) <= threshold
462                {
463                    // Pure substitution: correct by adding to the higher-count barcode's map
464                    let count = all_counts.get(cb).copied().unwrap_or(0);
465                    corrections
466                        .entry(near_misses[0].clone())
467                        .or_default()
468                        .push((cb.clone(), count));
469                    // Re-sort after adding
470                    if let Some(corr_list) = corrections.get_mut(&near_misses[0]) {
471                        corr_list.sort_by(|a, b| a.0.cmp(&b.0));
472                    }
473                }
474                discard.insert(cb.clone());
475            }
476        }
477    }
478
479    sorted_wl
480        .into_iter()
481        .filter(|bc| !discard.contains(bc))
482        .collect()
483}
484
485/// Semi-global edit distance: minimum edit distance between pattern `a` and any
486/// prefix of target `b`. Matches Python `regex.compile("(a){e<=N}").match(b)`
487/// semantics where `match()` anchors at the start but doesn't require consuming
488/// the entire target.
489fn prefix_edit_distance(a: &[u8], b: &[u8]) -> usize {
490    let m = a.len();
491    let n = b.len();
492
493    let mut prev: Vec<usize> = (0..=n).collect();
494    let mut curr = vec![0; n + 1];
495
496    for i in 1..=m {
497        curr[0] = i;
498        for j in 1..=n {
499            let cost = usize::from(a[i - 1] != b[j - 1]);
500            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
501        }
502        std::mem::swap(&mut prev, &mut curr);
503    }
504
505    // Minimum over all prefix lengths of b (including full b)
506    *prev.iter().min().unwrap_or(&usize::MAX)
507}
508
509/// Build error correction map: for each non-whitelist barcode, find if it maps
510/// uniquely to exactly one whitelist barcode within the Hamming distance threshold.
511fn build_error_correction_map(
512    all_counts: &HashMap<String, u64>,
513    whitelist: &[String],
514    threshold: usize,
515) -> HashMap<String, Vec<(String, u64)>> {
516    let mut corrections: HashMap<String, Vec<(String, u64)>> = HashMap::new();
517
518    for (barcode, &count) in all_counts {
519        if whitelist.contains(barcode) {
520            continue;
521        }
522
523        let mut matches: Vec<&String> = Vec::new();
524        for wl_bc in whitelist {
525            if hamming_distance(barcode.as_bytes(), wl_bc.as_bytes()) <= threshold {
526                matches.push(wl_bc);
527            }
528        }
529
530        if matches.len() == 1 {
531            corrections
532                .entry(matches[0].clone())
533                .or_default()
534                .push((barcode.clone(), count));
535        }
536    }
537
538    for corr_list in corrections.values_mut() {
539        corr_list.sort_by(|a, b| a.0.cmp(&b.0));
540    }
541
542    corrections
543}
544
545/// Hamming distance between two byte strings. Returns `usize::MAX` if lengths differ.
546fn hamming_distance(a: &[u8], b: &[u8]) -> usize {
547    if a.len() != b.len() {
548        return usize::MAX;
549    }
550    a.iter().zip(b.iter()).filter(|(x, y)| x != y).count()
551}
552
553/// Write whitelist entries as 4-column TSV.
554fn write_whitelist_tsv<W: Write>(
555    entries: &[WhitelistEntry],
556    writer: &mut W,
557) -> Result<(), ExtractError> {
558    for entry in entries {
559        let error_barcodes: String = entry
560            .corrections
561            .iter()
562            .map(|(bc, _)| bc.as_str())
563            .collect::<Vec<_>>()
564            .join(",");
565        let error_counts: String = entry
566            .corrections
567            .iter()
568            .map(|(_, count)| count.to_string())
569            .collect::<Vec<_>>()
570            .join(",");
571
572        writer.write_all(entry.barcode.as_bytes())?;
573        writer.write_all(b"\t")?;
574        writer.write_all(error_barcodes.as_bytes())?;
575        writer.write_all(b"\t")?;
576        writer.write_all(entry.count.to_string().as_bytes())?;
577        writer.write_all(b"\t")?;
578        writer.write_all(error_counts.as_bytes())?;
579        writer.write_all(b"\n")?;
580    }
581
582    Ok(())
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_hamming_distance_same() {
591        assert_eq!(hamming_distance(b"ACGT", b"ACGT"), 0);
592    }
593
594    #[test]
595    fn test_hamming_distance_one() {
596        assert_eq!(hamming_distance(b"ACGT", b"ACGA"), 1);
597    }
598
599    #[test]
600    fn test_hamming_distance_different_length() {
601        assert_eq!(hamming_distance(b"ACGT", b"ACG"), usize::MAX);
602    }
603
604    #[test]
605    fn test_cumulative_sum() {
606        let counts = vec![10, 5, 3, 1];
607        let result = cumulative_sum(&counts);
608        assert_eq!(result, vec![10.0, 15.0, 18.0, 19.0]);
609    }
610
611    #[test]
612    fn test_get_max_distance_index() {
613        let values = vec![10.0, 15.0, 18.0, 19.0, 20.0];
614        let idx = get_max_distance_index(&values);
615        assert!(idx > 0 && idx < values.len() - 1);
616    }
617
618    #[test]
619    fn test_sample_std() {
620        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
621        let s = sample_std(&data);
622        // Expected: sqrt(32/7) ≈ 2.138
623        assert!((s - 2.138).abs() < 0.01);
624    }
625
626    #[test]
627    fn test_gaussian_kde_single_point() {
628        let data = vec![0.0];
629        let bw = 1.0;
630        let points = vec![0.0];
631        let result = gaussian_kde(&data, bw, &points);
632        // At data point with bw=1: 1/(1*1*sqrt(2*pi)) * exp(0) ≈ 0.3989
633        assert!((result[0] - 0.3989).abs() < 0.001);
634    }
635}