1use std::collections::HashMap;
2use std::f64::consts::PI;
3use std::io::{BufWriter, Write};
4
5use needletail::parser::{FastqReader, FastxReader};
6
7use crate::error::ExtractError;
8use crate::pattern::BarcodePattern;
9
10#[derive(Debug, Clone, Copy, Default)]
12pub enum KneeMethod {
13 #[default]
14 Distance,
15 Density,
16}
17
18#[derive(Debug, Clone, Copy)]
21pub enum EdAboveThreshold {
22 Discard,
23 Correct,
24}
25
26pub struct WhitelistConfig {
28 pub pattern: BarcodePattern,
29 pub knee_method: KneeMethod,
30 pub cell_number: Option<usize>,
31 pub expect_cells: Option<usize>,
32 pub error_correct_threshold: usize,
33 pub ed_above_threshold: Option<EdAboveThreshold>,
34 pub subset_reads: usize,
35}
36
37pub struct WhitelistEntry {
39 pub barcode: String,
40 pub count: u64,
41 pub corrections: Vec<(String, u64)>,
42}
43
44pub struct WhitelistStats {
46 pub input_reads: u64,
47 pub no_match: u64,
48}
49
50pub fn run_whitelist<R: std::io::Read + Send, W: Write, FW: Write>(
55 config: &WhitelistConfig,
56 input: R,
57 output: W,
58 filtered_out: Option<FW>,
59) -> Result<WhitelistStats, ExtractError> {
60 let (all_counts, first_seen, stats) =
61 count_barcodes(&config.pattern, input, config.subset_reads, filtered_out)?;
62
63 let whitelist = determine_whitelist(
64 &all_counts,
65 config.knee_method,
66 config.cell_number,
67 config.expect_cells,
68 );
69
70 let mut corrections =
71 build_error_correction_map(&all_counts, &whitelist, config.error_correct_threshold);
72
73 let whitelist = if let Some(mode) = config.ed_above_threshold {
74 error_detect_above_threshold(
75 &all_counts,
76 &first_seen,
77 whitelist,
78 &mut corrections,
79 config.error_correct_threshold,
80 mode,
81 )
82 } else {
83 whitelist
84 };
85
86 let mut entries: Vec<WhitelistEntry> = whitelist
87 .into_iter()
88 .map(|bc| {
89 let count = all_counts.get(&bc).copied().unwrap_or(0);
90 let corr = corrections.get(&bc).cloned().unwrap_or_default();
91 WhitelistEntry {
92 barcode: bc,
93 count,
94 corrections: corr,
95 }
96 })
97 .collect();
98
99 entries.sort_by(|a, b| a.barcode.cmp(&b.barcode));
100
101 let mut writer = BufWriter::new(output);
102 write_whitelist_tsv(&entries, &mut writer)?;
103 writer.flush()?;
104
105 Ok(stats)
106}
107
108#[allow(clippy::type_complexity)]
111fn count_barcodes<R: std::io::Read + Send, FW: Write>(
112 pattern: &BarcodePattern,
113 input: R,
114 subset_reads: usize,
115 filtered_out: Option<FW>,
116) -> Result<(HashMap<String, u64>, HashMap<String, usize>, WhitelistStats), ExtractError> {
117 let mut counts: HashMap<String, u64> = HashMap::new();
118 let mut first_seen: HashMap<String, usize> = HashMap::new();
119 let mut seen_order: usize = 0;
120 let mut stats = WhitelistStats {
121 input_reads: 0,
122 no_match: 0,
123 };
124 let mut filt_writer = filtered_out.map(BufWriter::new);
125
126 let mut reader = FastqReader::new(input);
127
128 while let Some(result) = reader.next() {
129 let record = result.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
130 stats.input_reads += 1;
131
132 if stats.input_reads > subset_reads as u64 {
133 break;
134 }
135
136 let seq = record.seq();
137 let qual = record
138 .qual()
139 .ok_or_else(|| ExtractError::FastqParse("missing quality scores".into()))?;
140
141 match pattern.extract(&seq, qual) {
142 Ok(extraction) => {
143 let cell = String::from_utf8_lossy(&extraction.cell_barcode).into_owned();
144 if !cell.is_empty() {
145 if !counts.contains_key(&cell) {
146 first_seen.insert(cell.clone(), seen_order);
147 seen_order += 1;
148 }
149 *counts.entry(cell).or_insert(0) += 1;
150 }
151 }
152 Err(ExtractError::ReadTooShort { .. } | ExtractError::RegexNoMatch) => {
153 stats.no_match += 1;
154 if let Some(fw) = filt_writer.as_mut() {
155 write_fastq_record(fw, record.id(), &seq, qual)?;
156 }
157 }
158 Err(e) => return Err(e),
159 }
160 }
161
162 if let Some(fw) = filt_writer.as_mut() {
163 fw.flush()?;
164 }
165
166 Ok((counts, first_seen, stats))
167}
168
169fn write_fastq_record<W: Write>(
171 writer: &mut W,
172 id: &[u8],
173 seq: &[u8],
174 qual: &[u8],
175) -> Result<(), ExtractError> {
176 writer.write_all(b"@")?;
177 writer.write_all(id)?;
178 writer.write_all(b"\n")?;
179 writer.write_all(seq)?;
180 writer.write_all(b"\n+\n")?;
181 writer.write_all(qual)?;
182 writer.write_all(b"\n")?;
183 Ok(())
184}
185
186fn determine_whitelist(
188 all_counts: &HashMap<String, u64>,
189 knee_method: KneeMethod,
190 cell_number: Option<usize>,
191 expect_cells: Option<usize>,
192) -> Vec<String> {
193 let mut sorted_barcodes: Vec<(&String, &u64)> = all_counts.iter().collect();
194 sorted_barcodes.sort_by(|a, b| b.1.cmp(a.1));
195
196 if let Some(n) = cell_number {
197 if n == 0 || sorted_barcodes.is_empty() {
198 return Vec::new();
199 }
200 let threshold_idx = n.min(sorted_barcodes.len()) - 1;
201 let threshold = *sorted_barcodes[threshold_idx].1;
202 sorted_barcodes
203 .iter()
204 .filter(|(_, count)| **count > threshold)
205 .map(|(bc, _)| (*bc).clone())
206 .collect()
207 } else {
208 match knee_method {
209 KneeMethod::Distance => {
210 let counts: Vec<u64> = sorted_barcodes.iter().map(|(_, c)| **c).collect();
211 if counts.is_empty() {
212 return Vec::new();
213 }
214 let knee = knee_distance(&counts);
215 sorted_barcodes[..=knee]
216 .iter()
217 .map(|(bc, _)| (*bc).clone())
218 .collect()
219 }
220 KneeMethod::Density => knee_density(&sorted_barcodes, expect_cells),
221 }
222 }
223}
224
225fn knee_distance(sorted_desc_counts: &[u64]) -> usize {
227 let values = cumulative_sum(sorted_desc_counts);
228 let mut prev = 0;
229 let mut knee = get_max_distance_index(&values);
230 for _ in 0..100 {
231 if knee == prev {
232 break;
233 }
234 prev = knee;
235 let end = (knee * 3).min(values.len());
236 knee = get_max_distance_index(&values[..end]);
237 }
238 knee
239}
240
241#[allow(clippy::cast_precision_loss)]
243fn get_max_distance_index(values: &[f64]) -> usize {
244 let n = values.len();
245 if n <= 1 {
246 return 0;
247 }
248
249 let first = (0.0_f64, values[0]);
250 let last = ((n - 1) as f64, values[n - 1]);
251 let line_vec = (last.0 - first.0, last.1 - first.1);
252 let line_len = line_vec.0.hypot(line_vec.1);
253
254 if line_len == 0.0 {
255 return 0;
256 }
257
258 let line_norm = (line_vec.0 / line_len, line_vec.1 / line_len);
259
260 let mut best_dist = 0.0_f64;
261 let mut best_idx = 0;
262 for (i, &val) in values.iter().enumerate() {
263 let v = (i as f64 - first.0, val - first.1);
264 let scalar_proj = v.0.mul_add(line_norm.0, v.1 * line_norm.1);
265 let parallel = (scalar_proj * line_norm.0, scalar_proj * line_norm.1);
266 let perp = (v.0 - parallel.0, v.1 - parallel.1);
267 let dist = perp.0.hypot(perp.1);
268 if dist > best_dist {
269 best_dist = dist;
270 best_idx = i;
271 }
272 }
273 best_idx
274}
275
276#[allow(clippy::cast_precision_loss)]
278fn cumulative_sum(counts: &[u64]) -> Vec<f64> {
279 let mut result = Vec::with_capacity(counts.len());
280 let mut sum = 0.0_f64;
281 for &c in counts {
282 sum += c as f64;
283 result.push(sum);
284 }
285 result
286}
287
288#[allow(clippy::cast_precision_loss)]
291fn knee_density(sorted_barcodes: &[(&String, &u64)], expect_cells: Option<usize>) -> Vec<String> {
292 if sorted_barcodes.is_empty() {
293 return Vec::new();
294 }
295
296 let max_count = *sorted_barcodes[0].1 as f64;
297 let abundance_threshold = max_count * 0.001;
298
299 let log_counts: Vec<f64> = sorted_barcodes
301 .iter()
302 .map(|(_, c)| **c as f64)
303 .filter(|&c| c > abundance_threshold)
304 .map(f64::log10)
305 .collect();
306
307 if log_counts.is_empty() {
308 return Vec::new();
309 }
310
311 let bw = sample_std(&log_counts) * 0.1;
312 if bw <= 0.0 {
313 return Vec::new();
314 }
315
316 let log_min = log_counts.iter().copied().fold(f64::INFINITY, f64::min);
317 let log_max = log_counts.iter().copied().fold(f64::NEG_INFINITY, f64::max);
318
319 let num_points: usize = 10_000;
320 let xx: Vec<f64> = (0..num_points)
321 .map(|i| (log_max - log_min).mul_add(i as f64 / (num_points - 1) as f64, log_min))
322 .collect();
323
324 let density = gaussian_kde(&log_counts, bw, &xx);
325
326 let local_mins: Vec<usize> = (1..density.len() - 1)
328 .filter(|&i| density[i] < density[i - 1] && density[i] < density[i + 1])
329 .collect();
330
331 if local_mins.is_empty() {
332 return Vec::new();
333 }
334
335 let mut selected_min: Option<usize> = None;
337 for &min_idx in local_mins.iter().rev() {
338 let threshold = 10.0_f64.powf(xx[min_idx]);
339 let passing_count = sorted_barcodes
340 .iter()
341 .filter(|(_, c)| **c as f64 > threshold)
342 .count();
343
344 if let Some(expected) = expect_cells {
345 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
346 let lo = (expected as f64 * 0.1) as usize;
347 if passing_count > lo && passing_count <= expected {
348 selected_min = Some(min_idx);
349 break;
350 }
351 } else {
352 let xx_values = xx.len();
353 let at_least_20pct = min_idx as f64 >= 0.2 * xx_values as f64;
354 let far_from_max = log_max - xx[min_idx] > 0.5;
355 let below_half_max = xx[min_idx] < log_max / 2.0;
356
357 if at_least_20pct && (far_from_max || below_half_max) {
358 selected_min = Some(min_idx);
359 break;
360 }
361 }
362 }
363
364 let Some(min_idx) = selected_min else {
365 return Vec::new();
366 };
367
368 let threshold = 10.0_f64.powf(xx[min_idx]);
369 sorted_barcodes
370 .iter()
371 .filter(|(_, c)| **c as f64 > threshold)
372 .map(|(bc, _)| (*bc).clone())
373 .collect()
374}
375
376#[allow(clippy::cast_precision_loss)]
378fn gaussian_kde(data: &[f64], bw: f64, points: &[f64]) -> Vec<f64> {
379 let n = data.len() as f64;
380 let coeff = 1.0 / (n * bw * (2.0 * PI).sqrt());
381 points
382 .iter()
383 .map(|&x| {
384 coeff
385 * data
386 .iter()
387 .map(|&d| {
388 let z = (x - d) / bw;
389 (-0.5 * z * z).exp()
390 })
391 .sum::<f64>()
392 })
393 .collect()
394}
395
396#[allow(clippy::cast_precision_loss)]
398fn sample_std(data: &[f64]) -> f64 {
399 let n = data.len() as f64;
400 if n <= 1.0 {
401 return 0.0;
402 }
403 let mean = data.iter().sum::<f64>() / n;
404 let var = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
405 var.sqrt()
406}
407
408fn error_detect_above_threshold(
411 all_counts: &HashMap<String, u64>,
412 first_seen: &HashMap<String, usize>,
413 whitelist: Vec<String>,
414 corrections: &mut HashMap<String, Vec<(String, u64)>>,
415 threshold: usize,
416 mode: EdAboveThreshold,
417) -> Vec<String> {
418 let mut sorted_wl: Vec<String> = whitelist;
421 sorted_wl.sort_by(|a, b| {
422 let count_a = all_counts.get(a).copied().unwrap_or(0);
423 let count_b = all_counts.get(b).copied().unwrap_or(0);
424 count_a
425 .cmp(&count_b)
426 .then_with(|| first_seen.get(a).cmp(&first_seen.get(b)))
427 });
428
429 let mut discard: std::collections::HashSet<String> = std::collections::HashSet::new();
430
431 for ix in 0..sorted_wl.len() {
432 let cb = &sorted_wl[ix];
433
434 let mut near_misses: Vec<String> = Vec::new();
436 for higher_bc in &sorted_wl[ix + 1..] {
437 let cb_len = cb.len();
438 let h_len = higher_bc.len();
439 if cb_len.max(h_len) > cb_len.min(h_len) + threshold {
440 continue;
441 }
442 if prefix_edit_distance(cb.as_bytes(), higher_bc.as_bytes()) <= threshold {
443 near_misses.push(higher_bc.clone());
444 if near_misses.len() > 1 {
445 break;
446 }
447 }
448 }
449
450 if near_misses.is_empty() {
451 continue;
452 }
453
454 match mode {
455 EdAboveThreshold::Discard => {
456 discard.insert(cb.clone());
457 }
458 EdAboveThreshold::Correct => {
459 if near_misses.len() == 1
460 && cb.len() == near_misses[0].len()
461 && hamming_distance(cb.as_bytes(), near_misses[0].as_bytes()) <= threshold
462 {
463 let count = all_counts.get(cb).copied().unwrap_or(0);
465 corrections
466 .entry(near_misses[0].clone())
467 .or_default()
468 .push((cb.clone(), count));
469 if let Some(corr_list) = corrections.get_mut(&near_misses[0]) {
471 corr_list.sort_by(|a, b| a.0.cmp(&b.0));
472 }
473 }
474 discard.insert(cb.clone());
475 }
476 }
477 }
478
479 sorted_wl
480 .into_iter()
481 .filter(|bc| !discard.contains(bc))
482 .collect()
483}
484
485fn prefix_edit_distance(a: &[u8], b: &[u8]) -> usize {
490 let m = a.len();
491 let n = b.len();
492
493 let mut prev: Vec<usize> = (0..=n).collect();
494 let mut curr = vec![0; n + 1];
495
496 for i in 1..=m {
497 curr[0] = i;
498 for j in 1..=n {
499 let cost = usize::from(a[i - 1] != b[j - 1]);
500 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
501 }
502 std::mem::swap(&mut prev, &mut curr);
503 }
504
505 *prev.iter().min().unwrap_or(&usize::MAX)
507}
508
509fn build_error_correction_map(
512 all_counts: &HashMap<String, u64>,
513 whitelist: &[String],
514 threshold: usize,
515) -> HashMap<String, Vec<(String, u64)>> {
516 let mut corrections: HashMap<String, Vec<(String, u64)>> = HashMap::new();
517
518 for (barcode, &count) in all_counts {
519 if whitelist.contains(barcode) {
520 continue;
521 }
522
523 let mut matches: Vec<&String> = Vec::new();
524 for wl_bc in whitelist {
525 if hamming_distance(barcode.as_bytes(), wl_bc.as_bytes()) <= threshold {
526 matches.push(wl_bc);
527 }
528 }
529
530 if matches.len() == 1 {
531 corrections
532 .entry(matches[0].clone())
533 .or_default()
534 .push((barcode.clone(), count));
535 }
536 }
537
538 for corr_list in corrections.values_mut() {
539 corr_list.sort_by(|a, b| a.0.cmp(&b.0));
540 }
541
542 corrections
543}
544
545fn hamming_distance(a: &[u8], b: &[u8]) -> usize {
547 if a.len() != b.len() {
548 return usize::MAX;
549 }
550 a.iter().zip(b.iter()).filter(|(x, y)| x != y).count()
551}
552
553fn write_whitelist_tsv<W: Write>(
555 entries: &[WhitelistEntry],
556 writer: &mut W,
557) -> Result<(), ExtractError> {
558 for entry in entries {
559 let error_barcodes: String = entry
560 .corrections
561 .iter()
562 .map(|(bc, _)| bc.as_str())
563 .collect::<Vec<_>>()
564 .join(",");
565 let error_counts: String = entry
566 .corrections
567 .iter()
568 .map(|(_, count)| count.to_string())
569 .collect::<Vec<_>>()
570 .join(",");
571
572 writer.write_all(entry.barcode.as_bytes())?;
573 writer.write_all(b"\t")?;
574 writer.write_all(error_barcodes.as_bytes())?;
575 writer.write_all(b"\t")?;
576 writer.write_all(entry.count.to_string().as_bytes())?;
577 writer.write_all(b"\t")?;
578 writer.write_all(error_counts.as_bytes())?;
579 writer.write_all(b"\n")?;
580 }
581
582 Ok(())
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_hamming_distance_same() {
591 assert_eq!(hamming_distance(b"ACGT", b"ACGT"), 0);
592 }
593
594 #[test]
595 fn test_hamming_distance_one() {
596 assert_eq!(hamming_distance(b"ACGT", b"ACGA"), 1);
597 }
598
599 #[test]
600 fn test_hamming_distance_different_length() {
601 assert_eq!(hamming_distance(b"ACGT", b"ACG"), usize::MAX);
602 }
603
604 #[test]
605 fn test_cumulative_sum() {
606 let counts = vec![10, 5, 3, 1];
607 let result = cumulative_sum(&counts);
608 assert_eq!(result, vec![10.0, 15.0, 18.0, 19.0]);
609 }
610
611 #[test]
612 fn test_get_max_distance_index() {
613 let values = vec![10.0, 15.0, 18.0, 19.0, 20.0];
614 let idx = get_max_distance_index(&values);
615 assert!(idx > 0 && idx < values.len() - 1);
616 }
617
618 #[test]
619 fn test_sample_std() {
620 let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
621 let s = sample_std(&data);
622 assert!((s - 2.138).abs() < 0.01);
624 }
625
626 #[test]
627 fn test_gaussian_kde_single_point() {
628 let data = vec![0.0];
629 let bw = 1.0;
630 let points = vec![0.0];
631 let result = gaussian_kde(&data, bw, &points);
632 assert!((result[0] - 0.3989).abs() < 0.001);
634 }
635}