1use crate::csr::CsrMatrix;
17use crate::error::{SparseError, SparseResult};
18use scirs2_core::numeric::{SparseElement, Zero};
19use std::fmt::Debug;
20
21#[derive(Debug, Clone)]
28pub struct TileDescriptor {
29 pub first_row: usize,
31 pub has_segment_boundary: bool,
33 pub num_complete_rows: usize,
35 pub row_ids: Vec<usize>,
37 pub is_segment_start: Vec<bool>,
39}
40
41#[derive(Debug, Clone)]
46pub struct Csr5Matrix<T> {
47 pub nrows: usize,
49 pub ncols: usize,
51 pub tile_width: usize,
53 pub num_tiles: usize,
55 pub col_indices: Vec<usize>,
57 pub values: Vec<T>,
59 pub row_ptr: Vec<usize>,
61 pub tile_desc: Vec<TileDescriptor>,
63 pub tile_ptr: Vec<usize>,
65}
66
67impl<T> Csr5Matrix<T>
68where
69 T: Clone + Copy + Zero + SparseElement + Debug,
70{
71 pub fn from_csr(csr: &CsrMatrix<T>, tile_width: usize) -> SparseResult<Self> {
78 if tile_width == 0 {
79 return Err(SparseError::ValueError(
80 "tile_width must be at least 1".to_string(),
81 ));
82 }
83
84 let (nrows, ncols) = csr.shape();
85 let nnz = csr.nnz();
86
87 let col_indices = csr.indices.clone();
89 let values = csr.data.clone();
90 let row_ptr = csr.indptr.clone();
91
92 let num_tiles = if nnz == 0 {
94 0
95 } else {
96 nnz.div_ceil(tile_width)
97 };
98
99 let mut tile_ptr = Vec::with_capacity(num_tiles + 1);
101 for t in 0..=num_tiles {
102 tile_ptr.push((t * tile_width).min(nnz));
103 }
104
105 let tile_desc = Self::calibrate(&row_ptr, nrows, nnz, tile_width, num_tiles);
107
108 Ok(Self {
109 nrows,
110 ncols,
111 tile_width,
112 num_tiles,
113 col_indices,
114 values,
115 row_ptr,
116 tile_desc,
117 tile_ptr,
118 })
119 }
120
121 fn calibrate(
126 row_ptr: &[usize],
127 nrows: usize,
128 nnz: usize,
129 tile_width: usize,
130 num_tiles: usize,
131 ) -> Vec<TileDescriptor> {
132 let mut descriptors = Vec::with_capacity(num_tiles);
133
134 for t in 0..num_tiles {
135 let tile_start = t * tile_width;
136 let tile_end = nnz.min(tile_start + tile_width);
137 let tile_len = tile_end - tile_start;
138
139 let first_row = Self::find_row(row_ptr, nrows, tile_start);
141
142 let mut row_ids = Vec::with_capacity(tile_len);
144 let mut is_segment_start = Vec::with_capacity(tile_len);
145 let mut current_row = first_row;
146 let mut num_complete_rows = 0usize;
147 let mut has_boundary = false;
148
149 for pos in tile_start..tile_end {
150 while current_row < nrows && row_ptr[current_row + 1] <= pos {
152 current_row += 1;
153 }
154
155 let is_start = if pos == tile_start {
156 pos == row_ptr[current_row]
159 } else {
160 pos == row_ptr[current_row]
161 };
162
163 if is_start && pos != tile_start {
164 has_boundary = true;
165 num_complete_rows += 1;
166 }
167
168 row_ids.push(current_row);
169 is_segment_start.push(is_start);
170 }
171
172 descriptors.push(TileDescriptor {
173 first_row,
174 has_segment_boundary: has_boundary,
175 num_complete_rows,
176 row_ids,
177 is_segment_start,
178 });
179 }
180
181 descriptors
182 }
183
184 fn find_row(row_ptr: &[usize], nrows: usize, pos: usize) -> usize {
186 let mut lo = 0usize;
188 let mut hi = nrows;
189 while lo < hi {
190 let mid = lo + (hi - lo) / 2;
191 if row_ptr[mid + 1] <= pos {
192 lo = mid + 1;
193 } else {
194 hi = mid;
195 }
196 }
197 lo
198 }
199
200 pub fn spmv(&self, x: &[T]) -> SparseResult<Vec<T>>
206 where
207 T: std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
208 {
209 if x.len() != self.ncols {
210 return Err(SparseError::DimensionMismatch {
211 expected: self.ncols,
212 found: x.len(),
213 });
214 }
215
216 let mut y = vec![T::sparse_zero(); self.nrows];
217
218 if self.num_tiles == 0 {
219 return Ok(y);
220 }
221
222 let mut carries: Vec<Option<(usize, T)>> = vec![None; self.num_tiles];
229
230 for t in 0..self.num_tiles {
231 let desc = &self.tile_desc[t];
232 let tile_start = self.tile_ptr[t];
233 let tile_end = self.tile_ptr[t + 1];
234 let tile_len = tile_end - tile_start;
235
236 if tile_len == 0 {
237 continue;
238 }
239
240 let mut acc = T::sparse_zero();
242 let mut current_row = desc.first_row;
243
244 for i in 0..tile_len {
245 let pos = tile_start + i;
246 let row = desc.row_ids[i];
247
248 if row != current_row {
249 if i == 0 {
251 } else {
254 y[current_row] = y[current_row] + acc;
256 }
257 acc = T::sparse_zero();
258 current_row = row;
259 }
260
261 acc = acc + self.values[pos] * x[self.col_indices[pos]];
262 }
263
264 carries[t] = Some((current_row, acc));
266 }
267
268 for t in 0..self.num_tiles {
274 if let Some((row, val)) = carries[t] {
275 let continues = if t + 1 < self.num_tiles {
277 let next_desc = &self.tile_desc[t + 1];
278 next_desc.first_row == row
279 } else {
280 false
281 };
282
283 if continues {
284 if let Some((_, ref mut next_val)) = carries[t + 1] {
286 y[row] = y[row] + val;
291 } else {
292 y[row] = y[row] + val;
293 }
294 } else {
295 y[row] = y[row] + val;
296 }
297 }
298 }
299
300 Ok(y)
301 }
302
303 pub fn to_csr(&self) -> SparseResult<CsrMatrix<T>>
305 where
306 T: std::cmp::PartialEq,
307 {
308 let mut row_indices: Vec<usize> = Vec::with_capacity(self.values.len());
310 let mut col_indices: Vec<usize> = Vec::with_capacity(self.values.len());
311 let mut data: Vec<T> = Vec::with_capacity(self.values.len());
312
313 for row in 0..self.nrows {
314 let start = self.row_ptr[row];
315 let end = self.row_ptr[row + 1];
316 for pos in start..end {
317 row_indices.push(row);
318 col_indices.push(self.col_indices[pos]);
319 data.push(self.values[pos]);
320 }
321 }
322
323 CsrMatrix::new(data, row_indices, col_indices, (self.nrows, self.ncols))
324 }
325
326 pub fn nnz(&self) -> usize {
328 self.values.len()
329 }
330
331 pub fn get_tile_width(&self) -> usize {
333 self.tile_width
334 }
335
336 pub fn get_num_tiles(&self) -> usize {
338 self.num_tiles
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use approx::assert_relative_eq;
346
347 fn make_tridiag_csr(n: usize) -> CsrMatrix<f64> {
348 let mut rows = Vec::new();
349 let mut cols = Vec::new();
350 let mut vals = Vec::new();
351 for i in 0..n {
352 rows.push(i);
353 cols.push(i);
354 vals.push(2.0);
355 if i > 0 {
356 rows.push(i);
357 cols.push(i - 1);
358 vals.push(-1.0);
359 }
360 if i + 1 < n {
361 rows.push(i);
362 cols.push(i + 1);
363 vals.push(-1.0);
364 }
365 }
366 CsrMatrix::new(vals, rows, cols, (n, n)).expect("csr")
367 }
368
369 fn csr_spmv(csr: &CsrMatrix<f64>, x: &[f64]) -> Vec<f64> {
370 let (nrows, _) = csr.shape();
371 let mut y = vec![0.0f64; nrows];
372 for row in 0..nrows {
373 for j in csr.indptr[row]..csr.indptr[row + 1] {
374 y[row] += csr.data[j] * x[csr.indices[j]];
375 }
376 }
377 y
378 }
379
380 #[test]
381 fn test_csr5_spmv_matches_csr() {
382 let csr = make_tridiag_csr(8);
383 let x: Vec<f64> = (0..8).map(|i| (i + 1) as f64).collect();
384 let y_ref = csr_spmv(&csr, &x);
385
386 for &tw in &[4usize, 8, 16, 32] {
387 let csr5 = Csr5Matrix::from_csr(&csr, tw).expect("csr5");
388 let y_csr5 = csr5.spmv(&x).expect("spmv");
389 for i in 0..8 {
390 assert_relative_eq!(y_csr5[i], y_ref[i], epsilon = 1e-12);
391 }
392 }
393 }
394
395 #[test]
396 fn test_csr5_preserves_nnz() {
397 let csr = make_tridiag_csr(10);
398 let csr5 = Csr5Matrix::from_csr(&csr, 4).expect("csr5");
399 assert_eq!(csr5.nnz(), csr.nnz());
400 }
401
402 #[test]
403 fn test_csr5_roundtrip() {
404 let csr = make_tridiag_csr(6);
405 let csr5 = Csr5Matrix::from_csr(&csr, 4).expect("csr5");
406 let csr2 = csr5.to_csr().expect("to_csr");
407 assert_eq!(csr2.nnz(), csr.nnz());
408 let x: Vec<f64> = (0..6).map(|i| (i + 1) as f64).collect();
409 let y1 = csr_spmv(&csr, &x);
410 let y2 = csr_spmv(&csr2, &x);
411 for i in 0..6 {
412 assert_relative_eq!(y1[i], y2[i], epsilon = 1e-12);
413 }
414 }
415
416 #[test]
417 fn test_csr5_irregular_matrix() {
418 let rows = vec![0, 0, 0, 0, 1, 2, 2, 3, 3, 3];
420 let cols = vec![0, 1, 2, 3, 0, 0, 3, 1, 2, 3];
421 let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
422 let csr = CsrMatrix::new(vals, rows, cols, (4, 4)).expect("csr");
423
424 let x = vec![1.0, 2.0, 3.0, 4.0];
425 let y_ref = csr_spmv(&csr, &x);
426
427 let csr5 = Csr5Matrix::from_csr(&csr, 3).expect("csr5");
428 let y_csr5 = csr5.spmv(&x).expect("spmv");
429
430 for i in 0..4 {
431 assert_relative_eq!(y_csr5[i], y_ref[i], epsilon = 1e-12);
432 }
433 }
434
435 #[test]
436 fn test_csr5_empty_matrix() {
437 let csr = CsrMatrix::<f64>::new(vec![], vec![], vec![], (3, 3)).expect("csr");
438 let csr5 = Csr5Matrix::from_csr(&csr, 4).expect("csr5");
439 assert_eq!(csr5.nnz(), 0);
440 assert_eq!(csr5.num_tiles, 0);
441 let y = csr5.spmv(&[0.0, 0.0, 0.0]).expect("spmv");
442 assert_eq!(y, vec![0.0, 0.0, 0.0]);
443 }
444
445 #[test]
446 fn test_csr5_tile_width_error() {
447 let csr = make_tridiag_csr(4);
448 assert!(Csr5Matrix::<f64>::from_csr(&csr, 0).is_err());
449 }
450
451 #[test]
452 fn test_csr5_single_row() {
453 let csr =
454 CsrMatrix::new(vec![1.0, 2.0, 3.0], vec![0, 0, 0], vec![0, 1, 2], (1, 3)).expect("csr");
455 let x = vec![1.0, 2.0, 3.0];
456 let y_ref = csr_spmv(&csr, &x);
457 let csr5 = Csr5Matrix::from_csr(&csr, 2).expect("csr5");
458 let y = csr5.spmv(&x).expect("spmv");
459 assert_relative_eq!(y[0], y_ref[0], epsilon = 1e-12);
460 }
461
462 #[test]
463 fn test_csr5_large_tile() {
464 let csr = make_tridiag_csr(4);
466 let x: Vec<f64> = (0..4).map(|i| (i + 1) as f64).collect();
467 let y_ref = csr_spmv(&csr, &x);
468 let csr5 = Csr5Matrix::from_csr(&csr, 100).expect("csr5");
469 let y = csr5.spmv(&x).expect("spmv");
470 for i in 0..4 {
471 assert_relative_eq!(y[i], y_ref[i], epsilon = 1e-12);
472 }
473 }
474}