scirs2_sparse/
lib.rs

1//! Sparse module
2//!
3//! This module provides implementations of various sparse matrix and array formats and operations,
4//! similar to SciPy's `sparse` module.
5//!
6//! ## Overview
7//!
8//! * Sparse formats (CSR, CSC, COO, DOK, LIL, DIA, BSR, etc.)
9//! * Specialized sparse formats (Symmetric CSR, Symmetric COO)
10//! * Basic operations (addition, multiplication, etc.)
11//! * Sparse linear system solvers
12//! * Sparse eigenvalue computation
13//! * Conversion between different formats
14//!
15//! ## Matrix vs. Array API
16//!
17//! This module provides both a matrix-based API and an array-based API,
18//! following SciPy's transition to a more NumPy-compatible array interface.
19//!
20//! When using the array interface (e.g., `CsrArray`), please note that:
21//!
22//! - `*` performs element-wise multiplication, not matrix multiplication
23//! - Use `dot()` method for matrix multiplication
24//! - Operations like `sum` produce arrays, not matrices
25//! - Array-style slicing operations return scalars, 1D, or 2D arrays
26//!
27//! For new code, we recommend using the array interface, which is more consistent
28//! with the rest of the numerical ecosystem.
29//!
30//! ## Examples
31//!
32//! ### Matrix API (Legacy)
33//!
34//! ```
35//! use scirs2_sparse::csr::CsrMatrix;
36//!
37//! // Create a sparse matrix in CSR format
38//! let rows = vec![0, 0, 1, 2, 2];
39//! let cols = vec![0, 2, 2, 0, 1];
40//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
41//! let shape = (3, 3);
42//!
43//! let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
44//! ```
45//!
46//! ### Array API (Recommended)
47//!
48//! ```
49//! use scirs2_sparse::csr_array::CsrArray;
50//!
51//! // Create a sparse array in CSR format
52//! let rows = vec![0, 0, 1, 2, 2];
53//! let cols = vec![0, 2, 2, 0, 1];
54//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
55//! let shape = (3, 3);
56//!
57//! // From triplets (COO-like construction)
58//! let array = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
59//!
60//! // Or directly from CSR components
61//! // let array = CsrArray::new(...);
62//! ```
63
64// Export error types
65pub mod error;
66pub use error::{SparseError, SparseResult};
67
68// Base trait for sparse arrays
69pub mod sparray;
70pub use sparray::{is_sparse, SparseArray, SparseSum};
71
72// Trait for symmetric sparse arrays
73pub mod sym_sparray;
74pub use sym_sparray::SymSparseArray;
75
76// No spatial module in sparse
77
78// Array API (recommended)
79pub mod csr_array;
80pub use csr_array::CsrArray;
81
82pub mod csc_array;
83pub use csc_array::CscArray;
84
85pub mod coo_array;
86pub use coo_array::CooArray;
87
88pub mod dok_array;
89pub use dok_array::DokArray;
90
91pub mod lil_array;
92pub use lil_array::LilArray;
93
94pub mod dia_array;
95pub use dia_array::DiaArray;
96
97pub mod bsr_array;
98pub use bsr_array::BsrArray;
99
100// Symmetric array formats
101pub mod sym_csr;
102pub use sym_csr::{SymCsrArray, SymCsrMatrix};
103
104pub mod sym_coo;
105pub use sym_coo::{SymCooArray, SymCooMatrix};
106
107// Legacy matrix formats
108pub mod csr;
109pub use csr::CsrMatrix;
110
111pub mod csc;
112pub use csc::CscMatrix;
113
114pub mod coo;
115pub use coo::CooMatrix;
116
117pub mod dok;
118pub use dok::DokMatrix;
119
120pub mod lil;
121pub use lil::LilMatrix;
122
123pub mod dia;
124pub use dia::DiaMatrix;
125
126pub mod bsr;
127pub use bsr::BsrMatrix;
128
129// Utility functions
130pub mod utils;
131
132// Linear algebra with sparse matrices
133pub mod linalg;
134// Re-export the main functions from the reorganized linalg module
135pub use linalg::{
136    // Functions from solvers
137    add,
138    // Functions from iterative
139    bicg,
140    bicgstab,
141    cg,
142    diag_matrix,
143    expm,
144    // Functions from matfuncs
145    expm_multiply,
146    eye,
147    gmres,
148    inv,
149    matmul,
150    matrix_power,
151    multiply,
152    norm,
153    onenormest,
154    sparse_direct_solve,
155    sparse_lstsq,
156    spsolve,
157    // Interfaces
158    AsLinearOperator,
159    // Types from iterative
160    BiCGOptions,
161    BiCGSTABOptions,
162    BiCGSTABResult,
163    CGOptions,
164    CGSOptions,
165    CGSResult,
166    // Operator types
167    DiagonalOperator,
168    GMRESOptions,
169    // Preconditioners
170    ILU0Preconditioner,
171    IdentityOperator,
172    IterationResult,
173    JacobiPreconditioner,
174    LinearOperator,
175    SSORPreconditioner,
176    ScaledIdentityOperator,
177};
178
179// Format conversions
180pub mod convert;
181
182// Construction utilities
183pub mod construct;
184pub mod construct_sym;
185
186// Combining arrays
187pub mod combine;
188pub use combine::{block_diag, bmat, hstack, kron, kronsum, tril, triu, vstack};
189
190// Index dtype handling utilities
191pub mod index_dtype;
192pub use index_dtype::{can_cast_safely, get_index_dtype, safely_cast_index_arrays};
193
194// Optimized operations for symmetric sparse formats
195pub mod sym_ops;
196pub use sym_ops::{
197    sym_coo_matvec, sym_csr_matvec, sym_csr_quadratic_form, sym_csr_rank1_update, sym_csr_trace,
198};
199
200// Re-export warnings from scipy for compatibility
201pub struct SparseEfficiencyWarning;
202pub struct SparseWarning;
203
204/// Check if an object is a sparse array
205pub fn is_sparse_array<T>(obj: &dyn SparseArray<T>) -> bool
206where
207    T: num_traits::Float
208        + std::fmt::Debug
209        + Copy
210        + std::ops::Add<Output = T>
211        + std::ops::Sub<Output = T>
212        + std::ops::Mul<Output = T>
213        + std::ops::Div<Output = T>
214        + 'static,
215{
216    sparray::is_sparse(obj)
217}
218
219/// Check if an object is a symmetric sparse array
220pub fn is_sym_sparse_array<T>(obj: &dyn SymSparseArray<T>) -> bool
221where
222    T: num_traits::Float
223        + std::fmt::Debug
224        + Copy
225        + std::ops::Add<Output = T>
226        + std::ops::Sub<Output = T>
227        + std::ops::Mul<Output = T>
228        + std::ops::Div<Output = T>
229        + 'static,
230{
231    obj.is_symmetric()
232}
233
234/// Check if an object is a sparse matrix (legacy API)
235pub fn is_sparse_matrix(obj: &dyn std::any::Any) -> bool {
236    obj.is::<CsrMatrix<f64>>()
237        || obj.is::<CscMatrix<f64>>()
238        || obj.is::<CooMatrix<f64>>()
239        || obj.is::<DokMatrix<f64>>()
240        || obj.is::<LilMatrix<f64>>()
241        || obj.is::<DiaMatrix<f64>>()
242        || obj.is::<BsrMatrix<f64>>()
243        || obj.is::<SymCsrMatrix<f64>>()
244        || obj.is::<SymCooMatrix<f64>>()
245        || obj.is::<CsrMatrix<f32>>()
246        || obj.is::<CscMatrix<f32>>()
247        || obj.is::<CooMatrix<f32>>()
248        || obj.is::<DokMatrix<f32>>()
249        || obj.is::<LilMatrix<f32>>()
250        || obj.is::<DiaMatrix<f32>>()
251        || obj.is::<BsrMatrix<f32>>()
252        || obj.is::<SymCsrMatrix<f32>>()
253        || obj.is::<SymCooMatrix<f32>>()
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use approx::assert_relative_eq;
260
261    #[test]
262    fn test_csr_array() {
263        let rows = vec![0, 0, 1, 2, 2];
264        let cols = vec![0, 2, 2, 0, 1];
265        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
266        let shape = (3, 3);
267
268        let array = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
269
270        assert_eq!(array.shape(), (3, 3));
271        assert_eq!(array.nnz(), 5);
272        assert!(is_sparse_array(&array));
273    }
274
275    #[test]
276    fn test_coo_array() {
277        let rows = vec![0, 0, 1, 2, 2];
278        let cols = vec![0, 2, 2, 0, 1];
279        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
280        let shape = (3, 3);
281
282        let array = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
283
284        assert_eq!(array.shape(), (3, 3));
285        assert_eq!(array.nnz(), 5);
286        assert!(is_sparse_array(&array));
287    }
288
289    #[test]
290    fn test_dok_array() {
291        let rows = vec![0, 0, 1, 2, 2];
292        let cols = vec![0, 2, 2, 0, 1];
293        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
294        let shape = (3, 3);
295
296        let array = DokArray::from_triplets(&rows, &cols, &data, shape).unwrap();
297
298        assert_eq!(array.shape(), (3, 3));
299        assert_eq!(array.nnz(), 5);
300        assert!(is_sparse_array(&array));
301
302        // Test setting and getting values
303        let mut array = DokArray::<f64>::new((2, 2));
304        array.set(0, 0, 1.0).unwrap();
305        array.set(1, 1, 2.0).unwrap();
306
307        assert_eq!(array.get(0, 0), 1.0);
308        assert_eq!(array.get(0, 1), 0.0);
309        assert_eq!(array.get(1, 1), 2.0);
310
311        // Test removing zeros
312        array.set(0, 0, 0.0).unwrap();
313        assert_eq!(array.nnz(), 1);
314    }
315
316    #[test]
317    fn test_lil_array() {
318        let rows = vec![0, 0, 1, 2, 2];
319        let cols = vec![0, 2, 2, 0, 1];
320        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
321        let shape = (3, 3);
322
323        let array = LilArray::from_triplets(&rows, &cols, &data, shape).unwrap();
324
325        assert_eq!(array.shape(), (3, 3));
326        assert_eq!(array.nnz(), 5);
327        assert!(is_sparse_array(&array));
328
329        // Test setting and getting values
330        let mut array = LilArray::<f64>::new((2, 2));
331        array.set(0, 0, 1.0).unwrap();
332        array.set(1, 1, 2.0).unwrap();
333
334        assert_eq!(array.get(0, 0), 1.0);
335        assert_eq!(array.get(0, 1), 0.0);
336        assert_eq!(array.get(1, 1), 2.0);
337
338        // Test sorted indices
339        assert!(array.has_sorted_indices());
340
341        // Test removing zeros
342        array.set(0, 0, 0.0).unwrap();
343        assert_eq!(array.nnz(), 1);
344    }
345
346    #[test]
347    fn test_dia_array() {
348        use ndarray::Array1;
349
350        // Create a 3x3 diagonal matrix with main diagonal + upper diagonal
351        let data = vec![
352            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
353            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
354        ];
355        let offsets = vec![0, 1]; // Main diagonal and k=1
356        let shape = (3, 3);
357
358        let array = DiaArray::new(data, offsets, shape).unwrap();
359
360        assert_eq!(array.shape(), (3, 3));
361        assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
362        assert!(is_sparse_array(&array));
363
364        // Test values
365        assert_eq!(array.get(0, 0), 1.0);
366        assert_eq!(array.get(1, 1), 2.0);
367        assert_eq!(array.get(2, 2), 3.0);
368        assert_eq!(array.get(0, 1), 4.0);
369        assert_eq!(array.get(1, 2), 5.0);
370        assert_eq!(array.get(0, 2), 0.0);
371
372        // Test from_triplets
373        let rows = vec![0, 0, 1, 1, 2];
374        let cols = vec![0, 1, 1, 2, 2];
375        let data_vec = vec![1.0, 4.0, 2.0, 5.0, 3.0];
376
377        let array2 = DiaArray::from_triplets(&rows, &cols, &data_vec, shape).unwrap();
378
379        // Should have same values
380        assert_eq!(array2.get(0, 0), 1.0);
381        assert_eq!(array2.get(1, 1), 2.0);
382        assert_eq!(array2.get(2, 2), 3.0);
383        assert_eq!(array2.get(0, 1), 4.0);
384        assert_eq!(array2.get(1, 2), 5.0);
385
386        // Test conversion to other formats
387        let csr = array.to_csr().unwrap();
388        assert_eq!(csr.nnz(), 5);
389        assert_eq!(csr.get(0, 0), 1.0);
390        assert_eq!(csr.get(0, 1), 4.0);
391    }
392
393    #[test]
394    fn test_format_conversions() {
395        let rows = vec![0, 0, 1, 2, 2];
396        let cols = vec![0, 2, 1, 0, 2];
397        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
398        let shape = (3, 3);
399
400        // Create a COO array
401        let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
402
403        // Convert to CSR
404        let csr = coo.to_csr().unwrap();
405
406        // Check values are preserved
407        let coo_dense = coo.to_array();
408        let csr_dense = csr.to_array();
409
410        for i in 0..shape.0 {
411            for j in 0..shape.1 {
412                assert_relative_eq!(coo_dense[[i, j]], csr_dense[[i, j]]);
413            }
414        }
415    }
416
417    #[test]
418    fn test_dot_product() {
419        let rows = vec![0, 0, 1, 2, 2];
420        let cols = vec![0, 2, 1, 0, 2];
421        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
422        let shape = (3, 3);
423
424        // Create arrays in different formats
425        let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
426        let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
427
428        // Compute dot product (matrix multiplication)
429        let coo_result = coo.dot(&coo).unwrap();
430        let csr_result = csr.dot(&csr).unwrap();
431
432        // Check results match
433        let coo_dense = coo_result.to_array();
434        let csr_dense = csr_result.to_array();
435
436        for i in 0..shape.0 {
437            for j in 0..shape.1 {
438                assert_relative_eq!(coo_dense[[i, j]], csr_dense[[i, j]], epsilon = 1e-10);
439            }
440        }
441    }
442
443    #[test]
444    fn test_sym_csr_array() {
445        // Create a symmetric matrix
446        let data = vec![2.0, 1.0, 2.0, 3.0, 0.0, 3.0, 1.0];
447        let indices = vec![0, 0, 1, 2, 0, 1, 2];
448        let indptr = vec![0, 1, 3, 7];
449
450        let sym_matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
451        let sym_array = SymCsrArray::new(sym_matrix);
452
453        assert_eq!(sym_array.shape(), (3, 3));
454        assert!(is_sym_sparse_array(&sym_array));
455
456        // Check values
457        assert_eq!(SparseArray::get(&sym_array, 0, 0), 2.0);
458        assert_eq!(SparseArray::get(&sym_array, 0, 1), 1.0);
459        assert_eq!(SparseArray::get(&sym_array, 1, 0), 1.0); // Symmetric element
460        assert_eq!(SparseArray::get(&sym_array, 1, 2), 3.0);
461        assert_eq!(SparseArray::get(&sym_array, 2, 1), 3.0); // Symmetric element
462
463        // Convert to standard CSR
464        let csr = SymSparseArray::to_csr(&sym_array).unwrap();
465        assert_eq!(csr.nnz(), 10); // Full matrix with symmetric elements
466    }
467
468    #[test]
469    fn test_sym_coo_array() {
470        // Create a symmetric matrix in COO format
471        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
472        let rows = vec![0, 1, 1, 2, 2];
473        let cols = vec![0, 0, 1, 1, 2];
474
475        let sym_matrix = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
476        let sym_array = SymCooArray::new(sym_matrix);
477
478        assert_eq!(sym_array.shape(), (3, 3));
479        assert!(is_sym_sparse_array(&sym_array));
480
481        // Check values
482        assert_eq!(SparseArray::get(&sym_array, 0, 0), 2.0);
483        assert_eq!(SparseArray::get(&sym_array, 0, 1), 1.0);
484        assert_eq!(SparseArray::get(&sym_array, 1, 0), 1.0); // Symmetric element
485        assert_eq!(SparseArray::get(&sym_array, 1, 2), 3.0);
486        assert_eq!(SparseArray::get(&sym_array, 2, 1), 3.0); // Symmetric element
487
488        // Test from_triplets with enforce symmetry
489        // Input is intentionally asymmetric - will be fixed by enforce_symmetric=true
490        let rows2 = vec![0, 0, 1, 1, 2, 1, 0];
491        let cols2 = vec![0, 1, 1, 2, 2, 0, 2];
492        let data2 = vec![2.0, 1.5, 2.0, 3.5, 1.0, 0.5, 0.0];
493
494        let sym_array2 = SymCooArray::from_triplets(&rows2, &cols2, &data2, (3, 3), true).unwrap();
495
496        // Should average the asymmetric values
497        assert_eq!(SparseArray::get(&sym_array2, 0, 1), 1.0); // Average of 1.5 and 0.5
498        assert_eq!(SparseArray::get(&sym_array2, 1, 0), 1.0); // Symmetric element
499        assert_eq!(SparseArray::get(&sym_array2, 0, 2), 0.0); // Zero element
500    }
501
502    #[test]
503    fn test_construct_sym_utils() {
504        // Test creating an identity matrix
505        let eye = construct_sym::eye_sym_array::<f64>(3, "csr").unwrap();
506
507        assert_eq!(eye.shape(), (3, 3));
508        assert_eq!(SparseArray::get(&*eye, 0, 0), 1.0);
509        assert_eq!(SparseArray::get(&*eye, 1, 1), 1.0);
510        assert_eq!(SparseArray::get(&*eye, 2, 2), 1.0);
511        assert_eq!(SparseArray::get(&*eye, 0, 1), 0.0);
512
513        // Test creating a tridiagonal matrix - with coo format since csr had issues
514        let diag = vec![2.0, 2.0, 2.0];
515        let offdiag = vec![1.0, 1.0];
516
517        let tri = construct_sym::tridiagonal_sym_array(&diag, &offdiag, "coo").unwrap();
518
519        assert_eq!(tri.shape(), (3, 3));
520        assert_eq!(SparseArray::get(&*tri, 0, 0), 2.0); // Main diagonal
521        assert_eq!(SparseArray::get(&*tri, 1, 1), 2.0);
522        assert_eq!(SparseArray::get(&*tri, 2, 2), 2.0);
523        assert_eq!(SparseArray::get(&*tri, 0, 1), 1.0); // Off-diagonal
524        assert_eq!(SparseArray::get(&*tri, 1, 0), 1.0); // Symmetric element
525        assert_eq!(SparseArray::get(&*tri, 1, 2), 1.0);
526        assert_eq!(SparseArray::get(&*tri, 0, 2), 0.0); // Zero element
527
528        // Test creating a banded matrix
529        let diagonals = vec![
530            vec![2.0, 2.0, 2.0, 2.0, 2.0], // Main diagonal
531            vec![1.0, 1.0, 1.0, 1.0],      // First off-diagonal
532            vec![0.5, 0.5, 0.5],           // Second off-diagonal
533        ];
534
535        let band = construct_sym::banded_sym_array(&diagonals, 5, "csr").unwrap();
536
537        assert_eq!(band.shape(), (5, 5));
538        assert_eq!(SparseArray::get(&*band, 0, 0), 2.0);
539        assert_eq!(SparseArray::get(&*band, 0, 1), 1.0);
540        assert_eq!(SparseArray::get(&*band, 0, 2), 0.5);
541        assert_eq!(SparseArray::get(&*band, 2, 0), 0.5); // Symmetric element
542    }
543
544    #[test]
545    fn test_sym_conversions() {
546        // Create a symmetric matrix
547        // Lower triangular part only
548        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
549        let rows = vec![0, 1, 1, 2, 2];
550        let cols = vec![0, 0, 1, 1, 2];
551
552        let sym_coo = SymCooArray::from_triplets(&rows, &cols, &data, (3, 3), true).unwrap();
553
554        // Convert to symmetric CSR
555        let sym_csr = sym_coo.to_sym_csr().unwrap();
556
557        // Check values are preserved
558        for i in 0..3 {
559            for j in 0..3 {
560                assert_eq!(
561                    SparseArray::get(&sym_coo, i, j),
562                    SparseArray::get(&sym_csr, i, j)
563                );
564            }
565        }
566
567        // Convert to standard formats
568        let csr = SymSparseArray::to_csr(&sym_coo).unwrap();
569        let coo = SymSparseArray::to_coo(&sym_csr).unwrap();
570
571        // Check full symmetric matrix in standard formats
572        assert_eq!(csr.nnz(), 7); // Accounts for symmetric pairs
573        assert_eq!(coo.nnz(), 7);
574
575        for i in 0..3 {
576            for j in 0..3 {
577                assert_eq!(SparseArray::get(&csr, i, j), SparseArray::get(&coo, i, j));
578                assert_eq!(
579                    SparseArray::get(&csr, i, j),
580                    SparseArray::get(&sym_csr, i, j)
581                );
582            }
583        }
584    }
585}