Skip to main content

ruvector_solver/
router.rs

1//! Algorithm router and solver orchestrator.
2//!
3//! The [`SolverRouter`] inspects a matrix's [`SparsityProfile`] and the
4//! caller's [`QueryType`] to select the optimal [`Algorithm`] for each solve
5//! request. The [`SolverOrchestrator`] wraps the router together with concrete
6//! solver instances and provides high-level `solve` / `solve_with_fallback`
7//! entry points.
8//!
9//! # Routing decision tree
10//!
11//! | Query | Condition | Algorithm |
12//! |-------|-----------|-----------|
13//! | `LinearSystem` | diag-dominant + very sparse | Neumann |
14//! | `LinearSystem` | low condition number | CG |
15//! | `LinearSystem` | else | BMSSP |
16//! | `PageRankSingle` | always | ForwardPush |
17//! | `PageRankPairwise` | large graph | HybridRandomWalk |
18//! | `PageRankPairwise` | small graph | ForwardPush |
19//! | `SpectralFilter` | always | Neumann |
20//! | `BatchLinearSystem` | large batch | TRUE |
21//! | `BatchLinearSystem` | small batch | CG |
22//!
23//! # Fallback chain
24//!
25//! When the selected algorithm fails (non-convergence, numerical instability),
26//! [`SolverOrchestrator::solve_with_fallback`] tries a deterministic chain:
27//!
28//! **selected algorithm -> CG -> Dense**
29
30use std::time::Instant;
31
32use tracing::{debug, info, warn};
33
34use crate::error::SolverError;
35use crate::traits::SolverEngine;
36use crate::types::{
37    Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
38    QueryType, SolverResult, SparsityProfile,
39};
40
41// ---------------------------------------------------------------------------
42// RouterConfig
43// ---------------------------------------------------------------------------
44
45/// Configuration thresholds that govern the routing decision tree.
46///
47/// All thresholds have sensible defaults; override them when benchmarks on
48/// your workload indicate a different crossover point.
49///
50/// # Example
51///
52/// ```rust
53/// use ruvector_solver::router::RouterConfig;
54///
55/// let config = RouterConfig {
56///     cg_condition_threshold: 50.0,
57///     ..Default::default()
58/// };
59/// ```
60#[derive(Debug, Clone)]
61pub struct RouterConfig {
62    /// Maximum spectral radius for which the Neumann series is attempted.
63    ///
64    /// If the estimated spectral radius exceeds this value the router will
65    /// not select Neumann even for diagonally dominant matrices.
66    ///
67    /// Default: `0.95`.
68    pub neumann_spectral_radius_threshold: f64,
69
70    /// Maximum condition number for which CG is preferred over BMSSP.
71    ///
72    /// CG converges in O(sqrt(kappa)) iterations; when kappa is too large
73    /// a preconditioned method (BMSSP) is cheaper.
74    ///
75    /// Default: `100.0`.
76    pub cg_condition_threshold: f64,
77
78    /// Maximum density (fraction of non-zeros) for the Neumann sublinear
79    /// fast-path.
80    ///
81    /// Neumann is only worthwhile when the matrix is truly sparse.
82    ///
83    /// Default: `0.05` (5%).
84    pub sparsity_sublinear_threshold: f64,
85
86    /// Minimum batch size for which the TRUE solver is preferred over CG
87    /// in `BatchLinearSystem` queries.
88    ///
89    /// Default: `100`.
90    pub true_batch_threshold: usize,
91
92    /// Graph size threshold (number of rows) above which
93    /// `PageRankPairwise` switches from ForwardPush to HybridRandomWalk.
94    ///
95    /// Default: `1_000`.
96    pub push_graph_size_threshold: usize,
97}
98
99impl Default for RouterConfig {
100    fn default() -> Self {
101        Self {
102            neumann_spectral_radius_threshold: 0.95,
103            cg_condition_threshold: 100.0,
104            sparsity_sublinear_threshold: 0.05,
105            true_batch_threshold: 100,
106            push_graph_size_threshold: 1_000,
107        }
108    }
109}
110
111// ---------------------------------------------------------------------------
112// SolverRouter
113// ---------------------------------------------------------------------------
114
115/// Stateless algorithm selector.
116///
117/// Given a [`SparsityProfile`] and a [`QueryType`], the router walks a
118/// decision tree (documented in the [module-level docs](self)) to pick the
119/// [`Algorithm`] with the best expected cost.
120///
121/// # Example
122///
123/// ```rust
124/// use ruvector_solver::router::{SolverRouter, RouterConfig};
125/// use ruvector_solver::types::{Algorithm, QueryType, SparsityProfile};
126///
127/// let router = SolverRouter::new(RouterConfig::default());
128/// let profile = SparsityProfile {
129///     rows: 500,
130///     cols: 500,
131///     nnz: 1200,
132///     density: 0.0048,
133///     is_diag_dominant: true,
134///     estimated_spectral_radius: 0.4,
135///     estimated_condition: 10.0,
136///     is_symmetric_structure: true,
137///     avg_nnz_per_row: 2.4,
138///     max_nnz_per_row: 5,
139/// };
140///
141/// let algo = router.select_algorithm(&profile, &QueryType::LinearSystem);
142/// assert_eq!(algo, Algorithm::Neumann);
143/// ```
144#[derive(Debug, Clone)]
145pub struct SolverRouter {
146    config: RouterConfig,
147}
148
149impl SolverRouter {
150    /// Create a new router with the provided configuration.
151    pub fn new(config: RouterConfig) -> Self {
152        Self { config }
153    }
154
155    /// Return a shared reference to the active configuration.
156    pub fn config(&self) -> &RouterConfig {
157        &self.config
158    }
159
160    /// Select the optimal algorithm for the given matrix profile and query.
161    ///
162    /// This is a pure function with no side effects -- it does not touch the
163    /// matrix data, only the precomputed profile.
164    pub fn select_algorithm(
165        &self,
166        profile: &SparsityProfile,
167        query: &QueryType,
168    ) -> Algorithm {
169        match query {
170            // ----------------------------------------------------------
171            // Linear system: Neumann > CG > BMSSP
172            // ----------------------------------------------------------
173            QueryType::LinearSystem => self.route_linear_system(profile),
174
175            // ----------------------------------------------------------
176            // Single-source PageRank: always ForwardPush
177            // ----------------------------------------------------------
178            QueryType::PageRankSingle { .. } => {
179                debug!("routing to ForwardPush (single-source PageRank)");
180                Algorithm::ForwardPush
181            }
182
183            // ----------------------------------------------------------
184            // Pairwise PageRank: ForwardPush or HybridRandomWalk
185            // ----------------------------------------------------------
186            QueryType::PageRankPairwise { .. } => {
187                if profile.rows > self.config.push_graph_size_threshold {
188                    debug!(
189                        rows = profile.rows,
190                        threshold = self.config.push_graph_size_threshold,
191                        "routing to HybridRandomWalk (large graph pairwise PPR)"
192                    );
193                    Algorithm::HybridRandomWalk
194                } else {
195                    debug!(
196                        rows = profile.rows,
197                        "routing to ForwardPush (small graph pairwise PPR)"
198                    );
199                    Algorithm::ForwardPush
200                }
201            }
202
203            // ----------------------------------------------------------
204            // Spectral filter: always Neumann
205            // ----------------------------------------------------------
206            QueryType::SpectralFilter { .. } => {
207                debug!("routing to Neumann (spectral filter)");
208                Algorithm::Neumann
209            }
210
211            // ----------------------------------------------------------
212            // Batch linear system: TRUE or CG
213            // ----------------------------------------------------------
214            QueryType::BatchLinearSystem { batch_size } => {
215                if *batch_size > self.config.true_batch_threshold {
216                    debug!(
217                        batch_size,
218                        threshold = self.config.true_batch_threshold,
219                        "routing to TRUE (large batch)"
220                    );
221                    Algorithm::TRUE
222                } else {
223                    debug!(batch_size, "routing to CG (small batch)");
224                    Algorithm::CG
225                }
226            }
227        }
228    }
229
230    /// Internal routing logic for `LinearSystem` queries.
231    fn route_linear_system(&self, profile: &SparsityProfile) -> Algorithm {
232        if profile.is_diag_dominant
233            && profile.density < self.config.sparsity_sublinear_threshold
234            && profile.estimated_spectral_radius
235                < self.config.neumann_spectral_radius_threshold
236        {
237            debug!(
238                density = profile.density,
239                spectral_radius = profile.estimated_spectral_radius,
240                "routing to Neumann (diag-dominant, sparse, low spectral radius)"
241            );
242            Algorithm::Neumann
243        } else if profile.estimated_condition < self.config.cg_condition_threshold {
244            debug!(
245                condition = profile.estimated_condition,
246                "routing to CG (well-conditioned)"
247            );
248            Algorithm::CG
249        } else {
250            debug!(
251                condition = profile.estimated_condition,
252                "routing to BMSSP (ill-conditioned)"
253            );
254            Algorithm::BMSSP
255        }
256    }
257}
258
259impl Default for SolverRouter {
260    fn default() -> Self {
261        Self::new(RouterConfig::default())
262    }
263}
264
265// ---------------------------------------------------------------------------
266// SolverOrchestrator
267// ---------------------------------------------------------------------------
268
269/// High-level solver facade that combines routing with execution.
270///
271/// Owns a [`SolverRouter`] and delegates to the appropriate solver backend.
272/// Provides a [`solve_with_fallback`](Self::solve_with_fallback) method that
273/// automatically retries with progressively more robust (but slower)
274/// algorithms when the first choice fails.
275///
276/// # Example
277///
278/// ```rust
279/// use ruvector_solver::router::{SolverOrchestrator, RouterConfig};
280/// use ruvector_solver::types::{ComputeBudget, CsrMatrix, QueryType};
281///
282/// let orchestrator = SolverOrchestrator::new(RouterConfig::default());
283///
284/// let matrix = CsrMatrix::<f64>::from_coo(3, 3, vec![
285///     (0, 0, 2.0), (0, 1, -0.5),
286///     (1, 0, -0.5), (1, 1, 2.0), (1, 2, -0.5),
287///     (2, 1, -0.5), (2, 2, 2.0),
288/// ]);
289/// let rhs = vec![1.0, 0.0, 1.0];
290/// let budget = ComputeBudget::default();
291///
292/// let result = orchestrator
293///     .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
294///     .unwrap();
295/// assert!(result.residual_norm < 1e-6);
296/// ```
297#[derive(Debug, Clone)]
298pub struct SolverOrchestrator {
299    router: SolverRouter,
300}
301
302impl SolverOrchestrator {
303    /// Create a new orchestrator with the provided routing configuration.
304    pub fn new(config: RouterConfig) -> Self {
305        Self {
306            router: SolverRouter::new(config),
307        }
308    }
309
310    /// Return a reference to the inner router.
311    pub fn router(&self) -> &SolverRouter {
312        &self.router
313    }
314
315    // -----------------------------------------------------------------------
316    // Public API
317    // -----------------------------------------------------------------------
318
319    /// Auto-select the best algorithm and solve `Ax = b`.
320    ///
321    /// Analyses the sparsity profile of `matrix`, routes to the best
322    /// algorithm via [`SolverRouter::select_algorithm`], and dispatches.
323    ///
324    /// # Errors
325    ///
326    /// Returns [`SolverError`] if the selected solver fails (e.g.
327    /// non-convergence, dimension mismatch, numerical instability).
328    pub fn solve(
329        &self,
330        matrix: &CsrMatrix<f64>,
331        rhs: &[f64],
332        query: QueryType,
333        budget: &ComputeBudget,
334    ) -> Result<SolverResult, SolverError> {
335        let profile = Self::analyze_sparsity(matrix);
336        let algorithm = self.router.select_algorithm(&profile, &query);
337
338        info!(%algorithm, rows = matrix.rows, nnz = matrix.nnz(), "solve: selected algorithm");
339
340        self.dispatch(algorithm, matrix, rhs, budget)
341    }
342
343    /// Solve with a deterministic fallback chain.
344    ///
345    /// Tries the routed algorithm first. On failure, falls back through:
346    ///
347    /// 1. **Selected algorithm** (from routing)
348    /// 2. **CG** (robust iterative)
349    /// 3. **Dense** (direct, always works for small systems)
350    ///
351    /// Each step is only attempted if the previous one returned an error.
352    ///
353    /// # Errors
354    ///
355    /// Returns the error from the *last* fallback attempt if all fail.
356    pub fn solve_with_fallback(
357        &self,
358        matrix: &CsrMatrix<f64>,
359        rhs: &[f64],
360        query: QueryType,
361        budget: &ComputeBudget,
362    ) -> Result<SolverResult, SolverError> {
363        let profile = Self::analyze_sparsity(matrix);
364        let primary = self.router.select_algorithm(&profile, &query);
365
366        let chain = Self::build_fallback_chain(primary);
367
368        info!(
369            ?chain,
370            rows = matrix.rows,
371            nnz = matrix.nnz(),
372            "solve_with_fallback: attempting chain"
373        );
374
375        let mut last_err: Option<SolverError> = None;
376
377        for (idx, &algorithm) in chain.iter().enumerate() {
378            match self.dispatch(algorithm, matrix, rhs, budget) {
379                Ok(result) => {
380                    if idx > 0 {
381                        info!(
382                            %algorithm,
383                            "fallback succeeded on attempt {}",
384                            idx + 1
385                        );
386                    }
387                    return Ok(result);
388                }
389                Err(e) => {
390                    warn!(
391                        %algorithm,
392                        error = %e,
393                        "algorithm failed, trying next in fallback chain"
394                    );
395                    last_err = Some(e);
396                }
397            }
398        }
399
400        Err(last_err.unwrap_or_else(|| {
401            SolverError::BackendError("fallback chain was empty".into())
402        }))
403    }
404
405    /// Estimate the computational complexity of solving with the routed
406    /// algorithm, without actually solving.
407    ///
408    /// Useful for admission control, cost estimation, or deciding whether
409    /// to batch multiple queries.
410    pub fn estimate_complexity(
411        &self,
412        matrix: &CsrMatrix<f64>,
413        query: &QueryType,
414    ) -> ComplexityEstimate {
415        let profile = Self::analyze_sparsity(matrix);
416        let algorithm = self.router.select_algorithm(&profile, query);
417        let n = profile.rows;
418
419        let (estimated_iterations, complexity_class) = match algorithm {
420            Algorithm::Neumann => {
421                let k = if profile.estimated_spectral_radius > 0.0
422                    && profile.estimated_spectral_radius < 1.0
423                {
424                    let log_inv_eps = (1.0 / 1e-8_f64).ln();
425                    let log_inv_rho =
426                        (1.0 / profile.estimated_spectral_radius).ln();
427                    (log_inv_eps / log_inv_rho).ceil() as usize
428                } else {
429                    1000
430                };
431                (k.min(1000), ComplexityClass::SublinearNnz)
432            }
433            Algorithm::CG => {
434                let iters = (profile.estimated_condition.sqrt()).ceil() as usize;
435                (iters.min(1000), ComplexityClass::SqrtCondition)
436            }
437            Algorithm::ForwardPush | Algorithm::BackwardPush => {
438                let iters = ((n as f64).sqrt()).ceil() as usize;
439                (iters, ComplexityClass::SublinearNnz)
440            }
441            Algorithm::HybridRandomWalk => {
442                (n.min(1000), ComplexityClass::Linear)
443            }
444            Algorithm::TRUE => {
445                let iters = (profile.estimated_condition.sqrt()).ceil() as usize;
446                (iters.min(1000), ComplexityClass::SqrtCondition)
447            }
448            Algorithm::BMSSP => {
449                let iters = (profile.estimated_condition.sqrt().ln())
450                    .ceil() as usize;
451                (iters.max(1).min(1000), ComplexityClass::Linear)
452            }
453            Algorithm::Dense => (1, ComplexityClass::Cubic),
454            Algorithm::Jacobi | Algorithm::GaussSeidel => {
455                (1000, ComplexityClass::Linear)
456            }
457        };
458
459        let estimated_flops = match algorithm {
460            Algorithm::Dense => {
461                let dim = n as u64;
462                (2 * dim * dim * dim) / 3
463            }
464            _ => {
465                (estimated_iterations as u64)
466                    * (2 * profile.nnz as u64 + n as u64)
467            }
468        };
469
470        let estimated_memory_bytes = match algorithm {
471            Algorithm::Dense => {
472                n * profile.cols * std::mem::size_of::<f64>()
473            }
474            _ => {
475                // CSR storage + 3 work vectors.
476                let csr = profile.nnz
477                    * (std::mem::size_of::<f64>() + std::mem::size_of::<usize>())
478                    + (n + 1) * std::mem::size_of::<usize>();
479                let work = 3 * n * std::mem::size_of::<f64>();
480                csr + work
481            }
482        };
483
484        ComplexityEstimate {
485            algorithm,
486            estimated_flops,
487            estimated_iterations,
488            estimated_memory_bytes,
489            complexity_class,
490        }
491    }
492
493    /// Analyse the sparsity profile of a CSR matrix.
494    ///
495    /// Performs a single O(nnz) pass over the matrix to compute structural
496    /// and numerical properties used by the router. This is intentionally
497    /// cheap so it can be called on every solve request.
498    pub fn analyze_sparsity(matrix: &CsrMatrix<f64>) -> SparsityProfile {
499        let n = matrix.rows;
500        let m = matrix.cols;
501        let nnz = matrix.nnz();
502        let total_entries = (n as f64) * (m as f64);
503        let density = if total_entries > 0.0 {
504            nnz as f64 / total_entries
505        } else {
506            0.0
507        };
508
509        let mut is_diag_dominant = true;
510        let mut max_nnz_per_row: usize = 0;
511        let mut sum_off_diag_ratio = 0.0_f64;
512        let mut diag_min = f64::INFINITY;
513        let mut diag_max = 0.0_f64;
514        let mut symmetric_mismatches: usize = 0;
515
516        // Only check symmetry for small-to-medium matrices to keep O(nnz).
517        let check_symmetry = nnz <= 100_000;
518
519        for row in 0..n {
520            let start = matrix.row_ptr[row];
521            let end = matrix.row_ptr[row + 1];
522            let row_nnz = end - start;
523            max_nnz_per_row = max_nnz_per_row.max(row_nnz);
524
525            let mut diag_val: f64 = 0.0;
526            let mut off_diag_sum: f64 = 0.0;
527
528            for idx in start..end {
529                let col = matrix.col_indices[idx];
530                let val = matrix.values[idx];
531
532                if col == row {
533                    diag_val = val.abs();
534                } else {
535                    off_diag_sum += val.abs();
536                }
537
538                // Structural symmetry check: look for (col, row) entry.
539                if check_symmetry && col != row && col < n {
540                    let col_start = matrix.row_ptr[col];
541                    let col_end = matrix.row_ptr[col + 1];
542                    let found = matrix.col_indices[col_start..col_end]
543                        .binary_search(&row)
544                        .is_ok();
545                    if !found {
546                        symmetric_mismatches += 1;
547                    }
548                }
549            }
550
551            if diag_val <= off_diag_sum {
552                is_diag_dominant = false;
553            }
554
555            if diag_val > 0.0 {
556                let ratio = off_diag_sum / diag_val;
557                sum_off_diag_ratio += ratio;
558                diag_min = diag_min.min(diag_val);
559                diag_max = diag_max.max(diag_val);
560            } else if n > 0 {
561                is_diag_dominant = false;
562                sum_off_diag_ratio += 1.0;
563            }
564        }
565
566        let avg_nnz_per_row = if n > 0 {
567            nnz as f64 / n as f64
568        } else {
569            0.0
570        };
571
572        // Spectral radius of Jacobi iteration matrix D^{-1}(L+U).
573        let estimated_spectral_radius = if n > 0 {
574            sum_off_diag_ratio / n as f64
575        } else {
576            0.0
577        };
578
579        // Rough condition number from diagonal range.
580        let estimated_condition = if diag_min > 0.0 && diag_min.is_finite() {
581            diag_max / diag_min
582        } else {
583            f64::INFINITY
584        };
585
586        let is_symmetric_structure = if check_symmetry {
587            symmetric_mismatches == 0
588        } else {
589            n == m
590        };
591
592        SparsityProfile {
593            rows: n,
594            cols: m,
595            nnz,
596            density,
597            is_diag_dominant,
598            estimated_spectral_radius,
599            estimated_condition,
600            is_symmetric_structure,
601            avg_nnz_per_row,
602            max_nnz_per_row,
603        }
604    }
605
606    // -----------------------------------------------------------------------
607    // Internal helpers
608    // -----------------------------------------------------------------------
609
610    /// Build a deduplicated fallback chain: `[primary, CG, Dense]`.
611    fn build_fallback_chain(primary: Algorithm) -> Vec<Algorithm> {
612        let mut chain = Vec::with_capacity(3);
613        chain.push(primary);
614
615        if primary != Algorithm::CG {
616            chain.push(Algorithm::CG);
617        }
618        if primary != Algorithm::Dense {
619            chain.push(Algorithm::Dense);
620        }
621
622        chain
623    }
624
625    /// Dispatch a solve request to the concrete solver for `algorithm`.
626    ///
627    /// Feature-gated solvers return a `BackendError` when the feature is
628    /// not compiled in, allowing the fallback chain to proceed.
629    fn dispatch(
630        &self,
631        algorithm: Algorithm,
632        matrix: &CsrMatrix<f64>,
633        rhs: &[f64],
634        budget: &ComputeBudget,
635    ) -> Result<SolverResult, SolverError> {
636        match algorithm {
637            // ----- Neumann series ------------------------------------------
638            Algorithm::Neumann => {
639                #[cfg(feature = "neumann")]
640                {
641                    let solver = crate::neumann::NeumannSolver::new(
642                        budget.tolerance,
643                        budget.max_iterations,
644                    );
645                    SolverEngine::solve(&solver, matrix, rhs, budget)
646                }
647                #[cfg(not(feature = "neumann"))]
648                {
649                    Err(SolverError::BackendError(
650                        "neumann feature is not enabled".into(),
651                    ))
652                }
653            }
654
655            // ----- Conjugate Gradient --------------------------------------
656            Algorithm::CG => {
657                #[cfg(feature = "cg")]
658                {
659                    let solver =
660                        crate::cg::ConjugateGradientSolver::new(budget.tolerance, budget.max_iterations, false);
661                    solver.solve(matrix, rhs, budget)
662                }
663                #[cfg(not(feature = "cg"))]
664                {
665                    // Inline CG when the feature crate is not available.
666                    self.solve_cg_inline(matrix, rhs, budget)
667                }
668            }
669
670            // ----- ForwardPush ---------------------------------------------
671            Algorithm::ForwardPush => {
672                #[cfg(feature = "forward-push")]
673                {
674                    self.solve_jacobi_fallback(
675                        Algorithm::ForwardPush,
676                        matrix,
677                        rhs,
678                        budget,
679                    )
680                }
681                #[cfg(not(feature = "forward-push"))]
682                {
683                    Err(SolverError::BackendError(
684                        "forward-push feature is not enabled".into(),
685                    ))
686                }
687            }
688
689            // ----- BackwardPush --------------------------------------------
690            Algorithm::BackwardPush => {
691                #[cfg(feature = "backward-push")]
692                {
693                    self.solve_jacobi_fallback(
694                        Algorithm::BackwardPush,
695                        matrix,
696                        rhs,
697                        budget,
698                    )
699                }
700                #[cfg(not(feature = "backward-push"))]
701                {
702                    Err(SolverError::BackendError(
703                        "backward-push feature is not enabled".into(),
704                    ))
705                }
706            }
707
708            // ----- HybridRandomWalk ----------------------------------------
709            Algorithm::HybridRandomWalk => {
710                #[cfg(feature = "hybrid-random-walk")]
711                {
712                    self.solve_jacobi_fallback(
713                        Algorithm::HybridRandomWalk,
714                        matrix,
715                        rhs,
716                        budget,
717                    )
718                }
719                #[cfg(not(feature = "hybrid-random-walk"))]
720                {
721                    Err(SolverError::BackendError(
722                        "hybrid-random-walk feature is not enabled".into(),
723                    ))
724                }
725            }
726
727            // ----- TRUE batch solver ---------------------------------------
728            Algorithm::TRUE => {
729                #[cfg(feature = "true-solver")]
730                {
731                    // TRUE for a single RHS degrades to Neumann.
732                    let solver = crate::neumann::NeumannSolver::new(
733                        budget.tolerance,
734                        budget.max_iterations,
735                    );
736                    let mut result = SolverEngine::solve(&solver, matrix, rhs, budget)?;
737                    result.algorithm = Algorithm::TRUE;
738                    Ok(result)
739                }
740                #[cfg(not(feature = "true-solver"))]
741                {
742                    Err(SolverError::BackendError(
743                        "true-solver feature is not enabled".into(),
744                    ))
745                }
746            }
747
748            // ----- BMSSP ---------------------------------------------------
749            Algorithm::BMSSP => {
750                #[cfg(feature = "bmssp")]
751                {
752                    self.solve_jacobi_fallback(
753                        Algorithm::BMSSP,
754                        matrix,
755                        rhs,
756                        budget,
757                    )
758                }
759                #[cfg(not(feature = "bmssp"))]
760                {
761                    Err(SolverError::BackendError(
762                        "bmssp feature is not enabled".into(),
763                    ))
764                }
765            }
766
767            // ----- Dense direct solver -------------------------------------
768            Algorithm::Dense => self.solve_dense(matrix, rhs, budget),
769
770            // ----- Legacy iterative solvers --------------------------------
771            Algorithm::Jacobi => {
772                self.solve_jacobi_fallback(Algorithm::Jacobi, matrix, rhs, budget)
773            }
774            Algorithm::GaussSeidel => {
775                self.solve_jacobi_fallback(
776                    Algorithm::GaussSeidel,
777                    matrix,
778                    rhs,
779                    budget,
780                )
781            }
782        }
783    }
784
785    /// Inline Conjugate Gradient for symmetric positive-definite systems.
786    ///
787    /// Standard unpreconditioned CG. Used when the `cg` feature crate is
788    /// not compiled in but CG is needed (e.g. as a fallback).
789    #[allow(dead_code)]
790    fn solve_cg_inline(
791        &self,
792        matrix: &CsrMatrix<f64>,
793        rhs: &[f64],
794        budget: &ComputeBudget,
795    ) -> Result<SolverResult, SolverError> {
796        let n = matrix.rows;
797        validate_square(matrix)?;
798        validate_rhs_len(matrix, rhs)?;
799
800        let max_iters = budget.max_iterations;
801        let tol = budget.tolerance;
802        let start = Instant::now();
803
804        let mut x = vec![0.0_f64; n];
805        let mut r: Vec<f64> = rhs.to_vec();
806        let mut p = r.clone();
807        let mut ap = vec![0.0_f64; n];
808        let mut convergence_history = Vec::new();
809
810        let mut r_dot_r = dot(&r, &r);
811
812        for iter in 0..max_iters {
813            let residual_norm = r_dot_r.sqrt();
814
815            convergence_history.push(ConvergenceInfo {
816                iteration: iter,
817                residual_norm,
818            });
819
820            if residual_norm.is_nan() || residual_norm.is_infinite() {
821                return Err(SolverError::NumericalInstability {
822                    iteration: iter,
823                    detail: format!("CG residual became {}", residual_norm),
824                });
825            }
826
827            if residual_norm < tol {
828                return Ok(SolverResult {
829                    solution: x.iter().map(|&v| v as f32).collect(),
830                    iterations: iter,
831                    residual_norm,
832                    wall_time: start.elapsed(),
833                    convergence_history,
834                    algorithm: Algorithm::CG,
835                });
836            }
837
838            // ap = A * p
839            matrix.spmv(&p, &mut ap);
840
841            let p_dot_ap = dot(&p, &ap);
842            if p_dot_ap.abs() < 1e-30 {
843                return Err(SolverError::NumericalInstability {
844                    iteration: iter,
845                    detail: "CG: p^T A p near zero (matrix may not be SPD)"
846                        .into(),
847                });
848            }
849
850            let alpha = r_dot_r / p_dot_ap;
851
852            for i in 0..n {
853                x[i] += alpha * p[i];
854                r[i] -= alpha * ap[i];
855            }
856
857            let new_r_dot_r = dot(&r, &r);
858            let beta = new_r_dot_r / r_dot_r;
859
860            for i in 0..n {
861                p[i] = r[i] + beta * p[i];
862            }
863
864            r_dot_r = new_r_dot_r;
865
866            if start.elapsed() > budget.max_time {
867                return Err(SolverError::BudgetExhausted {
868                    reason: "wall-clock time limit exceeded".into(),
869                    elapsed: start.elapsed(),
870                });
871            }
872        }
873
874        let final_residual = convergence_history
875            .last()
876            .map(|c| c.residual_norm)
877            .unwrap_or(f64::INFINITY);
878
879        Err(SolverError::NonConvergence {
880            iterations: max_iters,
881            residual: final_residual,
882            tolerance: tol,
883        })
884    }
885
886    /// Dense direct solver via Gaussian elimination with partial pivoting.
887    ///
888    /// O(n^3) time and O(n^2) memory. Only used as a last-resort fallback.
889    fn solve_dense(
890        &self,
891        matrix: &CsrMatrix<f64>,
892        rhs: &[f64],
893        _budget: &ComputeBudget,
894    ) -> Result<SolverResult, SolverError> {
895        let n = matrix.rows;
896        validate_square(matrix)?;
897        validate_rhs_len(matrix, rhs)?;
898
899        const MAX_DENSE_DIM: usize = 4096;
900        if n > MAX_DENSE_DIM {
901            return Err(SolverError::InvalidInput(
902                crate::error::ValidationError::MatrixTooLarge {
903                    rows: n,
904                    cols: n,
905                    max_dim: MAX_DENSE_DIM,
906                },
907            ));
908        }
909
910        let start = Instant::now();
911
912        // Expand CSR to dense augmented matrix [A | b].
913        let stride = n + 1;
914        let mut aug = vec![0.0_f64; n * stride];
915        for row in 0..n {
916            let rs = matrix.row_ptr[row];
917            let re = matrix.row_ptr[row + 1];
918            for idx in rs..re {
919                let col = matrix.col_indices[idx];
920                aug[row * stride + col] = matrix.values[idx];
921            }
922            aug[row * stride + n] = rhs[row];
923        }
924
925        // Gaussian elimination with partial pivoting.
926        for col in 0..n {
927            let mut max_val = aug[col * stride + col].abs();
928            let mut max_row = col;
929            for row in (col + 1)..n {
930                let val = aug[row * stride + col].abs();
931                if val > max_val {
932                    max_val = val;
933                    max_row = row;
934                }
935            }
936
937            if max_val < 1e-12 {
938                return Err(SolverError::NumericalInstability {
939                    iteration: 0,
940                    detail: format!(
941                        "dense solver: near-zero pivot ({:.2e}) at column {}",
942                        max_val, col
943                    ),
944                });
945            }
946
947            if max_row != col {
948                for j in 0..stride {
949                    aug.swap(col * stride + j, max_row * stride + j);
950                }
951            }
952
953            let pivot = aug[col * stride + col];
954            for row in (col + 1)..n {
955                let factor = aug[row * stride + col] / pivot;
956                aug[row * stride + col] = 0.0;
957                for j in (col + 1)..stride {
958                    let above = aug[col * stride + j];
959                    aug[row * stride + j] -= factor * above;
960                }
961            }
962        }
963
964        // Back-substitution.
965        let mut solution_f64 = vec![0.0_f64; n];
966        for row in (0..n).rev() {
967            let mut sum = aug[row * stride + n];
968            for col in (row + 1)..n {
969                sum -= aug[row * stride + col] * solution_f64[col];
970            }
971            solution_f64[row] = sum / aug[row * stride + row];
972        }
973
974        // Compute residual.
975        let mut ax = vec![0.0_f64; n];
976        matrix.spmv(&solution_f64, &mut ax);
977        let residual_norm: f64 = (0..n)
978            .map(|i| {
979                let r = rhs[i] - ax[i];
980                r * r
981            })
982            .sum::<f64>()
983            .sqrt();
984
985        let solution: Vec<f32> = solution_f64.iter().map(|&v| v as f32).collect();
986
987        Ok(SolverResult {
988            solution,
989            iterations: 1,
990            residual_norm,
991            wall_time: start.elapsed(),
992            convergence_history: vec![ConvergenceInfo {
993                iteration: 0,
994                residual_norm,
995            }],
996            algorithm: Algorithm::Dense,
997        })
998    }
999
1000    /// Generic Jacobi-iteration fallback for algorithms whose specialised
1001    /// backends are not yet implemented.
1002    ///
1003    /// Tags the result with the requested `algorithm` label so callers see
1004    /// the correct algorithm in the result.
1005    fn solve_jacobi_fallback(
1006        &self,
1007        algorithm: Algorithm,
1008        matrix: &CsrMatrix<f64>,
1009        rhs: &[f64],
1010        budget: &ComputeBudget,
1011    ) -> Result<SolverResult, SolverError> {
1012        let n = matrix.rows;
1013        validate_square(matrix)?;
1014        validate_rhs_len(matrix, rhs)?;
1015
1016        let max_iters = budget.max_iterations;
1017        let tol = budget.tolerance;
1018        let start = Instant::now();
1019
1020        // Extract diagonal.
1021        let mut diag = vec![0.0_f64; n];
1022        for row in 0..n {
1023            let rs = matrix.row_ptr[row];
1024            let re = matrix.row_ptr[row + 1];
1025            for idx in rs..re {
1026                if matrix.col_indices[idx] == row {
1027                    diag[row] = matrix.values[idx];
1028                    break;
1029                }
1030            }
1031        }
1032
1033        for (i, &d) in diag.iter().enumerate() {
1034            if d.abs() < 1e-30 {
1035                return Err(SolverError::NumericalInstability {
1036                    iteration: 0,
1037                    detail: format!(
1038                        "zero or near-zero diagonal at row {} (val={:.2e})",
1039                        i, d
1040                    ),
1041                });
1042            }
1043        }
1044
1045        let mut x = vec![0.0_f64; n];
1046        let mut x_new = vec![0.0_f64; n];
1047        let mut temp = vec![0.0_f64; n];
1048        let mut convergence_history = Vec::new();
1049
1050        for iter in 0..max_iters {
1051            for row in 0..n {
1052                let rs = matrix.row_ptr[row];
1053                let re = matrix.row_ptr[row + 1];
1054                let mut sum = 0.0_f64;
1055                for idx in rs..re {
1056                    let col = matrix.col_indices[idx];
1057                    if col != row {
1058                        sum += matrix.values[idx] * x[col];
1059                    }
1060                }
1061                x_new[row] = (rhs[row] - sum) / diag[row];
1062            }
1063
1064            matrix.spmv(&x_new, &mut temp);
1065            let residual_norm: f64 = (0..n)
1066                .map(|i| {
1067                    let r = rhs[i] - temp[i];
1068                    r * r
1069                })
1070                .sum::<f64>()
1071                .sqrt();
1072
1073            convergence_history.push(ConvergenceInfo {
1074                iteration: iter,
1075                residual_norm,
1076            });
1077
1078            if residual_norm.is_nan() || residual_norm.is_infinite() {
1079                return Err(SolverError::NumericalInstability {
1080                    iteration: iter,
1081                    detail: format!("residual became {}", residual_norm),
1082                });
1083            }
1084
1085            if residual_norm < tol {
1086                return Ok(SolverResult {
1087                    solution: x_new.iter().map(|&v| v as f32).collect(),
1088                    iterations: iter + 1,
1089                    residual_norm,
1090                    wall_time: start.elapsed(),
1091                    convergence_history,
1092                    algorithm,
1093                });
1094            }
1095
1096            std::mem::swap(&mut x, &mut x_new);
1097
1098            if start.elapsed() > budget.max_time {
1099                return Err(SolverError::BudgetExhausted {
1100                    reason: "wall-clock time limit exceeded".into(),
1101                    elapsed: start.elapsed(),
1102                });
1103            }
1104        }
1105
1106        let final_residual = convergence_history
1107            .last()
1108            .map(|c| c.residual_norm)
1109            .unwrap_or(f64::INFINITY);
1110
1111        Err(SolverError::NonConvergence {
1112            iterations: max_iters,
1113            residual: final_residual,
1114            tolerance: tol,
1115        })
1116    }
1117}
1118
1119impl Default for SolverOrchestrator {
1120    fn default() -> Self {
1121        Self::new(RouterConfig::default())
1122    }
1123}
1124
1125// ---------------------------------------------------------------------------
1126// Utility functions
1127// ---------------------------------------------------------------------------
1128
1129/// Dot product of two f64 slices.
1130#[inline]
1131#[allow(dead_code)]
1132fn dot(a: &[f64], b: &[f64]) -> f64 {
1133    assert_eq!(a.len(), b.len(), "dot: length mismatch {} vs {}", a.len(), b.len());
1134    a.iter()
1135        .zip(b.iter())
1136        .map(|(&ai, &bi)| ai * bi)
1137        .sum()
1138}
1139
1140/// Validate that a matrix is square.
1141fn validate_square(matrix: &CsrMatrix<f64>) -> Result<(), SolverError> {
1142    if matrix.rows != matrix.cols {
1143        return Err(SolverError::InvalidInput(
1144            crate::error::ValidationError::DimensionMismatch(format!(
1145                "matrix must be square, got {}x{}",
1146                matrix.rows, matrix.cols
1147            )),
1148        ));
1149    }
1150    Ok(())
1151}
1152
1153/// Validate that the RHS vector length matches the matrix dimension.
1154fn validate_rhs_len(
1155    matrix: &CsrMatrix<f64>,
1156    rhs: &[f64],
1157) -> Result<(), SolverError> {
1158    if rhs.len() != matrix.rows {
1159        return Err(SolverError::InvalidInput(
1160            crate::error::ValidationError::DimensionMismatch(format!(
1161                "rhs length {} does not match matrix dimension {}",
1162                rhs.len(),
1163                matrix.rows
1164            )),
1165        ));
1166    }
1167    Ok(())
1168}
1169
1170// ---------------------------------------------------------------------------
1171// Tests
1172// ---------------------------------------------------------------------------
1173
1174#[cfg(test)]
1175mod tests {
1176    use super::*;
1177
1178    /// Build a 3x3 diagonally dominant SPD matrix.
1179    fn diag_dominant_3x3() -> CsrMatrix<f64> {
1180        CsrMatrix::<f64>::from_coo(
1181            3,
1182            3,
1183            vec![
1184                (0, 0, 4.0),
1185                (0, 1, -1.0),
1186                (1, 0, -1.0),
1187                (1, 1, 4.0),
1188                (1, 2, -1.0),
1189                (2, 1, -1.0),
1190                (2, 2, 4.0),
1191            ],
1192        )
1193    }
1194
1195    fn default_budget() -> ComputeBudget {
1196        ComputeBudget {
1197            tolerance: 1e-8,
1198            ..Default::default()
1199        }
1200    }
1201
1202    // -----------------------------------------------------------------------
1203    // Router tests
1204    // -----------------------------------------------------------------------
1205
1206    #[test]
1207    fn routes_diag_dominant_sparse_to_neumann() {
1208        let router = SolverRouter::new(RouterConfig::default());
1209        let profile = SparsityProfile {
1210            rows: 1000,
1211            cols: 1000,
1212            nnz: 3000,
1213            density: 0.003,
1214            is_diag_dominant: true,
1215            estimated_spectral_radius: 0.5,
1216            estimated_condition: 10.0,
1217            is_symmetric_structure: true,
1218            avg_nnz_per_row: 3.0,
1219            max_nnz_per_row: 5,
1220        };
1221
1222        assert_eq!(
1223            router.select_algorithm(&profile, &QueryType::LinearSystem),
1224            Algorithm::Neumann
1225        );
1226    }
1227
1228    #[test]
1229    fn routes_well_conditioned_non_diag_dominant_to_cg() {
1230        let router = SolverRouter::new(RouterConfig::default());
1231        let profile = SparsityProfile {
1232            rows: 1000,
1233            cols: 1000,
1234            nnz: 50_000,
1235            density: 0.05,
1236            is_diag_dominant: false,
1237            estimated_spectral_radius: 0.9,
1238            estimated_condition: 50.0,
1239            is_symmetric_structure: true,
1240            avg_nnz_per_row: 50.0,
1241            max_nnz_per_row: 80,
1242        };
1243
1244        assert_eq!(
1245            router.select_algorithm(&profile, &QueryType::LinearSystem),
1246            Algorithm::CG
1247        );
1248    }
1249
1250    #[test]
1251    fn routes_ill_conditioned_to_bmssp() {
1252        let router = SolverRouter::new(RouterConfig::default());
1253        let profile = SparsityProfile {
1254            rows: 1000,
1255            cols: 1000,
1256            nnz: 50_000,
1257            density: 0.05,
1258            is_diag_dominant: false,
1259            estimated_spectral_radius: 0.99,
1260            estimated_condition: 500.0,
1261            is_symmetric_structure: true,
1262            avg_nnz_per_row: 50.0,
1263            max_nnz_per_row: 80,
1264        };
1265
1266        assert_eq!(
1267            router.select_algorithm(&profile, &QueryType::LinearSystem),
1268            Algorithm::BMSSP
1269        );
1270    }
1271
1272    #[test]
1273    fn routes_single_pagerank_to_forward_push() {
1274        let router = SolverRouter::new(RouterConfig::default());
1275        let profile = SparsityProfile {
1276            rows: 5000,
1277            cols: 5000,
1278            nnz: 20_000,
1279            density: 0.0008,
1280            is_diag_dominant: false,
1281            estimated_spectral_radius: 0.85,
1282            estimated_condition: 100.0,
1283            is_symmetric_structure: false,
1284            avg_nnz_per_row: 4.0,
1285            max_nnz_per_row: 50,
1286        };
1287
1288        assert_eq!(
1289            router.select_algorithm(
1290                &profile,
1291                &QueryType::PageRankSingle { source: 0 }
1292            ),
1293            Algorithm::ForwardPush
1294        );
1295    }
1296
1297    #[test]
1298    fn routes_large_pairwise_to_hybrid_random_walk() {
1299        let router = SolverRouter::new(RouterConfig::default());
1300        let profile = SparsityProfile {
1301            rows: 5000,
1302            cols: 5000,
1303            nnz: 20_000,
1304            density: 0.0008,
1305            is_diag_dominant: false,
1306            estimated_spectral_radius: 0.85,
1307            estimated_condition: 100.0,
1308            is_symmetric_structure: false,
1309            avg_nnz_per_row: 4.0,
1310            max_nnz_per_row: 50,
1311        };
1312
1313        assert_eq!(
1314            router.select_algorithm(
1315                &profile,
1316                &QueryType::PageRankPairwise {
1317                    source: 0,
1318                    target: 100,
1319                }
1320            ),
1321            Algorithm::HybridRandomWalk
1322        );
1323    }
1324
1325    #[test]
1326    fn routes_small_pairwise_to_forward_push() {
1327        let router = SolverRouter::new(RouterConfig::default());
1328        let profile = SparsityProfile {
1329            rows: 500,
1330            cols: 500,
1331            nnz: 2000,
1332            density: 0.008,
1333            is_diag_dominant: false,
1334            estimated_spectral_radius: 0.85,
1335            estimated_condition: 100.0,
1336            is_symmetric_structure: false,
1337            avg_nnz_per_row: 4.0,
1338            max_nnz_per_row: 10,
1339        };
1340
1341        assert_eq!(
1342            router.select_algorithm(
1343                &profile,
1344                &QueryType::PageRankPairwise {
1345                    source: 0,
1346                    target: 10,
1347                }
1348            ),
1349            Algorithm::ForwardPush
1350        );
1351    }
1352
1353    #[test]
1354    fn routes_spectral_filter_to_neumann() {
1355        let router = SolverRouter::new(RouterConfig::default());
1356        let profile = SparsityProfile {
1357            rows: 100,
1358            cols: 100,
1359            nnz: 500,
1360            density: 0.05,
1361            is_diag_dominant: true,
1362            estimated_spectral_radius: 0.3,
1363            estimated_condition: 5.0,
1364            is_symmetric_structure: true,
1365            avg_nnz_per_row: 5.0,
1366            max_nnz_per_row: 8,
1367        };
1368
1369        assert_eq!(
1370            router.select_algorithm(
1371                &profile,
1372                &QueryType::SpectralFilter {
1373                    polynomial_degree: 10,
1374                }
1375            ),
1376            Algorithm::Neumann
1377        );
1378    }
1379
1380    #[test]
1381    fn routes_large_batch_to_true() {
1382        let router = SolverRouter::new(RouterConfig::default());
1383        let profile = SparsityProfile {
1384            rows: 1000,
1385            cols: 1000,
1386            nnz: 5000,
1387            density: 0.005,
1388            is_diag_dominant: true,
1389            estimated_spectral_radius: 0.5,
1390            estimated_condition: 10.0,
1391            is_symmetric_structure: true,
1392            avg_nnz_per_row: 5.0,
1393            max_nnz_per_row: 10,
1394        };
1395
1396        assert_eq!(
1397            router.select_algorithm(
1398                &profile,
1399                &QueryType::BatchLinearSystem { batch_size: 200 }
1400            ),
1401            Algorithm::TRUE
1402        );
1403    }
1404
1405    #[test]
1406    fn routes_small_batch_to_cg() {
1407        let router = SolverRouter::new(RouterConfig::default());
1408        let profile = SparsityProfile {
1409            rows: 1000,
1410            cols: 1000,
1411            nnz: 5000,
1412            density: 0.005,
1413            is_diag_dominant: true,
1414            estimated_spectral_radius: 0.5,
1415            estimated_condition: 10.0,
1416            is_symmetric_structure: true,
1417            avg_nnz_per_row: 5.0,
1418            max_nnz_per_row: 10,
1419        };
1420
1421        assert_eq!(
1422            router.select_algorithm(
1423                &profile,
1424                &QueryType::BatchLinearSystem { batch_size: 50 }
1425            ),
1426            Algorithm::CG
1427        );
1428    }
1429
1430    #[test]
1431    fn custom_config_overrides_thresholds() {
1432        let config = RouterConfig {
1433            cg_condition_threshold: 10.0,
1434            ..Default::default()
1435        };
1436        let router = SolverRouter::new(config);
1437
1438        let profile = SparsityProfile {
1439            rows: 1000,
1440            cols: 1000,
1441            nnz: 50_000,
1442            density: 0.05,
1443            is_diag_dominant: false,
1444            estimated_spectral_radius: 0.9,
1445            estimated_condition: 50.0,
1446            is_symmetric_structure: true,
1447            avg_nnz_per_row: 50.0,
1448            max_nnz_per_row: 80,
1449        };
1450
1451        assert_eq!(
1452            router.select_algorithm(&profile, &QueryType::LinearSystem),
1453            Algorithm::BMSSP
1454        );
1455    }
1456
1457    #[test]
1458    fn neumann_requires_low_spectral_radius() {
1459        let router = SolverRouter::new(RouterConfig::default());
1460        let profile = SparsityProfile {
1461            rows: 1000,
1462            cols: 1000,
1463            nnz: 3000,
1464            density: 0.003,
1465            is_diag_dominant: true,
1466            estimated_spectral_radius: 0.96, // above 0.95 threshold
1467            estimated_condition: 10.0,
1468            is_symmetric_structure: true,
1469            avg_nnz_per_row: 3.0,
1470            max_nnz_per_row: 5,
1471        };
1472
1473        // Should fall through to CG, not Neumann.
1474        assert_eq!(
1475            router.select_algorithm(&profile, &QueryType::LinearSystem),
1476            Algorithm::CG
1477        );
1478    }
1479
1480    // -----------------------------------------------------------------------
1481    // SparsityProfile analysis tests
1482    // -----------------------------------------------------------------------
1483
1484    #[test]
1485    fn analyze_identity_matrix() {
1486        let matrix = CsrMatrix::<f64>::identity(5);
1487        let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1488
1489        assert_eq!(profile.rows, 5);
1490        assert_eq!(profile.cols, 5);
1491        assert_eq!(profile.nnz, 5);
1492        assert!(profile.is_diag_dominant);
1493        assert!((profile.density - 0.2).abs() < 1e-10);
1494        assert!(profile.estimated_spectral_radius.abs() < 1e-10);
1495        assert!((profile.estimated_condition - 1.0).abs() < 1e-10);
1496        assert!(profile.is_symmetric_structure);
1497        assert_eq!(profile.max_nnz_per_row, 1);
1498    }
1499
1500    #[test]
1501    fn analyze_diag_dominant() {
1502        let matrix = diag_dominant_3x3();
1503        let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1504
1505        assert!(profile.is_diag_dominant);
1506        assert!(profile.estimated_spectral_radius < 1.0);
1507        assert!(profile.is_symmetric_structure);
1508    }
1509
1510    #[test]
1511    fn analyze_empty_matrix() {
1512        let matrix = CsrMatrix::<f64> {
1513            row_ptr: vec![0],
1514            col_indices: vec![],
1515            values: vec![],
1516            rows: 0,
1517            cols: 0,
1518        };
1519        let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1520
1521        assert_eq!(profile.rows, 0);
1522        assert_eq!(profile.nnz, 0);
1523        assert_eq!(profile.density, 0.0);
1524    }
1525
1526    // -----------------------------------------------------------------------
1527    // Orchestrator solve tests
1528    // -----------------------------------------------------------------------
1529
1530    #[test]
1531    fn orchestrator_solve_identity() {
1532        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1533        let matrix = CsrMatrix::<f64>::identity(4);
1534        let rhs = vec![1.0_f64, 2.0, 3.0, 4.0];
1535        let budget = default_budget();
1536
1537        let result = orchestrator
1538            .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
1539            .unwrap();
1540
1541        for (x, b) in result.solution.iter().zip(rhs.iter()) {
1542            assert!(
1543                (*x as f64 - b).abs() < 1e-4,
1544                "expected {}, got {}",
1545                b,
1546                x
1547            );
1548        }
1549    }
1550
1551    #[test]
1552    fn orchestrator_solve_diag_dominant() {
1553        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1554        let matrix = diag_dominant_3x3();
1555        let rhs = vec![1.0_f64, 0.0, 1.0];
1556        let budget = default_budget();
1557
1558        let result = orchestrator
1559            .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
1560            .unwrap();
1561
1562        assert!(result.residual_norm < 1e-6);
1563    }
1564
1565    #[test]
1566    fn orchestrator_solve_with_fallback_succeeds() {
1567        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1568        let matrix = diag_dominant_3x3();
1569        let rhs = vec![1.0_f64, 0.0, 1.0];
1570        let budget = default_budget();
1571
1572        let result = orchestrator
1573            .solve_with_fallback(
1574                &matrix,
1575                &rhs,
1576                QueryType::LinearSystem,
1577                &budget,
1578            )
1579            .unwrap();
1580
1581        assert!(result.residual_norm < 1e-6);
1582    }
1583
1584    #[test]
1585    fn orchestrator_dimension_mismatch() {
1586        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1587        let matrix = CsrMatrix::<f64>::identity(3);
1588        let rhs = vec![1.0_f64, 2.0]; // wrong length
1589        let budget = default_budget();
1590
1591        let result = orchestrator.solve(
1592            &matrix,
1593            &rhs,
1594            QueryType::LinearSystem,
1595            &budget,
1596        );
1597        assert!(result.is_err());
1598    }
1599
1600    #[test]
1601    fn estimate_complexity_returns_reasonable_values() {
1602        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1603        let matrix = diag_dominant_3x3();
1604
1605        let estimate =
1606            orchestrator.estimate_complexity(&matrix, &QueryType::LinearSystem);
1607
1608        assert!(estimate.estimated_flops > 0);
1609        assert!(estimate.estimated_memory_bytes > 0);
1610        assert!(estimate.estimated_iterations > 0);
1611    }
1612
1613    #[test]
1614    fn fallback_chain_deduplicates() {
1615        let chain = SolverOrchestrator::build_fallback_chain(Algorithm::CG);
1616        assert_eq!(chain, vec![Algorithm::CG, Algorithm::Dense]);
1617
1618        let chain = SolverOrchestrator::build_fallback_chain(Algorithm::Dense);
1619        assert_eq!(chain, vec![Algorithm::Dense, Algorithm::CG]);
1620
1621        let chain =
1622            SolverOrchestrator::build_fallback_chain(Algorithm::Neumann);
1623        assert_eq!(
1624            chain,
1625            vec![Algorithm::Neumann, Algorithm::CG, Algorithm::Dense]
1626        );
1627    }
1628
1629    #[test]
1630    fn cg_inline_solves_spd_system() {
1631        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1632        let matrix = diag_dominant_3x3();
1633        let rhs = vec![1.0_f64, 2.0, 3.0];
1634        let budget = default_budget();
1635
1636        let result = orchestrator
1637            .solve_cg_inline(&matrix, &rhs, &budget)
1638            .unwrap();
1639
1640        assert!(result.residual_norm < 1e-6);
1641        assert_eq!(result.algorithm, Algorithm::CG);
1642    }
1643
1644    #[test]
1645    fn dense_solves_small_system() {
1646        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1647        let matrix = diag_dominant_3x3();
1648        let rhs = vec![1.0_f64, 2.0, 3.0];
1649        let budget = default_budget();
1650
1651        let result = orchestrator
1652            .solve_dense(&matrix, &rhs, &budget)
1653            .unwrap();
1654
1655        assert!(result.residual_norm < 1e-4);
1656        assert_eq!(result.algorithm, Algorithm::Dense);
1657    }
1658
1659    #[test]
1660    fn dense_rejects_non_square() {
1661        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1662        let matrix = CsrMatrix::<f64> {
1663            row_ptr: vec![0, 1, 2],
1664            col_indices: vec![0, 1],
1665            values: vec![1.0, 1.0],
1666            rows: 2,
1667            cols: 3,
1668        };
1669        let rhs = vec![1.0_f64, 1.0];
1670        let budget = default_budget();
1671
1672        assert!(orchestrator.solve_dense(&matrix, &rhs, &budget).is_err());
1673    }
1674
1675    #[test]
1676    fn cg_and_dense_agree_on_solution() {
1677        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1678        let matrix = diag_dominant_3x3();
1679        let rhs = vec![3.0_f64, -1.0, 2.0];
1680        let budget = default_budget();
1681
1682        let cg_result = orchestrator
1683            .solve_cg_inline(&matrix, &rhs, &budget)
1684            .unwrap();
1685        let dense_result = orchestrator
1686            .solve_dense(&matrix, &rhs, &budget)
1687            .unwrap();
1688
1689        for (cg_x, dense_x) in cg_result
1690            .solution
1691            .iter()
1692            .zip(dense_result.solution.iter())
1693        {
1694            assert!(
1695                (cg_x - dense_x).abs() < 1e-3,
1696                "CG={} vs Dense={}",
1697                cg_x,
1698                dense_x
1699            );
1700        }
1701    }
1702}