Skip to main content

ruvector_solver/
types.rs

1//! Core types for sparse linear solvers.
2//!
3//! Provides [`CsrMatrix`] for compressed sparse row storage and result types
4//! for solver convergence tracking.
5
6use std::time::Duration;
7
8// ---------------------------------------------------------------------------
9// CsrMatrix<T>
10// ---------------------------------------------------------------------------
11
12/// Compressed Sparse Row (CSR) matrix.
13///
14/// Stores only non-zero entries for efficient sparse matrix-vector
15/// multiplication in O(nnz) time with excellent cache locality.
16///
17/// # Layout
18///
19/// For a matrix with `m` rows and `nnz` non-zeros:
20/// - `row_ptr` has length `m + 1`
21/// - `col_indices` and `values` each have length `nnz`
22/// - Row `i` spans indices `row_ptr[i]..row_ptr[i+1]`
23#[derive(Debug, Clone)]
24pub struct CsrMatrix<T> {
25    /// Row pointers: `row_ptr[i]` is the start index in `col_indices`/`values`
26    /// for row `i`.
27    pub row_ptr: Vec<usize>,
28    /// Column indices for each non-zero entry.
29    pub col_indices: Vec<usize>,
30    /// Values for each non-zero entry.
31    pub values: Vec<T>,
32    /// Number of rows.
33    pub rows: usize,
34    /// Number of columns.
35    pub cols: usize,
36}
37
38impl<T: Copy + Default + std::ops::Mul<Output = T> + std::ops::AddAssign> CsrMatrix<T> {
39    /// Sparse matrix-vector multiply: `y = A * x`.
40    ///
41    /// # Panics
42    ///
43    /// Debug-asserts that `x.len() >= self.cols` and `y.len() >= self.rows`.
44    #[inline]
45    pub fn spmv(&self, x: &[T], y: &mut [T]) {
46        debug_assert!(
47            x.len() >= self.cols,
48            "spmv: x.len()={} < cols={}",
49            x.len(),
50            self.cols,
51        );
52        debug_assert!(
53            y.len() >= self.rows,
54            "spmv: y.len()={} < rows={}",
55            y.len(),
56            self.rows,
57        );
58
59        for i in 0..self.rows {
60            let mut sum = T::default();
61            let start = self.row_ptr[i];
62            let end = self.row_ptr[i + 1];
63
64            for idx in start..end {
65                sum += self.values[idx] * x[self.col_indices[idx]];
66            }
67            y[i] = sum;
68        }
69    }
70}
71
72impl CsrMatrix<f32> {
73    /// High-performance SpMV with bounds-check elimination.
74    ///
75    /// Identical to [`spmv`](Self::spmv) but uses `unsafe` indexing to
76    /// eliminate per-element bounds checks in the inner loop, which is the
77    /// single hottest path in all iterative solvers.
78    ///
79    /// On AArch64 (Apple Silicon), dispatches to NEON-accelerated SpMV
80    /// via [`spmv_simd`](crate::simd::spmv_simd) for ~3x throughput.
81    ///
82    /// # Safety contract
83    ///
84    /// The caller must ensure the CSR structure is valid (use
85    /// [`validate_csr_matrix`](crate::validation::validate_csr_matrix) once
86    /// before entering the solve loop). The `x` and `y` slices must have
87    /// lengths `>= cols` and `>= rows` respectively.
88    #[inline]
89    pub fn spmv_unchecked(&self, x: &[f32], y: &mut [f32]) {
90        debug_assert!(x.len() >= self.cols);
91        debug_assert!(y.len() >= self.rows);
92
93        // Dispatch to NEON/AVX2-accelerated SpMV when available
94        #[cfg(target_arch = "aarch64")]
95        {
96            crate::simd::spmv_simd(self, x, y);
97            return;
98        }
99
100        #[cfg(all(feature = "simd", target_arch = "x86_64"))]
101        {
102            crate::simd::spmv_simd(self, x, y);
103            return;
104        }
105
106        #[allow(unreachable_code)]
107        {
108            let vals = self.values.as_ptr();
109            let cols = self.col_indices.as_ptr();
110            let rp = self.row_ptr.as_ptr();
111
112            for i in 0..self.rows {
113                let start = unsafe { *rp.add(i) };
114                let end = unsafe { *rp.add(i + 1) };
115                let mut sum = 0.0f32;
116
117                for idx in start..end {
118                    unsafe {
119                        let v = *vals.add(idx);
120                        let c = *cols.add(idx);
121                        sum += v * *x.get_unchecked(c);
122                    }
123                }
124                unsafe { *y.get_unchecked_mut(i) = sum };
125            }
126        }
127    }
128
129    /// Fused SpMV + residual computation: computes `r[j] = rhs[j] - (A*x)[j]`
130    /// and returns `||r||^2` in a single pass, avoiding a separate allocation
131    /// for `Ax`.
132    ///
133    /// This eliminates one full memory traversal per iteration compared to
134    /// separate `spmv` + vector subtraction.
135    #[inline]
136    pub fn fused_residual_norm_sq(&self, x: &[f32], rhs: &[f32], residual: &mut [f32]) -> f64 {
137        debug_assert!(x.len() >= self.cols);
138        debug_assert!(rhs.len() >= self.rows);
139        debug_assert!(residual.len() >= self.rows);
140
141        let vals = self.values.as_ptr();
142        let cols = self.col_indices.as_ptr();
143        let rp = self.row_ptr.as_ptr();
144        let mut norm_sq = 0.0f64;
145
146        for i in 0..self.rows {
147            let start = unsafe { *rp.add(i) };
148            let end = unsafe { *rp.add(i + 1) };
149            let mut ax_i = 0.0f32;
150
151            for idx in start..end {
152                unsafe {
153                    let v = *vals.add(idx);
154                    let c = *cols.add(idx);
155                    ax_i += v * *x.get_unchecked(c);
156                }
157            }
158
159            let r_i = rhs[i] - ax_i;
160            residual[i] = r_i;
161            norm_sq += (r_i as f64) * (r_i as f64);
162        }
163
164        norm_sq
165    }
166}
167
168impl CsrMatrix<f64> {
169    /// High-performance SpMV for f64 with bounds-check elimination.
170    ///
171    /// On AArch64, dispatches to NEON-accelerated SpMV for ~2x throughput.
172    #[inline]
173    pub fn spmv_unchecked(&self, x: &[f64], y: &mut [f64]) {
174        debug_assert!(x.len() >= self.cols);
175        debug_assert!(y.len() >= self.rows);
176
177        // Dispatch to NEON/AVX2-accelerated SpMV when available
178        #[cfg(target_arch = "aarch64")]
179        {
180            crate::simd::spmv_simd_f64(self, x, y);
181            return;
182        }
183
184        #[cfg(all(feature = "simd", target_arch = "x86_64"))]
185        {
186            crate::simd::spmv_simd_f64(self, x, y);
187            return;
188        }
189
190        #[allow(unreachable_code)]
191        {
192            let vals = self.values.as_ptr();
193            let cols = self.col_indices.as_ptr();
194            let rp = self.row_ptr.as_ptr();
195
196            for i in 0..self.rows {
197                let start = unsafe { *rp.add(i) };
198                let end = unsafe { *rp.add(i + 1) };
199                let mut sum = 0.0f64;
200
201                for idx in start..end {
202                    unsafe {
203                        let v = *vals.add(idx);
204                        let c = *cols.add(idx);
205                        sum += v * *x.get_unchecked(c);
206                    }
207                }
208                unsafe { *y.get_unchecked_mut(i) = sum };
209            }
210        }
211    }
212}
213
214impl<T> CsrMatrix<T> {
215    /// Number of non-zero entries.
216    #[inline]
217    pub fn nnz(&self) -> usize {
218        self.values.len()
219    }
220
221    /// Number of non-zeros in a specific row (i.e. the row degree for an
222    /// adjacency matrix).
223    #[inline]
224    pub fn row_degree(&self, row: usize) -> usize {
225        self.row_ptr[row + 1] - self.row_ptr[row]
226    }
227
228    /// Iterate over `(col_index, &value)` pairs for the given row.
229    #[inline]
230    pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, &T)> {
231        let start = self.row_ptr[row];
232        let end = self.row_ptr[row + 1];
233        self.col_indices[start..end]
234            .iter()
235            .copied()
236            .zip(self.values[start..end].iter())
237    }
238}
239
240impl<T: Copy + Default> CsrMatrix<T> {
241    /// Transpose: produces `A^T` in CSR form.
242    ///
243    /// Uses a two-pass counting sort in O(nnz + rows + cols) time and
244    /// O(nnz) extra memory. Required by backward push which operates on
245    /// the reversed adjacency structure.
246    pub fn transpose(&self) -> CsrMatrix<T> {
247        let nnz = self.nnz();
248        let t_rows = self.cols;
249        let t_cols = self.rows;
250
251        // Pass 1: count entries per new row (= old column).
252        let mut row_ptr = vec![0usize; t_rows + 1];
253        for &c in &self.col_indices {
254            row_ptr[c + 1] += 1;
255        }
256        for i in 1..=t_rows {
257            row_ptr[i] += row_ptr[i - 1];
258        }
259
260        // Pass 2: scatter entries into the transposed arrays.
261        let mut col_indices = vec![0usize; nnz];
262        let mut values = vec![T::default(); nnz];
263        let mut cursor = row_ptr.clone();
264
265        for row in 0..self.rows {
266            let start = self.row_ptr[row];
267            let end = self.row_ptr[row + 1];
268            for idx in start..end {
269                let c = self.col_indices[idx];
270                let dest = cursor[c];
271                col_indices[dest] = row;
272                values[dest] = self.values[idx];
273                cursor[c] += 1;
274            }
275        }
276
277        CsrMatrix {
278            row_ptr,
279            col_indices,
280            values,
281            rows: t_rows,
282            cols: t_cols,
283        }
284    }
285}
286
287impl<T: Copy + Default + std::ops::AddAssign> CsrMatrix<T> {
288    /// Build a CSR matrix from COO (coordinate) triplets.
289    ///
290    /// Entries are sorted by (row, col) internally. Duplicate positions at the
291    /// same (row, col) are kept as separate entries (caller should pre-merge if
292    /// needed).
293    pub fn from_coo_generic(
294        rows: usize,
295        cols: usize,
296        entries: impl IntoIterator<Item = (usize, usize, T)>,
297    ) -> Self {
298        let mut sorted: Vec<_> = entries.into_iter().collect();
299        sorted.sort_unstable_by_key(|(r, c, _)| (*r, *c));
300
301        let nnz = sorted.len();
302        let mut row_ptr = vec![0usize; rows + 1];
303        let mut col_indices = Vec::with_capacity(nnz);
304        let mut values = Vec::with_capacity(nnz);
305
306        for &(r, _, _) in &sorted {
307            assert!(r < rows, "row index {} out of bounds (rows={})", r, rows);
308            row_ptr[r + 1] += 1;
309        }
310        for i in 1..=rows {
311            row_ptr[i] += row_ptr[i - 1];
312        }
313
314        for (_, c, v) in sorted {
315            assert!(c < cols, "col index {} out of bounds (cols={})", c, cols);
316            col_indices.push(c);
317            values.push(v);
318        }
319
320        Self {
321            row_ptr,
322            col_indices,
323            values,
324            rows,
325            cols,
326        }
327    }
328}
329
330impl CsrMatrix<f32> {
331    /// Build a CSR matrix from COO (coordinate) triplets.
332    ///
333    /// Entries are sorted by (row, col) internally. Duplicate positions are
334    /// summed.
335    pub fn from_coo(
336        rows: usize,
337        cols: usize,
338        entries: impl IntoIterator<Item = (usize, usize, f32)>,
339    ) -> Self {
340        Self::from_coo_generic(rows, cols, entries)
341    }
342
343    /// Build a square identity matrix of dimension `n` in CSR format.
344    pub fn identity(n: usize) -> Self {
345        let row_ptr: Vec<usize> = (0..=n).collect();
346        let col_indices: Vec<usize> = (0..n).collect();
347        let values = vec![1.0f32; n];
348
349        Self {
350            row_ptr,
351            col_indices,
352            values,
353            rows: n,
354            cols: n,
355        }
356    }
357}
358
359impl CsrMatrix<f64> {
360    /// Build a CSR matrix from COO (coordinate) triplets (f64 variant).
361    ///
362    /// Entries are sorted by (row, col) internally.
363    pub fn from_coo(
364        rows: usize,
365        cols: usize,
366        entries: impl IntoIterator<Item = (usize, usize, f64)>,
367    ) -> Self {
368        Self::from_coo_generic(rows, cols, entries)
369    }
370
371    /// Build a square identity matrix of dimension `n` in CSR format (f64).
372    pub fn identity(n: usize) -> Self {
373        let row_ptr: Vec<usize> = (0..=n).collect();
374        let col_indices: Vec<usize> = (0..n).collect();
375        let values = vec![1.0f64; n];
376
377        Self {
378            row_ptr,
379            col_indices,
380            values,
381            rows: n,
382            cols: n,
383        }
384    }
385}
386
387// ---------------------------------------------------------------------------
388// Solver result types
389// ---------------------------------------------------------------------------
390
391/// Algorithm identifier for solver selection and routing.
392///
393/// Each variant corresponds to a solver strategy with different complexity
394/// characteristics and applicability constraints. The [`SolverRouter`] selects
395/// the best algorithm based on the matrix [`SparsityProfile`] and [`QueryType`].
396///
397/// [`SolverRouter`]: crate::router::SolverRouter
398#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
399pub enum Algorithm {
400    /// Neumann series: `x = sum_{k=0}^{K} (I - A)^k * b`.
401    ///
402    /// Requires spectral radius < 1. Best for diagonally dominant, very sparse
403    /// systems where the series converges in O(log(1/eps)) terms.
404    Neumann,
405    /// Jacobi iterative solver.
406    Jacobi,
407    /// Gauss-Seidel iterative solver.
408    GaussSeidel,
409    /// Forward Push (Andersen-Chung-Lang) for Personalized PageRank.
410    ///
411    /// Computes an approximate PPR vector by pushing residual mass forward
412    /// along edges. Sublinear in graph size for single-source queries.
413    ForwardPush,
414    /// Backward Push for target-centric PPR.
415    ///
416    /// Dual of Forward Push: propagates contributions backward from a target
417    /// node.
418    BackwardPush,
419    /// Conjugate Gradient (CG) iterative solver.
420    ///
421    /// Optimal for symmetric positive-definite systems. Converges in at most
422    /// `n` steps; practical convergence depends on the condition number.
423    CG,
424    /// Hybrid random-walk approach combining push with Monte Carlo sampling.
425    ///
426    /// For large graphs where pure push is too expensive, this approach uses
427    /// random walks to estimate the tail of the PageRank distribution.
428    HybridRandomWalk,
429    /// TRUE (Topology-aware Reduction for Updating Equations) batch solver.
430    ///
431    /// Exploits shared sparsity structure across a batch of right-hand sides
432    /// to amortise factorisation cost. Best when `batch_size` is large.
433    TRUE,
434    /// Block Maximum Spanning Subgraph Preconditioned solver.
435    ///
436    /// Uses a maximum spanning tree preconditioner for ill-conditioned systems
437    /// where CG and Neumann both struggle.
438    BMSSP,
439    /// Dense direct solver (LU/Cholesky fallback).
440    ///
441    /// Last-resort O(n^3) solver used when iterative methods fail. Only
442    /// practical for small matrices.
443    Dense,
444}
445
446impl std::fmt::Display for Algorithm {
447    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448        match self {
449            Algorithm::Neumann => write!(f, "neumann"),
450            Algorithm::Jacobi => write!(f, "jacobi"),
451            Algorithm::GaussSeidel => write!(f, "gauss-seidel"),
452            Algorithm::ForwardPush => write!(f, "forward-push"),
453            Algorithm::BackwardPush => write!(f, "backward-push"),
454            Algorithm::CG => write!(f, "cg"),
455            Algorithm::HybridRandomWalk => write!(f, "hybrid-random-walk"),
456            Algorithm::TRUE => write!(f, "true-solver"),
457            Algorithm::BMSSP => write!(f, "bmssp"),
458            Algorithm::Dense => write!(f, "dense"),
459        }
460    }
461}
462
463// ---------------------------------------------------------------------------
464// Query & profile types for routing
465// ---------------------------------------------------------------------------
466
467/// Query type describing what the caller wants to solve.
468///
469/// The [`SolverRouter`] inspects this together with the [`SparsityProfile`] to
470/// select the most appropriate [`Algorithm`].
471///
472/// [`SolverRouter`]: crate::router::SolverRouter
473#[derive(Debug, Clone, Copy, PartialEq, Eq)]
474pub enum QueryType {
475    /// Standard sparse linear system `Ax = b`.
476    LinearSystem,
477
478    /// Single-source Personalized PageRank.
479    PageRankSingle {
480        /// Source node index.
481        source: usize,
482    },
483
484    /// Pairwise Personalized PageRank between two nodes.
485    PageRankPairwise {
486        /// Source node index.
487        source: usize,
488        /// Target node index.
489        target: usize,
490    },
491
492    /// Spectral graph filter using polynomial expansion.
493    SpectralFilter {
494        /// Degree of the Chebyshev/polynomial expansion.
495        polynomial_degree: usize,
496    },
497
498    /// Batch of linear systems sharing the same matrix `A` but different
499    /// right-hand sides.
500    BatchLinearSystem {
501        /// Number of right-hand sides in the batch.
502        batch_size: usize,
503    },
504}
505
506/// Sparsity profile summarising the structural and numerical properties
507/// of a matrix that are relevant for algorithm selection.
508///
509/// Computed once by [`SolverOrchestrator::analyze_sparsity`] and reused
510/// across multiple solves on the same matrix.
511///
512/// [`SolverOrchestrator::analyze_sparsity`]: crate::router::SolverOrchestrator::analyze_sparsity
513#[derive(Debug, Clone)]
514pub struct SparsityProfile {
515    /// Number of rows.
516    pub rows: usize,
517    /// Number of columns.
518    pub cols: usize,
519    /// Total number of non-zero entries.
520    pub nnz: usize,
521    /// Fraction of non-zeros: `nnz / (rows * cols)`.
522    pub density: f64,
523    /// `true` if `|a_ii| > sum_{j != i} |a_ij|` for every row.
524    pub is_diag_dominant: bool,
525    /// Estimated spectral radius of the Jacobi iteration matrix `D^{-1}(L+U)`.
526    pub estimated_spectral_radius: f64,
527    /// Rough estimate of the 2-norm condition number.
528    pub estimated_condition: f64,
529    /// `true` if the matrix appears to be symmetric (checked on structure only).
530    pub is_symmetric_structure: bool,
531    /// Average number of non-zeros per row.
532    pub avg_nnz_per_row: f64,
533    /// Maximum number of non-zeros in any single row.
534    pub max_nnz_per_row: usize,
535}
536
537/// Estimated computational complexity for a solve.
538///
539/// Returned by [`SolverOrchestrator::estimate_complexity`] to let callers
540/// decide whether to proceed, batch, or reject a query.
541///
542/// [`SolverOrchestrator::estimate_complexity`]: crate::router::SolverOrchestrator::estimate_complexity
543#[derive(Debug, Clone)]
544pub struct ComplexityEstimate {
545    /// Algorithm that would be selected.
546    pub algorithm: Algorithm,
547    /// Estimated number of floating-point operations.
548    pub estimated_flops: u64,
549    /// Estimated number of iterations (for iterative methods).
550    pub estimated_iterations: usize,
551    /// Estimated peak memory usage in bytes.
552    pub estimated_memory_bytes: usize,
553    /// A qualitative complexity class label.
554    pub complexity_class: ComplexityClass,
555}
556
557/// Qualitative complexity class.
558#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
559pub enum ComplexityClass {
560    /// O(nnz * log(1/eps)) -- sublinear in matrix dimension.
561    SublinearNnz,
562    /// O(n * sqrt(kappa)) -- CG-like.
563    SqrtCondition,
564    /// O(n * nnz_per_row) -- linear scan.
565    Linear,
566    /// O(n^2) or worse -- superlinear.
567    Quadratic,
568    /// O(n^3) -- dense factorisation.
569    Cubic,
570}
571
572/// Compute lane priority for solver scheduling.
573#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
574pub enum ComputeLane {
575    /// Low-latency lane for small problems.
576    Fast,
577    /// Default throughput lane.
578    Normal,
579    /// Batch lane for large problems.
580    Batch,
581}
582
583/// Budget constraints for solver execution.
584#[derive(Debug, Clone)]
585pub struct ComputeBudget {
586    /// Maximum wall-clock time allowed.
587    pub max_time: Duration,
588    /// Maximum number of iterations.
589    pub max_iterations: usize,
590    /// Target residual tolerance.
591    pub tolerance: f64,
592}
593
594impl Default for ComputeBudget {
595    fn default() -> Self {
596        Self {
597            max_time: Duration::from_secs(30),
598            max_iterations: 1000,
599            tolerance: 1e-6,
600        }
601    }
602}
603
604/// Per-iteration convergence snapshot.
605#[derive(Debug, Clone)]
606pub struct ConvergenceInfo {
607    /// Iteration index (0-based).
608    pub iteration: usize,
609    /// Residual L2 norm at this iteration.
610    pub residual_norm: f64,
611}
612
613/// Result returned by a successful solver invocation.
614#[derive(Debug, Clone)]
615pub struct SolverResult {
616    /// Solution vector x.
617    pub solution: Vec<f32>,
618    /// Number of iterations performed.
619    pub iterations: usize,
620    /// Final residual L2 norm.
621    pub residual_norm: f64,
622    /// Wall-clock time taken.
623    pub wall_time: Duration,
624    /// Per-iteration convergence history.
625    pub convergence_history: Vec<ConvergenceInfo>,
626    /// Algorithm used.
627    pub algorithm: Algorithm,
628}