1use std::collections::{BTreeMap, BTreeSet, HashMap};
2use std::io::{self, BufRead, Write as IoWrite};
3
4use rust_htslib::bam::{self, Read as BamRead, record::Aux};
5use thiserror::Error;
6
7use crate::dedup::{DedupMethod, count_umis, extract_umi_umis};
8
9#[derive(Error, Debug)]
10pub enum CountError {
11 #[error("BAM open error: {0}")]
12 BamOpen(String),
13 #[error("BAM read error: {0}")]
14 BamRead(String),
15 #[error("invalid regex: {0}")]
16 InvalidRegex(String),
17 #[error("I/O error: {0}")]
18 Io(#[from] io::Error),
19}
20
21pub struct CountConfig {
22 pub method: DedupMethod,
23 pub gene_tag: String,
24 pub skip_tags_regex: Option<String>,
25 pub per_cell: bool,
26 pub wide_format: bool,
27 pub edit_distance_threshold: u32,
28}
29
30pub struct CountStats {
31 pub input_reads: u64,
32 pub counted_reads: u64,
33}
34
35pub struct CountTabConfig {
36 pub method: DedupMethod,
37 pub per_cell: bool,
38 pub separator: u8,
39 pub edit_distance_threshold: u32,
40}
41
42type UmiCountMap = HashMap<Vec<u8>, (u32, u32)>;
44
45#[allow(clippy::missing_errors_doc)]
46pub fn run_count(
47 config: &CountConfig,
48 bam_path: &str,
49 output: &mut dyn IoWrite,
50) -> Result<CountStats, CountError> {
51 let mut reader =
52 bam::Reader::from_path(bam_path).map_err(|e| CountError::BamOpen(e.to_string()))?;
53
54 let skip_regex = config
55 .skip_tags_regex
56 .as_ref()
57 .map(|s| regex::Regex::new(s).map_err(|e| CountError::InvalidRegex(e.to_string())))
58 .transpose()?;
59
60 let mut data: BTreeMap<String, CellUmiMap> = BTreeMap::new();
62 let mut stats = CountStats {
63 input_reads: 0,
64 counted_reads: 0,
65 };
66
67 for result in reader.records() {
68 let record = result.map_err(|e| CountError::BamRead(e.to_string()))?;
69
70 if record.is_unmapped() {
71 continue;
72 }
73 if record.is_paired() && record.is_last_in_template() {
74 continue;
75 }
76
77 stats.input_reads += 1;
78
79 let gene = match record.aux(config.gene_tag.as_bytes()) {
80 Ok(Aux::String(s)) => s.to_string(),
81 _ => continue,
82 };
83
84 if skip_regex.as_ref().is_some_and(|re| re.is_match(&gene)) {
85 continue;
86 }
87
88 let (umi, cell) = extract_umi_umis(record.qname());
89
90 let cell_key = if config.per_cell {
91 cell.map(|c| String::from_utf8_lossy(&c).into_owned())
92 } else {
93 None
94 };
95
96 stats.counted_reads += 1;
97
98 let cell_map = data.entry(gene).or_default();
99 cell_map.add(cell_key, umi);
100 }
101
102 if config.per_cell && config.wide_format {
103 write_wide_format(&data, config, output)?;
104 } else if config.per_cell {
105 write_long_format(&data, config, output)?;
106 } else {
107 write_gene_counts(&data, config, output)?;
108 }
109
110 Ok(stats)
111}
112
113#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
114pub fn run_count_tab(
115 config: &CountTabConfig,
116 input: &mut dyn BufRead,
117 output: &mut dyn IoWrite,
118) -> Result<CountStats, CountError> {
119 let mut stats = CountStats {
120 input_reads: 0,
121 counted_reads: 0,
122 };
123
124 if config.per_cell {
125 writeln!(output, "cell\tgene\tcount")?;
126 } else {
127 writeln!(output, "gene\tcount")?;
128 }
129
130 let mut current_gene: Option<String> = None;
131 let mut cell_umis = CellUmiMap::default();
132
133 let mut line_buf = String::new();
134 loop {
135 line_buf.clear();
136 let n = input.read_line(&mut line_buf)?;
137 if n == 0 {
138 break;
139 }
140 let line = line_buf.trim_end_matches('\n').trim_end_matches('\r');
141 if line.is_empty() {
142 continue;
143 }
144
145 let mut cols = line.splitn(2, '\t');
146 let Some(read_name) = cols.next() else {
147 continue;
148 };
149 let Some(gene) = cols.next() else {
150 continue;
151 };
152 let gene = gene.to_string();
153
154 stats.input_reads += 1;
155
156 if current_gene.as_ref().is_some_and(|g| *g != gene) {
158 flush_count_tab_gene(
159 current_gene.as_deref().expect("checked above"),
160 &cell_umis,
161 config,
162 output,
163 )?;
164 cell_umis = CellUmiMap::default();
165 }
166 current_gene = Some(gene);
167
168 let sep = config.separator;
169 let parts: Vec<&str> = read_name.split(|c: char| c as u8 == sep).collect();
170 let umi = parts
171 .last()
172 .map_or_else(Vec::new, |s| s.as_bytes().to_vec());
173
174 let cell_key = if config.per_cell && parts.len() >= 2 {
175 Some(parts[parts.len() - 2].to_string())
176 } else {
177 None
178 };
179
180 stats.counted_reads += 1;
181 cell_umis.add(cell_key, umi);
182 }
183
184 if let Some(ref gene) = current_gene {
185 flush_count_tab_gene(gene, &cell_umis, config, output)?;
186 }
187
188 Ok(stats)
189}
190
191#[derive(Default)]
192struct CellUmiMap {
193 cells: Vec<(Option<String>, UmiCountMap)>,
194 cell_index: HashMap<Option<String>, usize>,
195 next_order: u32,
196}
197
198impl CellUmiMap {
199 fn add(&mut self, cell: Option<String>, umi: Vec<u8>) {
200 let idx = if let Some(&i) = self.cell_index.get(&cell) {
201 i
202 } else {
203 let i = self.cells.len();
204 self.cell_index.insert(cell.clone(), i);
205 self.cells.push((cell, HashMap::new()));
206 i
207 };
208 let entry = self.cells[idx].1.entry(umi).or_insert_with(|| {
209 let order = self.next_order;
210 self.next_order += 1;
211 (0, order)
212 });
213 entry.0 += 1;
214 }
215
216 fn dedup_count(
217 &self,
218 method: DedupMethod,
219 edit_threshold: u32,
220 ) -> Vec<(&Option<String>, usize)> {
221 self.cells
222 .iter()
223 .map(|(cell, umi_map)| {
224 let counts: HashMap<Vec<u8>, u32> =
225 umi_map.iter().map(|(k, &(c, _))| (k.clone(), c)).collect();
226 let orders: HashMap<Vec<u8>, u32> =
227 umi_map.iter().map(|(k, &(_, o))| (k.clone(), o)).collect();
228 let n = count_umis(method, &counts, &orders, edit_threshold);
229 (cell, n)
230 })
231 .collect()
232 }
233}
234
235fn write_gene_counts(
236 data: &BTreeMap<String, CellUmiMap>,
237 config: &CountConfig,
238 output: &mut dyn IoWrite,
239) -> Result<(), CountError> {
240 writeln!(output, "gene\tcount")?;
241 for (gene, cell_map) in data {
242 let results = cell_map.dedup_count(config.method, config.edit_distance_threshold);
243 let total: usize = results.iter().map(|(_, n)| n).sum();
244 writeln!(output, "{gene}\t{total}")?;
245 }
246 Ok(())
247}
248
249fn write_long_format(
250 data: &BTreeMap<String, CellUmiMap>,
251 config: &CountConfig,
252 output: &mut dyn IoWrite,
253) -> Result<(), CountError> {
254 writeln!(output, "gene\tcell\tcount")?;
255 for (gene, cell_map) in data {
256 let results = cell_map.dedup_count(config.method, config.edit_distance_threshold);
257 let mut sorted: Vec<_> = results
258 .into_iter()
259 .filter_map(|(cell, n)| cell.as_ref().map(|c| (c.clone(), n)))
260 .collect();
261 sorted.sort_by(|a, b| a.0.cmp(&b.0));
262 for (cell, count) in sorted {
263 writeln!(output, "{gene}\t{cell}\t{count}")?;
264 }
265 }
266 Ok(())
267}
268
269fn write_wide_format(
270 data: &BTreeMap<String, CellUmiMap>,
271 config: &CountConfig,
272 output: &mut dyn IoWrite,
273) -> Result<(), CountError> {
274 let mut all_cells: BTreeSet<String> = BTreeSet::new();
275 for cell_map in data.values() {
276 for (cell, _) in &cell_map.cells {
277 if let Some(c) = cell {
278 all_cells.insert(c.clone());
279 }
280 }
281 }
282 let cell_list: Vec<&String> = all_cells.iter().collect();
283
284 write!(output, "gene")?;
285 for cell in &cell_list {
286 write!(output, "\t{cell}")?;
287 }
288 writeln!(output)?;
289
290 for (gene, cell_map) in data {
291 let results = cell_map.dedup_count(config.method, config.edit_distance_threshold);
292 let cell_counts: HashMap<&str, usize> = results
293 .into_iter()
294 .filter_map(|(cell, n)| cell.as_ref().map(|c| (c.as_str(), n)))
295 .collect();
296
297 write!(output, "{gene}")?;
298 for cell in &cell_list {
299 let count = cell_counts.get(cell.as_str()).copied().unwrap_or(0);
300 write!(output, "\t{count}")?;
301 }
302 writeln!(output)?;
303 }
304 Ok(())
305}
306
307fn flush_count_tab_gene(
308 gene: &str,
309 cell_umis: &CellUmiMap,
310 config: &CountTabConfig,
311 output: &mut dyn IoWrite,
312) -> Result<(), CountError> {
313 let results = cell_umis.dedup_count(config.method, config.edit_distance_threshold);
314
315 if config.per_cell {
316 for (cell, count) in results {
317 let cell_str = cell.as_deref().unwrap_or("");
318 writeln!(output, "{cell_str}\t{gene}\t{count}")?;
319 }
320 } else {
321 let total: usize = results.iter().map(|(_, n)| n).sum();
322 writeln!(output, "{gene}\t{total}")?;
323 }
324 Ok(())
325}