scirs2_sparse/
sym_csr.rs

1//! Symmetric Compressed Sparse Row (SymCSR) module
2//!
3//! This module provides a specialized implementation of the CSR format
4//! optimized for symmetric matrices, storing only the lower or upper
5//! triangular part of the matrix.
6
7use crate::csr::CsrMatrix;
8use crate::csr_array::CsrArray;
9use crate::error::{SparseError, SparseResult};
10use crate::sparray::SparseArray;
11use num_traits::Float;
12use std::fmt::Debug;
13use std::ops::{Add, Div, Mul, Sub};
14
15/// Symmetric Compressed Sparse Row (SymCSR) matrix
16///
17/// This format stores only the lower triangular part of a symmetric matrix
18/// to save memory and improve performance. Operations are optimized to
19/// take advantage of symmetry when possible.
20///
21/// # Note
22///
23/// All operations maintain symmetry implicitly.
24#[derive(Debug, Clone)]
25pub struct SymCsrMatrix<T>
26where
27    T: Float + Debug + Copy,
28{
29    /// CSR format data for the lower triangular part (including diagonal)
30    pub data: Vec<T>,
31
32    /// Row pointers (indptr): indices where each row starts in indices array
33    pub indptr: Vec<usize>,
34
35    /// Column indices for each non-zero element
36    pub indices: Vec<usize>,
37
38    /// Matrix shape (rows, cols), always square
39    pub shape: (usize, usize),
40}
41
42impl<T> SymCsrMatrix<T>
43where
44    T: Float + Debug + Copy,
45{
46    /// Create a new symmetric CSR matrix from raw data
47    ///
48    /// # Arguments
49    ///
50    /// * `data` - Non-zero values in the lower triangular part
51    /// * `indptr` - Row pointers
52    /// * `indices` - Column indices
53    /// * `shape` - Matrix shape (n, n)
54    ///
55    /// # Returns
56    ///
57    /// A symmetric CSR matrix
58    ///
59    /// # Errors
60    ///
61    /// Returns an error if:
62    /// - The shape is not square
63    /// - The indices array is incompatible with indptr
64    /// - Any column index is out of bounds
65    pub fn new(
66        data: Vec<T>,
67        indptr: Vec<usize>,
68        indices: Vec<usize>,
69        shape: (usize, usize),
70    ) -> SparseResult<Self> {
71        let (rows, cols) = shape;
72
73        // Ensure matrix is square
74        if rows != cols {
75            return Err(SparseError::ValueError(
76                "Symmetric matrix must be square".to_string(),
77            ));
78        }
79
80        // Check indptr length
81        if indptr.len() != rows + 1 {
82            return Err(SparseError::ValueError(format!(
83                "indptr length ({}) must be equal to rows + 1 ({})",
84                indptr.len(),
85                rows + 1
86            )));
87        }
88
89        // Check data and indices lengths
90        let nnz = indices.len();
91        if data.len() != nnz {
92            return Err(SparseError::ValueError(format!(
93                "data length ({}) must match indices length ({})",
94                data.len(),
95                nnz
96            )));
97        }
98
99        // Check last indptr value
100        if let Some(&last) = indptr.last() {
101            if last != nnz {
102                return Err(SparseError::ValueError(format!(
103                    "Last indptr value ({}) must equal nnz ({})",
104                    last, nnz
105                )));
106            }
107        }
108
109        // Check that row and column indices are within bounds
110        for (i, &row_start) in indptr.iter().enumerate().take(rows) {
111            let row_end = indptr[i + 1];
112
113            for &col in &indices[row_start..row_end] {
114                if col >= cols {
115                    return Err(SparseError::IndexOutOfBounds {
116                        index: (i, col),
117                        shape: (rows, cols),
118                    });
119                }
120
121                // For symmetric matrix, ensure we only store the lower triangular part
122                if col > i {
123                    return Err(SparseError::ValueError(
124                        "Symmetric CSR should only store the lower triangular part".to_string(),
125                    ));
126                }
127            }
128        }
129
130        Ok(Self {
131            data,
132            indptr,
133            indices,
134            shape,
135        })
136    }
137
138    /// Convert a regular CSR matrix to symmetric CSR format
139    ///
140    /// This will verify that the matrix is symmetric and extract
141    /// the lower triangular part.
142    ///
143    /// # Arguments
144    ///
145    /// * `matrix` - CSR matrix to convert
146    ///
147    /// # Returns
148    ///
149    /// A symmetric CSR matrix
150    pub fn from_csr(matrix: &CsrMatrix<T>) -> SparseResult<Self> {
151        let (rows, cols) = matrix.shape();
152
153        // Ensure matrix is square
154        if rows != cols {
155            return Err(SparseError::ValueError(
156                "Symmetric matrix must be square".to_string(),
157            ));
158        }
159
160        // Check if the matrix is symmetric
161        if !Self::is_symmetric(matrix) {
162            return Err(SparseError::ValueError(
163                "Matrix must be symmetric to convert to SymCSR format".to_string(),
164            ));
165        }
166
167        // Extract the lower triangular part
168        let mut data = Vec::new();
169        let mut indices = Vec::new();
170        let mut indptr = vec![0];
171
172        for i in 0..rows {
173            for j in matrix.indptr[i]..matrix.indptr[i + 1] {
174                let col = matrix.indices[j];
175
176                // Only include elements in lower triangular part (including diagonal)
177                if col <= i {
178                    data.push(matrix.data[j]);
179                    indices.push(col);
180                }
181            }
182
183            indptr.push(data.len());
184        }
185
186        Ok(Self {
187            data,
188            indptr,
189            indices,
190            shape: (rows, cols),
191        })
192    }
193
194    /// Check if a CSR matrix is symmetric
195    ///
196    /// # Arguments
197    ///
198    /// * `matrix` - CSR matrix to check
199    ///
200    /// # Returns
201    ///
202    /// `true` if the matrix is symmetric, `false` otherwise
203    pub fn is_symmetric(matrix: &CsrMatrix<T>) -> bool {
204        let (rows, cols) = matrix.shape();
205
206        // Must be square
207        if rows != cols {
208            return false;
209        }
210
211        // Compare each element (i,j) with (j,i)
212        for i in 0..rows {
213            for j_ptr in matrix.indptr[i]..matrix.indptr[i + 1] {
214                let j = matrix.indices[j_ptr];
215                let val = matrix.data[j_ptr];
216
217                // Find the corresponding (j,i) element
218                let i_val = matrix.get(j, i);
219
220                // Check if a[i,j] == a[j,i] with sufficient tolerance
221                let diff = (val - i_val).abs();
222                let epsilon = T::epsilon() * T::from(100.0).unwrap();
223                if diff > epsilon {
224                    return false;
225                }
226            }
227        }
228
229        true
230    }
231
232    /// Get the shape of the matrix
233    ///
234    /// # Returns
235    ///
236    /// A tuple (rows, cols)
237    pub fn shape(&self) -> (usize, usize) {
238        self.shape
239    }
240
241    /// Get the number of stored non-zero elements
242    ///
243    /// # Returns
244    ///
245    /// The number of non-zero elements in the lower triangular part
246    pub fn nnz_stored(&self) -> usize {
247        self.data.len()
248    }
249
250    /// Get the total number of non-zero elements in the full matrix
251    ///
252    /// # Returns
253    ///
254    /// The total number of non-zero elements in the full symmetric matrix
255    pub fn nnz(&self) -> usize {
256        let diag_count = (0..self.shape.0)
257            .filter(|&i| {
258                // Count diagonal elements that are non-zero
259                let row_start = self.indptr[i];
260                let row_end = self.indptr[i + 1];
261                (row_start..row_end).any(|j_ptr| self.indices[j_ptr] == i)
262            })
263            .count();
264
265        let offdiag_count = self.data.len() - diag_count;
266
267        // Diagonal elements count once, off-diagonal elements count twice
268        diag_count + 2 * offdiag_count
269    }
270
271    /// Get a single element from the matrix
272    ///
273    /// # Arguments
274    ///
275    /// * `row` - Row index
276    /// * `col` - Column index
277    ///
278    /// # Returns
279    ///
280    /// The value at position (row, col)
281    pub fn get(&self, row: usize, col: usize) -> T {
282        // Check bounds
283        if row >= self.shape.0 || col >= self.shape.1 {
284            return T::zero();
285        }
286
287        // For symmetric matrix, if (row,col) is in upper triangular part,
288        // we look for (col,row) in the lower triangular part
289        let (actual_row, actual_col) = if row < col { (col, row) } else { (row, col) };
290
291        // Search for the element
292        for j in self.indptr[actual_row]..self.indptr[actual_row + 1] {
293            if self.indices[j] == actual_col {
294                return self.data[j];
295            }
296        }
297
298        T::zero()
299    }
300
301    /// Convert to standard CSR matrix (reconstructing full symmetric matrix)
302    ///
303    /// # Returns
304    ///
305    /// A standard CSR matrix with both upper and lower triangular parts
306    pub fn to_csr(&self) -> SparseResult<CsrMatrix<T>> {
307        let n = self.shape.0;
308
309        // First, convert to triplet format for the full symmetric matrix
310        let mut data = Vec::new();
311        let mut row_indices = Vec::new();
312        let mut col_indices = Vec::new();
313
314        for i in 0..n {
315            // Add elements from lower triangular part (directly stored)
316            for j_ptr in self.indptr[i]..self.indptr[i + 1] {
317                let j = self.indices[j_ptr];
318                let val = self.data[j_ptr];
319
320                // Add the element itself
321                row_indices.push(i);
322                col_indices.push(j);
323                data.push(val);
324
325                // Add its symmetric counterpart (if not on diagonal)
326                if i != j {
327                    row_indices.push(j);
328                    col_indices.push(i);
329                    data.push(val);
330                }
331            }
332        }
333
334        // Create the CSR matrix from triplets
335        CsrMatrix::new(data, row_indices, col_indices, self.shape)
336    }
337
338    /// Convert to dense matrix
339    ///
340    /// # Returns
341    ///
342    /// A dense matrix representation as a vector of vectors
343    pub fn to_dense(&self) -> Vec<Vec<T>> {
344        let n = self.shape.0;
345        let mut dense = vec![vec![T::zero(); n]; n];
346
347        // Fill the lower triangular part (directly from stored data)
348        for (i, row) in dense.iter_mut().enumerate().take(n) {
349            for j_ptr in self.indptr[i]..self.indptr[i + 1] {
350                let j = self.indices[j_ptr];
351                row[j] = self.data[j_ptr];
352            }
353        }
354
355        // Fill the upper triangular part (from symmetry)
356        for i in 0..n {
357            for j in 0..i {
358                dense[j][i] = dense[i][j];
359            }
360        }
361
362        dense
363    }
364}
365
366/// Array-based SymCSR implementation compatible with SparseArray trait
367#[derive(Debug, Clone)]
368pub struct SymCsrArray<T>
369where
370    T: Float + Debug + Copy,
371{
372    /// Inner matrix
373    inner: SymCsrMatrix<T>,
374}
375
376impl<T> SymCsrArray<T>
377where
378    T: Float
379        + Debug
380        + Copy
381        + 'static
382        + Add<Output = T>
383        + Sub<Output = T>
384        + Mul<Output = T>
385        + Div<Output = T>,
386{
387    /// Create a new SymCSR array from a SymCSR matrix
388    ///
389    /// # Arguments
390    ///
391    /// * `matrix` - Symmetric CSR matrix
392    ///
393    /// # Returns
394    ///
395    /// SymCSR array
396    pub fn new(matrix: SymCsrMatrix<T>) -> Self {
397        Self { inner: matrix }
398    }
399
400    /// Create a SymCSR array from a regular CSR array
401    ///
402    /// # Arguments
403    ///
404    /// * `array` - CSR array to convert
405    ///
406    /// # Returns
407    ///
408    /// A symmetric CSR array
409    pub fn from_csr_array(array: &CsrArray<T>) -> SparseResult<Self> {
410        let shape = array.shape();
411        let (rows, cols) = shape;
412
413        // Ensure matrix is square
414        if rows != cols {
415            return Err(SparseError::ValueError(
416                "Symmetric matrix must be square".to_string(),
417            ));
418        }
419
420        // Create a temporary CSR matrix to check symmetry
421        let csr_matrix = CsrMatrix::new(
422            array.get_data().to_vec(),
423            array.get_indptr().to_vec(),
424            array.get_indices().to_vec(),
425            shape,
426        )?;
427
428        // Convert to symmetric CSR
429        let sym_csr = SymCsrMatrix::from_csr(&csr_matrix)?;
430
431        Ok(Self { inner: sym_csr })
432    }
433
434    /// Get the underlying matrix
435    ///
436    /// # Returns
437    ///
438    /// Reference to the inner SymCSR matrix
439    pub fn inner(&self) -> &SymCsrMatrix<T> {
440        &self.inner
441    }
442
443    /// Get access to the underlying data array
444    ///
445    /// # Returns
446    ///
447    /// Reference to the data array
448    pub fn data(&self) -> &[T] {
449        &self.inner.data
450    }
451
452    /// Get access to the underlying indices array
453    ///
454    /// # Returns
455    ///
456    /// Reference to the indices array
457    pub fn indices(&self) -> &[usize] {
458        &self.inner.indices
459    }
460
461    /// Get access to the underlying indptr array
462    ///
463    /// # Returns
464    ///
465    /// Reference to the indptr array
466    pub fn indptr(&self) -> &[usize] {
467        &self.inner.indptr
468    }
469
470    /// Convert to a standard CSR array
471    ///
472    /// # Returns
473    ///
474    /// CSR array containing the full symmetric matrix
475    pub fn to_csr_array(&self) -> SparseResult<CsrArray<T>> {
476        let csr = self.inner.to_csr()?;
477
478        // Convert the CsrMatrix to CsrArray using from_triplets
479        let (rows, cols, data) = csr.get_triplets();
480        let shape = csr.shape();
481
482        // Safety check - rows, cols, and data should all be the same length
483        if rows.len() != cols.len() || rows.len() != data.len() {
484            return Err(SparseError::DimensionMismatch {
485                expected: rows.len(),
486                found: cols.len().min(data.len()),
487            });
488        }
489
490        CsrArray::from_triplets(&rows, &cols, &data, shape, false)
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use crate::sparray::SparseArray;
498
499    #[test]
500    fn test_sym_csr_creation() {
501        // Create a simple symmetric matrix stored in lower triangular format
502        // [2 1 0]
503        // [1 2 3]
504        // [0 3 1]
505
506        // Note: Actually represents the lower triangular part only:
507        // [2 0 0]
508        // [1 2 0]
509        // [0 3 1]
510
511        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
512        let indices = vec![0, 0, 1, 1, 2];
513        let indptr = vec![0, 1, 3, 5];
514
515        let sym = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
516
517        assert_eq!(sym.shape(), (3, 3));
518        assert_eq!(sym.nnz_stored(), 5);
519
520        // Total non-zeros should count off-diagonal elements twice
521        assert_eq!(sym.nnz(), 7);
522
523        // Test accessing elements
524        assert_eq!(sym.get(0, 0), 2.0);
525        assert_eq!(sym.get(0, 1), 1.0);
526        assert_eq!(sym.get(1, 0), 1.0); // From symmetry
527        assert_eq!(sym.get(1, 1), 2.0);
528        assert_eq!(sym.get(1, 2), 3.0);
529        assert_eq!(sym.get(2, 1), 3.0); // From symmetry
530        assert_eq!(sym.get(2, 2), 1.0);
531        assert_eq!(sym.get(0, 2), 0.0); // Zero element - not stored
532        assert_eq!(sym.get(2, 0), 0.0); // Zero element - not stored
533    }
534
535    #[test]
536    fn test_sym_csr_from_standard() {
537        // Create a standard CSR matrix that's symmetric
538        // [2 1 0]
539        // [1 2 3]
540        // [0 3 1]
541
542        // Create it from triplets to ensure it's properly constructed
543        let row_indices = vec![0, 0, 1, 1, 1, 2, 2];
544        let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
545        let data = vec![2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 1.0];
546
547        let csr = CsrMatrix::new(data, row_indices, col_indices, (3, 3)).unwrap();
548        let sym = SymCsrMatrix::from_csr(&csr).unwrap();
549
550        assert_eq!(sym.shape(), (3, 3));
551
552        // Convert back to standard CSR to check
553        let csr2 = sym.to_csr().unwrap();
554        let dense = csr2.to_dense();
555
556        // Check the full matrix
557        assert_eq!(dense[0][0], 2.0);
558        assert_eq!(dense[0][1], 1.0);
559        assert_eq!(dense[0][2], 0.0);
560        assert_eq!(dense[1][0], 1.0);
561        assert_eq!(dense[1][1], 2.0);
562        assert_eq!(dense[1][2], 3.0);
563        assert_eq!(dense[2][0], 0.0);
564        assert_eq!(dense[2][1], 3.0);
565        assert_eq!(dense[2][2], 1.0);
566    }
567
568    #[test]
569    fn test_sym_csr_array() {
570        // Create a symmetric SymCSR matrix, storing only the lower triangular part
571        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
572        let indices = vec![0, 0, 1, 1, 2];
573        let indptr = vec![0, 1, 3, 5];
574
575        let sym_matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
576        let sym_array = SymCsrArray::new(sym_matrix);
577
578        assert_eq!(sym_array.inner().shape(), (3, 3));
579
580        // Convert to standard CSR array
581        let csr_array = sym_array.to_csr_array().unwrap();
582
583        // Verify shape and values
584        assert_eq!(csr_array.shape(), (3, 3));
585        assert_eq!(csr_array.get(0, 0), 2.0);
586        assert_eq!(csr_array.get(0, 1), 1.0);
587        assert_eq!(csr_array.get(1, 0), 1.0);
588        assert_eq!(csr_array.get(1, 1), 2.0);
589        assert_eq!(csr_array.get(1, 2), 3.0);
590        assert_eq!(csr_array.get(2, 1), 3.0);
591        assert_eq!(csr_array.get(2, 2), 1.0);
592        assert_eq!(csr_array.get(0, 2), 0.0);
593        assert_eq!(csr_array.get(2, 0), 0.0);
594    }
595}