Skip to main content

umi_core/
extract.rs

1use std::collections::{HashMap, HashSet};
2use std::io::{BufWriter, Write};
3
4use needletail::parser::{FastqReader, FastxReader, SequenceRecord};
5
6use crate::error::ExtractError;
7use crate::pattern::{BarcodePattern, ExtractionResult};
8
9/// Returns `true` if any base in `umi_quality` falls below the threshold after
10/// subtracting the encoding offset.
11fn fails_quality_filter(umi_quality: &[u8], threshold: u8, offset: u8) -> bool {
12    umi_quality
13        .iter()
14        .any(|&q| q.saturating_sub(offset) < threshold)
15}
16
17/// Quality score encoding scheme.
18#[derive(Debug, Clone, Copy, Default)]
19pub enum QualityEncoding {
20    #[default]
21    Phred33,
22    Phred64,
23    Solexa,
24}
25
26impl QualityEncoding {
27    #[must_use]
28    pub const fn offset(self) -> u8 {
29        match self {
30            Self::Phred33 => 33,
31            Self::Phred64 => 64,
32            Self::Solexa => 59,
33        }
34    }
35}
36
37/// Configuration for the extract command.
38#[derive(Debug, Clone)]
39pub struct ExtractConfig {
40    pub pattern: Option<BarcodePattern>,
41    pub pattern2: Option<BarcodePattern>,
42    pub umi_separator: u8,
43    pub quality_filter_threshold: Option<u8>,
44    pub quality_encoding: QualityEncoding,
45    pub whitelist: Option<HashSet<Vec<u8>>>,
46    pub correction_map: Option<HashMap<Vec<u8>, Vec<u8>>>,
47    pub blacklist: Option<HashSet<Vec<u8>>>,
48    pub ignore_read_pair_suffixes: bool,
49    pub reconcile_pairs: bool,
50}
51
52/// Statistics from an extraction run.
53#[derive(Debug, Default)]
54pub struct ExtractStats {
55    pub input_reads: u64,
56    pub output_reads: u64,
57    pub too_short: u64,
58    pub no_match: u64,
59    pub quality_filtered: u64,
60    pub whitelist_filtered: u64,
61    pub both_matched: u64,
62}
63
64/// Extract UMIs from FASTQ reads, writing modified reads to `output`.
65///
66/// # Errors
67/// Returns error on I/O or parse failures.
68pub fn extract_reads<R: std::io::Read + Send, W: Write>(
69    config: &ExtractConfig,
70    input: R,
71    output: W,
72) -> Result<ExtractStats, ExtractError> {
73    let pattern = config.pattern.as_ref().ok_or_else(|| {
74        ExtractError::InvalidPattern("no pattern provided for single-end extraction".into())
75    })?;
76
77    let mut stats = ExtractStats::default();
78    let mut writer = BufWriter::with_capacity(64 * 1024, output);
79    let mut reader = FastqReader::new(input);
80
81    while let Some(result) = reader.next() {
82        let record = result.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
83        stats.input_reads += 1;
84
85        match process_record(&record, pattern, config.umi_separator) {
86            Ok(processed) => {
87                if let Some(threshold) = config.quality_filter_threshold
88                    && fails_quality_filter(
89                        &processed.umi_quality,
90                        threshold,
91                        config.quality_encoding.offset(),
92                    )
93                {
94                    stats.quality_filtered += 1;
95                    continue;
96                }
97                write_fastq_record(&mut writer, &processed.id, &processed.seq, &processed.qual)?;
98                stats.output_reads += 1;
99            }
100            Err(ExtractError::ReadTooShort { .. }) => {
101                stats.too_short += 1;
102            }
103            Err(ExtractError::RegexNoMatch) => {
104                stats.no_match += 1;
105            }
106            Err(e) => return Err(e),
107        }
108    }
109
110    writer.flush()?;
111    Ok(stats)
112}
113
114struct ProcessedRecord {
115    id: Vec<u8>,
116    seq: Vec<u8>,
117    qual: Vec<u8>,
118    umi_quality: Vec<u8>,
119}
120
121fn process_record(
122    record: &SequenceRecord,
123    pattern: &BarcodePattern,
124    umi_separator: u8,
125) -> Result<ProcessedRecord, ExtractError> {
126    let seq = record.seq();
127    let qual = record
128        .qual()
129        .ok_or_else(|| ExtractError::FastqParse("missing quality scores in FASTQ record".into()))?;
130
131    let result = pattern.extract(&seq, qual)?;
132
133    let id = build_read_name(
134        record.id(),
135        &result.cell_barcode,
136        &result.umi,
137        umi_separator,
138        false,
139    );
140
141    Ok(ProcessedRecord {
142        id,
143        seq: result.trimmed_sequence,
144        qual: result.trimmed_quality,
145        umi_quality: result.umi_quality,
146    })
147}
148
149/// Return the name portion of a FASTQ header (before first space).
150fn read_name(header: &[u8]) -> &[u8] {
151    header
152        .iter()
153        .position(|&b| b == b' ')
154        .map_or(header, |pos| &header[..pos])
155}
156
157/// Strip trailing `/1` or `/2` from a read name.
158fn strip_pair_suffix(name: &[u8]) -> &[u8] {
159    if name.len() >= 2 && name[name.len() - 2] == b'/' {
160        let last = name[name.len() - 1];
161        if last == b'1' || last == b'2' {
162            return &name[..name.len() - 2];
163        }
164    }
165    name
166}
167
168/// Build a new read identifier with barcode(s) inserted after the read name.
169///
170/// Splits header at first space: `NAME COMMENT` → `NAME{sep}UMI COMMENT`.
171/// If `strip_suffixes` is true, strips trailing `/1` or `/2` from the name.
172fn build_read_name(
173    header: &[u8],
174    cell: &[u8],
175    umi: &[u8],
176    separator: u8,
177    strip_suffixes: bool,
178) -> Vec<u8> {
179    let (name, comment) = header
180        .iter()
181        .position(|&b| b == b' ')
182        .map_or((header, None), |pos| (&header[..pos], Some(&header[pos..])));
183
184    let name = if strip_suffixes {
185        strip_pair_suffix(name)
186    } else {
187        name
188    };
189
190    let mut out = Vec::with_capacity(header.len() + 1 + cell.len() + 1 + umi.len());
191    out.extend_from_slice(name);
192    if !cell.is_empty() {
193        out.push(separator);
194        out.extend_from_slice(cell);
195    }
196    out.push(separator);
197    out.extend_from_slice(umi);
198    if let Some(c) = comment {
199        out.extend_from_slice(c);
200    }
201    out
202}
203
204fn write_fastq_record<W: Write>(
205    writer: &mut W,
206    id: &[u8],
207    seq: &[u8],
208    qual: &[u8],
209) -> Result<(), ExtractError> {
210    writer.write_all(b"@")?;
211    writer.write_all(id)?;
212    writer.write_all(b"\n")?;
213    writer.write_all(seq)?;
214    writer.write_all(b"\n+\n")?;
215    writer.write_all(qual)?;
216    writer.write_all(b"\n")?;
217    Ok(())
218}
219
220/// Extract UMIs from paired-end FASTQ reads (read2-only pattern mode).
221///
222/// Pattern is applied to read2 only. UMI is appended to both read names.
223/// Read1 is written untrimmed to `output1`, read2 is written trimmed to `output2`.
224///
225/// # Errors
226/// Returns error on I/O failures, parse errors, or mismatched read counts.
227pub fn extract_reads_paired<R1, R2, W1, W2>(
228    config: &ExtractConfig,
229    input1: R1,
230    input2: R2,
231    output1: W1,
232    output2: W2,
233) -> Result<ExtractStats, ExtractError>
234where
235    R1: std::io::Read + Send,
236    R2: std::io::Read + Send,
237    W1: Write,
238    W2: Write,
239{
240    let pattern2 = config.pattern2.as_ref().ok_or_else(|| {
241        ExtractError::InvalidPattern("no pattern2 provided for paired-end extraction".into())
242    })?;
243
244    let mut stats = ExtractStats::default();
245    let mut writer1 = BufWriter::with_capacity(64 * 1024, output1);
246    let mut writer2 = BufWriter::with_capacity(64 * 1024, output2);
247    let mut reader1 = FastqReader::new(input1);
248    let mut reader2 = FastqReader::new(input2);
249
250    loop {
251        let rec1 = reader1.next();
252        let rec2 = reader2.next();
253
254        match (rec1, rec2) {
255            (Some(r1), Some(r2)) => {
256                let r1 = r1.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
257                let r2 = r2.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
258                stats.input_reads += 1;
259
260                let r2_seq = r2.seq();
261                let r2_qual = r2.qual().ok_or_else(|| {
262                    ExtractError::FastqParse("missing quality scores in read2".into())
263                })?;
264
265                let extraction = match pattern2.extract(&r2_seq, r2_qual) {
266                    Ok(result) => result,
267                    Err(ExtractError::ReadTooShort { .. }) => {
268                        stats.too_short += 1;
269                        continue;
270                    }
271                    Err(ExtractError::RegexNoMatch) => {
272                        stats.no_match += 1;
273                        continue;
274                    }
275                    Err(e) => return Err(e),
276                };
277
278                if let Some(threshold) = config.quality_filter_threshold
279                    && fails_quality_filter(
280                        &extraction.umi_quality,
281                        threshold,
282                        config.quality_encoding.offset(),
283                    )
284                {
285                    stats.quality_filtered += 1;
286                    continue;
287                }
288
289                let r1_id = build_read_name(
290                    r1.id(),
291                    &extraction.cell_barcode,
292                    &extraction.umi,
293                    config.umi_separator,
294                    false,
295                );
296                let r2_id = build_read_name(
297                    r2.id(),
298                    &extraction.cell_barcode,
299                    &extraction.umi,
300                    config.umi_separator,
301                    false,
302                );
303
304                // Read1: untrimmed, with new read name
305                let r1_seq = r1.seq();
306                let r1_qual = r1.qual().ok_or_else(|| {
307                    ExtractError::FastqParse("missing quality scores in read1".into())
308                })?;
309                write_fastq_record(&mut writer1, &r1_id, &r1_seq, r1_qual)?;
310
311                // Read2: trimmed, with new read name
312                write_fastq_record(
313                    &mut writer2,
314                    &r2_id,
315                    &extraction.trimmed_sequence,
316                    &extraction.trimmed_quality,
317                )?;
318
319                stats.output_reads += 1;
320            }
321            (None, None) => break,
322            _ => {
323                return Err(ExtractError::FastqParse(
324                    "read1 and read2 files have different numbers of records".into(),
325                ));
326            }
327        }
328    }
329
330    writer1.flush()?;
331    writer2.flush()?;
332    Ok(stats)
333}
334
335/// Process a single read pair in the r1-pattern extraction path.
336///
337/// Returns `true` if the pair produced output, `false` if filtered/skipped.
338fn process_r1_pattern_pair<W: Write>(
339    r1: &SequenceRecord,
340    r2: &SequenceRecord,
341    pattern: &BarcodePattern,
342    config: &ExtractConfig,
343    stats: &mut ExtractStats,
344    writer: &mut W,
345) -> Result<bool, ExtractError> {
346    let r1_seq = r1.seq();
347    let r1_qual = r1
348        .qual()
349        .ok_or_else(|| ExtractError::FastqParse("missing quality scores in read1".into()))?;
350
351    let extraction = match pattern.extract(&r1_seq, r1_qual) {
352        Ok(result) => result,
353        Err(ExtractError::ReadTooShort { .. }) => {
354            stats.too_short += 1;
355            return Ok(false);
356        }
357        Err(ExtractError::RegexNoMatch) => {
358            stats.no_match += 1;
359            return Ok(false);
360        }
361        Err(e) => return Err(e),
362    };
363
364    if let Some(threshold) = config.quality_filter_threshold
365        && fails_quality_filter(
366            &extraction.umi_quality,
367            threshold,
368            config.quality_encoding.offset(),
369        )
370    {
371        stats.quality_filtered += 1;
372        return Ok(false);
373    }
374
375    if let Some(ref blacklist) = config.blacklist
376        && blacklist.contains(&extraction.cell_barcode)
377    {
378        stats.whitelist_filtered += 1;
379        return Ok(false);
380    }
381
382    let cell_barcode = if let Some(ref whitelist) = config.whitelist {
383        if whitelist.contains(&extraction.cell_barcode) {
384            extraction.cell_barcode.clone()
385        } else if let Some(ref correction_map) = config.correction_map
386            && let Some(corrected) = correction_map.get(&extraction.cell_barcode)
387        {
388            corrected.clone()
389        } else {
390            stats.whitelist_filtered += 1;
391            return Ok(false);
392        }
393    } else {
394        extraction.cell_barcode.clone()
395    };
396
397    let r2_id = build_read_name(
398        r2.id(),
399        &cell_barcode,
400        &extraction.umi,
401        config.umi_separator,
402        config.ignore_read_pair_suffixes,
403    );
404
405    let r2_seq = r2.seq();
406    let r2_qual = r2
407        .qual()
408        .ok_or_else(|| ExtractError::FastqParse("missing quality scores in read2".into()))?;
409    write_fastq_record(writer, &r2_id, &r2_seq, r2_qual)?;
410
411    stats.output_reads += 1;
412    Ok(true)
413}
414
415/// Write original untrimmed reads to filtered output files (headers NOT modified).
416fn write_filtered_pair<W: Write>(
417    r1: &SequenceRecord,
418    r2: &SequenceRecord,
419    filt_writer1: &mut Option<BufWriter<W>>,
420    filt_writer2: &mut Option<BufWriter<W>>,
421) -> Result<(), ExtractError> {
422    if let Some(fw) = filt_writer1.as_mut() {
423        let r1_qual = r1
424            .qual()
425            .ok_or_else(|| ExtractError::FastqParse("missing quality scores in read1".into()))?;
426        write_fastq_record(fw, r1.id(), &r1.seq(), r1_qual)?;
427    }
428    if let Some(fw) = filt_writer2.as_mut() {
429        let r2_qual = r2
430            .qual()
431            .ok_or_else(|| ExtractError::FastqParse("missing quality scores in read2".into()))?;
432        write_fastq_record(fw, r2.id(), &r2.seq(), r2_qual)?;
433    }
434    Ok(())
435}
436
437/// Extract UMIs from paired-end FASTQ reads (read1-pattern mode with read2 output).
438///
439/// Pattern is applied to read1 to extract cell barcode + UMI. Read2 is written
440/// untrimmed to `output` with the cell+UMI appended to read2's header.
441/// Reads whose cell barcode is not in the whitelist (if provided) are discarded.
442///
443/// # Errors
444/// Returns error on I/O failures, parse errors, or mismatched read counts.
445pub fn extract_reads_paired_r1_pattern<R1, R2, W>(
446    config: &ExtractConfig,
447    input1: R1,
448    input2: R2,
449    output: W,
450    filtered_out1: Option<Box<dyn Write>>,
451    filtered_out2: Option<Box<dyn Write>>,
452) -> Result<ExtractStats, ExtractError>
453where
454    R1: std::io::Read + Send,
455    R2: std::io::Read + Send,
456    W: Write,
457{
458    let pattern = config.pattern.as_ref().ok_or_else(|| {
459        ExtractError::InvalidPattern(
460            "no pattern provided for paired-end read1-pattern extraction".into(),
461        )
462    })?;
463
464    let mut stats = ExtractStats::default();
465    let mut writer = BufWriter::with_capacity(64 * 1024, output);
466    let mut filt_writer1 = filtered_out1.map(BufWriter::new);
467    let mut filt_writer2 = filtered_out2.map(BufWriter::new);
468    let mut reader1 = FastqReader::new(input1);
469    let mut reader2 = FastqReader::new(input2);
470
471    if config.reconcile_pairs {
472        // Reconcile mode: read1 is a pre-filtered subset, read2 is the full set.
473        // Both files maintain original sequencing order. For each read1 record,
474        // advance read2 until a matching read name is found; skip unmatched read2s.
475        while let Some(r1_result) = reader1.next() {
476            let r1 = r1_result.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
477            stats.input_reads += 1;
478            let r1_name = read_name(r1.id());
479
480            loop {
481                match reader2.next() {
482                    Some(r2_result) => {
483                        let r2 = r2_result.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
484                        if read_name(r2.id()) == r1_name {
485                            let kept = process_r1_pattern_pair(
486                                &r1,
487                                &r2,
488                                pattern,
489                                config,
490                                &mut stats,
491                                &mut writer,
492                            )?;
493                            if !kept {
494                                write_filtered_pair(
495                                    &r1,
496                                    &r2,
497                                    &mut filt_writer1,
498                                    &mut filt_writer2,
499                                )?;
500                            }
501                            break;
502                        }
503                    }
504                    None => {
505                        return Err(ExtractError::FastqParse(format!(
506                            "read2 exhausted before finding match for read1: {}",
507                            String::from_utf8_lossy(r1_name)
508                        )));
509                    }
510                }
511            }
512        }
513    } else {
514        // Lockstep mode: read1 and read2 must have matching records in order.
515        loop {
516            let rec1 = reader1.next();
517            let rec2 = reader2.next();
518
519            match (rec1, rec2) {
520                (Some(r1), Some(r2)) => {
521                    let r1 = r1.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
522                    let r2 = r2.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
523                    stats.input_reads += 1;
524                    let kept = process_r1_pattern_pair(
525                        &r1,
526                        &r2,
527                        pattern,
528                        config,
529                        &mut stats,
530                        &mut writer,
531                    )?;
532                    if !kept {
533                        write_filtered_pair(&r1, &r2, &mut filt_writer1, &mut filt_writer2)?;
534                    }
535                }
536                (None, None) => break,
537                _ => {
538                    return Err(ExtractError::FastqParse(
539                        "read1 and read2 files have different numbers of records".into(),
540                    ));
541                }
542            }
543        }
544    }
545
546    writer.flush()?;
547    if let Some(fw) = filt_writer1.as_mut() {
548        fw.flush()?;
549    }
550    if let Some(fw) = filt_writer2.as_mut() {
551        fw.flush()?;
552    }
553    Ok(stats)
554}
555
556/// Try extracting with a pattern, returning `None` for recoverable failures (too short, no match).
557fn try_extract(
558    pattern: &BarcodePattern,
559    seq: &[u8],
560    qual: &[u8],
561) -> Result<Option<ExtractionResult>, ExtractError> {
562    match pattern.extract(seq, qual) {
563        Ok(result) => Ok(Some(result)),
564        Err(ExtractError::ReadTooShort { .. } | ExtractError::RegexNoMatch) => Ok(None),
565        Err(e) => Err(e),
566    }
567}
568
569/// Extract UMIs from paired-end FASTQ reads in either-read mode.
570///
571/// Both patterns are tried on their respective reads. If exactly one matches,
572/// the UMI is taken from that read (and only that read is trimmed). If both
573/// match, the pair is discarded (default `--either-read-resolve=discard`).
574/// If neither matches, the pair is discarded as `no_match`.
575///
576/// # Errors
577/// Returns error on I/O failures, parse errors, or mismatched read counts.
578#[allow(clippy::too_many_lines)]
579pub fn extract_reads_either_read<R1, R2, W1, W2>(
580    config: &ExtractConfig,
581    input1: R1,
582    input2: R2,
583    output1: W1,
584    output2: W2,
585) -> Result<ExtractStats, ExtractError>
586where
587    R1: std::io::Read + Send,
588    R2: std::io::Read + Send,
589    W1: Write,
590    W2: Write,
591{
592    let pattern1 = config.pattern.as_ref().ok_or_else(|| {
593        ExtractError::InvalidPattern("no pattern provided for either-read extraction".into())
594    })?;
595    let pattern2 = config.pattern2.as_ref().ok_or_else(|| {
596        ExtractError::InvalidPattern("no pattern2 provided for either-read extraction".into())
597    })?;
598
599    let mut stats = ExtractStats::default();
600    let mut writer1 = BufWriter::with_capacity(64 * 1024, output1);
601    let mut writer2 = BufWriter::with_capacity(64 * 1024, output2);
602    let mut reader1 = FastqReader::new(input1);
603    let mut reader2 = FastqReader::new(input2);
604
605    loop {
606        let rec1 = reader1.next();
607        let rec2 = reader2.next();
608
609        match (rec1, rec2) {
610            (Some(r1), Some(r2)) => {
611                let r1 = r1.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
612                let r2 = r2.map_err(|e| ExtractError::FastqParse(e.to_string()))?;
613                stats.input_reads += 1;
614
615                let r1_seq = r1.seq();
616                let r1_qual = r1.qual().ok_or_else(|| {
617                    ExtractError::FastqParse("missing quality scores in read1".into())
618                })?;
619                let r2_seq = r2.seq();
620                let r2_qual = r2.qual().ok_or_else(|| {
621                    ExtractError::FastqParse("missing quality scores in read2".into())
622                })?;
623
624                let r1_result = try_extract(pattern1, &r1_seq, r1_qual)?;
625                let r2_result = try_extract(pattern2, &r2_seq, r2_qual)?;
626
627                match (r1_result, r2_result) {
628                    (Some(_), Some(_)) => {
629                        stats.both_matched += 1;
630                    }
631                    (Some(extraction), None) => {
632                        if let Some(threshold) = config.quality_filter_threshold
633                            && fails_quality_filter(
634                                &extraction.umi_quality,
635                                threshold,
636                                config.quality_encoding.offset(),
637                            )
638                        {
639                            stats.quality_filtered += 1;
640                            continue;
641                        }
642
643                        // Both headers built from read1 (matches Python umi-tools behavior)
644                        let new_id = build_read_name(
645                            r1.id(),
646                            &extraction.cell_barcode,
647                            &extraction.umi,
648                            config.umi_separator,
649                            false,
650                        );
651
652                        // Read1: trimmed
653                        write_fastq_record(
654                            &mut writer1,
655                            &new_id,
656                            &extraction.trimmed_sequence,
657                            &extraction.trimmed_quality,
658                        )?;
659                        // Read2: untrimmed
660                        write_fastq_record(&mut writer2, &new_id, &r2_seq, r2_qual)?;
661
662                        stats.output_reads += 1;
663                    }
664                    (None, Some(extraction)) => {
665                        if let Some(threshold) = config.quality_filter_threshold
666                            && fails_quality_filter(
667                                &extraction.umi_quality,
668                                threshold,
669                                config.quality_encoding.offset(),
670                            )
671                        {
672                            stats.quality_filtered += 1;
673                            continue;
674                        }
675
676                        // Both headers built from read1 (matches Python umi-tools behavior)
677                        let new_id = build_read_name(
678                            r1.id(),
679                            &extraction.cell_barcode,
680                            &extraction.umi,
681                            config.umi_separator,
682                            false,
683                        );
684
685                        // Read1: untrimmed
686                        write_fastq_record(&mut writer1, &new_id, &r1_seq, r1_qual)?;
687                        // Read2: trimmed
688                        write_fastq_record(
689                            &mut writer2,
690                            &new_id,
691                            &extraction.trimmed_sequence,
692                            &extraction.trimmed_quality,
693                        )?;
694
695                        stats.output_reads += 1;
696                    }
697                    (None, None) => {
698                        stats.no_match += 1;
699                    }
700                }
701            }
702            (None, None) => break,
703            _ => {
704                return Err(ExtractError::FastqParse(
705                    "read1 and read2 files have different numbers of records".into(),
706                ));
707            }
708        }
709    }
710
711    writer1.flush()?;
712    writer2.flush()?;
713    Ok(stats)
714}