scirs2_sparse/sym_csr.rs
1//! Symmetric Compressed Sparse Row (SymCSR) module
2//!
3//! This module provides a specialized implementation of the CSR format
4//! optimized for symmetric matrices, storing only the lower or upper
5//! triangular part of the matrix.
6
7use crate::csr::CsrMatrix;
8use crate::csr_array::CsrArray;
9use crate::error::{SparseError, SparseResult};
10use scirs2_core::numeric::{Float, SparseElement};
11use std::fmt::Debug;
12use std::ops::{Add, Div, Mul, Sub};
13
14/// Symmetric Compressed Sparse Row (SymCSR) matrix
15///
16/// This format stores only the lower triangular part of a symmetric matrix
17/// to save memory and improve performance. Operations are optimized to
18/// take advantage of symmetry when possible.
19///
20/// # Note
21///
22/// All operations maintain symmetry implicitly.
23#[derive(Debug, Clone)]
24pub struct SymCsrMatrix<T>
25where
26 T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
27{
28 /// CSR format data for the lower triangular part (including diagonal)
29 pub data: Vec<T>,
30
31 /// Row pointers (indptr): indices where each row starts in indices array
32 pub indptr: Vec<usize>,
33
34 /// Column indices for each non-zero element
35 pub indices: Vec<usize>,
36
37 /// Matrix shape (rows, cols), always square
38 pub shape: (usize, usize),
39}
40
41impl<T> SymCsrMatrix<T>
42where
43 T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
44{
45 /// Create a new symmetric CSR matrix from raw data
46 ///
47 /// # Arguments
48 ///
49 /// * `data` - Non-zero values in the lower triangular part
50 /// * `indptr` - Row pointers
51 /// * `indices` - Column indices
52 /// * `shape` - Matrix shape (n, n)
53 ///
54 /// # Returns
55 ///
56 /// A symmetric CSR matrix
57 ///
58 /// # Errors
59 ///
60 /// Returns an error if:
61 /// - The shape is not square
62 /// - The indices array is incompatible with indptr
63 /// - Any column index is out of bounds
64 pub fn new(
65 data: Vec<T>,
66 indptr: Vec<usize>,
67 indices: Vec<usize>,
68 shape: (usize, usize),
69 ) -> SparseResult<Self> {
70 let (rows, cols) = shape;
71
72 // Ensure matrix is square
73 if rows != cols {
74 return Err(SparseError::ValueError(
75 "Symmetric matrix must be square".to_string(),
76 ));
77 }
78
79 // Check indptr length
80 if indptr.len() != rows + 1 {
81 return Err(SparseError::ValueError(format!(
82 "indptr length ({}) must be equal to rows + 1 ({})",
83 indptr.len(),
84 rows + 1
85 )));
86 }
87
88 // Check data and indices lengths
89 let nnz = indices.len();
90 if data.len() != nnz {
91 return Err(SparseError::ValueError(format!(
92 "data length ({}) must match indices length ({})",
93 data.len(),
94 nnz
95 )));
96 }
97
98 // Check last indptr value
99 if let Some(&last) = indptr.last() {
100 if last != nnz {
101 return Err(SparseError::ValueError(format!(
102 "Last indptr value ({last}) must equal nnz ({nnz})"
103 )));
104 }
105 }
106
107 // Check that row and column indices are within bounds
108 for (i, &row_start) in indptr.iter().enumerate().take(rows) {
109 let row_end = indptr[i + 1];
110
111 for &col in &indices[row_start..row_end] {
112 if col >= cols {
113 return Err(SparseError::IndexOutOfBounds {
114 index: (i, col),
115 shape: (rows, cols),
116 });
117 }
118
119 // For symmetric matrix, ensure we only store the lower triangular part
120 if col > i {
121 return Err(SparseError::ValueError(
122 "Symmetric CSR should only store the lower triangular part".to_string(),
123 ));
124 }
125 }
126 }
127
128 Ok(Self {
129 data,
130 indptr,
131 indices,
132 shape,
133 })
134 }
135
136 /// Convert a regular CSR matrix to symmetric CSR format
137 ///
138 /// This will verify that the matrix is symmetric and extract
139 /// the lower triangular part.
140 ///
141 /// # Arguments
142 ///
143 /// * `matrix` - CSR matrix to convert
144 ///
145 /// # Returns
146 ///
147 /// A symmetric CSR matrix
148 pub fn from_csr(matrix: &CsrMatrix<T>) -> SparseResult<Self> {
149 let (rows, cols) = matrix.shape();
150
151 // Ensure matrix is square
152 if rows != cols {
153 return Err(SparseError::ValueError(
154 "Symmetric matrix must be square".to_string(),
155 ));
156 }
157
158 // Check if the matrix is symmetric
159 if !Self::is_symmetric(matrix) {
160 return Err(SparseError::ValueError(
161 "Matrix must be symmetric to convert to SymCSR format".to_string(),
162 ));
163 }
164
165 // Extract the lower triangular part
166 let mut data = Vec::new();
167 let mut indices = Vec::new();
168 let mut indptr = vec![0];
169
170 for i in 0..rows {
171 for j in matrix.indptr[i]..matrix.indptr[i + 1] {
172 let col = matrix.indices[j];
173
174 // Only include elements in lower triangular part (including diagonal)
175 if col <= i {
176 data.push(matrix.data[j]);
177 indices.push(col);
178 }
179 }
180
181 indptr.push(data.len());
182 }
183
184 Ok(Self {
185 data,
186 indptr,
187 indices,
188 shape: (rows, cols),
189 })
190 }
191
192 /// Check if a CSR matrix is symmetric
193 ///
194 /// # Arguments
195 ///
196 /// * `matrix` - CSR matrix to check
197 ///
198 /// # Returns
199 ///
200 /// `true` if the matrix is symmetric, `false` otherwise
201 pub fn is_symmetric(matrix: &CsrMatrix<T>) -> bool {
202 let (rows, cols) = matrix.shape();
203
204 // Must be square
205 if rows != cols {
206 return false;
207 }
208
209 // Compare each element (i,j) with (j,i)
210 for i in 0..rows {
211 for j_ptr in matrix.indptr[i]..matrix.indptr[i + 1] {
212 let j = matrix.indices[j_ptr];
213 let val = matrix.data[j_ptr];
214
215 // Find the corresponding (j,i) element
216 let i_val = matrix.get(j, i);
217
218 // Check if a[i,j] == a[j,i] with sufficient tolerance
219 let diff = (val - i_val).abs();
220 let epsilon = T::epsilon() * T::from(100.0).unwrap();
221 if diff > epsilon {
222 return false;
223 }
224 }
225 }
226
227 true
228 }
229
230 /// Get the shape of the matrix
231 ///
232 /// # Returns
233 ///
234 /// A tuple (rows, cols)
235 pub fn shape(&self) -> (usize, usize) {
236 self.shape
237 }
238
239 /// Get the number of stored non-zero elements
240 ///
241 /// # Returns
242 ///
243 /// The number of non-zero elements in the lower triangular part
244 pub fn nnz_stored(&self) -> usize {
245 self.data.len()
246 }
247
248 /// Get the total number of non-zero elements in the full matrix
249 ///
250 /// # Returns
251 ///
252 /// The total number of non-zero elements in the full symmetric matrix
253 pub fn nnz(&self) -> usize {
254 let diag_count = (0..self.shape.0)
255 .filter(|&i| {
256 // Count diagonal elements that are non-zero
257 let row_start = self.indptr[i];
258 let row_end = self.indptr[i + 1];
259 (row_start..row_end).any(|j_ptr| self.indices[j_ptr] == i)
260 })
261 .count();
262
263 let offdiag_count = self.data.len() - diag_count;
264
265 // Diagonal elements count once, off-diagonal elements count twice
266 diag_count + 2 * offdiag_count
267 }
268
269 /// Get a single element from the matrix
270 ///
271 /// # Arguments
272 ///
273 /// * `row` - Row index
274 /// * `col` - Column index
275 ///
276 /// # Returns
277 ///
278 /// The value at position (row, col)
279 pub fn get(&self, row: usize, col: usize) -> T {
280 // Check bounds
281 if row >= self.shape.0 || col >= self.shape.1 {
282 return T::sparse_zero();
283 }
284
285 // For symmetric matrix, if (row,col) is in upper triangular part,
286 // we look for (col,row) in the lower triangular part
287 let (actual_row, actual_col) = if row < col { (col, row) } else { (row, col) };
288
289 // Search for the element
290 for j in self.indptr[actual_row]..self.indptr[actual_row + 1] {
291 if self.indices[j] == actual_col {
292 return self.data[j];
293 }
294 }
295
296 T::sparse_zero()
297 }
298
299 /// Convert to standard CSR matrix (reconstructing full symmetric matrix)
300 ///
301 /// # Returns
302 ///
303 /// A standard CSR matrix with both upper and lower triangular parts
304 pub fn to_csr(&self) -> SparseResult<CsrMatrix<T>> {
305 let n = self.shape.0;
306
307 // First, convert to triplet format for the full symmetric matrix
308 let mut data = Vec::new();
309 let mut row_indices = Vec::new();
310 let mut col_indices = Vec::new();
311
312 for i in 0..n {
313 // Add elements from lower triangular part (directly stored)
314 for j_ptr in self.indptr[i]..self.indptr[i + 1] {
315 let j = self.indices[j_ptr];
316 let val = self.data[j_ptr];
317
318 // Add the element itself
319 row_indices.push(i);
320 col_indices.push(j);
321 data.push(val);
322
323 // Add its symmetric counterpart (if not on diagonal)
324 if i != j {
325 row_indices.push(j);
326 col_indices.push(i);
327 data.push(val);
328 }
329 }
330 }
331
332 // Create the CSR matrix from triplets
333 CsrMatrix::new(data, row_indices, col_indices, self.shape)
334 }
335
336 /// Convert to dense matrix
337 ///
338 /// # Returns
339 ///
340 /// A dense matrix representation as a vector of vectors
341 pub fn to_dense(&self) -> Vec<Vec<T>> {
342 let n = self.shape.0;
343 let mut dense = vec![vec![T::sparse_zero(); n]; n];
344
345 // Fill the lower triangular part (directly from stored data)
346 for (i, row) in dense.iter_mut().enumerate().take(n) {
347 for j_ptr in self.indptr[i]..self.indptr[i + 1] {
348 let j = self.indices[j_ptr];
349 row[j] = self.data[j_ptr];
350 }
351 }
352
353 // Fill the upper triangular part (from symmetry)
354 for i in 0..n {
355 for j in 0..i {
356 dense[j][i] = dense[i][j];
357 }
358 }
359
360 dense
361 }
362}
363
364/// Array-based SymCSR implementation compatible with SparseArray trait
365#[derive(Clone)]
366pub struct SymCsrArray<T>
367where
368 T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
369{
370 /// Inner matrix
371 inner: SymCsrMatrix<T>,
372}
373
374impl<T> SymCsrArray<T>
375where
376 T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone + Div<Output = T> + 'static,
377{
378 /// Create a new SymCSR array from a SymCSR matrix
379 ///
380 /// # Arguments
381 ///
382 /// * `matrix` - Symmetric CSR matrix
383 ///
384 /// # Returns
385 ///
386 /// SymCSR array
387 pub fn new(matrix: SymCsrMatrix<T>) -> Self {
388 Self { inner: matrix }
389 }
390
391 /// Create a SymCSR array from a regular CSR array
392 ///
393 /// # Arguments
394 ///
395 /// * `array` - CSR array to convert
396 ///
397 /// # Returns
398 ///
399 /// A symmetric CSR array
400 pub fn from_csr_array(array: &CsrArray<T>) -> SparseResult<Self> {
401 let shape = array.shape();
402 let (rows, cols) = shape;
403
404 // Ensure matrix is square
405 if rows != cols {
406 return Err(SparseError::ValueError(
407 "Symmetric matrix must be square".to_string(),
408 ));
409 }
410
411 // Create a temporary CSR matrix to check symmetry
412 let csrmatrix = CsrMatrix::new(
413 array.get_data().to_vec(),
414 array.get_indptr().to_vec(),
415 array.get_indices().to_vec(),
416 shape,
417 )?;
418
419 // Convert to symmetric CSR
420 let sym_csr = SymCsrMatrix::from_csr(&csrmatrix)?;
421
422 Ok(Self { inner: sym_csr })
423 }
424
425 /// Get the underlying matrix
426 ///
427 /// # Returns
428 ///
429 /// Reference to the inner SymCSR matrix
430 pub fn inner(&self) -> &SymCsrMatrix<T> {
431 &self.inner
432 }
433
434 /// Get access to the underlying data array
435 ///
436 /// # Returns
437 ///
438 /// Reference to the data array
439 pub fn data(&self) -> &[T] {
440 &self.inner.data
441 }
442
443 /// Get access to the underlying indices array
444 ///
445 /// # Returns
446 ///
447 /// Reference to the indices array
448 pub fn indices(&self) -> &[usize] {
449 &self.inner.indices
450 }
451
452 /// Get access to the underlying indptr array
453 ///
454 /// # Returns
455 ///
456 /// Reference to the indptr array
457 pub fn indptr(&self) -> &[usize] {
458 &self.inner.indptr
459 }
460
461 /// Convert to a standard CSR array
462 ///
463 /// # Returns
464 ///
465 /// CSR array containing the full symmetric matrix
466 pub fn to_csr_array(&self) -> SparseResult<CsrArray<T>> {
467 let csr = self.inner.to_csr()?;
468
469 // Convert the CsrMatrix to CsrArray using from_triplets
470 let (rows, cols, data) = csr.get_triplets();
471 let shape = csr.shape();
472
473 // Safety check - rows, cols, and data should all be the same length
474 if rows.len() != cols.len() || rows.len() != data.len() {
475 return Err(SparseError::DimensionMismatch {
476 expected: rows.len(),
477 found: cols.len().min(data.len()),
478 });
479 }
480
481 CsrArray::from_triplets(&rows, &cols, &data, shape, false)
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::sparray::SparseArray;
489
490 #[test]
491 fn test_sym_csr_creation() {
492 // Create a simple symmetric matrix stored in lower triangular format
493 // [2 1 0]
494 // [1 2 3]
495 // [0 3 1]
496
497 // Note: Actually represents the lower triangular part only:
498 // [2 0 0]
499 // [1 2 0]
500 // [0 3 1]
501
502 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
503 let indices = vec![0, 0, 1, 1, 2];
504 let indptr = vec![0, 1, 3, 5];
505
506 let sym = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
507
508 assert_eq!(sym.shape(), (3, 3));
509 assert_eq!(sym.nnz_stored(), 5);
510
511 // Total non-zeros should count off-diagonal elements twice
512 assert_eq!(sym.nnz(), 7);
513
514 // Test accessing elements
515 assert_eq!(sym.get(0, 0), 2.0);
516 assert_eq!(sym.get(0, 1), 1.0);
517 assert_eq!(sym.get(1, 0), 1.0); // From symmetry
518 assert_eq!(sym.get(1, 1), 2.0);
519 assert_eq!(sym.get(1, 2), 3.0);
520 assert_eq!(sym.get(2, 1), 3.0); // From symmetry
521 assert_eq!(sym.get(2, 2), 1.0);
522 assert_eq!(sym.get(0, 2), 0.0); // Zero element - not stored
523 assert_eq!(sym.get(2, 0), 0.0); // Zero element - not stored
524 }
525
526 #[test]
527 fn test_sym_csr_from_standard() {
528 // Create a standard CSR matrix that's symmetric
529 // [2 1 0]
530 // [1 2 3]
531 // [0 3 1]
532
533 // Create it from triplets to ensure it's properly constructed
534 let row_indices = vec![0, 0, 1, 1, 1, 2, 2];
535 let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
536 let data = vec![2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 1.0];
537
538 let csr = CsrMatrix::new(data, row_indices, col_indices, (3, 3)).unwrap();
539 let sym = SymCsrMatrix::from_csr(&csr).unwrap();
540
541 assert_eq!(sym.shape(), (3, 3));
542
543 // Convert back to standard CSR to check
544 let csr2 = sym.to_csr().unwrap();
545 let dense = csr2.to_dense();
546
547 // Check the full matrix
548 assert_eq!(dense[0][0], 2.0);
549 assert_eq!(dense[0][1], 1.0);
550 assert_eq!(dense[0][2], 0.0);
551 assert_eq!(dense[1][0], 1.0);
552 assert_eq!(dense[1][1], 2.0);
553 assert_eq!(dense[1][2], 3.0);
554 assert_eq!(dense[2][0], 0.0);
555 assert_eq!(dense[2][1], 3.0);
556 assert_eq!(dense[2][2], 1.0);
557 }
558
559 #[test]
560 fn test_sym_csr_array() {
561 // Create a symmetric SymCSR matrix, storing only the lower triangular part
562 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
563 let indices = vec![0, 0, 1, 1, 2];
564 let indptr = vec![0, 1, 3, 5];
565
566 let symmatrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
567 let sym_array = SymCsrArray::new(symmatrix);
568
569 assert_eq!(sym_array.inner().shape(), (3, 3));
570
571 // Convert to standard CSR array
572 let csr_array = sym_array.to_csr_array().unwrap();
573
574 // Verify shape and values
575 assert_eq!(csr_array.shape(), (3, 3));
576 assert_eq!(csr_array.get(0, 0), 2.0);
577 assert_eq!(csr_array.get(0, 1), 1.0);
578 assert_eq!(csr_array.get(1, 0), 1.0);
579 assert_eq!(csr_array.get(1, 1), 2.0);
580 assert_eq!(csr_array.get(1, 2), 3.0);
581 assert_eq!(csr_array.get(2, 1), 3.0);
582 assert_eq!(csr_array.get(2, 2), 1.0);
583 assert_eq!(csr_array.get(0, 2), 0.0);
584 assert_eq!(csr_array.get(2, 0), 0.0);
585 }
586}