1#[repr(usize)]
16pub enum QuartetIndex {
17 E = 0,
18 F = 1,
19 H = 2,
20 #[allow(clippy::upper_case_acronyms)]
21 TMP = 3,
22}
23
24pub struct DpMatrix {
41 nvec_per_cell: usize,
43 nvecrow: usize,
45 nveccol: usize,
47 rowstride: usize,
49 colstride: usize,
51 matbuf: Vec<u8>,
53 inited: bool,
55}
56
57impl DpMatrix {
58 pub fn new() -> Self {
60 Self {
61 nvec_per_cell: 4,
62 nvecrow: 0,
63 nveccol: 0,
64 rowstride: 0,
65 colstride: 0,
66 matbuf: Vec::new(),
67 inited: false,
68 }
69 }
70
71 #[allow(clippy::manual_div_ceil)]
78 pub fn init(&mut self, nrow: usize, ncol: usize, nvecperrow: usize) {
79 self.nvecrow = (nrow + nvecperrow - 1) / nvecperrow; self.nveccol = ncol;
84 self.nvec_per_cell = 4;
85 self.colstride = 4; self.rowstride = nvecperrow * 4; let total_vecs = (self.nveccol + 1) * self.nvec_per_cell * self.nvecrow;
90 let total_bytes = total_vecs * 16; self.matbuf = vec![0u8; total_bytes];
93 self.inited = true;
94 }
95
96 pub fn is_initialized(&self) -> bool {
98 self.inited
99 }
100
101 pub fn clear(&mut self) {
103 self.matbuf.fill(0);
104 }
105
106 #[inline]
108 pub fn ptr(&mut self) -> *mut u8 {
109 assert!(self.inited, "Matrix not initialized");
110 self.matbuf.as_mut_ptr()
111 }
112
113 #[inline]
115 pub fn as_slice(&self) -> &[u8] {
116 assert!(self.inited, "Matrix not initialized");
117 &self.matbuf
118 }
119
120 #[inline]
122 pub fn as_mut_slice(&mut self) -> &mut [u8] {
123 assert!(self.inited, "Matrix not initialized");
124 &mut self.matbuf
125 }
126
127 #[inline]
131 pub fn evec_ptr(&self, row: usize, col: usize) -> usize {
132 assert!(
133 row < self.nvecrow,
134 "Row {} out of bounds (max {})",
135 row,
136 self.nvecrow
137 );
138 assert!(
139 col < self.nveccol,
140 "Col {} out of bounds (max {})",
141 col,
142 self.nveccol
143 );
144 let elt = row * self.rowstride + col * self.colstride + QuartetIndex::E as usize;
145 assert!(
146 elt * 16 < self.matbuf.len(),
147 "Element index {} out of bounds",
148 elt
149 );
150 elt * 16 }
152
153 #[inline]
157 pub fn fvec_ptr(&self, row: usize, col: usize) -> usize {
158 assert!(row < self.nvecrow, "Row {} out of bounds", row);
159 assert!(col < self.nveccol, "Col {} out of bounds", col);
160 let elt = row * self.rowstride + col * self.colstride + QuartetIndex::F as usize;
161 assert!(elt * 16 < self.matbuf.len());
162 elt * 16
163 }
164
165 #[inline]
169 pub fn hvec_ptr(&self, row: usize, col: usize) -> usize {
170 assert!(row < self.nvecrow, "Row {} out of bounds", row);
171 assert!(col < self.nveccol, "Col {} out of bounds", col);
172 let elt = row * self.rowstride + col * self.colstride + QuartetIndex::H as usize;
173 assert!(elt * 16 < self.matbuf.len());
174 elt * 16
175 }
176
177 #[inline]
181 pub fn tmpvec_ptr(&self, row: usize, col: usize) -> usize {
182 assert!(row < self.nvecrow, "Row {} out of bounds", row);
183 assert!(col < self.nveccol, "Col {} out of bounds", col);
184 let elt = row * self.rowstride + col * self.colstride + QuartetIndex::TMP as usize;
185 assert!(elt * 16 < self.matbuf.len());
186 elt * 16
187 }
188
189 pub fn nvecrow(&self) -> usize {
191 self.nvecrow
192 }
193
194 pub fn nveccol(&self) -> usize {
196 self.nveccol
197 }
198
199 pub fn rowstride(&self) -> usize {
201 self.rowstride
202 }
203
204 pub fn colstride(&self) -> usize {
206 self.colstride
207 }
208
209 pub fn size_bytes(&self) -> usize {
211 self.matbuf.len()
212 }
213}
214
215impl Default for DpMatrix {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221pub struct MatrixQuartet<'a> {
226 matrix: &'a mut DpMatrix,
227 row: usize,
228 col: usize,
229}
230
231impl<'a> MatrixQuartet<'a> {
232 pub fn new(matrix: &'a mut DpMatrix, row: usize, col: usize) -> Self {
234 assert!(row < matrix.nvecrow());
235 assert!(col < matrix.nveccol());
236 Self { matrix, row, col }
237 }
238
239 #[inline]
241 pub fn e_off(&self) -> usize {
242 self.matrix.evec_ptr(self.row, self.col)
243 }
244
245 #[inline]
247 pub fn f_off(&self) -> usize {
248 self.matrix.fvec_ptr(self.row, self.col)
249 }
250
251 #[inline]
253 pub fn h_off(&self) -> usize {
254 self.matrix.hvec_ptr(self.row, self.col)
255 }
256
257 #[inline]
259 pub fn tmp_off(&self) -> usize {
260 self.matrix.tmpvec_ptr(self.row, self.col)
261 }
262
263 pub fn buffer(&mut self) -> &mut [u8] {
265 self.matrix.as_mut_slice()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_dp_matrix_new() {
275 let matrix = DpMatrix::new();
276 assert!(!matrix.is_initialized());
277 assert_eq!(matrix.nvecrow(), 0);
278 assert_eq!(matrix.nveccol(), 0);
279 }
280
281 #[test]
282 fn test_dp_matrix_init() {
283 let mut matrix = DpMatrix::new();
284 matrix.init(10, 20, 2);
285 assert!(matrix.is_initialized());
286 assert_eq!(matrix.nvecrow(), 5);
288 assert_eq!(matrix.nveccol(), 20);
289 assert!(matrix.size_bytes() > 0);
290 }
291
292 #[test]
293 fn test_dp_matrix_clear() {
294 let mut matrix = DpMatrix::new();
295 matrix.init(10, 20, 2);
296 matrix.as_mut_slice()[0] = 42;
298 matrix.clear();
299 assert_eq!(matrix.as_slice()[0], 0);
300 }
301
302 #[test]
303 fn test_dp_matrix_quartet_offsets() {
304 let mut matrix = DpMatrix::new();
305 matrix.init(10, 20, 2);
306 let e_off = matrix.evec_ptr(0, 0);
307 let f_off = matrix.fvec_ptr(0, 0);
308 let h_off = matrix.hvec_ptr(0, 0);
309 let tmp_off = matrix.tmpvec_ptr(0, 0);
310
311 assert!(e_off < f_off);
313 assert!(f_off < h_off);
314 assert!(h_off < tmp_off);
315
316 assert_eq!(f_off - e_off, 16);
318 assert_eq!(h_off - f_off, 16);
319 assert_eq!(tmp_off - h_off, 16);
320 }
321
322 #[test]
323 fn test_dp_matrix_strides() {
324 let mut matrix = DpMatrix::new();
325 matrix.init(10, 20, 2);
326 assert_eq!(matrix.colstride(), 4);
328 assert_eq!(matrix.rowstride(), 8);
330 }
331
332 #[test]
333 fn test_matrix_quartet() {
334 let mut matrix = DpMatrix::new();
335 matrix.init(10, 20, 2);
336 let mut quartet = MatrixQuartet::new(&mut matrix, 3, 5);
337
338 let e_off = quartet.e_off();
340 {
341 let buffer = quartet.buffer();
342 buffer[e_off] = 42;
343 buffer[e_off + 1] = 43;
344
345 assert_eq!(buffer[e_off], 42);
347 assert_eq!(buffer[e_off + 1], 43);
348 }
349 }
350}