Skip to main content

umi_core/
count.rs

1use std::collections::{BTreeMap, BTreeSet, HashMap};
2use std::io::{self, BufRead, Write as IoWrite};
3
4use rust_htslib::bam::{self, Read as BamRead, record::Aux};
5use thiserror::Error;
6
7use crate::dedup::{DedupMethod, count_umis, extract_umi_umis};
8
9#[derive(Error, Debug)]
10pub enum CountError {
11    #[error("BAM open error: {0}")]
12    BamOpen(String),
13    #[error("BAM read error: {0}")]
14    BamRead(String),
15    #[error("invalid regex: {0}")]
16    InvalidRegex(String),
17    #[error("I/O error: {0}")]
18    Io(#[from] io::Error),
19}
20
21pub struct CountConfig {
22    pub method: DedupMethod,
23    pub gene_tag: String,
24    pub skip_tags_regex: Option<String>,
25    pub per_cell: bool,
26    pub wide_format: bool,
27    pub edit_distance_threshold: u32,
28}
29
30pub struct CountStats {
31    pub input_reads: u64,
32    pub counted_reads: u64,
33}
34
35pub struct CountTabConfig {
36    pub method: DedupMethod,
37    pub per_cell: bool,
38    pub separator: u8,
39    pub edit_distance_threshold: u32,
40}
41
42/// UMI count map: `umi -> (count, insertion_order)`.
43type UmiCountMap = HashMap<Vec<u8>, (u32, u32)>;
44
45#[allow(clippy::missing_errors_doc)]
46pub fn run_count(
47    config: &CountConfig,
48    bam_path: &str,
49    output: &mut dyn IoWrite,
50) -> Result<CountStats, CountError> {
51    let mut reader =
52        bam::Reader::from_path(bam_path).map_err(|e| CountError::BamOpen(e.to_string()))?;
53
54    let skip_regex = config
55        .skip_tags_regex
56        .as_ref()
57        .map(|s| regex::Regex::new(s).map_err(|e| CountError::InvalidRegex(e.to_string())))
58        .transpose()?;
59
60    // BTreeMap for genes so output is sorted
61    let mut data: BTreeMap<String, CellUmiMap> = BTreeMap::new();
62    let mut stats = CountStats {
63        input_reads: 0,
64        counted_reads: 0,
65    };
66
67    for result in reader.records() {
68        let record = result.map_err(|e| CountError::BamRead(e.to_string()))?;
69
70        if record.is_unmapped() {
71            continue;
72        }
73        if record.is_paired() && record.is_last_in_template() {
74            continue;
75        }
76
77        stats.input_reads += 1;
78
79        let gene = match record.aux(config.gene_tag.as_bytes()) {
80            Ok(Aux::String(s)) => s.to_string(),
81            _ => continue,
82        };
83
84        if skip_regex.as_ref().is_some_and(|re| re.is_match(&gene)) {
85            continue;
86        }
87
88        let (umi, cell) = extract_umi_umis(record.qname());
89
90        let cell_key = if config.per_cell {
91            cell.map(|c| String::from_utf8_lossy(&c).into_owned())
92        } else {
93            None
94        };
95
96        stats.counted_reads += 1;
97
98        let cell_map = data.entry(gene).or_default();
99        cell_map.add(cell_key, umi);
100    }
101
102    if config.per_cell && config.wide_format {
103        write_wide_format(&data, config, output)?;
104    } else if config.per_cell {
105        write_long_format(&data, config, output)?;
106    } else {
107        write_gene_counts(&data, config, output)?;
108    }
109
110    Ok(stats)
111}
112
113#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
114pub fn run_count_tab(
115    config: &CountTabConfig,
116    input: &mut dyn BufRead,
117    output: &mut dyn IoWrite,
118) -> Result<CountStats, CountError> {
119    let mut stats = CountStats {
120        input_reads: 0,
121        counted_reads: 0,
122    };
123
124    if config.per_cell {
125        writeln!(output, "cell\tgene\tcount")?;
126    } else {
127        writeln!(output, "gene\tcount")?;
128    }
129
130    let mut current_gene: Option<String> = None;
131    let mut cell_umis = CellUmiMap::default();
132
133    let mut line_buf = String::new();
134    loop {
135        line_buf.clear();
136        let n = input.read_line(&mut line_buf)?;
137        if n == 0 {
138            break;
139        }
140        let line = line_buf.trim_end_matches('\n').trim_end_matches('\r');
141        if line.is_empty() {
142            continue;
143        }
144
145        let mut cols = line.splitn(2, '\t');
146        let Some(read_name) = cols.next() else {
147            continue;
148        };
149        let Some(gene) = cols.next() else {
150            continue;
151        };
152        let gene = gene.to_string();
153
154        stats.input_reads += 1;
155
156        // When gene changes, flush previous gene
157        if current_gene.as_ref().is_some_and(|g| *g != gene) {
158            flush_count_tab_gene(
159                current_gene.as_deref().expect("checked above"),
160                &cell_umis,
161                config,
162                output,
163            )?;
164            cell_umis = CellUmiMap::default();
165        }
166        current_gene = Some(gene);
167
168        let sep = config.separator;
169        let parts: Vec<&str> = read_name.split(|c: char| c as u8 == sep).collect();
170        let umi = parts
171            .last()
172            .map_or_else(Vec::new, |s| s.as_bytes().to_vec());
173
174        let cell_key = if config.per_cell && parts.len() >= 2 {
175            Some(parts[parts.len() - 2].to_string())
176        } else {
177            None
178        };
179
180        stats.counted_reads += 1;
181        cell_umis.add(cell_key, umi);
182    }
183
184    if let Some(ref gene) = current_gene {
185        flush_count_tab_gene(gene, &cell_umis, config, output)?;
186    }
187
188    Ok(stats)
189}
190
191#[derive(Default)]
192struct CellUmiMap {
193    cells: Vec<(Option<String>, UmiCountMap)>,
194    cell_index: HashMap<Option<String>, usize>,
195    next_order: u32,
196}
197
198impl CellUmiMap {
199    fn add(&mut self, cell: Option<String>, umi: Vec<u8>) {
200        let idx = if let Some(&i) = self.cell_index.get(&cell) {
201            i
202        } else {
203            let i = self.cells.len();
204            self.cell_index.insert(cell.clone(), i);
205            self.cells.push((cell, HashMap::new()));
206            i
207        };
208        let entry = self.cells[idx].1.entry(umi).or_insert_with(|| {
209            let order = self.next_order;
210            self.next_order += 1;
211            (0, order)
212        });
213        entry.0 += 1;
214    }
215
216    fn dedup_count(
217        &self,
218        method: DedupMethod,
219        edit_threshold: u32,
220    ) -> Vec<(&Option<String>, usize)> {
221        self.cells
222            .iter()
223            .map(|(cell, umi_map)| {
224                let counts: HashMap<Vec<u8>, u32> =
225                    umi_map.iter().map(|(k, &(c, _))| (k.clone(), c)).collect();
226                let orders: HashMap<Vec<u8>, u32> =
227                    umi_map.iter().map(|(k, &(_, o))| (k.clone(), o)).collect();
228                let n = count_umis(method, &counts, &orders, edit_threshold);
229                (cell, n)
230            })
231            .collect()
232    }
233}
234
235fn write_gene_counts(
236    data: &BTreeMap<String, CellUmiMap>,
237    config: &CountConfig,
238    output: &mut dyn IoWrite,
239) -> Result<(), CountError> {
240    writeln!(output, "gene\tcount")?;
241    for (gene, cell_map) in data {
242        let results = cell_map.dedup_count(config.method, config.edit_distance_threshold);
243        let total: usize = results.iter().map(|(_, n)| n).sum();
244        writeln!(output, "{gene}\t{total}")?;
245    }
246    Ok(())
247}
248
249fn write_long_format(
250    data: &BTreeMap<String, CellUmiMap>,
251    config: &CountConfig,
252    output: &mut dyn IoWrite,
253) -> Result<(), CountError> {
254    writeln!(output, "gene\tcell\tcount")?;
255    for (gene, cell_map) in data {
256        let results = cell_map.dedup_count(config.method, config.edit_distance_threshold);
257        let mut sorted: Vec<_> = results
258            .into_iter()
259            .filter_map(|(cell, n)| cell.as_ref().map(|c| (c.clone(), n)))
260            .collect();
261        sorted.sort_by(|a, b| a.0.cmp(&b.0));
262        for (cell, count) in sorted {
263            writeln!(output, "{gene}\t{cell}\t{count}")?;
264        }
265    }
266    Ok(())
267}
268
269fn write_wide_format(
270    data: &BTreeMap<String, CellUmiMap>,
271    config: &CountConfig,
272    output: &mut dyn IoWrite,
273) -> Result<(), CountError> {
274    let mut all_cells: BTreeSet<String> = BTreeSet::new();
275    for cell_map in data.values() {
276        for (cell, _) in &cell_map.cells {
277            if let Some(c) = cell {
278                all_cells.insert(c.clone());
279            }
280        }
281    }
282    let cell_list: Vec<&String> = all_cells.iter().collect();
283
284    write!(output, "gene")?;
285    for cell in &cell_list {
286        write!(output, "\t{cell}")?;
287    }
288    writeln!(output)?;
289
290    for (gene, cell_map) in data {
291        let results = cell_map.dedup_count(config.method, config.edit_distance_threshold);
292        let cell_counts: HashMap<&str, usize> = results
293            .into_iter()
294            .filter_map(|(cell, n)| cell.as_ref().map(|c| (c.as_str(), n)))
295            .collect();
296
297        write!(output, "{gene}")?;
298        for cell in &cell_list {
299            let count = cell_counts.get(cell.as_str()).copied().unwrap_or(0);
300            write!(output, "\t{count}")?;
301        }
302        writeln!(output)?;
303    }
304    Ok(())
305}
306
307fn flush_count_tab_gene(
308    gene: &str,
309    cell_umis: &CellUmiMap,
310    config: &CountTabConfig,
311    output: &mut dyn IoWrite,
312) -> Result<(), CountError> {
313    let results = cell_umis.dedup_count(config.method, config.edit_distance_threshold);
314
315    if config.per_cell {
316        for (cell, count) in results {
317            let cell_str = cell.as_deref().unwrap_or("");
318            writeln!(output, "{cell_str}\t{gene}\t{count}")?;
319        }
320    } else {
321        let total: usize = results.iter().map(|(_, n)| n).sum();
322        writeln!(output, "{gene}\t{total}")?;
323    }
324    Ok(())
325}