Skip to main content

rustalign_simd/
dp_matrix.rs

1//! Dynamic Programming Matrix for SIMD-accelerated alignment
2//!
3//! This module implements the SSEMatrix structure from aligner_swsse.h,
4//! which stores DP matrices in a SIMD-friendly layout.
5
6// SIMD wrapper types (unused in this module but available for future use)
7
8/// Matrix quartet indices matching C++ SSEMatrix
9///
10/// Each matrix element is a quartet of vectors:
11/// - E: gap from above (insertion)
12/// - F: gap from left (deletion)
13/// - H: diagonal (match/mismatch)
14/// - TMP: temporary/reserved space
15#[repr(usize)]
16pub enum QuartetIndex {
17    E = 0,
18    F = 1,
19    H = 2,
20    #[allow(clippy::upper_case_acronyms)]
21    TMP = 3,
22}
23
24/// SIMD-accelerated DP matrix
25///
26/// This matches the C++ SSEMatrix structure with the following layout:
27/// - Elements (cells) are packed into SSE register vectors
28/// - Vectors are packed into quartets (E, F, H, TMP)
29/// - Quartets are packed into columns
30///
31/// The layout is optimized for SIMD operations:
32/// ```text
33/// Column 0:     Column 1:     Column 2:
34/// +----+----+  +----+----+  +----+----+
35/// | E0 | H0 |  | E1 | H1 |  | E2 | H2 |
36/// +----+----+  +----+----+  +----+----+
37/// | F0 | TMP|  | F1 | TMP|  | F2 | TMP|
38/// +----+----+  +----+----+  +----+----+
39/// ```
40pub struct DpMatrix {
41    /// Number of vectors per cell (always 4 for E, F, H, TMP)
42    nvec_per_cell: usize,
43    /// Number of vector rows
44    nvecrow: usize,
45    /// Number of vector columns
46    nveccol: usize,
47    /// Row stride in vectors
48    rowstride: usize,
49    /// Column stride in vectors
50    colstride: usize,
51    /// Matrix buffer
52    matbuf: Vec<u8>,
53    /// Whether the matrix has been initialized
54    inited: bool,
55}
56
57impl DpMatrix {
58    /// Create a new DP matrix
59    pub fn new() -> Self {
60        Self {
61            nvec_per_cell: 4,
62            nvecrow: 0,
63            nveccol: 0,
64            rowstride: 0,
65            colstride: 0,
66            matbuf: Vec::new(),
67            inited: false,
68        }
69    }
70
71    /// Initialize the matrix with the given dimensions
72    ///
73    /// # Arguments
74    /// * `nrow` - Number of rows (query characters + 1)
75    /// * `ncol` - Number of columns (reference characters + 1)
76    /// * `nvecperrow` - Number of vectors per row
77    #[allow(clippy::manual_div_ceil)]
78    pub fn init(&mut self, nrow: usize, ncol: usize, nvecperrow: usize) {
79        // nvecperrow (wperv in C++) is the number of elements per vector
80        // Each cell has 4 vectors (E, F, H, TMP)
81        // nvecPerCol = (nrow + wperv - 1) / wperv = ceil(nrow / wperv)
82        self.nvecrow = (nrow + nvecperrow - 1) / nvecperrow; // nvecPerCol in C++
83        self.nveccol = ncol;
84        self.nvec_per_cell = 4;
85        self.colstride = 4; // nvecPerCell in C++
86        self.rowstride = nvecperrow * 4; // nper * nvecPerCell in C++
87
88        // Total vectors needed: (ncol + 1) * nvecPerCell * nvecPerCol
89        let total_vecs = (self.nveccol + 1) * self.nvec_per_cell * self.nvecrow;
90        let total_bytes = total_vecs * 16; // 16 bytes per SSE register
91
92        self.matbuf = vec![0u8; total_bytes];
93        self.inited = true;
94    }
95
96    /// Check if the matrix is initialized
97    pub fn is_initialized(&self) -> bool {
98        self.inited
99    }
100
101    /// Clear the matrix
102    pub fn clear(&mut self) {
103        self.matbuf.fill(0);
104    }
105
106    /// Get the matrix buffer as a pointer
107    #[inline]
108    pub fn ptr(&mut self) -> *mut u8 {
109        assert!(self.inited, "Matrix not initialized");
110        self.matbuf.as_mut_ptr()
111    }
112
113    /// Get the matrix buffer as a slice
114    #[inline]
115    pub fn as_slice(&self) -> &[u8] {
116        assert!(self.inited, "Matrix not initialized");
117        &self.matbuf
118    }
119
120    /// Get the matrix buffer as a mutable slice
121    #[inline]
122    pub fn as_mut_slice(&mut self) -> &mut [u8] {
123        assert!(self.inited, "Matrix not initialized");
124        &mut self.matbuf
125    }
126
127    /// Get a pointer to the E vector at the given row and column
128    ///
129    /// This matches the C++ evec() method.
130    #[inline]
131    pub fn evec_ptr(&self, row: usize, col: usize) -> usize {
132        assert!(
133            row < self.nvecrow,
134            "Row {} out of bounds (max {})",
135            row,
136            self.nvecrow
137        );
138        assert!(
139            col < self.nveccol,
140            "Col {} out of bounds (max {})",
141            col,
142            self.nveccol
143        );
144        let elt = row * self.rowstride + col * self.colstride + QuartetIndex::E as usize;
145        assert!(
146            elt * 16 < self.matbuf.len(),
147            "Element index {} out of bounds",
148            elt
149        );
150        elt * 16 // Byte offset
151    }
152
153    /// Get a pointer to the F vector at the given row and column
154    ///
155    /// This matches the C++ fvec() method.
156    #[inline]
157    pub fn fvec_ptr(&self, row: usize, col: usize) -> usize {
158        assert!(row < self.nvecrow, "Row {} out of bounds", row);
159        assert!(col < self.nveccol, "Col {} out of bounds", col);
160        let elt = row * self.rowstride + col * self.colstride + QuartetIndex::F as usize;
161        assert!(elt * 16 < self.matbuf.len());
162        elt * 16
163    }
164
165    /// Get a pointer to the H vector at the given row and column
166    ///
167    /// This matches the C++ hvec() method.
168    #[inline]
169    pub fn hvec_ptr(&self, row: usize, col: usize) -> usize {
170        assert!(row < self.nvecrow, "Row {} out of bounds", row);
171        assert!(col < self.nveccol, "Col {} out of bounds", col);
172        let elt = row * self.rowstride + col * self.colstride + QuartetIndex::H as usize;
173        assert!(elt * 16 < self.matbuf.len());
174        elt * 16
175    }
176
177    /// Get a pointer to the TMP vector at the given row and column
178    ///
179    /// This matches the C++ tmpvec() method.
180    #[inline]
181    pub fn tmpvec_ptr(&self, row: usize, col: usize) -> usize {
182        assert!(row < self.nvecrow, "Row {} out of bounds", row);
183        assert!(col < self.nveccol, "Col {} out of bounds", col);
184        let elt = row * self.rowstride + col * self.colstride + QuartetIndex::TMP as usize;
185        assert!(elt * 16 < self.matbuf.len());
186        elt * 16
187    }
188
189    /// Get the number of vector rows
190    pub fn nvecrow(&self) -> usize {
191        self.nvecrow
192    }
193
194    /// Get the number of vector columns
195    pub fn nveccol(&self) -> usize {
196        self.nveccol
197    }
198
199    /// Get the row stride in vectors
200    pub fn rowstride(&self) -> usize {
201        self.rowstride
202    }
203
204    /// Get the column stride in vectors
205    pub fn colstride(&self) -> usize {
206        self.colstride
207    }
208
209    /// Get the total size of the matrix in bytes
210    pub fn size_bytes(&self) -> usize {
211        self.matbuf.len()
212    }
213}
214
215impl Default for DpMatrix {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221/// Helper structure for working with matrix quartets
222///
223/// This provides convenient access to the E, F, H, and TMP vectors
224/// at a specific matrix position.
225pub struct MatrixQuartet<'a> {
226    matrix: &'a mut DpMatrix,
227    row: usize,
228    col: usize,
229}
230
231impl<'a> MatrixQuartet<'a> {
232    /// Create a new quartet accessor
233    pub fn new(matrix: &'a mut DpMatrix, row: usize, col: usize) -> Self {
234        assert!(row < matrix.nvecrow());
235        assert!(col < matrix.nveccol());
236        Self { matrix, row, col }
237    }
238
239    /// Get the E vector byte offset
240    #[inline]
241    pub fn e_off(&self) -> usize {
242        self.matrix.evec_ptr(self.row, self.col)
243    }
244
245    /// Get the F vector byte offset
246    #[inline]
247    pub fn f_off(&self) -> usize {
248        self.matrix.fvec_ptr(self.row, self.col)
249    }
250
251    /// Get the H vector byte offset
252    #[inline]
253    pub fn h_off(&self) -> usize {
254        self.matrix.hvec_ptr(self.row, self.col)
255    }
256
257    /// Get the TMP vector byte offset
258    #[inline]
259    pub fn tmp_off(&self) -> usize {
260        self.matrix.tmpvec_ptr(self.row, self.col)
261    }
262
263    /// Get mutable references to the matrix buffer
264    pub fn buffer(&mut self) -> &mut [u8] {
265        self.matrix.as_mut_slice()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_dp_matrix_new() {
275        let matrix = DpMatrix::new();
276        assert!(!matrix.is_initialized());
277        assert_eq!(matrix.nvecrow(), 0);
278        assert_eq!(matrix.nveccol(), 0);
279    }
280
281    #[test]
282    fn test_dp_matrix_init() {
283        let mut matrix = DpMatrix::new();
284        matrix.init(10, 20, 2);
285        assert!(matrix.is_initialized());
286        // nvecrow = ceil(nrow / nvecperrow) = ceil(10/2) = 5
287        assert_eq!(matrix.nvecrow(), 5);
288        assert_eq!(matrix.nveccol(), 20);
289        assert!(matrix.size_bytes() > 0);
290    }
291
292    #[test]
293    fn test_dp_matrix_clear() {
294        let mut matrix = DpMatrix::new();
295        matrix.init(10, 20, 2);
296        // Write some data
297        matrix.as_mut_slice()[0] = 42;
298        matrix.clear();
299        assert_eq!(matrix.as_slice()[0], 0);
300    }
301
302    #[test]
303    fn test_dp_matrix_quartet_offsets() {
304        let mut matrix = DpMatrix::new();
305        matrix.init(10, 20, 2);
306        let e_off = matrix.evec_ptr(0, 0);
307        let f_off = matrix.fvec_ptr(0, 0);
308        let h_off = matrix.hvec_ptr(0, 0);
309        let tmp_off = matrix.tmpvec_ptr(0, 0);
310
311        // Offsets should be in order E < F < H < TMP
312        assert!(e_off < f_off);
313        assert!(f_off < h_off);
314        assert!(h_off < tmp_off);
315
316        // Each vector is 16 bytes
317        assert_eq!(f_off - e_off, 16);
318        assert_eq!(h_off - f_off, 16);
319        assert_eq!(tmp_off - h_off, 16);
320    }
321
322    #[test]
323    fn test_dp_matrix_strides() {
324        let mut matrix = DpMatrix::new();
325        matrix.init(10, 20, 2);
326        // colstride = nvecPerCell = 4
327        assert_eq!(matrix.colstride(), 4);
328        // rowstride = nper * nvecPerCell = 2 * 4 = 8
329        assert_eq!(matrix.rowstride(), 8);
330    }
331
332    #[test]
333    fn test_matrix_quartet() {
334        let mut matrix = DpMatrix::new();
335        matrix.init(10, 20, 2);
336        let mut quartet = MatrixQuartet::new(&mut matrix, 3, 5);
337
338        // Get buffer and write some data
339        let e_off = quartet.e_off();
340        {
341            let buffer = quartet.buffer();
342            buffer[e_off] = 42;
343            buffer[e_off + 1] = 43;
344
345            // Verify data is accessible
346            assert_eq!(buffer[e_off], 42);
347            assert_eq!(buffer[e_off + 1], 43);
348        }
349    }
350}