Skip to main content

sublinear_solver/matrix/
mod.rs

1//! Matrix operations and data structures for sparse linear algebra.
2//!
3//! This module provides efficient implementations of sparse matrix formats
4//! optimized for asymmetric diagonally dominant systems, with support for
5//! both traditional linear algebra operations and graph-based algorithms.
6
7use crate::error::{Result, SolverError};
8use crate::types::{ConditioningInfo, DimensionType, IndexType, Precision, SparsityInfo};
9use alloc::vec::Vec;
10use core::fmt;
11
12pub mod optimized;
13pub mod sparse;
14
15use sparse::*;
16
17// Re-export optimized types for convenience
18pub use optimized::{BufferPool, OptimizedCSRStorage, StreamingMatrix};
19
20/// Trait defining the interface for matrix operations.
21///
22/// This trait abstracts over different matrix storage formats,
23/// allowing algorithms to work with CSR, CSC, or graph adjacency
24/// representations transparently.
25pub trait Matrix: Send + Sync {
26    /// Get the number of rows in the matrix.
27    fn rows(&self) -> DimensionType;
28
29    /// Get the number of columns in the matrix.
30    fn cols(&self) -> DimensionType;
31
32    /// Get a specific matrix element, returning None if it's zero or out of bounds.
33    fn get(&self, row: usize, col: usize) -> Option<Precision>;
34
35    /// Get an iterator over non-zero elements in a specific row.
36    /// Returns (column_index, value) pairs.
37    fn row_iter(&self, row: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_>;
38
39    /// Get an iterator over non-zero elements in a specific column.
40    /// Returns (row_index, value) pairs.
41    fn col_iter(&self, col: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_>;
42
43    /// Perform matrix-vector multiplication: result = A * x
44    fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) -> Result<()>;
45
46    /// Perform matrix-vector multiplication with accumulation: result += A * x
47    fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) -> Result<()>;
48
49    /// Check if the matrix is diagonally dominant.
50    /// A matrix is diagonally dominant if |a_ii| >= Σ_{j≠i} |a_ij| for all i.
51    fn is_diagonally_dominant(&self) -> bool;
52
53    /// Get the diagonal dominance factor (minimum ratio of diagonal to off-diagonal).
54    fn diagonal_dominance_factor(&self) -> Option<Precision>;
55
56    /// Get the number of non-zero elements.
57    fn nnz(&self) -> usize;
58
59    /// Get sparsity pattern information.
60    fn sparsity_info(&self) -> SparsityInfo;
61
62    /// Get matrix conditioning information.
63    fn conditioning_info(&self) -> ConditioningInfo;
64
65    /// Get the storage format name.
66    fn format_name(&self) -> &'static str;
67
68    /// Check if the matrix is square.
69    fn is_square(&self) -> bool {
70        self.rows() == self.cols()
71    }
72
73    /// Get the Frobenius norm of the matrix.
74    fn frobenius_norm(&self) -> Precision {
75        let mut norm_sq = 0.0;
76        for row in 0..self.rows() {
77            for (_, value) in self.row_iter(row) {
78                norm_sq += value * value;
79            }
80        }
81        norm_sq.sqrt()
82    }
83
84    /// Estimate the spectral radius (largest eigenvalue magnitude).
85    /// Uses Gershgorin circle theorem for a conservative estimate.
86    fn spectral_radius_estimate(&self) -> Precision {
87        let mut max_radius: Precision = 0.0;
88        for row in 0..self.rows() {
89            let mut diagonal = 0.0;
90            let mut off_diagonal_sum = 0.0;
91
92            for (col, value) in self.row_iter(row) {
93                if col as usize == row {
94                    diagonal = value.abs();
95                } else {
96                    off_diagonal_sum += value.abs();
97                }
98            }
99
100            max_radius = max_radius.max(diagonal + off_diagonal_sum);
101        }
102        max_radius
103    }
104}
105
106/// Sparse matrix storage formats.
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
109pub enum SparseFormat {
110    /// Compressed Sparse Row format - efficient for row-wise operations
111    CSR,
112    /// Compressed Sparse Column format - efficient for column-wise operations  
113    CSC,
114    /// Coordinate format - efficient for construction and random access
115    COO,
116    /// Graph adjacency list - efficient for graph algorithms
117    GraphAdjacency,
118}
119
120/// Main sparse matrix implementation supporting multiple storage formats.
121#[derive(Debug, Clone)]
122#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
123pub struct SparseMatrix {
124    /// Current storage format
125    format: SparseFormat,
126    /// Matrix dimensions
127    rows: DimensionType,
128    cols: DimensionType,
129    /// Storage implementation
130    storage: SparseStorage,
131}
132
133/// Internal storage implementation for different sparse formats.
134#[derive(Debug, Clone)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136enum SparseStorage {
137    CSR(CSRStorage),
138    CSC(CSCStorage),
139    COO(COOStorage),
140    Graph(GraphStorage),
141}
142
143impl SparseMatrix {
144    /// Create a new sparse matrix from coordinate (triplet) format.
145    ///
146    /// # Arguments
147    /// * `triplets` - Vector of (row, col, value) triplets
148    /// * `rows` - Number of rows
149    /// * `cols` - Number of columns
150    ///
151    /// # Example
152    /// ```
153    /// use sublinear_solver::SparseMatrix;
154    ///
155    /// let matrix = SparseMatrix::from_triplets(
156    ///     vec![(0, 0, 4.0), (0, 1, 1.0), (1, 0, 2.0), (1, 1, 5.0)],
157    ///     2, 2
158    /// ).unwrap();
159    /// ```
160    pub fn from_triplets(
161        triplets: Vec<(usize, usize, Precision)>,
162        rows: DimensionType,
163        cols: DimensionType,
164    ) -> Result<Self> {
165        // Validate input
166        for &(r, c, v) in &triplets {
167            if r >= rows {
168                return Err(SolverError::IndexOutOfBounds {
169                    index: r,
170                    max_index: rows - 1,
171                    context: "row index in triplet".to_string(),
172                });
173            }
174            if c >= cols {
175                return Err(SolverError::IndexOutOfBounds {
176                    index: c,
177                    max_index: cols - 1,
178                    context: "column index in triplet".to_string(),
179                });
180            }
181            if !v.is_finite() {
182                return Err(SolverError::InvalidInput {
183                    message: format!("Non-finite value {} at ({}, {})", v, r, c),
184                    parameter: Some("matrix_element".to_string()),
185                });
186            }
187        }
188
189        // Create COO storage first, then convert to CSR for efficiency
190        let coo_storage = COOStorage::from_triplets(triplets)?;
191        let csr_storage = CSRStorage::from_coo(&coo_storage, rows, cols)?;
192
193        Ok(Self {
194            format: SparseFormat::CSR,
195            rows,
196            cols,
197            storage: SparseStorage::CSR(csr_storage),
198        })
199    }
200
201    /// Create a sparse matrix from dense row-major data.
202    ///
203    /// Zero elements are automatically filtered out.
204    pub fn from_dense(
205        data: &[Precision],
206        rows: DimensionType,
207        cols: DimensionType,
208    ) -> Result<Self> {
209        if data.len() != rows * cols {
210            return Err(SolverError::DimensionMismatch {
211                expected: rows * cols,
212                actual: data.len(),
213                operation: "dense_to_sparse_conversion".to_string(),
214            });
215        }
216
217        let mut triplets = Vec::new();
218        for (i, &value) in data.iter().enumerate() {
219            if value != 0.0 {
220                let row = i / cols;
221                let col = i % cols;
222                triplets.push((row, col, value));
223            }
224        }
225
226        Self::from_triplets(triplets, rows, cols)
227    }
228
229    /// Create an identity matrix of the given size.
230    pub fn identity(size: DimensionType) -> Result<Self> {
231        let triplets: Vec<_> = (0..size).map(|i| (i, i, 1.0)).collect();
232        Self::from_triplets(triplets, size, size)
233    }
234
235    /// Create a diagonal matrix from the given diagonal values.
236    pub fn diagonal(diag: &[Precision]) -> Result<Self> {
237        let size = diag.len();
238        let triplets: Vec<_> = diag
239            .iter()
240            .enumerate()
241            .filter(|(_, &v)| v != 0.0)
242            .map(|(i, &v)| (i, i, v))
243            .collect();
244        Self::from_triplets(triplets, size, size)
245    }
246
247    /// Convert the matrix to a different storage format.
248    ///
249    /// This operation may be expensive for large matrices.
250    pub fn convert_to_format(&mut self, new_format: SparseFormat) -> Result<()> {
251        if self.format == new_format {
252            return Ok(());
253        }
254
255        match (self.format, new_format) {
256            (SparseFormat::CSR, SparseFormat::CSC) => {
257                if let SparseStorage::CSR(ref csr) = self.storage {
258                    let csc = CSCStorage::from_csr(csr, self.rows, self.cols)?;
259                    self.storage = SparseStorage::CSC(csc);
260                    self.format = SparseFormat::CSC;
261                }
262            }
263            (SparseFormat::CSC, SparseFormat::CSR) => {
264                if let SparseStorage::CSC(ref csc) = self.storage {
265                    let csr = CSRStorage::from_csc(csc, self.rows, self.cols)?;
266                    self.storage = SparseStorage::CSR(csr);
267                    self.format = SparseFormat::CSR;
268                }
269            }
270            (_, SparseFormat::GraphAdjacency) => {
271                // Convert any format to graph adjacency
272                let triplets = self.to_triplets()?;
273                let graph = GraphStorage::from_triplets(triplets, self.rows)?;
274                self.storage = SparseStorage::Graph(graph);
275                self.format = SparseFormat::GraphAdjacency;
276            }
277            _ => {
278                // For other conversions, go through COO format
279                let triplets = self.to_triplets()?;
280                let coo = COOStorage::from_triplets(triplets)?;
281
282                match new_format {
283                    SparseFormat::CSR => {
284                        let csr = CSRStorage::from_coo(&coo, self.rows, self.cols)?;
285                        self.storage = SparseStorage::CSR(csr);
286                    }
287                    SparseFormat::CSC => {
288                        let csc = CSCStorage::from_coo(&coo, self.rows, self.cols)?;
289                        self.storage = SparseStorage::CSC(csc);
290                    }
291                    SparseFormat::COO => {
292                        self.storage = SparseStorage::COO(coo);
293                    }
294                    _ => unreachable!(),
295                }
296                self.format = new_format;
297            }
298        }
299
300        Ok(())
301    }
302
303    /// Extract the matrix as coordinate triplets.
304    pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
305        match &self.storage {
306            SparseStorage::CSR(csr) => csr.to_triplets(),
307            SparseStorage::CSC(csc) => csc.to_triplets(),
308            SparseStorage::COO(coo) => Ok(coo.to_triplets()),
309            SparseStorage::Graph(graph) => graph.to_triplets(),
310        }
311    }
312
313    /// Get the current storage format.
314    pub fn format(&self) -> SparseFormat {
315        self.format
316    }
317
318    /// Get a reference to the underlying CSR storage.
319    ///
320    /// Converts to CSR format if necessary.
321    pub fn as_csr(&mut self) -> Result<&CSRStorage> {
322        self.convert_to_format(SparseFormat::CSR)?;
323        match &self.storage {
324            SparseStorage::CSR(csr) => Ok(csr),
325            _ => unreachable!(),
326        }
327    }
328
329    /// Get a reference to the underlying CSC storage.
330    ///
331    /// Converts to CSC format if necessary.
332    pub fn as_csc(&mut self) -> Result<&CSCStorage> {
333        self.convert_to_format(SparseFormat::CSC)?;
334        match &self.storage {
335            SparseStorage::CSC(csc) => Ok(csc),
336            _ => unreachable!(),
337        }
338    }
339
340    /// Get a reference to the underlying graph storage.
341    ///
342    /// Converts to graph format if necessary.
343    pub fn as_graph(&mut self) -> Result<&GraphStorage> {
344        self.convert_to_format(SparseFormat::GraphAdjacency)?;
345        match &self.storage {
346            SparseStorage::Graph(graph) => Ok(graph),
347            _ => unreachable!(),
348        }
349    }
350
351    /// Scale the matrix by a scalar value.
352    pub fn scale(&mut self, factor: Precision) {
353        match &mut self.storage {
354            SparseStorage::CSR(csr) => csr.scale(factor),
355            SparseStorage::CSC(csc) => csc.scale(factor),
356            SparseStorage::COO(coo) => coo.scale(factor),
357            SparseStorage::Graph(graph) => graph.scale(factor),
358        }
359    }
360
361    /// Add a scalar multiple of the identity matrix: A = A + alpha * I
362    pub fn add_diagonal(&mut self, alpha: Precision) -> Result<()> {
363        if !self.is_square() {
364            return Err(SolverError::InvalidInput {
365                message: "Cannot add diagonal to non-square matrix".to_string(),
366                parameter: Some("matrix_dimensions".to_string()),
367            });
368        }
369
370        match &mut self.storage {
371            SparseStorage::CSR(csr) => csr.add_diagonal(alpha),
372            SparseStorage::CSC(csc) => csc.add_diagonal(alpha),
373            SparseStorage::COO(coo) => coo.add_diagonal(alpha, self.rows),
374            SparseStorage::Graph(graph) => graph.add_diagonal(alpha),
375        }
376
377        Ok(())
378    }
379}
380
381impl Matrix for SparseMatrix {
382    fn rows(&self) -> DimensionType {
383        self.rows
384    }
385
386    fn cols(&self) -> DimensionType {
387        self.cols
388    }
389
390    fn get(&self, row: usize, col: usize) -> Option<Precision> {
391        if row >= self.rows || col >= self.cols {
392            return None;
393        }
394
395        match &self.storage {
396            SparseStorage::CSR(csr) => csr.get(row, col),
397            SparseStorage::CSC(csc) => csc.get(row, col),
398            SparseStorage::COO(coo) => coo.get(row, col),
399            SparseStorage::Graph(graph) => graph.get(row, col),
400        }
401    }
402
403    fn row_iter(&self, row: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_> {
404        match &self.storage {
405            SparseStorage::CSR(csr) => Box::new(csr.row_iter(row)),
406            SparseStorage::CSC(csc) => Box::new(csc.row_iter(row)),
407            SparseStorage::COO(coo) => Box::new(coo.row_iter(row)),
408            SparseStorage::Graph(graph) => Box::new(graph.row_iter(row)),
409        }
410    }
411
412    fn col_iter(&self, col: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_> {
413        match &self.storage {
414            SparseStorage::CSR(csr) => Box::new(csr.col_iter(col)),
415            SparseStorage::CSC(csc) => Box::new(csc.col_iter(col)),
416            SparseStorage::COO(coo) => Box::new(coo.col_iter(col)),
417            SparseStorage::Graph(graph) => Box::new(graph.col_iter(col)),
418        }
419    }
420
421    fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) -> Result<()> {
422        if x.len() != self.cols {
423            return Err(SolverError::DimensionMismatch {
424                expected: self.cols,
425                actual: x.len(),
426                operation: "matrix_vector_multiply".to_string(),
427            });
428        }
429        if result.len() != self.rows {
430            return Err(SolverError::DimensionMismatch {
431                expected: self.rows,
432                actual: result.len(),
433                operation: "matrix_vector_multiply".to_string(),
434            });
435        }
436
437        match &self.storage {
438            SparseStorage::CSR(csr) => csr.multiply_vector(x, result),
439            SparseStorage::CSC(csc) => csc.multiply_vector(x, result),
440            SparseStorage::COO(coo) => coo.multiply_vector(x, result),
441            SparseStorage::Graph(graph) => graph.multiply_vector(x, result),
442        }
443
444        Ok(())
445    }
446
447    fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) -> Result<()> {
448        if x.len() != self.cols {
449            return Err(SolverError::DimensionMismatch {
450                expected: self.cols,
451                actual: x.len(),
452                operation: "matrix_vector_multiply_add".to_string(),
453            });
454        }
455        if result.len() != self.rows {
456            return Err(SolverError::DimensionMismatch {
457                expected: self.rows,
458                actual: result.len(),
459                operation: "matrix_vector_multiply_add".to_string(),
460            });
461        }
462
463        match &self.storage {
464            SparseStorage::CSR(csr) => csr.multiply_vector_add(x, result),
465            SparseStorage::CSC(csc) => csc.multiply_vector_add(x, result),
466            SparseStorage::COO(coo) => coo.multiply_vector_add(x, result),
467            SparseStorage::Graph(graph) => graph.multiply_vector_add(x, result),
468        }
469
470        Ok(())
471    }
472
473    fn is_diagonally_dominant(&self) -> bool {
474        for row in 0..self.rows {
475            let mut diagonal = 0.0;
476            let mut off_diagonal_sum = 0.0;
477
478            for (col, value) in self.row_iter(row) {
479                if col as usize == row {
480                    diagonal = value.abs();
481                } else {
482                    off_diagonal_sum += value.abs();
483                }
484            }
485
486            if diagonal < off_diagonal_sum {
487                return false;
488            }
489        }
490        true
491    }
492
493    fn diagonal_dominance_factor(&self) -> Option<Precision> {
494        let mut min_factor = Precision::INFINITY;
495
496        for row in 0..self.rows {
497            let mut diagonal = 0.0;
498            let mut off_diagonal_sum = 0.0;
499
500            for (col, value) in self.row_iter(row) {
501                if col as usize == row {
502                    diagonal = value.abs();
503                } else {
504                    off_diagonal_sum += value.abs();
505                }
506            }
507
508            if off_diagonal_sum > 0.0 {
509                let factor = diagonal / off_diagonal_sum;
510                min_factor = min_factor.min(factor);
511            }
512        }
513
514        if min_factor.is_finite() {
515            Some(min_factor)
516        } else {
517            None
518        }
519    }
520
521    fn nnz(&self) -> usize {
522        match &self.storage {
523            SparseStorage::CSR(csr) => csr.nnz(),
524            SparseStorage::CSC(csc) => csc.nnz(),
525            SparseStorage::COO(coo) => coo.nnz(),
526            SparseStorage::Graph(graph) => graph.nnz(),
527        }
528    }
529
530    fn sparsity_info(&self) -> SparsityInfo {
531        let mut info = SparsityInfo::new(self.nnz(), self.rows, self.cols);
532
533        // Compute additional statistics
534        let mut max_nnz_per_row = 0;
535        for row in 0..self.rows {
536            let row_nnz = self.row_iter(row).count();
537            max_nnz_per_row = max_nnz_per_row.max(row_nnz);
538        }
539        info.max_nnz_per_row = max_nnz_per_row;
540
541        // Check for banded structure (simple heuristic)
542        let mut max_bandwidth = 0;
543        for (r, c, _) in self.to_triplets().unwrap_or_default() {
544            let bandwidth = if r > c { r - c } else { c - r };
545            max_bandwidth = max_bandwidth.max(bandwidth);
546        }
547        info.bandwidth = Some(max_bandwidth);
548        info.is_banded = max_bandwidth < self.rows / 4; // Heuristic: banded if bandwidth < 25% of size
549
550        info
551    }
552
553    fn conditioning_info(&self) -> ConditioningInfo {
554        ConditioningInfo {
555            condition_number: None, // Expensive to compute exactly
556            is_diagonally_dominant: self.is_diagonally_dominant(),
557            diagonal_dominance_factor: self.diagonal_dominance_factor(),
558            spectral_radius: Some(self.spectral_radius_estimate()),
559            is_positive_definite: None, // Expensive to determine
560        }
561    }
562
563    fn format_name(&self) -> &'static str {
564        match self.format {
565            SparseFormat::CSR => "CSR",
566            SparseFormat::CSC => "CSC",
567            SparseFormat::COO => "COO",
568            SparseFormat::GraphAdjacency => "GraphAdjacency",
569        }
570    }
571}
572
573impl fmt::Display for SparseMatrix {
574    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
575        write!(
576            f,
577            "{}x{} sparse matrix ({} format, {} nnz)",
578            self.rows,
579            self.cols,
580            self.format_name(),
581            self.nnz()
582        )
583    }
584}
585
586#[cfg(all(test, feature = "std"))]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_matrix_creation() {
592        let triplets = vec![(0, 0, 4.0), (0, 1, 1.0), (1, 0, 2.0), (1, 1, 5.0)];
593        let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
594
595        assert_eq!(matrix.rows(), 2);
596        assert_eq!(matrix.cols(), 2);
597        assert_eq!(matrix.nnz(), 4);
598        assert!(matrix.is_diagonally_dominant());
599    }
600
601    #[test]
602    fn test_matrix_vector_multiply() {
603        let triplets = vec![(0, 0, 2.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
604        let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
605
606        let x = vec![1.0, 2.0];
607        let mut result = vec![0.0; 2];
608
609        matrix.multiply_vector(&x, &mut result).unwrap();
610
611        assert_eq!(result, vec![4.0, 7.0]); // [2*1 + 1*2, 1*1 + 3*2]
612    }
613
614    #[test]
615    fn test_diagonal_dominance() {
616        // Diagonally dominant matrix
617        let triplets = vec![(0, 0, 5.0), (0, 1, 1.0), (1, 0, 2.0), (1, 1, 7.0)];
618        let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
619        assert!(matrix.is_diagonally_dominant());
620
621        // Not diagonally dominant
622        let triplets = vec![(0, 0, 1.0), (0, 1, 3.0), (1, 0, 2.0), (1, 1, 2.0)];
623        let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
624        assert!(!matrix.is_diagonally_dominant());
625    }
626
627    #[test]
628    fn test_format_conversion() {
629        let triplets = vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 3.0)];
630        let mut matrix = SparseMatrix::from_triplets(triplets, 2, 3).unwrap();
631
632        assert_eq!(matrix.format(), SparseFormat::CSR);
633
634        matrix.convert_to_format(SparseFormat::CSC).unwrap();
635        assert_eq!(matrix.format(), SparseFormat::CSC);
636
637        matrix
638            .convert_to_format(SparseFormat::GraphAdjacency)
639            .unwrap();
640        assert_eq!(matrix.format(), SparseFormat::GraphAdjacency);
641    }
642}