Skip to main content

umi_core/
group.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fs::File;
3use std::io::{BufWriter, Write};
4
5use rust_htslib::bam::{self, Read as BamRead, Record};
6
7use crate::dedup::{
8    DedupMethod, GroupKey, PythonRandom, TieBreakRng, build_adjacency_list,
9    build_directional_adjacency_list, connected_components, extract_umi_from_name,
10    extract_umi_from_tag, get_read_position, median, min_set_cover,
11};
12
13#[derive(Clone, Copy, PartialEq, Eq)]
14pub enum ChimericPairs {
15    Discard,
16    Output,
17    Use,
18}
19
20#[derive(Clone, Copy, PartialEq, Eq)]
21pub enum UnmappedHandling {
22    Discard,
23    Output,
24    Use,
25}
26
27#[allow(clippy::struct_excessive_bools)]
28pub struct GroupConfig {
29    pub method: DedupMethod,
30    pub ignore_umi: bool,
31    pub umi_separator: u8,
32    pub random_seed: u64,
33    pub out_sam: bool,
34    pub output_bam: bool,
35    pub no_sort_output: bool,
36    pub chrom: Option<String>,
37    pub group_out: Option<String>,
38    pub edit_distance_threshold: u32,
39    pub subset: Option<f32>,
40    pub per_gene: bool,
41    pub gene_tag: Option<String>,
42    pub skip_tags_regex: Option<String>,
43    pub per_contig: bool,
44    pub paired: bool,
45    pub chimeric_pairs: ChimericPairs,
46    pub unmapped_handling: UnmappedHandling,
47}
48
49pub struct GroupStats {
50    pub input_reads: u64,
51    pub output_reads: u64,
52}
53
54struct GroupSlot {
55    records: Vec<Record>,
56    count: u32,
57    insertion_order: u32,
58}
59
60struct GroupBuffer {
61    groups: BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, GroupSlot>>>,
62    insertion_counters: BTreeMap<i64, BTreeMap<GroupKey, u32>>,
63}
64
65impl GroupBuffer {
66    const fn new() -> Self {
67        Self {
68            groups: BTreeMap::new(),
69            insertion_counters: BTreeMap::new(),
70        }
71    }
72
73    fn add(&mut self, record: Record, pos: i64, key: GroupKey, umi: Vec<u8>) {
74        let umi_map = self.groups.entry(pos).or_default().entry(key).or_default();
75
76        if let Some(slot) = umi_map.get_mut(&umi) {
77            slot.count += 1;
78            slot.records.push(record);
79            return;
80        }
81
82        let counter = self
83            .insertion_counters
84            .entry(pos)
85            .or_default()
86            .entry(key)
87            .or_default();
88        let order = *counter;
89        *counter += 1;
90
91        umi_map.insert(
92            umi,
93            GroupSlot {
94                records: vec![record],
95                count: 1,
96                insertion_order: order,
97            },
98        );
99    }
100
101    fn drain_up_to(
102        &mut self,
103        threshold: i64,
104    ) -> BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, GroupSlot>>> {
105        let rest = self.groups.split_off(&(threshold + 1));
106        let drained = std::mem::replace(&mut self.groups, rest);
107        let rest_counters = self.insertion_counters.split_off(&(threshold + 1));
108        let _ = std::mem::replace(&mut self.insertion_counters, rest_counters);
109        drained
110    }
111
112    fn drain_all(&mut self) -> BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, GroupSlot>>> {
113        let drained = std::mem::take(&mut self.groups);
114        self.insertion_counters.clear();
115        drained
116    }
117}
118
119/// Assign UMIs to groups. Returns groups where each group is a list of UMIs
120/// sorted by count descending, lex ascending. First UMI is the representative.
121#[allow(clippy::too_many_lines)]
122fn assign_groups(
123    method: DedupMethod,
124    umi_map: &HashMap<Vec<u8>, GroupSlot>,
125    edit_threshold: u32,
126) -> Vec<Vec<Vec<u8>>> {
127    let counts: HashMap<&[u8], u32> = umi_map
128        .iter()
129        .map(|(k, v)| (k.as_slice(), v.count))
130        .collect();
131    let orders: HashMap<&[u8], u32> = umi_map
132        .iter()
133        .map(|(k, v)| (k.as_slice(), v.insertion_order))
134        .collect();
135
136    let lex_sort = |a: &[u8], b: &[u8]| -> std::cmp::Ordering {
137        counts[b].cmp(&counts[a]).then_with(|| a.cmp(b))
138    };
139
140    match method {
141        DedupMethod::Unique => {
142            let mut umis: Vec<Vec<u8>> = umi_map.keys().cloned().collect();
143            umis.sort_by(|a, b| orders[a.as_slice()].cmp(&orders[b.as_slice()]));
144            umis.into_iter().map(|u| vec![u]).collect()
145        }
146
147        DedupMethod::Percentile => {
148            if counts.len() <= 1 {
149                return umi_map.keys().cloned().map(|u| vec![u]).collect();
150            }
151            let all_counts: Vec<u32> = counts.values().copied().collect();
152            let threshold = median(&all_counts) / 100.0;
153            let mut umis: Vec<Vec<u8>> = umi_map
154                .iter()
155                .filter(|(_, slot)| f64::from(slot.count) > threshold)
156                .map(|(umi, _)| umi.clone())
157                .collect();
158            umis.sort_by(|a, b| orders[a.as_slice()].cmp(&orders[b.as_slice()]));
159            umis.into_iter().map(|u| vec![u]).collect()
160        }
161
162        DedupMethod::Cluster => {
163            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
164            let adj_list = build_adjacency_list(&umis, edit_threshold);
165            let components = connected_components(&umis, &counts, &orders, &adj_list);
166            components
167                .into_iter()
168                .map(|mut comp| {
169                    comp.sort_by(|a, b| lex_sort(a, b));
170                    comp.into_iter().map(<[u8]>::to_vec).collect()
171                })
172                .collect()
173        }
174
175        DedupMethod::Adjacency => {
176            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
177            let adj_list = build_adjacency_list(&umis, edit_threshold);
178            let components = connected_components(&umis, &counts, &orders, &adj_list);
179            // Adjacency splits components via min_set_cover, grouping
180            // connected nodes around each lead UMI.
181            let mut groups = Vec::new();
182            for component in components {
183                if component.len() == 1 {
184                    groups.push(component.into_iter().map(<[u8]>::to_vec).collect());
185                } else {
186                    let lead_umis = min_set_cover(&component, &adj_list, &counts);
187                    let mut observed: HashSet<&[u8]> = lead_umis.iter().copied().collect();
188                    for &lead in &lead_umis {
189                        let connected: HashSet<&[u8]> = adj_list
190                            .get(lead)
191                            .map_or_else(HashSet::new, |ns| ns.iter().copied().collect());
192                        let mut group = vec![lead.to_vec()];
193                        for node in connected {
194                            if observed.insert(node) {
195                                group.push(node.to_vec());
196                            }
197                        }
198                        groups.push(group);
199                    }
200                }
201            }
202            groups
203        }
204
205        DedupMethod::Directional => {
206            let umis: Vec<&[u8]> = umi_map.keys().map(Vec::as_slice).collect();
207            let adj_list = build_directional_adjacency_list(&umis, &counts, edit_threshold);
208            let components = connected_components(&umis, &counts, &orders, &adj_list);
209            // Directed BFS can produce overlapping components. Filter already-
210            // observed UMIs so each UMI is assigned to exactly one group,
211            // matching Python's _group_directional logic.
212            let mut observed: HashSet<&[u8]> = HashSet::new();
213            let mut groups = Vec::new();
214            for mut comp in components {
215                comp.sort_by(|a, b| lex_sort(a, b));
216                if comp.len() == 1 {
217                    observed.insert(comp[0]);
218                    groups.push(comp.into_iter().map(<[u8]>::to_vec).collect());
219                } else {
220                    let mut filtered: Vec<Vec<u8>> = Vec::new();
221                    for node in comp {
222                        if observed.insert(node) {
223                            filtered.push(node.to_vec());
224                        }
225                    }
226                    if !filtered.is_empty() {
227                        groups.push(filtered);
228                    }
229                }
230            }
231            groups
232        }
233    }
234}
235
236/// Process drained position groups: assign UMI groups, annotate records, write TSV rows.
237#[allow(clippy::cast_sign_loss)]
238fn process_drained(
239    drained: BTreeMap<i64, BTreeMap<GroupKey, HashMap<Vec<u8>, GroupSlot>>>,
240    method: DedupMethod,
241    edit_threshold: u32,
242    unique_id: &mut u32,
243    tsv_writer: &mut Option<BufWriter<File>>,
244    header_view: &bam::HeaderView,
245    gene_labels: &HashMap<i64, String>,
246) -> Result<Vec<Record>, GroupError> {
247    let mut output_records = Vec::new();
248
249    // In per-gene mode, Python sorts genes alphabetically; replicate that order.
250    let entries: Vec<_> = if gene_labels.is_empty() {
251        drained.into_iter().collect()
252    } else {
253        let mut v: Vec<_> = drained.into_iter().collect();
254        v.sort_by(|(a, _), (b, _)| {
255            let la = gene_labels.get(a).map_or("", String::as_str);
256            let lb = gene_labels.get(b).map_or("", String::as_str);
257            la.cmp(lb)
258        });
259        v
260    };
261
262    for (pos, key_map) in entries {
263        let gene_label = gene_labels.get(&pos).map_or("NA", String::as_str);
264
265        for (_, mut umi_map) in key_map {
266            let groups = assign_groups(method, &umi_map, edit_threshold);
267
268            for group in &groups {
269                let top_umi = &group[0];
270                let group_count: u32 = group.iter().map(|u| umi_map[u].count).sum();
271                let top_umi_str = std::str::from_utf8(top_umi).unwrap_or("");
272
273                for umi in group {
274                    let slot = umi_map.remove(umi).expect("UMI must exist in umi_map");
275
276                    for record in slot.records {
277                        if let Some(w) = tsv_writer.as_mut() {
278                            let read_name = std::str::from_utf8(record.qname()).unwrap_or("");
279                            let contig =
280                                std::str::from_utf8(header_view.tid2name(record.tid() as u32))
281                                    .unwrap_or("");
282                            let umi_str = std::str::from_utf8(umi).unwrap_or("");
283                            let (_, read_pos) = get_read_position(&record);
284
285                            writeln!(
286                                w,
287                                "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}",
288                                read_name,
289                                contig,
290                                read_pos,
291                                gene_label,
292                                umi_str,
293                                slot.count,
294                                top_umi_str,
295                                group_count,
296                                *unique_id,
297                            )
298                            .map_err(|e| GroupError::TsvWrite(e.to_string()))?;
299                        }
300
301                        let mut tagged = record;
302                        tagged
303                            .push_aux(
304                                b"UG",
305                                #[allow(clippy::cast_possible_wrap)]
306                                rust_htslib::bam::record::Aux::I32(*unique_id as i32),
307                            )
308                            .ok();
309                        tagged
310                            .push_aux(b"BX", rust_htslib::bam::record::Aux::String(top_umi_str))
311                            .ok();
312
313                        output_records.push(tagged);
314                    }
315                }
316
317                *unique_id += 1;
318            }
319        }
320    }
321
322    Ok(output_records)
323}
324
325/// # Errors
326///
327/// Returns `GroupError` on BAM I/O failures or unknown chromosome filter.
328#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
329pub fn run_group(config: &GroupConfig, input_path: &str) -> Result<GroupStats, GroupError> {
330    if config.per_contig && !config.per_gene {
331        return Err(GroupError::PerContigRequiresPerGene);
332    }
333
334    let mut reader =
335        bam::Reader::from_path(input_path).map_err(|e| GroupError::BamOpen(e.to_string()))?;
336    let header = bam::Header::from_template(reader.header());
337    let header_view = reader.header().clone();
338
339    let format = if config.out_sam {
340        bam::Format::Sam
341    } else {
342        bam::Format::Bam
343    };
344
345    let mut writer = bam::Writer::from_stdout(&header, format)
346        .map_err(|e| GroupError::BamWrite(e.to_string()))?;
347
348    // Optional chromosome filter
349    let chrom_filter: Option<i32> = config
350        .chrom
351        .as_ref()
352        .map(|c| {
353            let tid = reader
354                .header()
355                .tid(c.as_bytes())
356                .ok_or_else(|| GroupError::UnknownChrom(c.clone()))?;
357            #[allow(clippy::cast_possible_wrap)]
358            Ok(tid as i32)
359        })
360        .transpose()?;
361
362    // Open TSV writer
363    let mut tsv_writer: Option<BufWriter<File>> = config
364        .group_out
365        .as_ref()
366        .map(|path| {
367            let file =
368                File::create(path).map_err(|e| GroupError::TsvWrite(e.to_string()))?;
369            let mut w = BufWriter::new(file);
370            writeln!(
371                w,
372                "read_id\tcontig\tposition\tgene\tumi\tumi_count\tfinal_umi\tfinal_umi_count\tunique_id"
373            )
374            .map_err(|e| GroupError::TsvWrite(e.to_string()))?;
375            Ok(w)
376        })
377        .transpose()?;
378
379    let skip_regex = config
380        .skip_tags_regex
381        .as_ref()
382        .map(|s| regex::Regex::new(s).map_err(|e| GroupError::InvalidRegex(e.to_string())))
383        .transpose()?;
384
385    let output_unmapped = config.unmapped_handling == UnmappedHandling::Output
386        || config.unmapped_handling == UnmappedHandling::Use;
387
388    let mut buffer = GroupBuffer::new();
389    let mut stats = GroupStats {
390        input_reads: 0,
391        output_reads: 0,
392    };
393
394    #[allow(clippy::cast_possible_truncation)]
395    let mut rng = PythonRandom::new(config.random_seed as u32);
396
397    let mut output_records: Vec<Record> = Vec::new();
398    let mut unique_id: u32 = 0;
399
400    let mut last_start: i64 = 0;
401    let mut last_chrom: i32 = -1;
402
403    // Per-gene state: map gene name → sequential ID, and reverse map for TSV labels
404    let mut gene_ids: HashMap<Vec<u8>, i64> = HashMap::new();
405    let mut gene_labels: HashMap<i64, String> = HashMap::new();
406    let mut next_gene_id: i64 = 0;
407
408    for result in reader.records() {
409        let record = result.map_err(|e| GroupError::BamRead(e.to_string()))?;
410
411        // R2 reads are passthrough (no grouping, no tags).
412        if record.is_last_in_template() {
413            if record.is_unmapped() {
414                if output_unmapped {
415                    output_records.push(record);
416                }
417            } else {
418                output_records.push(record);
419            }
420            continue;
421        }
422
423        // Handle unmapped reads (R1 in paired mode, or any read in single-end)
424        if record.is_unmapped() {
425            if output_unmapped {
426                output_records.push(record);
427            }
428            continue;
429        }
430
431        let tid = record.tid();
432
433        if chrom_filter.is_some_and(|filter_tid| tid != filter_tid) {
434            continue;
435        }
436
437        stats.input_reads += 1;
438
439        // Subset check consumes one RNG call per mapped read (before buffer.add)
440        if config.subset.is_some_and(|s| rng.random() >= f64::from(s)) {
441            continue;
442        }
443
444        // Paired-mode filtering for R1 reads
445        if config.paired {
446            let is_chimeric =
447                !record.is_mate_unmapped() && record.tid() != record.mtid() && record.mtid() >= 0;
448
449            if is_chimeric {
450                match config.chimeric_pairs {
451                    ChimericPairs::Discard => continue,
452                    ChimericPairs::Output => {
453                        output_records.push(record);
454                        continue;
455                    }
456                    ChimericPairs::Use => {} // fall through to grouping with TLEN=0
457                }
458            }
459
460            if record.is_mate_unmapped() {
461                match config.unmapped_handling {
462                    UnmappedHandling::Discard => continue,
463                    UnmappedHandling::Output => {
464                        output_records.push(record);
465                        continue;
466                    }
467                    UnmappedHandling::Use => {} // fall through to grouping with TLEN=0
468                }
469            }
470        }
471
472        if config.per_gene {
473            // Per-gene mode: group by gene tag value (or contig name) instead of position
474            let gene = if config.per_contig {
475                #[allow(clippy::cast_sign_loss)]
476                Some(header_view.tid2name(tid as u32).to_vec())
477            } else {
478                let gene_tag_name = config.gene_tag.as_deref().unwrap_or("XF");
479                extract_umi_from_tag(&record, gene_tag_name)
480            };
481
482            let Some(gene) = gene else {
483                continue;
484            };
485
486            if skip_regex
487                .as_ref()
488                .is_some_and(|re| re.is_match(std::str::from_utf8(&gene).unwrap_or("")))
489            {
490                continue;
491            }
492
493            let gene_id = *gene_ids.entry(gene.clone()).or_insert_with(|| {
494                let id = next_gene_id;
495                gene_labels.insert(id, String::from_utf8_lossy(&gene).into_owned());
496                next_gene_id += 1;
497                id
498            });
499
500            // In per-gene mode, flush all when chromosome changes (no position-based flushing)
501            if tid != last_chrom && last_chrom >= 0 {
502                output_records.extend(process_drained(
503                    buffer.drain_all(),
504                    config.method,
505                    config.edit_distance_threshold,
506                    &mut unique_id,
507                    &mut tsv_writer,
508                    &header_view,
509                    &gene_labels,
510                )?);
511            }
512            last_chrom = tid;
513
514            let key: GroupKey = (false, false, 0, 0);
515            let umi = if config.ignore_umi {
516                Vec::new()
517            } else {
518                extract_umi_from_name(&record, config.umi_separator)
519            };
520            buffer.add(record, gene_id, key, umi);
521        } else {
522            // Standard coordinate mode
523            let (start, pos) = get_read_position(&record);
524
525            if tid != last_chrom {
526                output_records.extend(process_drained(
527                    buffer.drain_all(),
528                    config.method,
529                    config.edit_distance_threshold,
530                    &mut unique_id,
531                    &mut tsv_writer,
532                    &header_view,
533                    &gene_labels,
534                )?);
535            } else if start > last_start + 1000 {
536                let threshold = start - 1000;
537                output_records.extend(process_drained(
538                    buffer.drain_up_to(threshold),
539                    config.method,
540                    config.edit_distance_threshold,
541                    &mut unique_id,
542                    &mut tsv_writer,
543                    &header_view,
544                    &gene_labels,
545                )?);
546            }
547
548            last_start = start;
549            last_chrom = tid;
550
551            // For paired non-chimeric reads, include signed TLEN in the group key.
552            // Python sorts GroupKeys as tuples: (is_reverse, is_spliced, tlen, r_length).
553            // We place signed tlen in position 2 (i64) to match Python's sorted() ordering.
554            let tlen =
555                if config.paired && !record.is_mate_unmapped() && record.tid() == record.mtid() {
556                    record.insert_size()
557                } else {
558                    0
559                };
560            let key: GroupKey = (record.is_reverse(), false, tlen, 0);
561
562            let umi = if config.ignore_umi {
563                Vec::new()
564            } else {
565                extract_umi_from_name(&record, config.umi_separator)
566            };
567
568            buffer.add(record, pos, key, umi);
569        }
570    }
571
572    output_records.extend(process_drained(
573        buffer.drain_all(),
574        config.method,
575        config.edit_distance_threshold,
576        &mut unique_id,
577        &mut tsv_writer,
578        &header_view,
579        &gene_labels,
580    )?);
581
582    // Flush TSV
583    if let Some(w) = tsv_writer.as_mut() {
584        w.flush().map_err(|e| GroupError::TsvWrite(e.to_string()))?;
585    }
586
587    // Sort by coordinate unless --no-sort-output.
588    // Unmapped reads are placed after all mapped reads (matching Python).
589    if !config.no_sort_output {
590        let (mut mapped, unmapped): (Vec<_>, Vec<_>) =
591            output_records.into_iter().partition(|r| !r.is_unmapped());
592        mapped.sort_by(|a, b| a.tid().cmp(&b.tid()).then_with(|| a.pos().cmp(&b.pos())));
593        mapped.extend(unmapped);
594        output_records = mapped;
595    }
596
597    stats.output_reads = output_records.len() as u64;
598
599    if config.output_bam {
600        for r in &output_records {
601            writer
602                .write(r)
603                .map_err(|e| GroupError::BamWrite(e.to_string()))?;
604        }
605    }
606
607    drop(writer);
608
609    Ok(stats)
610}
611
612#[derive(Debug, thiserror::Error)]
613pub enum GroupError {
614    #[error("failed to open BAM: {0}")]
615    BamOpen(String),
616    #[error("failed to read BAM record: {0}")]
617    BamRead(String),
618    #[error("failed to write BAM/SAM: {0}")]
619    BamWrite(String),
620    #[error("failed to write TSV: {0}")]
621    TsvWrite(String),
622    #[error("unknown chromosome: {0}")]
623    UnknownChrom(String),
624    #[error("invalid regex: {0}")]
625    InvalidRegex(String),
626    #[error("--per-contig requires --per-gene")]
627    PerContigRequiresPerGene,
628}