Skip to main content

umi_core/
pattern.rs

1use std::ops::Range;
2
3use crate::error::ExtractError;
4
5/// Result of extracting barcodes from a single read's sequence and quality.
6#[derive(Debug, Clone)]
7pub struct ExtractionResult {
8    pub umi: Vec<u8>,
9    pub umi_quality: Vec<u8>,
10    pub cell_barcode: Vec<u8>,
11    pub trimmed_sequence: Vec<u8>,
12    pub trimmed_quality: Vec<u8>,
13}
14
15/// Which end of the read to extract the barcode from (string method only).
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum PrimeEnd {
18    Five,
19    Three,
20}
21
22/// A parsed barcode pattern that knows how to extract UMI/cell/sample bases from a read.
23#[derive(Debug, Clone)]
24pub enum BarcodePattern {
25    String(StringPattern),
26    Regex(RegexPattern),
27}
28
29impl BarcodePattern {
30    /// # Errors
31    /// Returns error if the read is too short (string method) or doesn't match (regex method).
32    pub fn extract(
33        &self,
34        sequence: &[u8],
35        quality: &[u8],
36    ) -> Result<ExtractionResult, ExtractError> {
37        match self {
38            Self::String(p) => p.extract(sequence, quality),
39            Self::Regex(p) => p.extract(sequence, quality),
40        }
41    }
42}
43
44/// String-method pattern using fixed-position characters.
45///
46/// Pattern characters:
47/// - `N` — UMI base (extracted to read name)
48/// - `C` — Cell barcode base (extracted to read name)
49/// - `X` — Sample/discard base (stays in output sequence, removed from barcode region)
50#[derive(Debug, Clone)]
51pub struct StringPattern {
52    umi_positions: Vec<usize>,
53    cell_positions: Vec<usize>,
54    sample_positions: Vec<usize>,
55    umi_range: Option<Range<usize>>,
56    cell_range: Option<Range<usize>>,
57    sample_range: Option<Range<usize>>,
58    pattern_length: usize,
59    prime_end: PrimeEnd,
60}
61
62impl StringPattern {
63    /// Parse a string-method pattern like `NNNXXXXNN`.
64    ///
65    /// # Errors
66    /// Returns error if pattern is empty or contains characters other than N, X, C.
67    pub fn parse(pattern_str: &str, prime_end: PrimeEnd) -> Result<Self, ExtractError> {
68        if pattern_str.is_empty() {
69            return Err(ExtractError::InvalidPattern(
70                "pattern must not be empty".into(),
71            ));
72        }
73
74        let mut umi_positions = Vec::new();
75        let mut cell_positions = Vec::new();
76        let mut sample_positions = Vec::new();
77
78        for (i, ch) in pattern_str.chars().enumerate() {
79            match ch {
80                'N' => umi_positions.push(i),
81                'C' => cell_positions.push(i),
82                'X' => sample_positions.push(i),
83                other => {
84                    return Err(ExtractError::InvalidPattern(format!(
85                        "pattern contains invalid character '{other}' at position {i}; \
86                         only N, X, C are allowed"
87                    )));
88                }
89            }
90        }
91
92        let umi_range = as_contiguous_range(&umi_positions);
93        let cell_range = as_contiguous_range(&cell_positions);
94        let sample_range = as_contiguous_range(&sample_positions);
95
96        Ok(Self {
97            umi_positions,
98            cell_positions,
99            sample_positions,
100            umi_range,
101            cell_range,
102            sample_range,
103            pattern_length: pattern_str.len(),
104            prime_end,
105        })
106    }
107
108    /// Extract barcodes from a read's sequence and quality strings.
109    ///
110    /// # Errors
111    /// Returns error if the read is shorter than the pattern.
112    pub fn extract(
113        &self,
114        sequence: &[u8],
115        quality: &[u8],
116    ) -> Result<ExtractionResult, ExtractError> {
117        if sequence.len() < self.pattern_length {
118            return Err(ExtractError::ReadTooShort {
119                read_len: sequence.len(),
120                pattern_len: self.pattern_length,
121            });
122        }
123
124        let (barcode_region, remaining_seq, barcode_qual, remaining_qual) = match self.prime_end {
125            PrimeEnd::Five => (
126                &sequence[..self.pattern_length],
127                &sequence[self.pattern_length..],
128                &quality[..self.pattern_length],
129                &quality[self.pattern_length..],
130            ),
131            PrimeEnd::Three => (
132                &sequence[sequence.len() - self.pattern_length..],
133                &sequence[..sequence.len() - self.pattern_length],
134                &quality[quality.len() - self.pattern_length..],
135                &quality[..quality.len() - self.pattern_length],
136            ),
137        };
138
139        let umi = extract_slice(barcode_region, self.umi_range.as_ref(), &self.umi_positions);
140        let umi_quality = extract_slice(barcode_qual, self.umi_range.as_ref(), &self.umi_positions);
141        let cell_barcode = extract_slice(
142            barcode_region,
143            self.cell_range.as_ref(),
144            &self.cell_positions,
145        );
146
147        let (trimmed_sequence, trimmed_quality) = if self.sample_positions.is_empty() {
148            (remaining_seq.to_vec(), remaining_qual.to_vec())
149        } else {
150            let sample_seq = extract_slice(
151                barcode_region,
152                self.sample_range.as_ref(),
153                &self.sample_positions,
154            );
155            let sample_qual = extract_slice(
156                barcode_qual,
157                self.sample_range.as_ref(),
158                &self.sample_positions,
159            );
160            match self.prime_end {
161                PrimeEnd::Five => (
162                    join_slices(&sample_seq, remaining_seq),
163                    join_slices(&sample_qual, remaining_qual),
164                ),
165                PrimeEnd::Three => (
166                    join_slices(remaining_seq, &sample_seq),
167                    join_slices(remaining_qual, &sample_qual),
168                ),
169            }
170        };
171
172        Ok(ExtractionResult {
173            umi,
174            umi_quality,
175            cell_barcode,
176            trimmed_sequence,
177            trimmed_quality,
178        })
179    }
180}
181
182/// Regex-method pattern using named capture groups.
183///
184/// Groups starting with `umi_` are extracted as UMI, `cell_` as cell barcode,
185/// `discard_` as bases to remove. Everything else is kept in the output sequence.
186#[derive(Debug, Clone)]
187pub struct RegexPattern {
188    pattern: regex::Regex,
189}
190
191impl RegexPattern {
192    /// Parse a regex pattern string.
193    ///
194    /// # Errors
195    /// Returns error if the regex is invalid or has no `umi_` or `cell_` groups.
196    pub fn parse(pattern_str: &str) -> Result<Self, ExtractError> {
197        let processed = preprocess_fuzzy(pattern_str)?;
198
199        let pattern = regex::Regex::new(&processed)
200            .map_err(|e| ExtractError::InvalidPattern(format!("invalid regex: {e}")))?;
201
202        let has_barcode_group = pattern
203            .capture_names()
204            .flatten()
205            .any(|name| name.starts_with("umi_") || name.starts_with("cell_"));
206
207        if !has_barcode_group {
208            return Err(ExtractError::InvalidPattern(
209                "regex must contain at least one named group starting with 'umi_' or 'cell_'"
210                    .into(),
211            ));
212        }
213
214        Ok(Self { pattern })
215    }
216
217    /// Extract barcodes from a read's sequence and quality strings.
218    ///
219    /// # Errors
220    /// Returns `RegexNoMatch` if the regex doesn't match the sequence.
221    pub fn extract(
222        &self,
223        sequence: &[u8],
224        quality: &[u8],
225    ) -> Result<ExtractionResult, ExtractError> {
226        let seq_str = std::str::from_utf8(sequence)
227            .map_err(|e| ExtractError::FastqParse(format!("non-UTF8 sequence: {e}")))?;
228
229        let caps = self
230            .pattern
231            .captures(seq_str)
232            .ok_or(ExtractError::RegexNoMatch)?;
233
234        // Collect named group spans into (name, start, end) sorted by name
235        let mut umi_spans: Vec<(&str, usize, usize)> = Vec::new();
236        let mut cell_spans: Vec<(&str, usize, usize)> = Vec::new();
237        let mut discard_spans: Vec<(usize, usize)> = Vec::new();
238
239        for name in self.pattern.capture_names().flatten() {
240            if let Some(m) = caps.name(name) {
241                let span = (m.start(), m.end());
242                if name.starts_with("umi_") {
243                    umi_spans.push((name, span.0, span.1));
244                } else if name.starts_with("cell_") {
245                    cell_spans.push((name, span.0, span.1));
246                } else if name.starts_with("discard_") {
247                    discard_spans.push(span);
248                }
249            }
250        }
251
252        // Sort by group name for deterministic concatenation
253        umi_spans.sort_by_key(|&(name, _, _)| name);
254        cell_spans.sort_by_key(|&(name, _, _)| name);
255
256        // Build extracted-position bitmask (O(n) lookup instead of O(n*m) Vec::contains)
257        let mut extracted = vec![false; sequence.len()];
258        for &(_, start, end) in &umi_spans {
259            extracted[start..end].fill(true);
260        }
261        for &(_, start, end) in &cell_spans {
262            extracted[start..end].fill(true);
263        }
264        for &(start, end) in &discard_spans {
265            extracted[start..end].fill(true);
266        }
267
268        // Build UMI and cell by concatenating group values in sorted name order
269        let mut umi = Vec::new();
270        let mut umi_quality = Vec::new();
271        for &(_, start, end) in &umi_spans {
272            umi.extend_from_slice(&sequence[start..end]);
273            umi_quality.extend_from_slice(&quality[start..end]);
274        }
275
276        let mut cell_barcode = Vec::new();
277        for &(_, start, end) in &cell_spans {
278            cell_barcode.extend_from_slice(&sequence[start..end]);
279        }
280
281        // Build trimmed sequence/quality: keep positions not in any extraction set
282        let mut trimmed_sequence = Vec::new();
283        let mut trimmed_quality = Vec::new();
284
285        for (i, &is_extracted) in extracted.iter().enumerate() {
286            if !is_extracted {
287                trimmed_sequence.push(sequence[i]);
288                trimmed_quality.push(quality[i]);
289            }
290        }
291
292        Ok(ExtractionResult {
293            umi,
294            umi_quality,
295            cell_barcode,
296            trimmed_sequence,
297            trimmed_quality,
298        })
299    }
300}
301
302/// Pre-process a regex string, replacing `CHAR{s<=N}` fuzzy quantifiers.
303///
304/// In Python's `regex` module, `{s<=N}` applies to the single preceding character
305/// (not to an entire literal sequence). For N >= 1, `CHAR{s<=N}` matches any
306/// single character, equivalent to `.`. For N == 0, it's an exact match (no-op).
307fn preprocess_fuzzy(pattern_str: &str) -> Result<String, ExtractError> {
308    let mut result = String::with_capacity(pattern_str.len());
309    let bytes = pattern_str.as_bytes();
310    let len = bytes.len();
311    let mut i = 0;
312
313    while i < len {
314        if bytes[i] == b'{'
315            && i + 4 < len
316            && bytes[i + 1] == b's'
317            && bytes[i + 2] == b'<'
318            && bytes[i + 3] == b'='
319        {
320            let num_start = i + 4;
321            let mut num_end = num_start;
322            while num_end < len && bytes[num_end].is_ascii_digit() {
323                num_end += 1;
324            }
325            if num_end == num_start || num_end >= len || bytes[num_end] != b'}' {
326                return Err(ExtractError::InvalidPattern(format!(
327                    "malformed fuzzy quantifier at position {i}"
328                )));
329            }
330            let max_subs: usize = std::str::from_utf8(&bytes[num_start..num_end])
331                .expect("ASCII digits validated above")
332                .parse()
333                .expect("ASCII digits validated above");
334
335            if result.is_empty() {
336                return Err(ExtractError::InvalidPattern(format!(
337                    "fuzzy quantifier at position {i} has no preceding character"
338                )));
339            }
340
341            if max_subs >= 1 {
342                // Replace the preceding character with '.' (any character)
343                result.pop();
344                result.push('.');
345            }
346            // For max_subs == 0, keep the character as-is (exact match)
347
348            i = num_end + 1;
349        } else {
350            result.push(bytes[i] as char);
351            i += 1;
352        }
353    }
354
355    Ok(result)
356}
357
358/// If `positions` is a contiguous ascending sequence [a, a+1, ..., b-1], return Some(a..b).
359fn as_contiguous_range(positions: &[usize]) -> Option<Range<usize>> {
360    let start = *positions.first()?;
361    let is_contiguous = positions
362        .iter()
363        .enumerate()
364        .skip(1)
365        .all(|(i, &pos)| pos == start + i);
366    is_contiguous.then(|| start..start + positions.len())
367}
368
369fn extract_slice(source: &[u8], range: Option<&Range<usize>>, positions: &[usize]) -> Vec<u8> {
370    range.map_or_else(
371        || positions.iter().map(|&i| source[i]).collect(),
372        |r| source[r.clone()].to_vec(),
373    )
374}
375
376fn join_slices(a: &[u8], b: &[u8]) -> Vec<u8> {
377    let mut result = Vec::with_capacity(a.len() + b.len());
378    result.extend_from_slice(a);
379    result.extend_from_slice(b);
380    result
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    // --- StringPattern tests ---
388
389    #[test]
390    fn parse_valid_pattern() {
391        let pat = StringPattern::parse("NNNXXXXNN", PrimeEnd::Five).unwrap();
392        assert_eq!(pat.umi_positions, vec![0, 1, 2, 7, 8]);
393        assert_eq!(pat.sample_positions, vec![3, 4, 5, 6]);
394        assert!(pat.cell_positions.is_empty());
395        assert_eq!(pat.pattern_length, 9);
396    }
397
398    #[test]
399    fn parse_pattern_with_cell() {
400        let pat = StringPattern::parse("CCCNNNNXXXX", PrimeEnd::Five).unwrap();
401        assert_eq!(pat.cell_positions, vec![0, 1, 2]);
402        assert_eq!(pat.umi_positions, vec![3, 4, 5, 6]);
403        assert_eq!(pat.sample_positions, vec![7, 8, 9, 10]);
404    }
405
406    #[test]
407    fn parse_invalid_pattern() {
408        assert!(StringPattern::parse("NNNZXXNN", PrimeEnd::Five).is_err());
409        assert!(StringPattern::parse("", PrimeEnd::Five).is_err());
410    }
411
412    #[test]
413    fn extract_5prime_nnnxxxxnn() {
414        let pat = StringPattern::parse("NNNXXXXNN", PrimeEnd::Five).unwrap();
415        let seq = b"CAGGTTCAATCTCGGTGGGACCTC";
416        let qual = b"1=DFFFFHHHHHJJJFGIJIJJIJ";
417
418        let result = pat.extract(seq, qual).unwrap();
419
420        assert_eq!(result.umi, b"CAGAA");
421        assert_eq!(result.umi_quality, b"1=DHH");
422        assert!(result.cell_barcode.is_empty());
423        assert_eq!(result.trimmed_sequence, b"GTTCTCTCGGTGGGACCTC");
424        assert_eq!(result.trimmed_quality, b"FFFFHHHJJJFGIJIJJIJ");
425    }
426
427    #[test]
428    fn extract_read_too_short() {
429        let pat = StringPattern::parse("NNNXXXXNN", PrimeEnd::Five).unwrap();
430        assert!(pat.extract(b"ACGT", b"IIII").is_err());
431    }
432
433    #[test]
434    fn extract_3prime() {
435        let pat = StringPattern::parse("NNXX", PrimeEnd::Three).unwrap();
436        let seq = b"ACGTAATTGG";
437        let qual = b"IIIIIIIIII";
438
439        let result = pat.extract(seq, qual).unwrap();
440
441        assert_eq!(result.umi, b"TT");
442        assert_eq!(result.trimmed_sequence, b"ACGTAAGG");
443    }
444
445    // --- RegexPattern tests ---
446
447    #[test]
448    fn regex_parse_valid() {
449        let pat = RegexPattern::parse(r"^(?P<umi_1>.{3}).{4}(?P<umi_2>.{2})").unwrap();
450        assert!(pat.pattern.is_match("CAGGTTCAATCTCGGTGGGACCTC"));
451    }
452
453    #[test]
454    fn regex_parse_no_barcode_groups() {
455        assert!(RegexPattern::parse(r"^(.{3}).{4}(.{2})").is_err());
456    }
457
458    #[test]
459    fn regex_parse_invalid_regex() {
460        assert!(RegexPattern::parse(r"^(?P<umi_1>.{3").is_err());
461    }
462
463    #[test]
464    fn regex_extract_equivalent_to_string() {
465        // Regex ^(?P<umi_1>.{3}).{4}(?P<umi_2>.{2}) should produce same result as NNNXXXXNN
466        let string_pat = StringPattern::parse("NNNXXXXNN", PrimeEnd::Five).unwrap();
467        let regex_pat = RegexPattern::parse(r"^(?P<umi_1>.{3}).{4}(?P<umi_2>.{2})").unwrap();
468
469        let seq = b"CAGGTTCAATCTCGGTGGGACCTC";
470        let qual = b"1=DFFFFHHHHHJJJFGIJIJJIJ";
471
472        let string_result = string_pat.extract(seq, qual).unwrap();
473        let regex_result = regex_pat.extract(seq, qual).unwrap();
474
475        assert_eq!(string_result.umi, regex_result.umi);
476        assert_eq!(string_result.cell_barcode, regex_result.cell_barcode);
477        assert_eq!(
478            string_result.trimmed_sequence,
479            regex_result.trimmed_sequence
480        );
481        assert_eq!(string_result.trimmed_quality, regex_result.trimmed_quality);
482    }
483
484    #[test]
485    fn regex_extract_with_cell() {
486        let pat =
487            RegexPattern::parse(r"^(?P<cell_1>.{3})(?P<umi_1>.{4})(?P<discard_1>.{2})").unwrap();
488
489        let seq = b"ABCDEFGHIJKLM";
490        let qual = b"1234567890ABC";
491
492        let result = pat.extract(seq, qual).unwrap();
493
494        assert_eq!(result.cell_barcode, b"ABC");
495        assert_eq!(result.umi, b"DEFG");
496        // Positions 0-8 extracted/discarded, remaining: JKLM (positions 9-12)
497        assert_eq!(result.trimmed_sequence, b"JKLM");
498        assert_eq!(result.trimmed_quality, b"0ABC");
499    }
500
501    #[test]
502    fn regex_no_match() {
503        let pat = RegexPattern::parse(r"^(?P<umi_1>ZZZZZ)").unwrap();
504        let result = pat.extract(b"ACGTACGT", b"IIIIIIII");
505        assert!(matches!(result, Err(ExtractError::RegexNoMatch)));
506    }
507}