scirs2_sparse/
lib.rs

1#![allow(deprecated)]
2#![allow(clippy::manual_div_ceil)]
3#![allow(clippy::needless_return)]
4#![allow(clippy::manual_ok_err)]
5#![allow(clippy::needless_range_loop)]
6#![allow(clippy::while_let_loop)]
7#![allow(clippy::vec_init_then_push)]
8#![allow(clippy::should_implement_trait)]
9#![allow(clippy::only_used_in_recursion)]
10#![allow(clippy::manual_slice_fill)]
11#![allow(dead_code)]
12//! # SciRS2 Sparse - Sparse Matrix Operations
13//!
14//! **scirs2-sparse** provides comprehensive sparse matrix formats and operations modeled after SciPy's
15//! `sparse` module, offering CSR, CSC, COO, DOK, LIL, DIA, BSR formats with efficient algorithms
16//! for large-scale sparse linear algebra, eigenvalue problems, and graph operations.
17//!
18//! ## 🎯 Key Features
19//!
20//! - **SciPy Compatibility**: Drop-in replacement for `scipy.sparse` classes
21//! - **Multiple Formats**: CSR, CSC, COO, DOK, LIL, DIA, BSR with easy conversion
22//! - **Efficient Operations**: Sparse matrix-vector/matrix multiplication
23//! - **Linear Solvers**: Direct (LU, Cholesky) and iterative (CG, GMRES) solvers
24//! - **Eigenvalue Solvers**: ARPACK-based sparse eigenvalue computation
25//! - **Array API**: Modern NumPy-compatible array interface (recommended)
26//!
27//! ## 📦 Module Overview
28//!
29//! | SciRS2 Format | SciPy Equivalent | Description |
30//! |---------------|------------------|-------------|
31//! | `CsrArray` | `scipy.sparse.csr_array` | Compressed Sparse Row (efficient row slicing) |
32//! | `CscArray` | `scipy.sparse.csc_array` | Compressed Sparse Column (efficient column slicing) |
33//! | `CooArray` | `scipy.sparse.coo_array` | Coordinate format (efficient construction) |
34//! | `DokArray` | `scipy.sparse.dok_array` | Dictionary of Keys (efficient element access) |
35//! | `LilArray` | `scipy.sparse.lil_array` | List of Lists (efficient incremental construction) |
36//! | `DiaArray` | `scipy.sparse.dia_array` | Diagonal format (efficient banded matrices) |
37//! | `BsrArray` | `scipy.sparse.bsr_array` | Block Sparse Row (efficient block operations) |
38//!
39//! ## 🚀 Quick Start
40//!
41//! ```toml
42//! [dependencies]
43//! scirs2-sparse = "0.1.0-rc.1"
44//! ```
45//!
46//! ```rust
47//! use scirs2_sparse::csr_array::CsrArray;
48//!
49//! // Create sparse matrix from triplets (row, col, value)
50//! let rows = vec![0, 0, 1, 2, 2];
51//! let cols = vec![0, 2, 2, 0, 1];
52//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
53//! let sparse = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
54//! ```
55//!
56//! ## 🔒 Version: 0.1.0-rc.1 (October 03, 2025)
57//!
58//! ## Matrix vs. Array API
59//!
60//! This module provides both a matrix-based API and an array-based API,
61//! following SciPy's transition to a more NumPy-compatible array interface.
62//!
63//! When using the array interface (e.g., `CsrArray`), please note that:
64//!
65//! - `*` performs element-wise multiplication, not matrix multiplication
66//! - Use `dot()` method for matrix multiplication
67//! - Operations like `sum` produce arrays, not matrices
68//! - Array-style slicing operations return scalars, 1D, or 2D arrays
69//!
70//! For new code, we recommend using the array interface, which is more consistent
71//! with the rest of the numerical ecosystem.
72//!
73//! ## Examples
74//!
75//! ### Matrix API (Legacy)
76//!
77//! ```
78//! use scirs2_sparse::csr::CsrMatrix;
79//!
80//! // Create a sparse matrix in CSR format
81//! let rows = vec![0, 0, 1, 2, 2];
82//! let cols = vec![0, 2, 2, 0, 1];
83//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
84//! let shape = (3, 3);
85//!
86//! let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
87//! ```
88//!
89//! ### Array API (Recommended)
90//!
91//! ```
92//! use scirs2_sparse::csr_array::CsrArray;
93//!
94//! // Create a sparse array in CSR format
95//! let rows = vec![0, 0, 1, 2, 2];
96//! let cols = vec![0, 2, 2, 0, 1];
97//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
98//! let shape = (3, 3);
99//!
100//! // From triplets (COO-like construction)
101//! let array = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
102//!
103//! // Or directly from CSR components
104//! // let array = CsrArray::new(...);
105//! ```
106
107// Export error types
108pub mod error;
109pub use error::{SparseError, SparseResult};
110
111// Base trait for sparse arrays
112pub mod sparray;
113pub use sparray::{is_sparse, SparseArray, SparseSum};
114
115// Trait for symmetric sparse arrays
116pub mod sym_sparray;
117pub use sym_sparray::SymSparseArray;
118
119// No spatial module in sparse
120
121// Array API (recommended)
122pub mod csr_array;
123pub use csr_array::CsrArray;
124
125pub mod csc_array;
126pub use csc_array::CscArray;
127
128pub mod coo_array;
129pub use coo_array::CooArray;
130
131pub mod dok_array;
132pub use dok_array::DokArray;
133
134pub mod lil_array;
135pub use lil_array::LilArray;
136
137pub mod dia_array;
138pub use dia_array::DiaArray;
139
140pub mod bsr_array;
141pub use bsr_array::BsrArray;
142
143pub mod banded_array;
144pub use banded_array::BandedArray;
145
146// Symmetric array formats
147pub mod sym_csr;
148pub use sym_csr::{SymCsrArray, SymCsrMatrix};
149
150pub mod sym_coo;
151pub use sym_coo::{SymCooArray, SymCooMatrix};
152
153// Legacy matrix formats
154pub mod csr;
155pub use csr::CsrMatrix;
156
157pub mod csc;
158pub use csc::CscMatrix;
159
160pub mod coo;
161pub use coo::CooMatrix;
162
163pub mod dok;
164pub use dok::DokMatrix;
165
166pub mod lil;
167pub use lil::LilMatrix;
168
169pub mod dia;
170pub use dia::DiaMatrix;
171
172pub mod bsr;
173pub use bsr::BsrMatrix;
174
175pub mod banded;
176pub use banded::BandedMatrix;
177
178// Utility functions
179pub mod utils;
180
181// Linear algebra with sparse matrices
182pub mod linalg;
183// Re-export the main functions from the reorganized linalg module
184pub use linalg::{
185    // Functions from solvers
186    add,
187    // Functions from iterative
188    bicg,
189    bicgstab,
190    cg,
191    cholesky_decomposition,
192    // Enhanced operators
193    convolution_operator,
194    diag_matrix,
195    eigs,
196    eigsh,
197    enhanced_add,
198    enhanced_diagonal,
199    enhanced_scale,
200    enhanced_subtract,
201    expm,
202    // Functions from matfuncs
203    expm_multiply,
204    eye,
205    finite_difference_operator,
206    // GCROT solver
207    gcrot,
208    gmres,
209    incomplete_cholesky,
210    incomplete_lu,
211    inv,
212    lanczos,
213    // Decomposition functions
214    lu_decomposition,
215    matmul,
216    matrix_power,
217    multiply,
218    norm,
219    onenormest,
220    // Eigenvalue functions
221    power_iteration,
222    qr_decomposition,
223    sparse_direct_solve,
224    sparse_lstsq,
225    spsolve,
226    svd_truncated,
227    // SVD functions
228    svds,
229    // TFQMR solver
230    tfqmr,
231    ArpackOptions,
232    // Interfaces
233    AsLinearOperator,
234    // Types from iterative
235    BiCGOptions,
236    BiCGSTABOptions,
237    BiCGSTABResult,
238    // Enhanced operator types
239    BoundaryCondition,
240    CGOptions,
241    CGSOptions,
242    CGSResult,
243    CholeskyResult,
244    ConvolutionMode,
245    ConvolutionOperator,
246    // Operator types
247    DiagonalOperator,
248    EigenResult,
249    EigenvalueMethod,
250    EnhancedDiagonalOperator,
251    EnhancedDifferenceOperator,
252    EnhancedOperatorOptions,
253    EnhancedScaledOperator,
254    EnhancedSumOperator,
255    FiniteDifferenceOperator,
256    GCROTOptions,
257    GCROTResult,
258    GMRESOptions,
259    ICOptions,
260    // Preconditioners
261    ILU0Preconditioner,
262    ILUOptions,
263    IdentityOperator,
264    IterationResult,
265    JacobiPreconditioner,
266    // Decomposition types
267    LUResult,
268    LanczosOptions,
269    LinearOperator,
270    // Eigenvalue types
271    PowerIterationOptions,
272    QRResult,
273    SSORPreconditioner,
274    // SVD types
275    SVDOptions,
276    SVDResult,
277    ScaledIdentityOperator,
278    TFQMROptions,
279    TFQMRResult,
280};
281
282// Format conversions
283pub mod convert;
284
285// Construction utilities
286pub mod construct;
287pub mod construct_sym;
288
289// Combining arrays
290pub mod combine;
291pub use combine::{block_diag, bmat, hstack, kron, kronsum, tril, triu, vstack};
292
293// Index dtype handling utilities
294pub mod index_dtype;
295pub use index_dtype::{can_cast_safely, get_index_dtype, safely_cast_index_arrays};
296
297// Optimized operations for symmetric sparse formats
298pub mod sym_ops;
299pub use sym_ops::{
300    sym_coo_matvec, sym_csr_matvec, sym_csr_quadratic_form, sym_csr_rank1_update, sym_csr_trace,
301};
302
303// GPU-accelerated operations
304pub mod gpu;
305pub mod gpu_kernel_execution;
306pub mod gpu_ops;
307pub mod gpu_spmv_implementation;
308pub use gpu_kernel_execution::{
309    calculate_adaptive_workgroup_size, execute_spmv_kernel, execute_symmetric_spmv_kernel,
310    execute_triangular_solve_kernel, GpuKernelConfig, GpuMemoryManager as GpuKernelMemoryManager,
311    GpuPerformanceProfiler, MemoryStrategy,
312};
313pub use gpu_ops::{
314    gpu_sparse_matvec, gpu_sym_sparse_matvec, AdvancedGpuOps, GpuKernelScheduler, GpuMemoryManager,
315    GpuOptions, GpuProfiler, OptimizedGpuOps,
316};
317pub use gpu_spmv_implementation::GpuSpMV;
318
319// Memory-efficient algorithms and patterns
320pub mod memory_efficient;
321pub use memory_efficient::{
322    streaming_sparse_matvec, CacheAwareOps, MemoryPool, MemoryTracker, OutOfCoreProcessor,
323};
324
325// SIMD-accelerated operations
326pub mod simd_ops;
327pub use simd_ops::{
328    simd_csr_matvec, simd_sparse_elementwise, simd_sparse_linear_combination, simd_sparse_matmul,
329    simd_sparse_norm, simd_sparse_scale, simd_sparse_transpose, ElementwiseOp, SimdOptions,
330};
331
332// Parallel vector operations for iterative solvers
333pub mod parallel_vector_ops;
334pub use parallel_vector_ops::{
335    advanced_sparse_matvec_csr, parallel_axpy, parallel_dot, parallel_linear_combination,
336    parallel_norm2, parallel_sparse_matvec_csr, parallel_vector_add, parallel_vector_copy,
337    parallel_vector_scale, parallel_vector_sub, ParallelVectorOptions,
338};
339
340// Quantum-inspired sparse matrix operations (Advanced mode)
341pub mod quantum_inspired_sparse;
342pub use quantum_inspired_sparse::{
343    QuantumProcessorStats, QuantumSparseConfig, QuantumSparseProcessor, QuantumStrategy,
344};
345
346// Neural-adaptive sparse matrix operations (Advanced mode)
347pub mod neural_adaptive_sparse;
348pub use neural_adaptive_sparse::{
349    NeuralAdaptiveConfig, NeuralAdaptiveSparseProcessor, NeuralProcessorStats, OptimizationStrategy,
350};
351
352// Quantum-Neural hybrid optimization (Advanced mode)
353pub mod quantum_neural_hybrid;
354pub use quantum_neural_hybrid::{
355    HybridStrategy, QuantumNeuralConfig, QuantumNeuralHybridProcessor, QuantumNeuralHybridStats,
356};
357
358// Adaptive memory compression for advanced-large sparse matrices (Advanced mode)
359pub mod adaptive_memory_compression;
360pub use adaptive_memory_compression::{
361    AdaptiveCompressionConfig, AdaptiveMemoryCompressor, CompressedMatrix, CompressionAlgorithm,
362    MemoryStats,
363};
364
365// Real-time performance monitoring and adaptation (Advanced mode)
366pub mod realtime_performance_monitor;
367pub use realtime_performance_monitor::{
368    Alert, AlertSeverity, PerformanceMonitorConfig, PerformanceSample, ProcessorType,
369    RealTimePerformanceMonitor,
370};
371
372// Compressed sparse graph algorithms
373pub mod csgraph;
374pub use csgraph::{
375    all_pairs_shortest_path,
376    bellman_ford_single_source,
377    bfs_distances,
378    // Traversal algorithms
379    breadth_first_search,
380    compute_laplacianmatrix,
381    connected_components,
382    degree_matrix,
383    depth_first_search,
384    dijkstra_single_source,
385    floyd_warshall,
386    has_path,
387    is_connected,
388    is_laplacian,
389    is_spanning_tree,
390    // Minimum spanning trees
391    kruskal_mst,
392    // Laplacian matrices
393    laplacian,
394    largest_component,
395    minimum_spanning_tree,
396    num_edges,
397    num_vertices,
398    prim_mst,
399    reachable_vertices,
400    reconstruct_path,
401    // Graph algorithms
402    shortest_path,
403    // Shortest path algorithms
404    single_source_shortest_path,
405    spanning_tree_weight,
406    strongly_connected_components,
407    to_adjacency_list,
408    topological_sort,
409    traversegraph,
410    // Connected components
411    undirected_connected_components,
412    // Graph utilities
413    validate_graph,
414    weakly_connected_components,
415    LaplacianType,
416    MSTAlgorithm,
417    // Enums and types
418    ShortestPathMethod,
419    TraversalOrder,
420};
421
422// Re-export warnings from scipy for compatibility
423pub struct SparseEfficiencyWarning;
424pub struct SparseWarning;
425
426/// Check if an object is a sparse array
427#[allow(dead_code)]
428pub fn is_sparse_array<T>(obj: &dyn SparseArray<T>) -> bool
429where
430    T: scirs2_core::numeric::Float
431        + std::fmt::Debug
432        + Copy
433        + std::ops::Add<Output = T>
434        + std::ops::Sub<Output = T>
435        + std::ops::Mul<Output = T>
436        + std::ops::Div<Output = T>
437        + 'static,
438{
439    sparray::is_sparse(obj)
440}
441
442/// Check if an object is a symmetric sparse array
443#[allow(dead_code)]
444pub fn is_sym_sparse_array<T>(obj: &dyn SymSparseArray<T>) -> bool
445where
446    T: scirs2_core::numeric::Float
447        + std::fmt::Debug
448        + Copy
449        + std::ops::Add<Output = T>
450        + std::ops::Sub<Output = T>
451        + std::ops::Mul<Output = T>
452        + std::ops::Div<Output = T>
453        + 'static,
454{
455    obj.is_symmetric()
456}
457
458/// Check if an object is a sparse matrix (legacy API)
459#[allow(dead_code)]
460pub fn is_sparse_matrix(obj: &dyn std::any::Any) -> bool {
461    obj.is::<CsrMatrix<f64>>()
462        || obj.is::<CscMatrix<f64>>()
463        || obj.is::<CooMatrix<f64>>()
464        || obj.is::<DokMatrix<f64>>()
465        || obj.is::<LilMatrix<f64>>()
466        || obj.is::<DiaMatrix<f64>>()
467        || obj.is::<BsrMatrix<f64>>()
468        || obj.is::<SymCsrMatrix<f64>>()
469        || obj.is::<SymCooMatrix<f64>>()
470        || obj.is::<CsrMatrix<f32>>()
471        || obj.is::<CscMatrix<f32>>()
472        || obj.is::<CooMatrix<f32>>()
473        || obj.is::<DokMatrix<f32>>()
474        || obj.is::<LilMatrix<f32>>()
475        || obj.is::<DiaMatrix<f32>>()
476        || obj.is::<BsrMatrix<f32>>()
477        || obj.is::<SymCsrMatrix<f32>>()
478        || obj.is::<SymCooMatrix<f32>>()
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use approx::assert_relative_eq;
485
486    #[test]
487    fn test_csr_array() {
488        let rows = vec![0, 0, 1, 2, 2];
489        let cols = vec![0, 2, 2, 0, 1];
490        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
491        let shape = (3, 3);
492
493        let array = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
494
495        assert_eq!(array.shape(), (3, 3));
496        assert_eq!(array.nnz(), 5);
497        assert!(is_sparse_array(&array));
498    }
499
500    #[test]
501    fn test_coo_array() {
502        let rows = vec![0, 0, 1, 2, 2];
503        let cols = vec![0, 2, 2, 0, 1];
504        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
505        let shape = (3, 3);
506
507        let array = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
508
509        assert_eq!(array.shape(), (3, 3));
510        assert_eq!(array.nnz(), 5);
511        assert!(is_sparse_array(&array));
512    }
513
514    #[test]
515    fn test_dok_array() {
516        let rows = vec![0, 0, 1, 2, 2];
517        let cols = vec![0, 2, 2, 0, 1];
518        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
519        let shape = (3, 3);
520
521        let array = DokArray::from_triplets(&rows, &cols, &data, shape).unwrap();
522
523        assert_eq!(array.shape(), (3, 3));
524        assert_eq!(array.nnz(), 5);
525        assert!(is_sparse_array(&array));
526
527        // Test setting and getting values
528        let mut array = DokArray::<f64>::new((2, 2));
529        array.set(0, 0, 1.0).unwrap();
530        array.set(1, 1, 2.0).unwrap();
531
532        assert_eq!(array.get(0, 0), 1.0);
533        assert_eq!(array.get(0, 1), 0.0);
534        assert_eq!(array.get(1, 1), 2.0);
535
536        // Test removing zeros
537        array.set(0, 0, 0.0).unwrap();
538        assert_eq!(array.nnz(), 1);
539    }
540
541    #[test]
542    fn test_lil_array() {
543        let rows = vec![0, 0, 1, 2, 2];
544        let cols = vec![0, 2, 2, 0, 1];
545        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
546        let shape = (3, 3);
547
548        let array = LilArray::from_triplets(&rows, &cols, &data, shape).unwrap();
549
550        assert_eq!(array.shape(), (3, 3));
551        assert_eq!(array.nnz(), 5);
552        assert!(is_sparse_array(&array));
553
554        // Test setting and getting values
555        let mut array = LilArray::<f64>::new((2, 2));
556        array.set(0, 0, 1.0).unwrap();
557        array.set(1, 1, 2.0).unwrap();
558
559        assert_eq!(array.get(0, 0), 1.0);
560        assert_eq!(array.get(0, 1), 0.0);
561        assert_eq!(array.get(1, 1), 2.0);
562
563        // Test sorted indices
564        assert!(array.has_sorted_indices());
565
566        // Test removing zeros
567        array.set(0, 0, 0.0).unwrap();
568        assert_eq!(array.nnz(), 1);
569    }
570
571    #[test]
572    fn test_dia_array() {
573        use scirs2_core::ndarray::Array1;
574
575        // Create a 3x3 diagonal matrix with main diagonal + upper diagonal
576        let data = vec![
577            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
578            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
579        ];
580        let offsets = vec![0, 1]; // Main diagonal and k=1
581        let shape = (3, 3);
582
583        let array = DiaArray::new(data, offsets, shape).unwrap();
584
585        assert_eq!(array.shape(), (3, 3));
586        assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
587        assert!(is_sparse_array(&array));
588
589        // Test values
590        assert_eq!(array.get(0, 0), 1.0);
591        assert_eq!(array.get(1, 1), 2.0);
592        assert_eq!(array.get(2, 2), 3.0);
593        assert_eq!(array.get(0, 1), 4.0);
594        assert_eq!(array.get(1, 2), 5.0);
595        assert_eq!(array.get(0, 2), 0.0);
596
597        // Test from_triplets
598        let rows = vec![0, 0, 1, 1, 2];
599        let cols = vec![0, 1, 1, 2, 2];
600        let data_vec = vec![1.0, 4.0, 2.0, 5.0, 3.0];
601
602        let array2 = DiaArray::from_triplets(&rows, &cols, &data_vec, shape).unwrap();
603
604        // Should have same values
605        assert_eq!(array2.get(0, 0), 1.0);
606        assert_eq!(array2.get(1, 1), 2.0);
607        assert_eq!(array2.get(2, 2), 3.0);
608        assert_eq!(array2.get(0, 1), 4.0);
609        assert_eq!(array2.get(1, 2), 5.0);
610
611        // Test conversion to other formats
612        let csr = array.to_csr().unwrap();
613        assert_eq!(csr.nnz(), 5);
614        assert_eq!(csr.get(0, 0), 1.0);
615        assert_eq!(csr.get(0, 1), 4.0);
616    }
617
618    #[test]
619    fn test_format_conversions() {
620        let rows = vec![0, 0, 1, 2, 2];
621        let cols = vec![0, 2, 1, 0, 2];
622        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
623        let shape = (3, 3);
624
625        // Create a COO array
626        let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
627
628        // Convert to CSR
629        let csr = coo.to_csr().unwrap();
630
631        // Check values are preserved
632        let coo_dense = coo.to_array();
633        let csr_dense = csr.to_array();
634
635        for i in 0..shape.0 {
636            for j in 0..shape.1 {
637                assert_relative_eq!(coo_dense[[i, j]], csr_dense[[i, j]]);
638            }
639        }
640    }
641
642    #[test]
643    fn test_dot_product() {
644        let rows = vec![0, 0, 1, 2, 2];
645        let cols = vec![0, 2, 1, 0, 2];
646        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
647        let shape = (3, 3);
648
649        // Create arrays in different formats
650        let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
651        let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
652
653        // Compute dot product (matrix multiplication)
654        let coo_result = coo.dot(&coo).unwrap();
655        let csr_result = csr.dot(&csr).unwrap();
656
657        // Check results match
658        let coo_dense = coo_result.to_array();
659        let csr_dense = csr_result.to_array();
660
661        for i in 0..shape.0 {
662            for j in 0..shape.1 {
663                assert_relative_eq!(coo_dense[[i, j]], csr_dense[[i, j]], epsilon = 1e-10);
664            }
665        }
666    }
667
668    #[test]
669    fn test_sym_csr_array() {
670        // Create a symmetric matrix
671        let data = vec![2.0, 1.0, 2.0, 3.0, 0.0, 3.0, 1.0];
672        let indices = vec![0, 0, 1, 2, 0, 1, 2];
673        let indptr = vec![0, 1, 3, 7];
674
675        let sym_matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
676        let sym_array = SymCsrArray::new(sym_matrix);
677
678        assert_eq!(sym_array.shape(), (3, 3));
679        assert!(is_sym_sparse_array(&sym_array));
680
681        // Check values
682        assert_eq!(SparseArray::get(&sym_array, 0, 0), 2.0);
683        assert_eq!(SparseArray::get(&sym_array, 0, 1), 1.0);
684        assert_eq!(SparseArray::get(&sym_array, 1, 0), 1.0); // Symmetric element
685        assert_eq!(SparseArray::get(&sym_array, 1, 2), 3.0);
686        assert_eq!(SparseArray::get(&sym_array, 2, 1), 3.0); // Symmetric element
687
688        // Convert to standard CSR
689        let csr = SymSparseArray::to_csr(&sym_array).unwrap();
690        assert_eq!(csr.nnz(), 10); // Full matrix with symmetric elements
691    }
692
693    #[test]
694    fn test_sym_coo_array() {
695        // Create a symmetric matrix in COO format
696        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
697        let rows = vec![0, 1, 1, 2, 2];
698        let cols = vec![0, 0, 1, 1, 2];
699
700        let sym_matrix = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
701        let sym_array = SymCooArray::new(sym_matrix);
702
703        assert_eq!(sym_array.shape(), (3, 3));
704        assert!(is_sym_sparse_array(&sym_array));
705
706        // Check values
707        assert_eq!(SparseArray::get(&sym_array, 0, 0), 2.0);
708        assert_eq!(SparseArray::get(&sym_array, 0, 1), 1.0);
709        assert_eq!(SparseArray::get(&sym_array, 1, 0), 1.0); // Symmetric element
710        assert_eq!(SparseArray::get(&sym_array, 1, 2), 3.0);
711        assert_eq!(SparseArray::get(&sym_array, 2, 1), 3.0); // Symmetric element
712
713        // Test from_triplets with enforce symmetry
714        // Input is intentionally asymmetric - will be fixed by enforce_symmetric=true
715        let rows2 = vec![0, 0, 1, 1, 2, 1, 0];
716        let cols2 = vec![0, 1, 1, 2, 2, 0, 2];
717        let data2 = vec![2.0, 1.5, 2.0, 3.5, 1.0, 0.5, 0.0];
718
719        let sym_array2 = SymCooArray::from_triplets(&rows2, &cols2, &data2, (3, 3), true).unwrap();
720
721        // Should average the asymmetric values
722        assert_eq!(SparseArray::get(&sym_array2, 0, 1), 1.0); // Average of 1.5 and 0.5
723        assert_eq!(SparseArray::get(&sym_array2, 1, 0), 1.0); // Symmetric element
724        assert_eq!(SparseArray::get(&sym_array2, 0, 2), 0.0); // Zero element
725    }
726
727    #[test]
728    fn test_construct_sym_utils() {
729        // Test creating an identity matrix
730        let eye = construct_sym::eye_sym_array::<f64>(3, "csr").unwrap();
731
732        assert_eq!(eye.shape(), (3, 3));
733        assert_eq!(SparseArray::get(&*eye, 0, 0), 1.0);
734        assert_eq!(SparseArray::get(&*eye, 1, 1), 1.0);
735        assert_eq!(SparseArray::get(&*eye, 2, 2), 1.0);
736        assert_eq!(SparseArray::get(&*eye, 0, 1), 0.0);
737
738        // Test creating a tridiagonal matrix - with coo format since csr had issues
739        let diag = vec![2.0, 2.0, 2.0];
740        let offdiag = vec![1.0, 1.0];
741
742        let tri = construct_sym::tridiagonal_sym_array(&diag, &offdiag, "coo").unwrap();
743
744        assert_eq!(tri.shape(), (3, 3));
745        assert_eq!(SparseArray::get(&*tri, 0, 0), 2.0); // Main diagonal
746        assert_eq!(SparseArray::get(&*tri, 1, 1), 2.0);
747        assert_eq!(SparseArray::get(&*tri, 2, 2), 2.0);
748        assert_eq!(SparseArray::get(&*tri, 0, 1), 1.0); // Off-diagonal
749        assert_eq!(SparseArray::get(&*tri, 1, 0), 1.0); // Symmetric element
750        assert_eq!(SparseArray::get(&*tri, 1, 2), 1.0);
751        assert_eq!(SparseArray::get(&*tri, 0, 2), 0.0); // Zero element
752
753        // Test creating a banded matrix
754        let diagonals = vec![
755            vec![2.0, 2.0, 2.0, 2.0, 2.0], // Main diagonal
756            vec![1.0, 1.0, 1.0, 1.0],      // First off-diagonal
757            vec![0.5, 0.5, 0.5],           // Second off-diagonal
758        ];
759
760        let band = construct_sym::banded_sym_array(&diagonals, 5, "csr").unwrap();
761
762        assert_eq!(band.shape(), (5, 5));
763        assert_eq!(SparseArray::get(&*band, 0, 0), 2.0);
764        assert_eq!(SparseArray::get(&*band, 0, 1), 1.0);
765        assert_eq!(SparseArray::get(&*band, 0, 2), 0.5);
766        assert_eq!(SparseArray::get(&*band, 2, 0), 0.5); // Symmetric element
767    }
768
769    #[test]
770    fn test_sym_conversions() {
771        // Create a symmetric matrix
772        // Lower triangular part only
773        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
774        let rows = vec![0, 1, 1, 2, 2];
775        let cols = vec![0, 0, 1, 1, 2];
776
777        let sym_coo = SymCooArray::from_triplets(&rows, &cols, &data, (3, 3), true).unwrap();
778
779        // Convert to symmetric CSR
780        let sym_csr = sym_coo.to_sym_csr().unwrap();
781
782        // Check values are preserved
783        for i in 0..3 {
784            for j in 0..3 {
785                assert_eq!(
786                    SparseArray::get(&sym_coo, i, j),
787                    SparseArray::get(&sym_csr, i, j)
788                );
789            }
790        }
791
792        // Convert to standard formats
793        let csr = SymSparseArray::to_csr(&sym_coo).unwrap();
794        let coo = SymSparseArray::to_coo(&sym_csr).unwrap();
795
796        // Check full symmetric matrix in standard formats
797        assert_eq!(csr.nnz(), 7); // Accounts for symmetric pairs
798        assert_eq!(coo.nnz(), 7);
799
800        for i in 0..3 {
801            for j in 0..3 {
802                assert_eq!(SparseArray::get(&csr, i, j), SparseArray::get(&coo, i, j));
803                assert_eq!(
804                    SparseArray::get(&csr, i, j),
805                    SparseArray::get(&sym_csr, i, j)
806                );
807            }
808        }
809    }
810}