Skip to main content

rustalign_fmindex/
reference.rs

1//! Reference sequence storage and extraction
2//!
3//! This module handles loading the reference sequence from .3.rai and .4.rai
4//! files and extracting stretches for Smith-Waterman alignment.
5
6use memmap2::Mmap;
7use rustalign_common::{AlignError, AlignResult, Nuc};
8use std::fs::File;
9use std::path::Path;
10
11/// A record describing an unambiguous stretch of the reference
12#[derive(Debug, Clone)]
13struct RefRecord {
14    /// Number of ambiguous (N) characters before this stretch
15    off: u32,
16    /// Length of the unambiguous stretch
17    len: u32,
18    /// Whether this is the first record for a reference sequence
19    #[allow(dead_code)]
20    first: bool,
21}
22
23/// Reference sequence storage
24///
25/// Loads and stores reference sequences from .3.rai and .4.rai RustAlign index files.
26/// The reference is stored in bit-packed format (2 bits per nucleotide).
27pub struct Reference {
28    /// Records describing unambiguous stretches
29    recs: Vec<RefRecord>,
30
31    /// Cumulative count of unambiguous characters up to each record
32    cum_unambig: Vec<u64>,
33
34    /// Approximate lengths of reference sequences (total chars including Ns)
35    ref_lens: Vec<u64>,
36
37    /// Buffer begin offsets per reference sequence (unambiguous chars only)
38    ref_offs: Vec<u64>,
39
40    /// Record begin/end indices per reference sequence
41    ref_rec_offs: Vec<(usize, usize)>,
42
43    /// Number of reference sequences
44    n_refs: u32,
45
46    /// Total size of buffer needed
47    buf_sz: u64,
48}
49
50impl Reference {
51    /// Load reference from .3.rai and .4.rai files
52    ///
53    /// # Arguments
54    /// * `path_base` - Path to the index base (without extensions)
55    pub fn load<P: AsRef<Path>>(path_base: P) -> AlignResult<Self> {
56        let path_base = path_base.as_ref();
57        let base_name = path_base
58            .file_name()
59            .ok_or_else(|| AlignError::InvalidFormat("Invalid index path".into()))?
60            .to_str()
61            .unwrap();
62
63        let file_3 = path_base.with_file_name(format!("{}.3.rai", base_name));
64
65        if !file_3.exists() {
66            return Err(AlignError::IndexCorrupted {
67                path: path_base.to_path_buf(),
68            });
69        }
70
71        // Load .3.rai (metadata - RefRecords)
72        let file3 = File::open(&file_3)?;
73        let mmap3 = unsafe { Mmap::map(&file3)? };
74
75        Self::parse_from_mmap(&mmap3)
76    }
77
78    /// Parse reference from memory-mapped .3.rai file
79    fn parse_from_mmap(meta: &[u8]) -> AlignResult<Self> {
80        use byteorder::{LittleEndian, ReadBytesExt};
81        use std::io::Cursor;
82
83        if meta.len() < 8 {
84            return Err(AlignError::InvalidFormat(
85                "Reference metadata too small".into(),
86            ));
87        }
88
89        let mut cursor = Cursor::new(meta);
90
91        // Read endian hint
92        let endian_hint = cursor
93            .read_i32::<LittleEndian>()
94            .map_err(|e| AlignError::InvalidFormat(format!("Failed to read endian hint: {}", e)))?;
95
96        if endian_hint != 1 {
97            return Err(AlignError::InvalidFormat(format!(
98                "Invalid endian hint: {} (expected 1)",
99                endian_hint
100            )));
101        }
102
103        // Read number of records (sz)
104        let num_records = cursor
105            .read_u32::<LittleEndian>()
106            .map_err(|e| AlignError::InvalidFormat(format!("Failed to read record count: {}", e)))?
107            as usize;
108
109        if num_records == 0 {
110            return Err(AlignError::InvalidFormat(
111                "Number of reference records is 0".into(),
112            ));
113        }
114
115        // Parse RefRecords
116        // Each record is: off (u32), len (u32), first (u8)
117        let mut recs = Vec::with_capacity(num_records);
118        let mut cum_unambig = Vec::with_capacity(num_records);
119        let mut ref_lens = Vec::new();
120        let mut ref_offs = Vec::new();
121        let mut ref_rec_offs = Vec::new();
122
123        let mut n_refs = 0u32;
124        let mut cumsz: u64 = 0; // cumulative unambiguous chars
125        let mut cumlen: u64 = 0; // cumulative total chars for current ref
126
127        for _ in 0..num_records {
128            let pos = cursor.position() as usize;
129            if pos + 9 > meta.len() {
130                break;
131            }
132
133            // Read RefRecord: off (u32), len (u32), first (u8)
134            let off = cursor.read_u32::<LittleEndian>().map_err(|e| {
135                AlignError::InvalidFormat(format!("Failed to read record off: {}", e))
136            })?;
137
138            let len = cursor.read_u32::<LittleEndian>().map_err(|e| {
139                AlignError::InvalidFormat(format!("Failed to read record len: {}", e))
140            })?;
141
142            let first_byte = cursor.read_u8().map_err(|e| {
143                AlignError::InvalidFormat(format!("Failed to read record first: {}", e))
144            })?;
145            let first = first_byte != 0;
146
147            if first {
148                // This is the first record for a new reference sequence
149                ref_rec_offs.push((recs.len(), recs.len()));
150                ref_offs.push(cumsz);
151                if n_refs > 0 {
152                    // Close out the previous reference
153                    ref_lens.push(cumlen);
154                    ref_rec_offs[n_refs as usize - 1].1 = recs.len();
155                }
156                cumlen = 0;
157                n_refs += 1;
158            } else if recs.is_empty() {
159                return Err(AlignError::InvalidFormat(
160                    "First record in reference index was not marked as 'first'".into(),
161                ));
162            }
163
164            cum_unambig.push(cumsz);
165            cumsz += len as u64;
166            cumlen += off as u64; // count ambiguous chars before
167            cumlen += len as u64; // count unambiguous chars
168
169            recs.push(RefRecord { off, len, first });
170        }
171
172        // Store cap entries for the last reference
173        if n_refs > 0 {
174            ref_rec_offs[n_refs as usize - 1].1 = recs.len();
175        }
176        ref_rec_offs.push((recs.len(), recs.len()));
177        ref_offs.push(cumsz);
178        ref_lens.push(cumlen);
179        cum_unambig.push(cumsz);
180
181        Ok(Self {
182            recs,
183            cum_unambig,
184            ref_lens,
185            ref_offs,
186            ref_rec_offs,
187            n_refs,
188            buf_sz: cumsz,
189        })
190    }
191
192    /// Get the number of reference sequences
193    pub fn n_refs(&self) -> u32 {
194        self.n_refs
195    }
196
197    /// Get the approximate length of a reference sequence
198    pub fn approx_len(&self, tidx: u32) -> u64 {
199        if (tidx as usize) < self.ref_lens.len() {
200            self.ref_lens[tidx as usize]
201        } else {
202            0
203        }
204    }
205
206    /// Extract a stretch of reference sequence
207    ///
208    /// # Arguments
209    /// * `buf` - The bit-packed reference buffer from .4.rai file
210    /// * `tidx` - Reference sequence index (chromosome index)
211    /// * `toff` - Offset within the reference sequence
212    /// * `count` - Number of bases to extract
213    ///
214    /// # Returns
215    /// A vector of nucleotides
216    pub fn get_stretch(
217        &self,
218        buf: &[u8],
219        tidx: u32,
220        toff: u64,
221        count: usize,
222    ) -> AlignResult<Vec<Nuc>> {
223        if (tidx as usize) >= self.ref_rec_offs.len() {
224            return Err(AlignError::InvalidArgument(format!(
225                "Invalid reference index: {}",
226                tidx
227            )));
228        }
229
230        let (rec_begin, rec_end) = self.ref_rec_offs[tidx as usize];
231
232        // Walk through records to find the right position
233        let mut cur_off: u64 = 0; // Current offset within this reference
234        let mut buf_off: u64 = self.ref_offs.get(tidx as usize).copied().unwrap_or(0);
235
236        for reci in rec_begin..rec_end {
237            let rec = &self.recs[reci];
238            cur_off += rec.off as u64; // Skip ambiguous chars
239
240            // Check if toff falls within this record's unambiguous stretch
241            if toff < cur_off + rec.len as u64 {
242                // Found the right record
243                if toff >= cur_off {
244                    // toff is within the unambiguous part
245                    let within_rec = toff - cur_off;
246                    buf_off += within_rec;
247
248                    // Calculate byte/bit offset in bit-packed buffer
249                    let byte_off = (buf_off / 4) as usize;
250                    let bit_off = (buf_off % 4) * 2;
251
252                    // Extract nucleotides
253                    let mut result = Vec::with_capacity(count);
254                    let mut current_byte_off = byte_off;
255                    let mut current_bit_off = bit_off as u8;
256                    let mut remaining_in_rec = (rec.len as u64 - within_rec) as usize;
257                    let mut next_rec = reci + 1;
258
259                    for _ in 0..count {
260                        if remaining_in_rec == 0 {
261                            // Need to move to next record
262                            if next_rec < rec_end {
263                                // Skip ambiguous chars (treat as N)
264                                let next_rec_obj = &self.recs[next_rec];
265                                if next_rec_obj.off > 0 {
266                                    result.push(Nuc::N);
267                                    // For simplicity, just continue with Ns for ambiguous regions
268                                    // In a full implementation, we'd handle this more carefully
269                                }
270                                remaining_in_rec = next_rec_obj.len as usize;
271                                current_byte_off = (self.cum_unambig[next_rec] / 4) as usize;
272                                current_bit_off = (self.cum_unambig[next_rec] % 4 * 2) as u8;
273                                next_rec += 1;
274                            } else {
275                                // End of reference
276                                result.push(Nuc::N);
277                                continue;
278                            }
279                        }
280
281                        if current_byte_off >= buf.len() {
282                            result.push(Nuc::N);
283                            continue;
284                        }
285
286                        let byte = buf[current_byte_off];
287                        let nuc_code = (byte >> current_bit_off) & 0x03;
288
289                        let nuc = match nuc_code {
290                            0 => Nuc::A,
291                            1 => Nuc::C,
292                            2 => Nuc::G,
293                            3 => Nuc::T,
294                            _ => Nuc::N,
295                        };
296                        result.push(nuc);
297
298                        // Advance to next nucleotide
299                        current_bit_off += 2;
300                        if current_bit_off >= 8 {
301                            current_bit_off = 0;
302                            current_byte_off += 1;
303                        }
304                        remaining_in_rec -= 1;
305                    }
306
307                    return Ok(result);
308                }
309            }
310
311            cur_off += rec.len as u64;
312            buf_off += rec.len as u64;
313        }
314
315        // toff is in an ambiguous region at the end
316        Ok(vec![Nuc::N; count])
317    }
318
319    /// Get the buffer size needed
320    pub fn buf_sz(&self) -> u64 {
321        self.buf_sz
322    }
323}
324
325/// Combined structure for FM-index with reference
326#[allow(dead_code)]
327pub struct EbwtWithRef {
328    /// The FM-index
329    pub ebwt: super::Ebwt,
330    /// The reference sequence metadata
331    pub reference: Reference,
332    /// The memory-mapped reference buffer
333    ref_buf: Mmap,
334    /// Mapping from FM-index chromosome index to Reference chromosome index
335    /// FM-index uses rstarts array ordering, Reference uses file ordering
336    fm_to_ref_chr_map: Vec<u32>,
337}
338
339impl EbwtWithRef {
340    /// Load both FM-index and reference
341    pub fn load<P: AsRef<Path>>(path_base: P) -> AlignResult<Self> {
342        let path_base = path_base.as_ref();
343        let base_name = path_base
344            .file_name()
345            .ok_or_else(|| AlignError::InvalidFormat("Invalid index path".into()))?
346            .to_str()
347            .unwrap();
348
349        // Load the FM-index
350        let ebwt = super::Ebwt::load(path_base)?;
351
352        // Load the reference metadata
353        let reference = Reference::load(path_base)?;
354
355        // Memory-map the .4.rai file for the reference buffer
356        let file_4 = path_base.with_file_name(format!("{}.4.rai", base_name));
357        let file = File::open(&file_4)?;
358        let ref_buf = unsafe { Mmap::map(&file)? };
359
360        Ok(Self {
361            ebwt,
362            reference,
363            ref_buf,
364            fm_to_ref_chr_map: Vec::new(), // Will build below
365        })
366    }
367
368    /// Extract reference sequence for a given chromosome and position
369    ///
370    /// Note: pos should be 0-based position within the chromosome
371    pub fn get_reference(&self, chr_idx: u32, pos: u64, len: usize) -> AlignResult<Vec<Nuc>> {
372        self.reference.get_stretch(&self.ref_buf, chr_idx, pos, len)
373    }
374
375    /// Get the underlying Ebwt
376    pub fn ebwt(&self) -> &super::Ebwt {
377        &self.ebwt
378    }
379}