ruvector_solver/types.rs
1//! Core types for sparse linear solvers.
2//!
3//! Provides [`CsrMatrix`] for compressed sparse row storage and result types
4//! for solver convergence tracking.
5
6use std::time::Duration;
7
8// ---------------------------------------------------------------------------
9// CsrMatrix<T>
10// ---------------------------------------------------------------------------
11
12/// Compressed Sparse Row (CSR) matrix.
13///
14/// Stores only non-zero entries for efficient sparse matrix-vector
15/// multiplication in O(nnz) time with excellent cache locality.
16///
17/// # Layout
18///
19/// For a matrix with `m` rows and `nnz` non-zeros:
20/// - `row_ptr` has length `m + 1`
21/// - `col_indices` and `values` each have length `nnz`
22/// - Row `i` spans indices `row_ptr[i]..row_ptr[i+1]`
23#[derive(Debug, Clone)]
24pub struct CsrMatrix<T> {
25 /// Row pointers: `row_ptr[i]` is the start index in `col_indices`/`values`
26 /// for row `i`.
27 pub row_ptr: Vec<usize>,
28 /// Column indices for each non-zero entry.
29 pub col_indices: Vec<usize>,
30 /// Values for each non-zero entry.
31 pub values: Vec<T>,
32 /// Number of rows.
33 pub rows: usize,
34 /// Number of columns.
35 pub cols: usize,
36}
37
38impl<T: Copy + Default + std::ops::Mul<Output = T> + std::ops::AddAssign> CsrMatrix<T> {
39 /// Sparse matrix-vector multiply: `y = A * x`.
40 ///
41 /// # Panics
42 ///
43 /// Debug-asserts that `x.len() >= self.cols` and `y.len() >= self.rows`.
44 #[inline]
45 pub fn spmv(&self, x: &[T], y: &mut [T]) {
46 debug_assert!(
47 x.len() >= self.cols,
48 "spmv: x.len()={} < cols={}",
49 x.len(),
50 self.cols,
51 );
52 debug_assert!(
53 y.len() >= self.rows,
54 "spmv: y.len()={} < rows={}",
55 y.len(),
56 self.rows,
57 );
58
59 for i in 0..self.rows {
60 let mut sum = T::default();
61 let start = self.row_ptr[i];
62 let end = self.row_ptr[i + 1];
63
64 for idx in start..end {
65 sum += self.values[idx] * x[self.col_indices[idx]];
66 }
67 y[i] = sum;
68 }
69 }
70}
71
72impl CsrMatrix<f32> {
73 /// High-performance SpMV with bounds-check elimination.
74 ///
75 /// Identical to [`spmv`](Self::spmv) but uses `unsafe` indexing to
76 /// eliminate per-element bounds checks in the inner loop, which is the
77 /// single hottest path in all iterative solvers.
78 ///
79 /// # Safety contract
80 ///
81 /// The caller must ensure the CSR structure is valid (use
82 /// [`validate_csr_matrix`](crate::validation::validate_csr_matrix) once
83 /// before entering the solve loop). The `x` and `y` slices must have
84 /// lengths `>= cols` and `>= rows` respectively.
85 #[inline]
86 pub fn spmv_unchecked(&self, x: &[f32], y: &mut [f32]) {
87 debug_assert!(x.len() >= self.cols);
88 debug_assert!(y.len() >= self.rows);
89
90 let vals = self.values.as_ptr();
91 let cols = self.col_indices.as_ptr();
92 let rp = self.row_ptr.as_ptr();
93
94 for i in 0..self.rows {
95 // SAFETY: row_ptr has length rows+1, so i and i+1 are in bounds.
96 let start = unsafe { *rp.add(i) };
97 let end = unsafe { *rp.add(i + 1) };
98 let mut sum = 0.0f32;
99
100 for idx in start..end {
101 // SAFETY: idx < nnz (enforced by valid CSR structure),
102 // col_indices[idx] < cols <= x.len() (enforced by validation).
103 unsafe {
104 let v = *vals.add(idx);
105 let c = *cols.add(idx);
106 sum += v * *x.get_unchecked(c);
107 }
108 }
109 // SAFETY: i < rows <= y.len()
110 unsafe { *y.get_unchecked_mut(i) = sum };
111 }
112 }
113
114 /// Fused SpMV + residual computation: computes `r[j] = rhs[j] - (A*x)[j]`
115 /// and returns `||r||^2` in a single pass, avoiding a separate allocation
116 /// for `Ax`.
117 ///
118 /// This eliminates one full memory traversal per iteration compared to
119 /// separate `spmv` + vector subtraction.
120 #[inline]
121 pub fn fused_residual_norm_sq(&self, x: &[f32], rhs: &[f32], residual: &mut [f32]) -> f64 {
122 debug_assert!(x.len() >= self.cols);
123 debug_assert!(rhs.len() >= self.rows);
124 debug_assert!(residual.len() >= self.rows);
125
126 let vals = self.values.as_ptr();
127 let cols = self.col_indices.as_ptr();
128 let rp = self.row_ptr.as_ptr();
129 let mut norm_sq = 0.0f64;
130
131 for i in 0..self.rows {
132 let start = unsafe { *rp.add(i) };
133 let end = unsafe { *rp.add(i + 1) };
134 let mut ax_i = 0.0f32;
135
136 for idx in start..end {
137 unsafe {
138 let v = *vals.add(idx);
139 let c = *cols.add(idx);
140 ax_i += v * *x.get_unchecked(c);
141 }
142 }
143
144 let r_i = rhs[i] - ax_i;
145 residual[i] = r_i;
146 norm_sq += (r_i as f64) * (r_i as f64);
147 }
148
149 norm_sq
150 }
151}
152
153impl CsrMatrix<f64> {
154 /// High-performance SpMV for f64 with bounds-check elimination.
155 #[inline]
156 pub fn spmv_unchecked(&self, x: &[f64], y: &mut [f64]) {
157 debug_assert!(x.len() >= self.cols);
158 debug_assert!(y.len() >= self.rows);
159
160 let vals = self.values.as_ptr();
161 let cols = self.col_indices.as_ptr();
162 let rp = self.row_ptr.as_ptr();
163
164 for i in 0..self.rows {
165 let start = unsafe { *rp.add(i) };
166 let end = unsafe { *rp.add(i + 1) };
167 let mut sum = 0.0f64;
168
169 for idx in start..end {
170 unsafe {
171 let v = *vals.add(idx);
172 let c = *cols.add(idx);
173 sum += v * *x.get_unchecked(c);
174 }
175 }
176 unsafe { *y.get_unchecked_mut(i) = sum };
177 }
178 }
179}
180
181impl<T> CsrMatrix<T> {
182 /// Number of non-zero entries.
183 #[inline]
184 pub fn nnz(&self) -> usize {
185 self.values.len()
186 }
187
188 /// Number of non-zeros in a specific row (i.e. the row degree for an
189 /// adjacency matrix).
190 #[inline]
191 pub fn row_degree(&self, row: usize) -> usize {
192 self.row_ptr[row + 1] - self.row_ptr[row]
193 }
194
195 /// Iterate over `(col_index, &value)` pairs for the given row.
196 #[inline]
197 pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, &T)> {
198 let start = self.row_ptr[row];
199 let end = self.row_ptr[row + 1];
200 self.col_indices[start..end]
201 .iter()
202 .copied()
203 .zip(self.values[start..end].iter())
204 }
205}
206
207impl<T: Copy + Default> CsrMatrix<T> {
208 /// Transpose: produces `A^T` in CSR form.
209 ///
210 /// Uses a two-pass counting sort in O(nnz + rows + cols) time and
211 /// O(nnz) extra memory. Required by backward push which operates on
212 /// the reversed adjacency structure.
213 pub fn transpose(&self) -> CsrMatrix<T> {
214 let nnz = self.nnz();
215 let t_rows = self.cols;
216 let t_cols = self.rows;
217
218 // Pass 1: count entries per new row (= old column).
219 let mut row_ptr = vec![0usize; t_rows + 1];
220 for &c in &self.col_indices {
221 row_ptr[c + 1] += 1;
222 }
223 for i in 1..=t_rows {
224 row_ptr[i] += row_ptr[i - 1];
225 }
226
227 // Pass 2: scatter entries into the transposed arrays.
228 let mut col_indices = vec![0usize; nnz];
229 let mut values = vec![T::default(); nnz];
230 let mut cursor = row_ptr.clone();
231
232 for row in 0..self.rows {
233 let start = self.row_ptr[row];
234 let end = self.row_ptr[row + 1];
235 for idx in start..end {
236 let c = self.col_indices[idx];
237 let dest = cursor[c];
238 col_indices[dest] = row;
239 values[dest] = self.values[idx];
240 cursor[c] += 1;
241 }
242 }
243
244 CsrMatrix {
245 row_ptr,
246 col_indices,
247 values,
248 rows: t_rows,
249 cols: t_cols,
250 }
251 }
252}
253
254impl<T: Copy + Default + std::ops::AddAssign> CsrMatrix<T> {
255 /// Build a CSR matrix from COO (coordinate) triplets.
256 ///
257 /// Entries are sorted by (row, col) internally. Duplicate positions at the
258 /// same (row, col) are kept as separate entries (caller should pre-merge if
259 /// needed).
260 pub fn from_coo_generic(
261 rows: usize,
262 cols: usize,
263 entries: impl IntoIterator<Item = (usize, usize, T)>,
264 ) -> Self {
265 let mut sorted: Vec<_> = entries.into_iter().collect();
266 sorted.sort_unstable_by_key(|(r, c, _)| (*r, *c));
267
268 let nnz = sorted.len();
269 let mut row_ptr = vec![0usize; rows + 1];
270 let mut col_indices = Vec::with_capacity(nnz);
271 let mut values = Vec::with_capacity(nnz);
272
273 for &(r, _, _) in &sorted {
274 assert!(r < rows, "row index {} out of bounds (rows={})", r, rows);
275 row_ptr[r + 1] += 1;
276 }
277 for i in 1..=rows {
278 row_ptr[i] += row_ptr[i - 1];
279 }
280
281 for (_, c, v) in sorted {
282 assert!(c < cols, "col index {} out of bounds (cols={})", c, cols);
283 col_indices.push(c);
284 values.push(v);
285 }
286
287 Self {
288 row_ptr,
289 col_indices,
290 values,
291 rows,
292 cols,
293 }
294 }
295}
296
297impl CsrMatrix<f32> {
298 /// Build a CSR matrix from COO (coordinate) triplets.
299 ///
300 /// Entries are sorted by (row, col) internally. Duplicate positions are
301 /// summed.
302 pub fn from_coo(
303 rows: usize,
304 cols: usize,
305 entries: impl IntoIterator<Item = (usize, usize, f32)>,
306 ) -> Self {
307 Self::from_coo_generic(rows, cols, entries)
308 }
309
310 /// Build a square identity matrix of dimension `n` in CSR format.
311 pub fn identity(n: usize) -> Self {
312 let row_ptr: Vec<usize> = (0..=n).collect();
313 let col_indices: Vec<usize> = (0..n).collect();
314 let values = vec![1.0f32; n];
315
316 Self {
317 row_ptr,
318 col_indices,
319 values,
320 rows: n,
321 cols: n,
322 }
323 }
324}
325
326impl CsrMatrix<f64> {
327 /// Build a CSR matrix from COO (coordinate) triplets (f64 variant).
328 ///
329 /// Entries are sorted by (row, col) internally.
330 pub fn from_coo(
331 rows: usize,
332 cols: usize,
333 entries: impl IntoIterator<Item = (usize, usize, f64)>,
334 ) -> Self {
335 Self::from_coo_generic(rows, cols, entries)
336 }
337
338 /// Build a square identity matrix of dimension `n` in CSR format (f64).
339 pub fn identity(n: usize) -> Self {
340 let row_ptr: Vec<usize> = (0..=n).collect();
341 let col_indices: Vec<usize> = (0..n).collect();
342 let values = vec![1.0f64; n];
343
344 Self {
345 row_ptr,
346 col_indices,
347 values,
348 rows: n,
349 cols: n,
350 }
351 }
352}
353
354// ---------------------------------------------------------------------------
355// Solver result types
356// ---------------------------------------------------------------------------
357
358/// Algorithm identifier for solver selection and routing.
359///
360/// Each variant corresponds to a solver strategy with different complexity
361/// characteristics and applicability constraints. The [`SolverRouter`] selects
362/// the best algorithm based on the matrix [`SparsityProfile`] and [`QueryType`].
363///
364/// [`SolverRouter`]: crate::router::SolverRouter
365#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
366pub enum Algorithm {
367 /// Neumann series: `x = sum_{k=0}^{K} (I - A)^k * b`.
368 ///
369 /// Requires spectral radius < 1. Best for diagonally dominant, very sparse
370 /// systems where the series converges in O(log(1/eps)) terms.
371 Neumann,
372 /// Jacobi iterative solver.
373 Jacobi,
374 /// Gauss-Seidel iterative solver.
375 GaussSeidel,
376 /// Forward Push (Andersen-Chung-Lang) for Personalized PageRank.
377 ///
378 /// Computes an approximate PPR vector by pushing residual mass forward
379 /// along edges. Sublinear in graph size for single-source queries.
380 ForwardPush,
381 /// Backward Push for target-centric PPR.
382 ///
383 /// Dual of Forward Push: propagates contributions backward from a target
384 /// node.
385 BackwardPush,
386 /// Conjugate Gradient (CG) iterative solver.
387 ///
388 /// Optimal for symmetric positive-definite systems. Converges in at most
389 /// `n` steps; practical convergence depends on the condition number.
390 CG,
391 /// Hybrid random-walk approach combining push with Monte Carlo sampling.
392 ///
393 /// For large graphs where pure push is too expensive, this approach uses
394 /// random walks to estimate the tail of the PageRank distribution.
395 HybridRandomWalk,
396 /// TRUE (Topology-aware Reduction for Updating Equations) batch solver.
397 ///
398 /// Exploits shared sparsity structure across a batch of right-hand sides
399 /// to amortise factorisation cost. Best when `batch_size` is large.
400 TRUE,
401 /// Block Maximum Spanning Subgraph Preconditioned solver.
402 ///
403 /// Uses a maximum spanning tree preconditioner for ill-conditioned systems
404 /// where CG and Neumann both struggle.
405 BMSSP,
406 /// Dense direct solver (LU/Cholesky fallback).
407 ///
408 /// Last-resort O(n^3) solver used when iterative methods fail. Only
409 /// practical for small matrices.
410 Dense,
411}
412
413impl std::fmt::Display for Algorithm {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 match self {
416 Algorithm::Neumann => write!(f, "neumann"),
417 Algorithm::Jacobi => write!(f, "jacobi"),
418 Algorithm::GaussSeidel => write!(f, "gauss-seidel"),
419 Algorithm::ForwardPush => write!(f, "forward-push"),
420 Algorithm::BackwardPush => write!(f, "backward-push"),
421 Algorithm::CG => write!(f, "cg"),
422 Algorithm::HybridRandomWalk => write!(f, "hybrid-random-walk"),
423 Algorithm::TRUE => write!(f, "true-solver"),
424 Algorithm::BMSSP => write!(f, "bmssp"),
425 Algorithm::Dense => write!(f, "dense"),
426 }
427 }
428}
429
430// ---------------------------------------------------------------------------
431// Query & profile types for routing
432// ---------------------------------------------------------------------------
433
434/// Query type describing what the caller wants to solve.
435///
436/// The [`SolverRouter`] inspects this together with the [`SparsityProfile`] to
437/// select the most appropriate [`Algorithm`].
438///
439/// [`SolverRouter`]: crate::router::SolverRouter
440#[derive(Debug, Clone, Copy, PartialEq, Eq)]
441pub enum QueryType {
442 /// Standard sparse linear system `Ax = b`.
443 LinearSystem,
444
445 /// Single-source Personalized PageRank.
446 PageRankSingle {
447 /// Source node index.
448 source: usize,
449 },
450
451 /// Pairwise Personalized PageRank between two nodes.
452 PageRankPairwise {
453 /// Source node index.
454 source: usize,
455 /// Target node index.
456 target: usize,
457 },
458
459 /// Spectral graph filter using polynomial expansion.
460 SpectralFilter {
461 /// Degree of the Chebyshev/polynomial expansion.
462 polynomial_degree: usize,
463 },
464
465 /// Batch of linear systems sharing the same matrix `A` but different
466 /// right-hand sides.
467 BatchLinearSystem {
468 /// Number of right-hand sides in the batch.
469 batch_size: usize,
470 },
471}
472
473/// Sparsity profile summarising the structural and numerical properties
474/// of a matrix that are relevant for algorithm selection.
475///
476/// Computed once by [`SolverOrchestrator::analyze_sparsity`] and reused
477/// across multiple solves on the same matrix.
478///
479/// [`SolverOrchestrator::analyze_sparsity`]: crate::router::SolverOrchestrator::analyze_sparsity
480#[derive(Debug, Clone)]
481pub struct SparsityProfile {
482 /// Number of rows.
483 pub rows: usize,
484 /// Number of columns.
485 pub cols: usize,
486 /// Total number of non-zero entries.
487 pub nnz: usize,
488 /// Fraction of non-zeros: `nnz / (rows * cols)`.
489 pub density: f64,
490 /// `true` if `|a_ii| > sum_{j != i} |a_ij|` for every row.
491 pub is_diag_dominant: bool,
492 /// Estimated spectral radius of the Jacobi iteration matrix `D^{-1}(L+U)`.
493 pub estimated_spectral_radius: f64,
494 /// Rough estimate of the 2-norm condition number.
495 pub estimated_condition: f64,
496 /// `true` if the matrix appears to be symmetric (checked on structure only).
497 pub is_symmetric_structure: bool,
498 /// Average number of non-zeros per row.
499 pub avg_nnz_per_row: f64,
500 /// Maximum number of non-zeros in any single row.
501 pub max_nnz_per_row: usize,
502}
503
504/// Estimated computational complexity for a solve.
505///
506/// Returned by [`SolverOrchestrator::estimate_complexity`] to let callers
507/// decide whether to proceed, batch, or reject a query.
508///
509/// [`SolverOrchestrator::estimate_complexity`]: crate::router::SolverOrchestrator::estimate_complexity
510#[derive(Debug, Clone)]
511pub struct ComplexityEstimate {
512 /// Algorithm that would be selected.
513 pub algorithm: Algorithm,
514 /// Estimated number of floating-point operations.
515 pub estimated_flops: u64,
516 /// Estimated number of iterations (for iterative methods).
517 pub estimated_iterations: usize,
518 /// Estimated peak memory usage in bytes.
519 pub estimated_memory_bytes: usize,
520 /// A qualitative complexity class label.
521 pub complexity_class: ComplexityClass,
522}
523
524/// Qualitative complexity class.
525#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
526pub enum ComplexityClass {
527 /// O(nnz * log(1/eps)) -- sublinear in matrix dimension.
528 SublinearNnz,
529 /// O(n * sqrt(kappa)) -- CG-like.
530 SqrtCondition,
531 /// O(n * nnz_per_row) -- linear scan.
532 Linear,
533 /// O(n^2) or worse -- superlinear.
534 Quadratic,
535 /// O(n^3) -- dense factorisation.
536 Cubic,
537}
538
539/// Compute lane priority for solver scheduling.
540#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
541pub enum ComputeLane {
542 /// Low-latency lane for small problems.
543 Fast,
544 /// Default throughput lane.
545 Normal,
546 /// Batch lane for large problems.
547 Batch,
548}
549
550/// Budget constraints for solver execution.
551#[derive(Debug, Clone)]
552pub struct ComputeBudget {
553 /// Maximum wall-clock time allowed.
554 pub max_time: Duration,
555 /// Maximum number of iterations.
556 pub max_iterations: usize,
557 /// Target residual tolerance.
558 pub tolerance: f64,
559}
560
561impl Default for ComputeBudget {
562 fn default() -> Self {
563 Self {
564 max_time: Duration::from_secs(30),
565 max_iterations: 1000,
566 tolerance: 1e-6,
567 }
568 }
569}
570
571/// Per-iteration convergence snapshot.
572#[derive(Debug, Clone)]
573pub struct ConvergenceInfo {
574 /// Iteration index (0-based).
575 pub iteration: usize,
576 /// Residual L2 norm at this iteration.
577 pub residual_norm: f64,
578}
579
580/// Result returned by a successful solver invocation.
581#[derive(Debug, Clone)]
582pub struct SolverResult {
583 /// Solution vector x.
584 pub solution: Vec<f32>,
585 /// Number of iterations performed.
586 pub iterations: usize,
587 /// Final residual L2 norm.
588 pub residual_norm: f64,
589 /// Wall-clock time taken.
590 pub wall_time: Duration,
591 /// Per-iteration convergence history.
592 pub convergence_history: Vec<ConvergenceInfo>,
593 /// Algorithm used.
594 pub algorithm: Algorithm,
595}