scirs2_sparse/sym_csr.rs
1//! Symmetric Compressed Sparse Row (SymCSR) module
2//!
3//! This module provides a specialized implementation of the CSR format
4//! optimized for symmetric matrices, storing only the lower or upper
5//! triangular part of the matrix.
6
7use crate::csr::CsrMatrix;
8use crate::csr_array::CsrArray;
9use crate::error::{SparseError, SparseResult};
10use crate::sparray::SparseArray;
11use num_traits::Float;
12use std::fmt::Debug;
13use std::ops::{Add, Div, Mul, Sub};
14
15/// Symmetric Compressed Sparse Row (SymCSR) matrix
16///
17/// This format stores only the lower triangular part of a symmetric matrix
18/// to save memory and improve performance. Operations are optimized to
19/// take advantage of symmetry when possible.
20///
21/// # Note
22///
23/// All operations maintain symmetry implicitly.
24#[derive(Debug, Clone)]
25pub struct SymCsrMatrix<T>
26where
27 T: Float + Debug + Copy,
28{
29 /// CSR format data for the lower triangular part (including diagonal)
30 pub data: Vec<T>,
31
32 /// Row pointers (indptr): indices where each row starts in indices array
33 pub indptr: Vec<usize>,
34
35 /// Column indices for each non-zero element
36 pub indices: Vec<usize>,
37
38 /// Matrix shape (rows, cols), always square
39 pub shape: (usize, usize),
40}
41
42impl<T> SymCsrMatrix<T>
43where
44 T: Float + Debug + Copy,
45{
46 /// Create a new symmetric CSR matrix from raw data
47 ///
48 /// # Arguments
49 ///
50 /// * `data` - Non-zero values in the lower triangular part
51 /// * `indptr` - Row pointers
52 /// * `indices` - Column indices
53 /// * `shape` - Matrix shape (n, n)
54 ///
55 /// # Returns
56 ///
57 /// A symmetric CSR matrix
58 ///
59 /// # Errors
60 ///
61 /// Returns an error if:
62 /// - The shape is not square
63 /// - The indices array is incompatible with indptr
64 /// - Any column index is out of bounds
65 pub fn new(
66 data: Vec<T>,
67 indptr: Vec<usize>,
68 indices: Vec<usize>,
69 shape: (usize, usize),
70 ) -> SparseResult<Self> {
71 let (rows, cols) = shape;
72
73 // Ensure matrix is square
74 if rows != cols {
75 return Err(SparseError::ValueError(
76 "Symmetric matrix must be square".to_string(),
77 ));
78 }
79
80 // Check indptr length
81 if indptr.len() != rows + 1 {
82 return Err(SparseError::ValueError(format!(
83 "indptr length ({}) must be equal to rows + 1 ({})",
84 indptr.len(),
85 rows + 1
86 )));
87 }
88
89 // Check data and indices lengths
90 let nnz = indices.len();
91 if data.len() != nnz {
92 return Err(SparseError::ValueError(format!(
93 "data length ({}) must match indices length ({})",
94 data.len(),
95 nnz
96 )));
97 }
98
99 // Check last indptr value
100 if let Some(&last) = indptr.last() {
101 if last != nnz {
102 return Err(SparseError::ValueError(format!(
103 "Last indptr value ({}) must equal nnz ({})",
104 last, nnz
105 )));
106 }
107 }
108
109 // Check that row and column indices are within bounds
110 for (i, &row_start) in indptr.iter().enumerate().take(rows) {
111 let row_end = indptr[i + 1];
112
113 for &col in &indices[row_start..row_end] {
114 if col >= cols {
115 return Err(SparseError::IndexOutOfBounds {
116 index: (i, col),
117 shape: (rows, cols),
118 });
119 }
120
121 // For symmetric matrix, ensure we only store the lower triangular part
122 if col > i {
123 return Err(SparseError::ValueError(
124 "Symmetric CSR should only store the lower triangular part".to_string(),
125 ));
126 }
127 }
128 }
129
130 Ok(Self {
131 data,
132 indptr,
133 indices,
134 shape,
135 })
136 }
137
138 /// Convert a regular CSR matrix to symmetric CSR format
139 ///
140 /// This will verify that the matrix is symmetric and extract
141 /// the lower triangular part.
142 ///
143 /// # Arguments
144 ///
145 /// * `matrix` - CSR matrix to convert
146 ///
147 /// # Returns
148 ///
149 /// A symmetric CSR matrix
150 pub fn from_csr(matrix: &CsrMatrix<T>) -> SparseResult<Self> {
151 let (rows, cols) = matrix.shape();
152
153 // Ensure matrix is square
154 if rows != cols {
155 return Err(SparseError::ValueError(
156 "Symmetric matrix must be square".to_string(),
157 ));
158 }
159
160 // Check if the matrix is symmetric
161 if !Self::is_symmetric(matrix) {
162 return Err(SparseError::ValueError(
163 "Matrix must be symmetric to convert to SymCSR format".to_string(),
164 ));
165 }
166
167 // Extract the lower triangular part
168 let mut data = Vec::new();
169 let mut indices = Vec::new();
170 let mut indptr = vec![0];
171
172 for i in 0..rows {
173 for j in matrix.indptr[i]..matrix.indptr[i + 1] {
174 let col = matrix.indices[j];
175
176 // Only include elements in lower triangular part (including diagonal)
177 if col <= i {
178 data.push(matrix.data[j]);
179 indices.push(col);
180 }
181 }
182
183 indptr.push(data.len());
184 }
185
186 Ok(Self {
187 data,
188 indptr,
189 indices,
190 shape: (rows, cols),
191 })
192 }
193
194 /// Check if a CSR matrix is symmetric
195 ///
196 /// # Arguments
197 ///
198 /// * `matrix` - CSR matrix to check
199 ///
200 /// # Returns
201 ///
202 /// `true` if the matrix is symmetric, `false` otherwise
203 pub fn is_symmetric(matrix: &CsrMatrix<T>) -> bool {
204 let (rows, cols) = matrix.shape();
205
206 // Must be square
207 if rows != cols {
208 return false;
209 }
210
211 // Compare each element (i,j) with (j,i)
212 for i in 0..rows {
213 for j_ptr in matrix.indptr[i]..matrix.indptr[i + 1] {
214 let j = matrix.indices[j_ptr];
215 let val = matrix.data[j_ptr];
216
217 // Find the corresponding (j,i) element
218 let i_val = matrix.get(j, i);
219
220 // Check if a[i,j] == a[j,i] with sufficient tolerance
221 let diff = (val - i_val).abs();
222 let epsilon = T::epsilon() * T::from(100.0).unwrap();
223 if diff > epsilon {
224 return false;
225 }
226 }
227 }
228
229 true
230 }
231
232 /// Get the shape of the matrix
233 ///
234 /// # Returns
235 ///
236 /// A tuple (rows, cols)
237 pub fn shape(&self) -> (usize, usize) {
238 self.shape
239 }
240
241 /// Get the number of stored non-zero elements
242 ///
243 /// # Returns
244 ///
245 /// The number of non-zero elements in the lower triangular part
246 pub fn nnz_stored(&self) -> usize {
247 self.data.len()
248 }
249
250 /// Get the total number of non-zero elements in the full matrix
251 ///
252 /// # Returns
253 ///
254 /// The total number of non-zero elements in the full symmetric matrix
255 pub fn nnz(&self) -> usize {
256 let diag_count = (0..self.shape.0)
257 .filter(|&i| {
258 // Count diagonal elements that are non-zero
259 let row_start = self.indptr[i];
260 let row_end = self.indptr[i + 1];
261 (row_start..row_end).any(|j_ptr| self.indices[j_ptr] == i)
262 })
263 .count();
264
265 let offdiag_count = self.data.len() - diag_count;
266
267 // Diagonal elements count once, off-diagonal elements count twice
268 diag_count + 2 * offdiag_count
269 }
270
271 /// Get a single element from the matrix
272 ///
273 /// # Arguments
274 ///
275 /// * `row` - Row index
276 /// * `col` - Column index
277 ///
278 /// # Returns
279 ///
280 /// The value at position (row, col)
281 pub fn get(&self, row: usize, col: usize) -> T {
282 // Check bounds
283 if row >= self.shape.0 || col >= self.shape.1 {
284 return T::zero();
285 }
286
287 // For symmetric matrix, if (row,col) is in upper triangular part,
288 // we look for (col,row) in the lower triangular part
289 let (actual_row, actual_col) = if row < col { (col, row) } else { (row, col) };
290
291 // Search for the element
292 for j in self.indptr[actual_row]..self.indptr[actual_row + 1] {
293 if self.indices[j] == actual_col {
294 return self.data[j];
295 }
296 }
297
298 T::zero()
299 }
300
301 /// Convert to standard CSR matrix (reconstructing full symmetric matrix)
302 ///
303 /// # Returns
304 ///
305 /// A standard CSR matrix with both upper and lower triangular parts
306 pub fn to_csr(&self) -> SparseResult<CsrMatrix<T>> {
307 let n = self.shape.0;
308
309 // First, convert to triplet format for the full symmetric matrix
310 let mut data = Vec::new();
311 let mut row_indices = Vec::new();
312 let mut col_indices = Vec::new();
313
314 for i in 0..n {
315 // Add elements from lower triangular part (directly stored)
316 for j_ptr in self.indptr[i]..self.indptr[i + 1] {
317 let j = self.indices[j_ptr];
318 let val = self.data[j_ptr];
319
320 // Add the element itself
321 row_indices.push(i);
322 col_indices.push(j);
323 data.push(val);
324
325 // Add its symmetric counterpart (if not on diagonal)
326 if i != j {
327 row_indices.push(j);
328 col_indices.push(i);
329 data.push(val);
330 }
331 }
332 }
333
334 // Create the CSR matrix from triplets
335 CsrMatrix::new(data, row_indices, col_indices, self.shape)
336 }
337
338 /// Convert to dense matrix
339 ///
340 /// # Returns
341 ///
342 /// A dense matrix representation as a vector of vectors
343 pub fn to_dense(&self) -> Vec<Vec<T>> {
344 let n = self.shape.0;
345 let mut dense = vec![vec![T::zero(); n]; n];
346
347 // Fill the lower triangular part (directly from stored data)
348 for (i, row) in dense.iter_mut().enumerate().take(n) {
349 for j_ptr in self.indptr[i]..self.indptr[i + 1] {
350 let j = self.indices[j_ptr];
351 row[j] = self.data[j_ptr];
352 }
353 }
354
355 // Fill the upper triangular part (from symmetry)
356 for i in 0..n {
357 for j in 0..i {
358 dense[j][i] = dense[i][j];
359 }
360 }
361
362 dense
363 }
364}
365
366/// Array-based SymCSR implementation compatible with SparseArray trait
367#[derive(Debug, Clone)]
368pub struct SymCsrArray<T>
369where
370 T: Float + Debug + Copy,
371{
372 /// Inner matrix
373 inner: SymCsrMatrix<T>,
374}
375
376impl<T> SymCsrArray<T>
377where
378 T: Float
379 + Debug
380 + Copy
381 + 'static
382 + Add<Output = T>
383 + Sub<Output = T>
384 + Mul<Output = T>
385 + Div<Output = T>,
386{
387 /// Create a new SymCSR array from a SymCSR matrix
388 ///
389 /// # Arguments
390 ///
391 /// * `matrix` - Symmetric CSR matrix
392 ///
393 /// # Returns
394 ///
395 /// SymCSR array
396 pub fn new(matrix: SymCsrMatrix<T>) -> Self {
397 Self { inner: matrix }
398 }
399
400 /// Create a SymCSR array from a regular CSR array
401 ///
402 /// # Arguments
403 ///
404 /// * `array` - CSR array to convert
405 ///
406 /// # Returns
407 ///
408 /// A symmetric CSR array
409 pub fn from_csr_array(array: &CsrArray<T>) -> SparseResult<Self> {
410 let shape = array.shape();
411 let (rows, cols) = shape;
412
413 // Ensure matrix is square
414 if rows != cols {
415 return Err(SparseError::ValueError(
416 "Symmetric matrix must be square".to_string(),
417 ));
418 }
419
420 // Create a temporary CSR matrix to check symmetry
421 let csr_matrix = CsrMatrix::new(
422 array.get_data().to_vec(),
423 array.get_indptr().to_vec(),
424 array.get_indices().to_vec(),
425 shape,
426 )?;
427
428 // Convert to symmetric CSR
429 let sym_csr = SymCsrMatrix::from_csr(&csr_matrix)?;
430
431 Ok(Self { inner: sym_csr })
432 }
433
434 /// Get the underlying matrix
435 ///
436 /// # Returns
437 ///
438 /// Reference to the inner SymCSR matrix
439 pub fn inner(&self) -> &SymCsrMatrix<T> {
440 &self.inner
441 }
442
443 /// Get access to the underlying data array
444 ///
445 /// # Returns
446 ///
447 /// Reference to the data array
448 pub fn data(&self) -> &[T] {
449 &self.inner.data
450 }
451
452 /// Get access to the underlying indices array
453 ///
454 /// # Returns
455 ///
456 /// Reference to the indices array
457 pub fn indices(&self) -> &[usize] {
458 &self.inner.indices
459 }
460
461 /// Get access to the underlying indptr array
462 ///
463 /// # Returns
464 ///
465 /// Reference to the indptr array
466 pub fn indptr(&self) -> &[usize] {
467 &self.inner.indptr
468 }
469
470 /// Convert to a standard CSR array
471 ///
472 /// # Returns
473 ///
474 /// CSR array containing the full symmetric matrix
475 pub fn to_csr_array(&self) -> SparseResult<CsrArray<T>> {
476 let csr = self.inner.to_csr()?;
477
478 // Convert the CsrMatrix to CsrArray using from_triplets
479 let (rows, cols, data) = csr.get_triplets();
480 let shape = csr.shape();
481
482 // Safety check - rows, cols, and data should all be the same length
483 if rows.len() != cols.len() || rows.len() != data.len() {
484 return Err(SparseError::DimensionMismatch {
485 expected: rows.len(),
486 found: cols.len().min(data.len()),
487 });
488 }
489
490 CsrArray::from_triplets(&rows, &cols, &data, shape, false)
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use crate::sparray::SparseArray;
498
499 #[test]
500 fn test_sym_csr_creation() {
501 // Create a simple symmetric matrix stored in lower triangular format
502 // [2 1 0]
503 // [1 2 3]
504 // [0 3 1]
505
506 // Note: Actually represents the lower triangular part only:
507 // [2 0 0]
508 // [1 2 0]
509 // [0 3 1]
510
511 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
512 let indices = vec![0, 0, 1, 1, 2];
513 let indptr = vec![0, 1, 3, 5];
514
515 let sym = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
516
517 assert_eq!(sym.shape(), (3, 3));
518 assert_eq!(sym.nnz_stored(), 5);
519
520 // Total non-zeros should count off-diagonal elements twice
521 assert_eq!(sym.nnz(), 7);
522
523 // Test accessing elements
524 assert_eq!(sym.get(0, 0), 2.0);
525 assert_eq!(sym.get(0, 1), 1.0);
526 assert_eq!(sym.get(1, 0), 1.0); // From symmetry
527 assert_eq!(sym.get(1, 1), 2.0);
528 assert_eq!(sym.get(1, 2), 3.0);
529 assert_eq!(sym.get(2, 1), 3.0); // From symmetry
530 assert_eq!(sym.get(2, 2), 1.0);
531 assert_eq!(sym.get(0, 2), 0.0); // Zero element - not stored
532 assert_eq!(sym.get(2, 0), 0.0); // Zero element - not stored
533 }
534
535 #[test]
536 fn test_sym_csr_from_standard() {
537 // Create a standard CSR matrix that's symmetric
538 // [2 1 0]
539 // [1 2 3]
540 // [0 3 1]
541
542 // Create it from triplets to ensure it's properly constructed
543 let row_indices = vec![0, 0, 1, 1, 1, 2, 2];
544 let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
545 let data = vec![2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 1.0];
546
547 let csr = CsrMatrix::new(data, row_indices, col_indices, (3, 3)).unwrap();
548 let sym = SymCsrMatrix::from_csr(&csr).unwrap();
549
550 assert_eq!(sym.shape(), (3, 3));
551
552 // Convert back to standard CSR to check
553 let csr2 = sym.to_csr().unwrap();
554 let dense = csr2.to_dense();
555
556 // Check the full matrix
557 assert_eq!(dense[0][0], 2.0);
558 assert_eq!(dense[0][1], 1.0);
559 assert_eq!(dense[0][2], 0.0);
560 assert_eq!(dense[1][0], 1.0);
561 assert_eq!(dense[1][1], 2.0);
562 assert_eq!(dense[1][2], 3.0);
563 assert_eq!(dense[2][0], 0.0);
564 assert_eq!(dense[2][1], 3.0);
565 assert_eq!(dense[2][2], 1.0);
566 }
567
568 #[test]
569 fn test_sym_csr_array() {
570 // Create a symmetric SymCSR matrix, storing only the lower triangular part
571 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
572 let indices = vec![0, 0, 1, 1, 2];
573 let indptr = vec![0, 1, 3, 5];
574
575 let sym_matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
576 let sym_array = SymCsrArray::new(sym_matrix);
577
578 assert_eq!(sym_array.inner().shape(), (3, 3));
579
580 // Convert to standard CSR array
581 let csr_array = sym_array.to_csr_array().unwrap();
582
583 // Verify shape and values
584 assert_eq!(csr_array.shape(), (3, 3));
585 assert_eq!(csr_array.get(0, 0), 2.0);
586 assert_eq!(csr_array.get(0, 1), 1.0);
587 assert_eq!(csr_array.get(1, 0), 1.0);
588 assert_eq!(csr_array.get(1, 1), 2.0);
589 assert_eq!(csr_array.get(1, 2), 3.0);
590 assert_eq!(csr_array.get(2, 1), 3.0);
591 assert_eq!(csr_array.get(2, 2), 1.0);
592 assert_eq!(csr_array.get(0, 2), 0.0);
593 assert_eq!(csr_array.get(2, 0), 0.0);
594 }
595}