scirs2_sparse/
sym_csr.rs

1//! Symmetric Compressed Sparse Row (SymCSR) module
2//!
3//! This module provides a specialized implementation of the CSR format
4//! optimized for symmetric matrices, storing only the lower or upper
5//! triangular part of the matrix.
6
7use crate::csr::CsrMatrix;
8use crate::csr_array::CsrArray;
9use crate::error::{SparseError, SparseResult};
10use scirs2_core::numeric::{Float, SparseElement};
11use std::fmt::Debug;
12use std::ops::{Add, Div, Mul, Sub};
13
14/// Symmetric Compressed Sparse Row (SymCSR) matrix
15///
16/// This format stores only the lower triangular part of a symmetric matrix
17/// to save memory and improve performance. Operations are optimized to
18/// take advantage of symmetry when possible.
19///
20/// # Note
21///
22/// All operations maintain symmetry implicitly.
23#[derive(Debug, Clone)]
24pub struct SymCsrMatrix<T>
25where
26    T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
27{
28    /// CSR format data for the lower triangular part (including diagonal)
29    pub data: Vec<T>,
30
31    /// Row pointers (indptr): indices where each row starts in indices array
32    pub indptr: Vec<usize>,
33
34    /// Column indices for each non-zero element
35    pub indices: Vec<usize>,
36
37    /// Matrix shape (rows, cols), always square
38    pub shape: (usize, usize),
39}
40
41impl<T> SymCsrMatrix<T>
42where
43    T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
44{
45    /// Create a new symmetric CSR matrix from raw data
46    ///
47    /// # Arguments
48    ///
49    /// * `data` - Non-zero values in the lower triangular part
50    /// * `indptr` - Row pointers
51    /// * `indices` - Column indices
52    /// * `shape` - Matrix shape (n, n)
53    ///
54    /// # Returns
55    ///
56    /// A symmetric CSR matrix
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if:
61    /// - The shape is not square
62    /// - The indices array is incompatible with indptr
63    /// - Any column index is out of bounds
64    pub fn new(
65        data: Vec<T>,
66        indptr: Vec<usize>,
67        indices: Vec<usize>,
68        shape: (usize, usize),
69    ) -> SparseResult<Self> {
70        let (rows, cols) = shape;
71
72        // Ensure matrix is square
73        if rows != cols {
74            return Err(SparseError::ValueError(
75                "Symmetric matrix must be square".to_string(),
76            ));
77        }
78
79        // Check indptr length
80        if indptr.len() != rows + 1 {
81            return Err(SparseError::ValueError(format!(
82                "indptr length ({}) must be equal to rows + 1 ({})",
83                indptr.len(),
84                rows + 1
85            )));
86        }
87
88        // Check data and indices lengths
89        let nnz = indices.len();
90        if data.len() != nnz {
91            return Err(SparseError::ValueError(format!(
92                "data length ({}) must match indices length ({})",
93                data.len(),
94                nnz
95            )));
96        }
97
98        // Check last indptr value
99        if let Some(&last) = indptr.last() {
100            if last != nnz {
101                return Err(SparseError::ValueError(format!(
102                    "Last indptr value ({last}) must equal nnz ({nnz})"
103                )));
104            }
105        }
106
107        // Check that row and column indices are within bounds
108        for (i, &row_start) in indptr.iter().enumerate().take(rows) {
109            let row_end = indptr[i + 1];
110
111            for &col in &indices[row_start..row_end] {
112                if col >= cols {
113                    return Err(SparseError::IndexOutOfBounds {
114                        index: (i, col),
115                        shape: (rows, cols),
116                    });
117                }
118
119                // For symmetric matrix, ensure we only store the lower triangular part
120                if col > i {
121                    return Err(SparseError::ValueError(
122                        "Symmetric CSR should only store the lower triangular part".to_string(),
123                    ));
124                }
125            }
126        }
127
128        Ok(Self {
129            data,
130            indptr,
131            indices,
132            shape,
133        })
134    }
135
136    /// Convert a regular CSR matrix to symmetric CSR format
137    ///
138    /// This will verify that the matrix is symmetric and extract
139    /// the lower triangular part.
140    ///
141    /// # Arguments
142    ///
143    /// * `matrix` - CSR matrix to convert
144    ///
145    /// # Returns
146    ///
147    /// A symmetric CSR matrix
148    pub fn from_csr(matrix: &CsrMatrix<T>) -> SparseResult<Self> {
149        let (rows, cols) = matrix.shape();
150
151        // Ensure matrix is square
152        if rows != cols {
153            return Err(SparseError::ValueError(
154                "Symmetric matrix must be square".to_string(),
155            ));
156        }
157
158        // Check if the matrix is symmetric
159        if !Self::is_symmetric(matrix) {
160            return Err(SparseError::ValueError(
161                "Matrix must be symmetric to convert to SymCSR format".to_string(),
162            ));
163        }
164
165        // Extract the lower triangular part
166        let mut data = Vec::new();
167        let mut indices = Vec::new();
168        let mut indptr = vec![0];
169
170        for i in 0..rows {
171            for j in matrix.indptr[i]..matrix.indptr[i + 1] {
172                let col = matrix.indices[j];
173
174                // Only include elements in lower triangular part (including diagonal)
175                if col <= i {
176                    data.push(matrix.data[j]);
177                    indices.push(col);
178                }
179            }
180
181            indptr.push(data.len());
182        }
183
184        Ok(Self {
185            data,
186            indptr,
187            indices,
188            shape: (rows, cols),
189        })
190    }
191
192    /// Check if a CSR matrix is symmetric
193    ///
194    /// # Arguments
195    ///
196    /// * `matrix` - CSR matrix to check
197    ///
198    /// # Returns
199    ///
200    /// `true` if the matrix is symmetric, `false` otherwise
201    pub fn is_symmetric(matrix: &CsrMatrix<T>) -> bool {
202        let (rows, cols) = matrix.shape();
203
204        // Must be square
205        if rows != cols {
206            return false;
207        }
208
209        // Compare each element (i,j) with (j,i)
210        for i in 0..rows {
211            for j_ptr in matrix.indptr[i]..matrix.indptr[i + 1] {
212                let j = matrix.indices[j_ptr];
213                let val = matrix.data[j_ptr];
214
215                // Find the corresponding (j,i) element
216                let i_val = matrix.get(j, i);
217
218                // Check if a[i,j] == a[j,i] with sufficient tolerance
219                let diff = (val - i_val).abs();
220                let epsilon = T::epsilon() * T::from(100.0).unwrap();
221                if diff > epsilon {
222                    return false;
223                }
224            }
225        }
226
227        true
228    }
229
230    /// Get the shape of the matrix
231    ///
232    /// # Returns
233    ///
234    /// A tuple (rows, cols)
235    pub fn shape(&self) -> (usize, usize) {
236        self.shape
237    }
238
239    /// Get the number of stored non-zero elements
240    ///
241    /// # Returns
242    ///
243    /// The number of non-zero elements in the lower triangular part
244    pub fn nnz_stored(&self) -> usize {
245        self.data.len()
246    }
247
248    /// Get the total number of non-zero elements in the full matrix
249    ///
250    /// # Returns
251    ///
252    /// The total number of non-zero elements in the full symmetric matrix
253    pub fn nnz(&self) -> usize {
254        let diag_count = (0..self.shape.0)
255            .filter(|&i| {
256                // Count diagonal elements that are non-zero
257                let row_start = self.indptr[i];
258                let row_end = self.indptr[i + 1];
259                (row_start..row_end).any(|j_ptr| self.indices[j_ptr] == i)
260            })
261            .count();
262
263        let offdiag_count = self.data.len() - diag_count;
264
265        // Diagonal elements count once, off-diagonal elements count twice
266        diag_count + 2 * offdiag_count
267    }
268
269    /// Get a single element from the matrix
270    ///
271    /// # Arguments
272    ///
273    /// * `row` - Row index
274    /// * `col` - Column index
275    ///
276    /// # Returns
277    ///
278    /// The value at position (row, col)
279    pub fn get(&self, row: usize, col: usize) -> T {
280        // Check bounds
281        if row >= self.shape.0 || col >= self.shape.1 {
282            return T::sparse_zero();
283        }
284
285        // For symmetric matrix, if (row,col) is in upper triangular part,
286        // we look for (col,row) in the lower triangular part
287        let (actual_row, actual_col) = if row < col { (col, row) } else { (row, col) };
288
289        // Search for the element
290        for j in self.indptr[actual_row]..self.indptr[actual_row + 1] {
291            if self.indices[j] == actual_col {
292                return self.data[j];
293            }
294        }
295
296        T::sparse_zero()
297    }
298
299    /// Convert to standard CSR matrix (reconstructing full symmetric matrix)
300    ///
301    /// # Returns
302    ///
303    /// A standard CSR matrix with both upper and lower triangular parts
304    pub fn to_csr(&self) -> SparseResult<CsrMatrix<T>> {
305        let n = self.shape.0;
306
307        // First, convert to triplet format for the full symmetric matrix
308        let mut data = Vec::new();
309        let mut row_indices = Vec::new();
310        let mut col_indices = Vec::new();
311
312        for i in 0..n {
313            // Add elements from lower triangular part (directly stored)
314            for j_ptr in self.indptr[i]..self.indptr[i + 1] {
315                let j = self.indices[j_ptr];
316                let val = self.data[j_ptr];
317
318                // Add the element itself
319                row_indices.push(i);
320                col_indices.push(j);
321                data.push(val);
322
323                // Add its symmetric counterpart (if not on diagonal)
324                if i != j {
325                    row_indices.push(j);
326                    col_indices.push(i);
327                    data.push(val);
328                }
329            }
330        }
331
332        // Create the CSR matrix from triplets
333        CsrMatrix::new(data, row_indices, col_indices, self.shape)
334    }
335
336    /// Convert to dense matrix
337    ///
338    /// # Returns
339    ///
340    /// A dense matrix representation as a vector of vectors
341    pub fn to_dense(&self) -> Vec<Vec<T>> {
342        let n = self.shape.0;
343        let mut dense = vec![vec![T::sparse_zero(); n]; n];
344
345        // Fill the lower triangular part (directly from stored data)
346        for (i, row) in dense.iter_mut().enumerate().take(n) {
347            for j_ptr in self.indptr[i]..self.indptr[i + 1] {
348                let j = self.indices[j_ptr];
349                row[j] = self.data[j_ptr];
350            }
351        }
352
353        // Fill the upper triangular part (from symmetry)
354        for i in 0..n {
355            for j in 0..i {
356                dense[j][i] = dense[i][j];
357            }
358        }
359
360        dense
361    }
362}
363
364/// Array-based SymCSR implementation compatible with SparseArray trait
365#[derive(Clone)]
366pub struct SymCsrArray<T>
367where
368    T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone,
369{
370    /// Inner matrix
371    inner: SymCsrMatrix<T>,
372}
373
374impl<T> SymCsrArray<T>
375where
376    T: SparseElement + Float + Sub<Output = T> + PartialOrd + Clone + Div<Output = T> + 'static,
377{
378    /// Create a new SymCSR array from a SymCSR matrix
379    ///
380    /// # Arguments
381    ///
382    /// * `matrix` - Symmetric CSR matrix
383    ///
384    /// # Returns
385    ///
386    /// SymCSR array
387    pub fn new(matrix: SymCsrMatrix<T>) -> Self {
388        Self { inner: matrix }
389    }
390
391    /// Create a SymCSR array from a regular CSR array
392    ///
393    /// # Arguments
394    ///
395    /// * `array` - CSR array to convert
396    ///
397    /// # Returns
398    ///
399    /// A symmetric CSR array
400    pub fn from_csr_array(array: &CsrArray<T>) -> SparseResult<Self> {
401        let shape = array.shape();
402        let (rows, cols) = shape;
403
404        // Ensure matrix is square
405        if rows != cols {
406            return Err(SparseError::ValueError(
407                "Symmetric matrix must be square".to_string(),
408            ));
409        }
410
411        // Create a temporary CSR matrix to check symmetry
412        let csrmatrix = CsrMatrix::new(
413            array.get_data().to_vec(),
414            array.get_indptr().to_vec(),
415            array.get_indices().to_vec(),
416            shape,
417        )?;
418
419        // Convert to symmetric CSR
420        let sym_csr = SymCsrMatrix::from_csr(&csrmatrix)?;
421
422        Ok(Self { inner: sym_csr })
423    }
424
425    /// Get the underlying matrix
426    ///
427    /// # Returns
428    ///
429    /// Reference to the inner SymCSR matrix
430    pub fn inner(&self) -> &SymCsrMatrix<T> {
431        &self.inner
432    }
433
434    /// Get access to the underlying data array
435    ///
436    /// # Returns
437    ///
438    /// Reference to the data array
439    pub fn data(&self) -> &[T] {
440        &self.inner.data
441    }
442
443    /// Get access to the underlying indices array
444    ///
445    /// # Returns
446    ///
447    /// Reference to the indices array
448    pub fn indices(&self) -> &[usize] {
449        &self.inner.indices
450    }
451
452    /// Get access to the underlying indptr array
453    ///
454    /// # Returns
455    ///
456    /// Reference to the indptr array
457    pub fn indptr(&self) -> &[usize] {
458        &self.inner.indptr
459    }
460
461    /// Convert to a standard CSR array
462    ///
463    /// # Returns
464    ///
465    /// CSR array containing the full symmetric matrix
466    pub fn to_csr_array(&self) -> SparseResult<CsrArray<T>> {
467        let csr = self.inner.to_csr()?;
468
469        // Convert the CsrMatrix to CsrArray using from_triplets
470        let (rows, cols, data) = csr.get_triplets();
471        let shape = csr.shape();
472
473        // Safety check - rows, cols, and data should all be the same length
474        if rows.len() != cols.len() || rows.len() != data.len() {
475            return Err(SparseError::DimensionMismatch {
476                expected: rows.len(),
477                found: cols.len().min(data.len()),
478            });
479        }
480
481        CsrArray::from_triplets(&rows, &cols, &data, shape, false)
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::sparray::SparseArray;
489
490    #[test]
491    fn test_sym_csr_creation() {
492        // Create a simple symmetric matrix stored in lower triangular format
493        // [2 1 0]
494        // [1 2 3]
495        // [0 3 1]
496
497        // Note: Actually represents the lower triangular part only:
498        // [2 0 0]
499        // [1 2 0]
500        // [0 3 1]
501
502        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
503        let indices = vec![0, 0, 1, 1, 2];
504        let indptr = vec![0, 1, 3, 5];
505
506        let sym = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
507
508        assert_eq!(sym.shape(), (3, 3));
509        assert_eq!(sym.nnz_stored(), 5);
510
511        // Total non-zeros should count off-diagonal elements twice
512        assert_eq!(sym.nnz(), 7);
513
514        // Test accessing elements
515        assert_eq!(sym.get(0, 0), 2.0);
516        assert_eq!(sym.get(0, 1), 1.0);
517        assert_eq!(sym.get(1, 0), 1.0); // From symmetry
518        assert_eq!(sym.get(1, 1), 2.0);
519        assert_eq!(sym.get(1, 2), 3.0);
520        assert_eq!(sym.get(2, 1), 3.0); // From symmetry
521        assert_eq!(sym.get(2, 2), 1.0);
522        assert_eq!(sym.get(0, 2), 0.0); // Zero element - not stored
523        assert_eq!(sym.get(2, 0), 0.0); // Zero element - not stored
524    }
525
526    #[test]
527    fn test_sym_csr_from_standard() {
528        // Create a standard CSR matrix that's symmetric
529        // [2 1 0]
530        // [1 2 3]
531        // [0 3 1]
532
533        // Create it from triplets to ensure it's properly constructed
534        let row_indices = vec![0, 0, 1, 1, 1, 2, 2];
535        let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
536        let data = vec![2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 1.0];
537
538        let csr = CsrMatrix::new(data, row_indices, col_indices, (3, 3)).unwrap();
539        let sym = SymCsrMatrix::from_csr(&csr).unwrap();
540
541        assert_eq!(sym.shape(), (3, 3));
542
543        // Convert back to standard CSR to check
544        let csr2 = sym.to_csr().unwrap();
545        let dense = csr2.to_dense();
546
547        // Check the full matrix
548        assert_eq!(dense[0][0], 2.0);
549        assert_eq!(dense[0][1], 1.0);
550        assert_eq!(dense[0][2], 0.0);
551        assert_eq!(dense[1][0], 1.0);
552        assert_eq!(dense[1][1], 2.0);
553        assert_eq!(dense[1][2], 3.0);
554        assert_eq!(dense[2][0], 0.0);
555        assert_eq!(dense[2][1], 3.0);
556        assert_eq!(dense[2][2], 1.0);
557    }
558
559    #[test]
560    fn test_sym_csr_array() {
561        // Create a symmetric SymCSR matrix, storing only the lower triangular part
562        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
563        let indices = vec![0, 0, 1, 1, 2];
564        let indptr = vec![0, 1, 3, 5];
565
566        let symmatrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
567        let sym_array = SymCsrArray::new(symmatrix);
568
569        assert_eq!(sym_array.inner().shape(), (3, 3));
570
571        // Convert to standard CSR array
572        let csr_array = sym_array.to_csr_array().unwrap();
573
574        // Verify shape and values
575        assert_eq!(csr_array.shape(), (3, 3));
576        assert_eq!(csr_array.get(0, 0), 2.0);
577        assert_eq!(csr_array.get(0, 1), 1.0);
578        assert_eq!(csr_array.get(1, 0), 1.0);
579        assert_eq!(csr_array.get(1, 1), 2.0);
580        assert_eq!(csr_array.get(1, 2), 3.0);
581        assert_eq!(csr_array.get(2, 1), 3.0);
582        assert_eq!(csr_array.get(2, 2), 1.0);
583        assert_eq!(csr_array.get(0, 2), 0.0);
584        assert_eq!(csr_array.get(2, 0), 0.0);
585    }
586}