1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fs::File;
3use std::io::{BufWriter, Write};
4
5use rust_htslib::bam::{self, Read as BamRead, Record};
6
7use crate::dedup::{
8 DedupMethod, GroupKey, PythonRandom, TieBreakRng, build_adjacency_list,
9 build_directional_adjacency_list, connected_components, extract_umi_from_name,
10 extract_umi_from_tag, get_read_position, median, min_set_cover,
11};
12
13#[derive(Clone, Copy, PartialEq, Eq)]
14pub enum ChimericPairs {
15 Discard,
16 Output,
17 Use,
18}
19
20#[derive(Clone, Copy, PartialEq, Eq)]
21pub enum UnmappedHandling {
22 Discard,
23 Output,
24 Use,
25}
26
27#[allow(clippy::struct_excessive_bools)]
28pub struct GroupConfig {
29 pub method: DedupMethod,
30 pub ignore_umi: bool,
31 pub umi_separator: u8,
32 pub random_seed: u64,
33 pub out_sam: bool,
34 pub output_bam: bool,
35 pub no_sort_output: bool,
36 pub chrom: Option<String>,
37 pub group_out: Option<String>,
38 pub edit_distance_threshold: u32,
39 pub subset: Option<f32>,
40 pub per_gene: bool,
41 pub gene_tag: Option<String>,
42 pub skip_tags_regex: Option<String>,
43 pub per_contig: bool,
44 pub paired: bool,
45 pub chimeric_pairs: ChimericPairs,
46 pub unmapped_handling: UnmappedHandling,
47}
48
49pub struct GroupStats {
50 pub input_reads: u64,
51 pub output_reads: u64,
52}
53
54struct GroupSlot {
55 records: Vec<Record>,
56 count: u32,
57 insertion_order: u32,
58}
59
60struct GroupBuffer {
61 groups: BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, GroupSlot>>>,
62 insertion_counters: BTreeMap<i64, BTreeMap<GroupKey, u32>>,
63}
64
65impl GroupBuffer {
66 const fn new() -> Self {
67 Self {
68 groups: BTreeMap::new(),
69 insertion_counters: BTreeMap::new(),
70 }
71 }
72
73 fn add(&mut self, record: Record, pos: i64, key: GroupKey, umi: Vec<u8>) {
74 let umi_map = self.groups.entry(pos).or_default().entry(key).or_default();
75
76 if let Some(slot) = umi_map.get_mut(&umi) {
77 slot.count += 1;
78 slot.records.push(record);
79 return;
80 }
81
82 let counter = self
83 .insertion_counters
84 .entry(pos)
85 .or_default()
86 .entry(key)
87 .or_default();
88 let order = *counter;
89 *counter += 1;
90
91 umi_map.insert(
92 umi,
93 GroupSlot {
94 records: vec![record],
95 count: 1,
96 insertion_order: order,
97 },
98 );
99 }
100
101 fn drain_up_to(
102 &mut self,
103 threshold: i64,
104 ) -> BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, GroupSlot>>> {
105 let rest = self.groups.split_off(&(threshold + 1));
106 let drained = std::mem::replace(&mut self.groups, rest);
107 let rest_counters = self.insertion_counters.split_off(&(threshold + 1));
108 let _ = std::mem::replace(&mut self.insertion_counters, rest_counters);
109 drained
110 }
111
112 fn drain_all(&mut self) -> BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, GroupSlot>>> {
113 let drained = std::mem::take(&mut self.groups);
114 self.insertion_counters.clear();
115 drained
116 }
117}
118
119#[allow(clippy::too_many_lines)]
122fn assign_groups(
123 method: DedupMethod,
124 umi_map: &HashMap<Vec<u8>, GroupSlot>,
125 edit_threshold: u32,
126) -> Vec<Vec<Vec<u8>>> {
127 let counts: HashMap<&[u8], u32> = umi_map
128 .iter()
129 .map(|(k, v)| (k.as_slice(), v.count))
130 .collect();
131 let orders: HashMap<&[u8], u32> = umi_map
132 .iter()
133 .map(|(k, v)| (k.as_slice(), v.insertion_order))
134 .collect();
135
136 let lex_sort = |a: &[u8], b: &[u8]| -> std::cmp::Ordering {
137 counts[b].cmp(&counts[a]).then_with(|| a.cmp(b))
138 };
139
140 match method {
141 DedupMethod::Unique => {
142 let mut umis: Vec<Vec<u8>> = umi_map.keys().cloned().collect();
143 umis.sort_by(|a, b| orders[a.as_slice()].cmp(&orders[b.as_slice()]));
144 umis.into_iter().map(|u| vec![u]).collect()
145 }
146
147 DedupMethod::Percentile => {
148 if counts.len() <= 1 {
149 return umi_map.keys().cloned().map(|u| vec![u]).collect();
150 }
151 let all_counts: Vec<u32> = counts.values().copied().collect();
152 let threshold = median(&all_counts) / 100.0;
153 let mut umis: Vec<Vec<u8>> = umi_map
154 .iter()
155 .filter(|(_, slot)| f64::from(slot.count) > threshold)
156 .map(|(umi, _)| umi.clone())
157 .collect();
158 umis.sort_by(|a, b| orders[a.as_slice()].cmp(&orders[b.as_slice()]));
159 umis.into_iter().map(|u| vec![u]).collect()
160 }
161
162 DedupMethod::Cluster => {
163 let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
164 let adj_list = build_adjacency_list(&umis, edit_threshold);
165 let components = connected_components(&umis, &counts, &orders, &adj_list);
166 components
167 .into_iter()
168 .map(|mut comp| {
169 comp.sort_by(|a, b| lex_sort(a, b));
170 comp.into_iter().map(<[u8]>::to_vec).collect()
171 })
172 .collect()
173 }
174
175 DedupMethod::Adjacency => {
176 let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
177 let adj_list = build_adjacency_list(&umis, edit_threshold);
178 let components = connected_components(&umis, &counts, &orders, &adj_list);
179 let mut groups = Vec::new();
182 for component in components {
183 if component.len() == 1 {
184 groups.push(component.into_iter().map(<[u8]>::to_vec).collect());
185 } else {
186 let lead_umis = min_set_cover(&component, &adj_list, &counts);
187 let mut observed: HashSet<&[u8]> = lead_umis.iter().copied().collect();
188 for &lead in &lead_umis {
189 let connected: HashSet<&[u8]> = adj_list
190 .get(lead)
191 .map_or_else(HashSet::new, |ns| ns.iter().copied().collect());
192 let mut group = vec![lead.to_vec()];
193 for node in connected {
194 if observed.insert(node) {
195 group.push(node.to_vec());
196 }
197 }
198 groups.push(group);
199 }
200 }
201 }
202 groups
203 }
204
205 DedupMethod::Directional => {
206 let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
207 let adj_list = build_directional_adjacency_list(&umis, &counts, edit_threshold);
208 let components = connected_components(&umis, &counts, &orders, &adj_list);
209 let mut observed: HashSet<&[u8]> = HashSet::new();
213 let mut groups = Vec::new();
214 for mut comp in components {
215 comp.sort_by(|a, b| lex_sort(a, b));
216 if comp.len() == 1 {
217 observed.insert(comp[0]);
218 groups.push(comp.into_iter().map(<[u8]>::to_vec).collect());
219 } else {
220 let mut filtered: Vec<Vec<u8>> = Vec::new();
221 for node in comp {
222 if observed.insert(node) {
223 filtered.push(node.to_vec());
224 }
225 }
226 if !filtered.is_empty() {
227 groups.push(filtered);
228 }
229 }
230 }
231 groups
232 }
233 }
234}
235
236#[allow(clippy::cast_sign_loss)]
238fn process_drained(
239 drained: BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, GroupSlot>>>,
240 method: DedupMethod,
241 edit_threshold: u32,
242 unique_id: &mut u32,
243 tsv_writer: &mut Option<BufWriter<File>>,
244 header_view: &bam::HeaderView,
245 gene_labels: &HashMap<i64, String>,
246) -> Result<Vec<Record>, GroupError> {
247 let mut output_records = Vec::new();
248
249 let entries: Vec<_> = if gene_labels.is_empty() {
251 drained.into_iter().collect()
252 } else {
253 let mut v: Vec<_> = drained.into_iter().collect();
254 v.sort_by(|(a, _), (b, _)| {
255 let la = gene_labels.get(a).map_or("", String::as_str);
256 let lb = gene_labels.get(b).map_or("", String::as_str);
257 la.cmp(lb)
258 });
259 v
260 };
261
262 for (pos, key_map) in entries {
263 let gene_label = gene_labels.get(&pos).map_or("NA", String::as_str);
264
265 for (_, mut umi_map) in key_map {
266 let groups = assign_groups(method, &umi_map, edit_threshold);
267
268 for group in &groups {
269 let top_umi = &group[0];
270 let group_count: u32 = group.iter().map(|u| umi_map[u].count).sum();
271 let top_umi_str = std::str::from_utf8(top_umi).unwrap_or("");
272
273 for umi in group {
274 let slot = umi_map.remove(umi).expect("UMI must exist in umi_map");
275
276 for record in slot.records {
277 if let Some(w) = tsv_writer.as_mut() {
278 let read_name = std::str::from_utf8(record.qname()).unwrap_or("");
279 let contig =
280 std::str::from_utf8(header_view.tid2name(record.tid() as u32))
281 .unwrap_or("");
282 let umi_str = std::str::from_utf8(umi).unwrap_or("");
283 let (_, read_pos) = get_read_position(&record);
284
285 writeln!(
286 w,
287 "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}",
288 read_name,
289 contig,
290 read_pos,
291 gene_label,
292 umi_str,
293 slot.count,
294 top_umi_str,
295 group_count,
296 *unique_id,
297 )
298 .map_err(|e| GroupError::TsvWrite(e.to_string()))?;
299 }
300
301 let mut tagged = record;
302 tagged
303 .push_aux(
304 b"UG",
305 #[allow(clippy::cast_possible_wrap)]
306 rust_htslib::bam::record::Aux::I32(*unique_id as i32),
307 )
308 .ok();
309 tagged
310 .push_aux(b"BX", rust_htslib::bam::record::Aux::String(top_umi_str))
311 .ok();
312
313 output_records.push(tagged);
314 }
315 }
316
317 *unique_id += 1;
318 }
319 }
320 }
321
322 Ok(output_records)
323}
324
325#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
329pub fn run_group(config: &GroupConfig, input_path: &str) -> Result<GroupStats, GroupError> {
330 if config.per_contig && !config.per_gene {
331 return Err(GroupError::PerContigRequiresPerGene);
332 }
333
334 let mut reader =
335 bam::Reader::from_path(input_path).map_err(|e| GroupError::BamOpen(e.to_string()))?;
336 let header = bam::Header::from_template(reader.header());
337 let header_view = reader.header().clone();
338
339 let format = if config.out_sam {
340 bam::Format::Sam
341 } else {
342 bam::Format::Bam
343 };
344
345 let mut writer = bam::Writer::from_stdout(&header, format)
346 .map_err(|e| GroupError::BamWrite(e.to_string()))?;
347
348 let chrom_filter: Option<i32> = config
350 .chrom
351 .as_ref()
352 .map(|c| {
353 let tid = reader
354 .header()
355 .tid(c.as_bytes())
356 .ok_or_else(|| GroupError::UnknownChrom(c.clone()))?;
357 #[allow(clippy::cast_possible_wrap)]
358 Ok(tid as i32)
359 })
360 .transpose()?;
361
362 let mut tsv_writer: Option<BufWriter<File>> = config
364 .group_out
365 .as_ref()
366 .map(|path| {
367 let file =
368 File::create(path).map_err(|e| GroupError::TsvWrite(e.to_string()))?;
369 let mut w = BufWriter::new(file);
370 writeln!(
371 w,
372 "read_id\tcontig\tposition\tgene\tumi\tumi_count\tfinal_umi\tfinal_umi_count\tunique_id"
373 )
374 .map_err(|e| GroupError::TsvWrite(e.to_string()))?;
375 Ok(w)
376 })
377 .transpose()?;
378
379 let skip_regex = config
380 .skip_tags_regex
381 .as_ref()
382 .map(|s| regex::Regex::new(s).map_err(|e| GroupError::InvalidRegex(e.to_string())))
383 .transpose()?;
384
385 let output_unmapped = config.unmapped_handling == UnmappedHandling::Output
386 || config.unmapped_handling == UnmappedHandling::Use;
387
388 let mut buffer = GroupBuffer::new();
389 let mut stats = GroupStats {
390 input_reads: 0,
391 output_reads: 0,
392 };
393
394 #[allow(clippy::cast_possible_truncation)]
395 let mut rng = PythonRandom::new(config.random_seed as u32);
396
397 let mut output_records: Vec<Record> = Vec::new();
398 let mut unique_id: u32 = 0;
399
400 let mut last_start: i64 = 0;
401 let mut last_chrom: i32 = -1;
402
403 let mut gene_ids: HashMap<Vec<u8>, i64> = HashMap::new();
405 let mut gene_labels: HashMap<i64, String> = HashMap::new();
406 let mut next_gene_id: i64 = 0;
407
408 for result in reader.records() {
409 let record = result.map_err(|e| GroupError::BamRead(e.to_string()))?;
410
411 if record.is_last_in_template() {
413 if record.is_unmapped() {
414 if output_unmapped {
415 output_records.push(record);
416 }
417 } else {
418 output_records.push(record);
419 }
420 continue;
421 }
422
423 if record.is_unmapped() {
425 if output_unmapped {
426 output_records.push(record);
427 }
428 continue;
429 }
430
431 let tid = record.tid();
432
433 if chrom_filter.is_some_and(|filter_tid| tid != filter_tid) {
434 continue;
435 }
436
437 stats.input_reads += 1;
438
439 if config.subset.is_some_and(|s| rng.random() >= f64::from(s)) {
441 continue;
442 }
443
444 if config.paired {
446 let is_chimeric =
447 !record.is_mate_unmapped() && record.tid() != record.mtid() && record.mtid() >= 0;
448
449 if is_chimeric {
450 match config.chimeric_pairs {
451 ChimericPairs::Discard => continue,
452 ChimericPairs::Output => {
453 output_records.push(record);
454 continue;
455 }
456 ChimericPairs::Use => {} }
458 }
459
460 if record.is_mate_unmapped() {
461 match config.unmapped_handling {
462 UnmappedHandling::Discard => continue,
463 UnmappedHandling::Output => {
464 output_records.push(record);
465 continue;
466 }
467 UnmappedHandling::Use => {} }
469 }
470 }
471
472 if config.per_gene {
473 let gene = if config.per_contig {
475 #[allow(clippy::cast_sign_loss)]
476 Some(header_view.tid2name(tid as u32).to_vec())
477 } else {
478 let gene_tag_name = config.gene_tag.as_deref().unwrap_or("XF");
479 extract_umi_from_tag(&record, gene_tag_name)
480 };
481
482 let Some(gene) = gene else {
483 continue;
484 };
485
486 if skip_regex
487 .as_ref()
488 .is_some_and(|re| re.is_match(std::str::from_utf8(&gene).unwrap_or("")))
489 {
490 continue;
491 }
492
493 let gene_id = *gene_ids.entry(gene.clone()).or_insert_with(|| {
494 let id = next_gene_id;
495 gene_labels.insert(id, String::from_utf8_lossy(&gene).into_owned());
496 next_gene_id += 1;
497 id
498 });
499
500 if tid != last_chrom && last_chrom >= 0 {
502 output_records.extend(process_drained(
503 buffer.drain_all(),
504 config.method,
505 config.edit_distance_threshold,
506 &mut unique_id,
507 &mut tsv_writer,
508 &header_view,
509 &gene_labels,
510 )?);
511 }
512 last_chrom = tid;
513
514 let key: GroupKey = (false, false, 0, 0);
515 let umi = if config.ignore_umi {
516 Vec::new()
517 } else {
518 extract_umi_from_name(&record, config.umi_separator)
519 };
520 buffer.add(record, gene_id, key, umi);
521 } else {
522 let (start, pos) = get_read_position(&record);
524
525 if tid != last_chrom {
526 output_records.extend(process_drained(
527 buffer.drain_all(),
528 config.method,
529 config.edit_distance_threshold,
530 &mut unique_id,
531 &mut tsv_writer,
532 &header_view,
533 &gene_labels,
534 )?);
535 } else if start > last_start + 1000 {
536 let threshold = start - 1000;
537 output_records.extend(process_drained(
538 buffer.drain_up_to(threshold),
539 config.method,
540 config.edit_distance_threshold,
541 &mut unique_id,
542 &mut tsv_writer,
543 &header_view,
544 &gene_labels,
545 )?);
546 }
547
548 last_start = start;
549 last_chrom = tid;
550
551 let tlen =
555 if config.paired && !record.is_mate_unmapped() && record.tid() == record.mtid() {
556 record.insert_size()
557 } else {
558 0
559 };
560 let key: GroupKey = (record.is_reverse(), false, tlen, 0);
561
562 let umi = if config.ignore_umi {
563 Vec::new()
564 } else {
565 extract_umi_from_name(&record, config.umi_separator)
566 };
567
568 buffer.add(record, pos, key, umi);
569 }
570 }
571
572 output_records.extend(process_drained(
573 buffer.drain_all(),
574 config.method,
575 config.edit_distance_threshold,
576 &mut unique_id,
577 &mut tsv_writer,
578 &header_view,
579 &gene_labels,
580 )?);
581
582 if let Some(w) = tsv_writer.as_mut() {
584 w.flush().map_err(|e| GroupError::TsvWrite(e.to_string()))?;
585 }
586
587 if !config.no_sort_output {
590 let (mut mapped, unmapped): (Vec<_>, Vec<_>) =
591 output_records.into_iter().partition(|r| !r.is_unmapped());
592 mapped.sort_by(|a, b| a.tid().cmp(&b.tid()).then_with(|| a.pos().cmp(&b.pos())));
593 mapped.extend(unmapped);
594 output_records = mapped;
595 }
596
597 stats.output_reads = output_records.len() as u64;
598
599 if config.output_bam {
600 for r in &output_records {
601 writer
602 .write(r)
603 .map_err(|e| GroupError::BamWrite(e.to_string()))?;
604 }
605 }
606
607 drop(writer);
608
609 Ok(stats)
610}
611
612#[derive(Debug, thiserror::Error)]
613pub enum GroupError {
614 #[error("failed to open BAM: {0}")]
615 BamOpen(String),
616 #[error("failed to read BAM record: {0}")]
617 BamRead(String),
618 #[error("failed to write BAM/SAM: {0}")]
619 BamWrite(String),
620 #[error("failed to write TSV: {0}")]
621 TsvWrite(String),
622 #[error("unknown chromosome: {0}")]
623 UnknownChrom(String),
624 #[error("invalid regex: {0}")]
625 InvalidRegex(String),
626 #[error("--per-contig requires --per-gene")]
627 PerContigRequiresPerGene,
628}