Skip to main content

ruvector_solver/
types.rs

1//! Core types for sparse linear solvers.
2//!
3//! Provides [`CsrMatrix`] for compressed sparse row storage and result types
4//! for solver convergence tracking.
5
6use std::time::Duration;
7
8// ---------------------------------------------------------------------------
9// CsrMatrix<T>
10// ---------------------------------------------------------------------------
11
12/// Compressed Sparse Row (CSR) matrix.
13///
14/// Stores only non-zero entries for efficient sparse matrix-vector
15/// multiplication in O(nnz) time with excellent cache locality.
16///
17/// # Layout
18///
19/// For a matrix with `m` rows and `nnz` non-zeros:
20/// - `row_ptr` has length `m + 1`
21/// - `col_indices` and `values` each have length `nnz`
22/// - Row `i` spans indices `row_ptr[i]..row_ptr[i+1]`
23#[derive(Debug, Clone)]
24pub struct CsrMatrix<T> {
25    /// Row pointers: `row_ptr[i]` is the start index in `col_indices`/`values`
26    /// for row `i`.
27    pub row_ptr: Vec<usize>,
28    /// Column indices for each non-zero entry.
29    pub col_indices: Vec<usize>,
30    /// Values for each non-zero entry.
31    pub values: Vec<T>,
32    /// Number of rows.
33    pub rows: usize,
34    /// Number of columns.
35    pub cols: usize,
36}
37
38impl<T: Copy + Default + std::ops::Mul<Output = T> + std::ops::AddAssign> CsrMatrix<T> {
39    /// Sparse matrix-vector multiply: `y = A * x`.
40    ///
41    /// # Panics
42    ///
43    /// Debug-asserts that `x.len() >= self.cols` and `y.len() >= self.rows`.
44    #[inline]
45    pub fn spmv(&self, x: &[T], y: &mut [T]) {
46        debug_assert!(
47            x.len() >= self.cols,
48            "spmv: x.len()={} < cols={}",
49            x.len(),
50            self.cols,
51        );
52        debug_assert!(
53            y.len() >= self.rows,
54            "spmv: y.len()={} < rows={}",
55            y.len(),
56            self.rows,
57        );
58
59        for i in 0..self.rows {
60            let mut sum = T::default();
61            let start = self.row_ptr[i];
62            let end = self.row_ptr[i + 1];
63
64            for idx in start..end {
65                sum += self.values[idx] * x[self.col_indices[idx]];
66            }
67            y[i] = sum;
68        }
69    }
70}
71
72impl CsrMatrix<f32> {
73    /// High-performance SpMV with bounds-check elimination.
74    ///
75    /// Identical to [`spmv`](Self::spmv) but uses `unsafe` indexing to
76    /// eliminate per-element bounds checks in the inner loop, which is the
77    /// single hottest path in all iterative solvers.
78    ///
79    /// # Safety contract
80    ///
81    /// The caller must ensure the CSR structure is valid (use
82    /// [`validate_csr_matrix`](crate::validation::validate_csr_matrix) once
83    /// before entering the solve loop). The `x` and `y` slices must have
84    /// lengths `>= cols` and `>= rows` respectively.
85    #[inline]
86    pub fn spmv_unchecked(&self, x: &[f32], y: &mut [f32]) {
87        debug_assert!(x.len() >= self.cols);
88        debug_assert!(y.len() >= self.rows);
89
90        let vals = self.values.as_ptr();
91        let cols = self.col_indices.as_ptr();
92        let rp = self.row_ptr.as_ptr();
93
94        for i in 0..self.rows {
95            // SAFETY: row_ptr has length rows+1, so i and i+1 are in bounds.
96            let start = unsafe { *rp.add(i) };
97            let end = unsafe { *rp.add(i + 1) };
98            let mut sum = 0.0f32;
99
100            for idx in start..end {
101                // SAFETY: idx < nnz (enforced by valid CSR structure),
102                // col_indices[idx] < cols <= x.len() (enforced by validation).
103                unsafe {
104                    let v = *vals.add(idx);
105                    let c = *cols.add(idx);
106                    sum += v * *x.get_unchecked(c);
107                }
108            }
109            // SAFETY: i < rows <= y.len()
110            unsafe { *y.get_unchecked_mut(i) = sum };
111        }
112    }
113
114    /// Fused SpMV + residual computation: computes `r[j] = rhs[j] - (A*x)[j]`
115    /// and returns `||r||^2` in a single pass, avoiding a separate allocation
116    /// for `Ax`.
117    ///
118    /// This eliminates one full memory traversal per iteration compared to
119    /// separate `spmv` + vector subtraction.
120    #[inline]
121    pub fn fused_residual_norm_sq(
122        &self,
123        x: &[f32],
124        rhs: &[f32],
125        residual: &mut [f32],
126    ) -> f64 {
127        debug_assert!(x.len() >= self.cols);
128        debug_assert!(rhs.len() >= self.rows);
129        debug_assert!(residual.len() >= self.rows);
130
131        let vals = self.values.as_ptr();
132        let cols = self.col_indices.as_ptr();
133        let rp = self.row_ptr.as_ptr();
134        let mut norm_sq = 0.0f64;
135
136        for i in 0..self.rows {
137            let start = unsafe { *rp.add(i) };
138            let end = unsafe { *rp.add(i + 1) };
139            let mut ax_i = 0.0f32;
140
141            for idx in start..end {
142                unsafe {
143                    let v = *vals.add(idx);
144                    let c = *cols.add(idx);
145                    ax_i += v * *x.get_unchecked(c);
146                }
147            }
148
149            let r_i = rhs[i] - ax_i;
150            residual[i] = r_i;
151            norm_sq += (r_i as f64) * (r_i as f64);
152        }
153
154        norm_sq
155    }
156}
157
158impl CsrMatrix<f64> {
159    /// High-performance SpMV for f64 with bounds-check elimination.
160    #[inline]
161    pub fn spmv_unchecked(&self, x: &[f64], y: &mut [f64]) {
162        debug_assert!(x.len() >= self.cols);
163        debug_assert!(y.len() >= self.rows);
164
165        let vals = self.values.as_ptr();
166        let cols = self.col_indices.as_ptr();
167        let rp = self.row_ptr.as_ptr();
168
169        for i in 0..self.rows {
170            let start = unsafe { *rp.add(i) };
171            let end = unsafe { *rp.add(i + 1) };
172            let mut sum = 0.0f64;
173
174            for idx in start..end {
175                unsafe {
176                    let v = *vals.add(idx);
177                    let c = *cols.add(idx);
178                    sum += v * *x.get_unchecked(c);
179                }
180            }
181            unsafe { *y.get_unchecked_mut(i) = sum };
182        }
183    }
184}
185
186impl<T> CsrMatrix<T> {
187    /// Number of non-zero entries.
188    #[inline]
189    pub fn nnz(&self) -> usize {
190        self.values.len()
191    }
192
193    /// Number of non-zeros in a specific row (i.e. the row degree for an
194    /// adjacency matrix).
195    #[inline]
196    pub fn row_degree(&self, row: usize) -> usize {
197        self.row_ptr[row + 1] - self.row_ptr[row]
198    }
199
200    /// Iterate over `(col_index, &value)` pairs for the given row.
201    #[inline]
202    pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, &T)> {
203        let start = self.row_ptr[row];
204        let end = self.row_ptr[row + 1];
205        self.col_indices[start..end]
206            .iter()
207            .copied()
208            .zip(self.values[start..end].iter())
209    }
210}
211
212impl<T: Copy + Default> CsrMatrix<T> {
213    /// Transpose: produces `A^T` in CSR form.
214    ///
215    /// Uses a two-pass counting sort in O(nnz + rows + cols) time and
216    /// O(nnz) extra memory. Required by backward push which operates on
217    /// the reversed adjacency structure.
218    pub fn transpose(&self) -> CsrMatrix<T> {
219        let nnz = self.nnz();
220        let t_rows = self.cols;
221        let t_cols = self.rows;
222
223        // Pass 1: count entries per new row (= old column).
224        let mut row_ptr = vec![0usize; t_rows + 1];
225        for &c in &self.col_indices {
226            row_ptr[c + 1] += 1;
227        }
228        for i in 1..=t_rows {
229            row_ptr[i] += row_ptr[i - 1];
230        }
231
232        // Pass 2: scatter entries into the transposed arrays.
233        let mut col_indices = vec![0usize; nnz];
234        let mut values = vec![T::default(); nnz];
235        let mut cursor = row_ptr.clone();
236
237        for row in 0..self.rows {
238            let start = self.row_ptr[row];
239            let end = self.row_ptr[row + 1];
240            for idx in start..end {
241                let c = self.col_indices[idx];
242                let dest = cursor[c];
243                col_indices[dest] = row;
244                values[dest] = self.values[idx];
245                cursor[c] += 1;
246            }
247        }
248
249        CsrMatrix {
250            row_ptr,
251            col_indices,
252            values,
253            rows: t_rows,
254            cols: t_cols,
255        }
256    }
257}
258
259impl<T: Copy + Default + std::ops::AddAssign> CsrMatrix<T> {
260    /// Build a CSR matrix from COO (coordinate) triplets.
261    ///
262    /// Entries are sorted by (row, col) internally. Duplicate positions at the
263    /// same (row, col) are kept as separate entries (caller should pre-merge if
264    /// needed).
265    pub fn from_coo_generic(
266        rows: usize,
267        cols: usize,
268        entries: impl IntoIterator<Item = (usize, usize, T)>,
269    ) -> Self {
270        let mut sorted: Vec<_> = entries.into_iter().collect();
271        sorted.sort_unstable_by_key(|(r, c, _)| (*r, *c));
272
273        let nnz = sorted.len();
274        let mut row_ptr = vec![0usize; rows + 1];
275        let mut col_indices = Vec::with_capacity(nnz);
276        let mut values = Vec::with_capacity(nnz);
277
278        for &(r, _, _) in &sorted {
279            assert!(r < rows, "row index {} out of bounds (rows={})", r, rows);
280            row_ptr[r + 1] += 1;
281        }
282        for i in 1..=rows {
283            row_ptr[i] += row_ptr[i - 1];
284        }
285
286        for (_, c, v) in sorted {
287            assert!(c < cols, "col index {} out of bounds (cols={})", c, cols);
288            col_indices.push(c);
289            values.push(v);
290        }
291
292        Self {
293            row_ptr,
294            col_indices,
295            values,
296            rows,
297            cols,
298        }
299    }
300}
301
302impl CsrMatrix<f32> {
303    /// Build a CSR matrix from COO (coordinate) triplets.
304    ///
305    /// Entries are sorted by (row, col) internally. Duplicate positions are
306    /// summed.
307    pub fn from_coo(
308        rows: usize,
309        cols: usize,
310        entries: impl IntoIterator<Item = (usize, usize, f32)>,
311    ) -> Self {
312        Self::from_coo_generic(rows, cols, entries)
313    }
314
315    /// Build a square identity matrix of dimension `n` in CSR format.
316    pub fn identity(n: usize) -> Self {
317        let row_ptr: Vec<usize> = (0..=n).collect();
318        let col_indices: Vec<usize> = (0..n).collect();
319        let values = vec![1.0f32; n];
320
321        Self {
322            row_ptr,
323            col_indices,
324            values,
325            rows: n,
326            cols: n,
327        }
328    }
329}
330
331impl CsrMatrix<f64> {
332    /// Build a CSR matrix from COO (coordinate) triplets (f64 variant).
333    ///
334    /// Entries are sorted by (row, col) internally.
335    pub fn from_coo(
336        rows: usize,
337        cols: usize,
338        entries: impl IntoIterator<Item = (usize, usize, f64)>,
339    ) -> Self {
340        Self::from_coo_generic(rows, cols, entries)
341    }
342
343    /// Build a square identity matrix of dimension `n` in CSR format (f64).
344    pub fn identity(n: usize) -> Self {
345        let row_ptr: Vec<usize> = (0..=n).collect();
346        let col_indices: Vec<usize> = (0..n).collect();
347        let values = vec![1.0f64; n];
348
349        Self {
350            row_ptr,
351            col_indices,
352            values,
353            rows: n,
354            cols: n,
355        }
356    }
357}
358
359// ---------------------------------------------------------------------------
360// Solver result types
361// ---------------------------------------------------------------------------
362
363/// Algorithm identifier for solver selection and routing.
364///
365/// Each variant corresponds to a solver strategy with different complexity
366/// characteristics and applicability constraints. The [`SolverRouter`] selects
367/// the best algorithm based on the matrix [`SparsityProfile`] and [`QueryType`].
368///
369/// [`SolverRouter`]: crate::router::SolverRouter
370#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
371pub enum Algorithm {
372    /// Neumann series: `x = sum_{k=0}^{K} (I - A)^k * b`.
373    ///
374    /// Requires spectral radius < 1. Best for diagonally dominant, very sparse
375    /// systems where the series converges in O(log(1/eps)) terms.
376    Neumann,
377    /// Jacobi iterative solver.
378    Jacobi,
379    /// Gauss-Seidel iterative solver.
380    GaussSeidel,
381    /// Forward Push (Andersen-Chung-Lang) for Personalized PageRank.
382    ///
383    /// Computes an approximate PPR vector by pushing residual mass forward
384    /// along edges. Sublinear in graph size for single-source queries.
385    ForwardPush,
386    /// Backward Push for target-centric PPR.
387    ///
388    /// Dual of Forward Push: propagates contributions backward from a target
389    /// node.
390    BackwardPush,
391    /// Conjugate Gradient (CG) iterative solver.
392    ///
393    /// Optimal for symmetric positive-definite systems. Converges in at most
394    /// `n` steps; practical convergence depends on the condition number.
395    CG,
396    /// Hybrid random-walk approach combining push with Monte Carlo sampling.
397    ///
398    /// For large graphs where pure push is too expensive, this approach uses
399    /// random walks to estimate the tail of the PageRank distribution.
400    HybridRandomWalk,
401    /// TRUE (Topology-aware Reduction for Updating Equations) batch solver.
402    ///
403    /// Exploits shared sparsity structure across a batch of right-hand sides
404    /// to amortise factorisation cost. Best when `batch_size` is large.
405    TRUE,
406    /// Block Maximum Spanning Subgraph Preconditioned solver.
407    ///
408    /// Uses a maximum spanning tree preconditioner for ill-conditioned systems
409    /// where CG and Neumann both struggle.
410    BMSSP,
411    /// Dense direct solver (LU/Cholesky fallback).
412    ///
413    /// Last-resort O(n^3) solver used when iterative methods fail. Only
414    /// practical for small matrices.
415    Dense,
416}
417
418impl std::fmt::Display for Algorithm {
419    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
420        match self {
421            Algorithm::Neumann => write!(f, "neumann"),
422            Algorithm::Jacobi => write!(f, "jacobi"),
423            Algorithm::GaussSeidel => write!(f, "gauss-seidel"),
424            Algorithm::ForwardPush => write!(f, "forward-push"),
425            Algorithm::BackwardPush => write!(f, "backward-push"),
426            Algorithm::CG => write!(f, "cg"),
427            Algorithm::HybridRandomWalk => write!(f, "hybrid-random-walk"),
428            Algorithm::TRUE => write!(f, "true-solver"),
429            Algorithm::BMSSP => write!(f, "bmssp"),
430            Algorithm::Dense => write!(f, "dense"),
431        }
432    }
433}
434
435// ---------------------------------------------------------------------------
436// Query & profile types for routing
437// ---------------------------------------------------------------------------
438
439/// Query type describing what the caller wants to solve.
440///
441/// The [`SolverRouter`] inspects this together with the [`SparsityProfile`] to
442/// select the most appropriate [`Algorithm`].
443///
444/// [`SolverRouter`]: crate::router::SolverRouter
445#[derive(Debug, Clone, Copy, PartialEq, Eq)]
446pub enum QueryType {
447    /// Standard sparse linear system `Ax = b`.
448    LinearSystem,
449
450    /// Single-source Personalized PageRank.
451    PageRankSingle {
452        /// Source node index.
453        source: usize,
454    },
455
456    /// Pairwise Personalized PageRank between two nodes.
457    PageRankPairwise {
458        /// Source node index.
459        source: usize,
460        /// Target node index.
461        target: usize,
462    },
463
464    /// Spectral graph filter using polynomial expansion.
465    SpectralFilter {
466        /// Degree of the Chebyshev/polynomial expansion.
467        polynomial_degree: usize,
468    },
469
470    /// Batch of linear systems sharing the same matrix `A` but different
471    /// right-hand sides.
472    BatchLinearSystem {
473        /// Number of right-hand sides in the batch.
474        batch_size: usize,
475    },
476}
477
478/// Sparsity profile summarising the structural and numerical properties
479/// of a matrix that are relevant for algorithm selection.
480///
481/// Computed once by [`SolverOrchestrator::analyze_sparsity`] and reused
482/// across multiple solves on the same matrix.
483///
484/// [`SolverOrchestrator::analyze_sparsity`]: crate::router::SolverOrchestrator::analyze_sparsity
485#[derive(Debug, Clone)]
486pub struct SparsityProfile {
487    /// Number of rows.
488    pub rows: usize,
489    /// Number of columns.
490    pub cols: usize,
491    /// Total number of non-zero entries.
492    pub nnz: usize,
493    /// Fraction of non-zeros: `nnz / (rows * cols)`.
494    pub density: f64,
495    /// `true` if `|a_ii| > sum_{j != i} |a_ij|` for every row.
496    pub is_diag_dominant: bool,
497    /// Estimated spectral radius of the Jacobi iteration matrix `D^{-1}(L+U)`.
498    pub estimated_spectral_radius: f64,
499    /// Rough estimate of the 2-norm condition number.
500    pub estimated_condition: f64,
501    /// `true` if the matrix appears to be symmetric (checked on structure only).
502    pub is_symmetric_structure: bool,
503    /// Average number of non-zeros per row.
504    pub avg_nnz_per_row: f64,
505    /// Maximum number of non-zeros in any single row.
506    pub max_nnz_per_row: usize,
507}
508
509/// Estimated computational complexity for a solve.
510///
511/// Returned by [`SolverOrchestrator::estimate_complexity`] to let callers
512/// decide whether to proceed, batch, or reject a query.
513///
514/// [`SolverOrchestrator::estimate_complexity`]: crate::router::SolverOrchestrator::estimate_complexity
515#[derive(Debug, Clone)]
516pub struct ComplexityEstimate {
517    /// Algorithm that would be selected.
518    pub algorithm: Algorithm,
519    /// Estimated number of floating-point operations.
520    pub estimated_flops: u64,
521    /// Estimated number of iterations (for iterative methods).
522    pub estimated_iterations: usize,
523    /// Estimated peak memory usage in bytes.
524    pub estimated_memory_bytes: usize,
525    /// A qualitative complexity class label.
526    pub complexity_class: ComplexityClass,
527}
528
529/// Qualitative complexity class.
530#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
531pub enum ComplexityClass {
532    /// O(nnz * log(1/eps)) -- sublinear in matrix dimension.
533    SublinearNnz,
534    /// O(n * sqrt(kappa)) -- CG-like.
535    SqrtCondition,
536    /// O(n * nnz_per_row) -- linear scan.
537    Linear,
538    /// O(n^2) or worse -- superlinear.
539    Quadratic,
540    /// O(n^3) -- dense factorisation.
541    Cubic,
542}
543
544/// Compute lane priority for solver scheduling.
545#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
546pub enum ComputeLane {
547    /// Low-latency lane for small problems.
548    Fast,
549    /// Default throughput lane.
550    Normal,
551    /// Batch lane for large problems.
552    Batch,
553}
554
555/// Budget constraints for solver execution.
556#[derive(Debug, Clone)]
557pub struct ComputeBudget {
558    /// Maximum wall-clock time allowed.
559    pub max_time: Duration,
560    /// Maximum number of iterations.
561    pub max_iterations: usize,
562    /// Target residual tolerance.
563    pub tolerance: f64,
564}
565
566impl Default for ComputeBudget {
567    fn default() -> Self {
568        Self {
569            max_time: Duration::from_secs(30),
570            max_iterations: 1000,
571            tolerance: 1e-6,
572        }
573    }
574}
575
576/// Per-iteration convergence snapshot.
577#[derive(Debug, Clone)]
578pub struct ConvergenceInfo {
579    /// Iteration index (0-based).
580    pub iteration: usize,
581    /// Residual L2 norm at this iteration.
582    pub residual_norm: f64,
583}
584
585/// Result returned by a successful solver invocation.
586#[derive(Debug, Clone)]
587pub struct SolverResult {
588    /// Solution vector x.
589    pub solution: Vec<f32>,
590    /// Number of iterations performed.
591    pub iterations: usize,
592    /// Final residual L2 norm.
593    pub residual_norm: f64,
594    /// Wall-clock time taken.
595    pub wall_time: Duration,
596    /// Per-iteration convergence history.
597    pub convergence_history: Vec<ConvergenceInfo>,
598    /// Algorithm used.
599    pub algorithm: Algorithm,
600}