Skip to main content

ruvector_solver/
router.rs

1//! Algorithm router and solver orchestrator.
2//!
3//! The [`SolverRouter`] inspects a matrix's [`SparsityProfile`] and the
4//! caller's [`QueryType`] to select the optimal [`Algorithm`] for each solve
5//! request. The [`SolverOrchestrator`] wraps the router together with concrete
6//! solver instances and provides high-level `solve` / `solve_with_fallback`
7//! entry points.
8//!
9//! # Routing decision tree
10//!
11//! | Query | Condition | Algorithm |
12//! |-------|-----------|-----------|
13//! | `LinearSystem` | diag-dominant + very sparse | Neumann |
14//! | `LinearSystem` | low condition number | CG |
15//! | `LinearSystem` | else | BMSSP |
16//! | `PageRankSingle` | always | ForwardPush |
17//! | `PageRankPairwise` | large graph | HybridRandomWalk |
18//! | `PageRankPairwise` | small graph | ForwardPush |
19//! | `SpectralFilter` | always | Neumann |
20//! | `BatchLinearSystem` | large batch | TRUE |
21//! | `BatchLinearSystem` | small batch | CG |
22//!
23//! # Fallback chain
24//!
25//! When the selected algorithm fails (non-convergence, numerical instability),
26//! [`SolverOrchestrator::solve_with_fallback`] tries a deterministic chain:
27//!
28//! **selected algorithm -> CG -> Dense**
29
30use std::time::Instant;
31
32use tracing::{debug, info, warn};
33
34use crate::error::SolverError;
35use crate::traits::SolverEngine;
36use crate::types::{
37    Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
38    QueryType, SolverResult, SparsityProfile,
39};
40
41// ---------------------------------------------------------------------------
42// RouterConfig
43// ---------------------------------------------------------------------------
44
45/// Configuration thresholds that govern the routing decision tree.
46///
47/// All thresholds have sensible defaults; override them when benchmarks on
48/// your workload indicate a different crossover point.
49///
50/// # Example
51///
52/// ```rust
53/// use ruvector_solver::router::RouterConfig;
54///
55/// let config = RouterConfig {
56///     cg_condition_threshold: 50.0,
57///     ..Default::default()
58/// };
59/// ```
60#[derive(Debug, Clone)]
61pub struct RouterConfig {
62    /// Maximum spectral radius for which the Neumann series is attempted.
63    ///
64    /// If the estimated spectral radius exceeds this value the router will
65    /// not select Neumann even for diagonally dominant matrices.
66    ///
67    /// Default: `0.95`.
68    pub neumann_spectral_radius_threshold: f64,
69
70    /// Maximum condition number for which CG is preferred over BMSSP.
71    ///
72    /// CG converges in O(sqrt(kappa)) iterations; when kappa is too large
73    /// a preconditioned method (BMSSP) is cheaper.
74    ///
75    /// Default: `100.0`.
76    pub cg_condition_threshold: f64,
77
78    /// Maximum density (fraction of non-zeros) for the Neumann sublinear
79    /// fast-path.
80    ///
81    /// Neumann is only worthwhile when the matrix is truly sparse.
82    ///
83    /// Default: `0.05` (5%).
84    pub sparsity_sublinear_threshold: f64,
85
86    /// Minimum batch size for which the TRUE solver is preferred over CG
87    /// in `BatchLinearSystem` queries.
88    ///
89    /// Default: `100`.
90    pub true_batch_threshold: usize,
91
92    /// Graph size threshold (number of rows) above which
93    /// `PageRankPairwise` switches from ForwardPush to HybridRandomWalk.
94    ///
95    /// Default: `1_000`.
96    pub push_graph_size_threshold: usize,
97}
98
99impl Default for RouterConfig {
100    fn default() -> Self {
101        Self {
102            neumann_spectral_radius_threshold: 0.95,
103            cg_condition_threshold: 100.0,
104            sparsity_sublinear_threshold: 0.05,
105            true_batch_threshold: 100,
106            push_graph_size_threshold: 1_000,
107        }
108    }
109}
110
111// ---------------------------------------------------------------------------
112// SolverRouter
113// ---------------------------------------------------------------------------
114
115/// Stateless algorithm selector.
116///
117/// Given a [`SparsityProfile`] and a [`QueryType`], the router walks a
118/// decision tree (documented in the [module-level docs](self)) to pick the
119/// [`Algorithm`] with the best expected cost.
120///
121/// # Example
122///
123/// ```rust
124/// use ruvector_solver::router::{SolverRouter, RouterConfig};
125/// use ruvector_solver::types::{Algorithm, QueryType, SparsityProfile};
126///
127/// let router = SolverRouter::new(RouterConfig::default());
128/// let profile = SparsityProfile {
129///     rows: 500,
130///     cols: 500,
131///     nnz: 1200,
132///     density: 0.0048,
133///     is_diag_dominant: true,
134///     estimated_spectral_radius: 0.4,
135///     estimated_condition: 10.0,
136///     is_symmetric_structure: true,
137///     avg_nnz_per_row: 2.4,
138///     max_nnz_per_row: 5,
139/// };
140///
141/// let algo = router.select_algorithm(&profile, &QueryType::LinearSystem);
142/// assert_eq!(algo, Algorithm::Neumann);
143/// ```
144#[derive(Debug, Clone)]
145pub struct SolverRouter {
146    config: RouterConfig,
147}
148
149impl SolverRouter {
150    /// Create a new router with the provided configuration.
151    pub fn new(config: RouterConfig) -> Self {
152        Self { config }
153    }
154
155    /// Return a shared reference to the active configuration.
156    pub fn config(&self) -> &RouterConfig {
157        &self.config
158    }
159
160    /// Select the optimal algorithm for the given matrix profile and query.
161    ///
162    /// This is a pure function with no side effects -- it does not touch the
163    /// matrix data, only the precomputed profile.
164    pub fn select_algorithm(&self, profile: &SparsityProfile, query: &QueryType) -> Algorithm {
165        match query {
166            // ----------------------------------------------------------
167            // Linear system: Neumann > CG > BMSSP
168            // ----------------------------------------------------------
169            QueryType::LinearSystem => self.route_linear_system(profile),
170
171            // ----------------------------------------------------------
172            // Single-source PageRank: always ForwardPush
173            // ----------------------------------------------------------
174            QueryType::PageRankSingle { .. } => {
175                debug!("routing to ForwardPush (single-source PageRank)");
176                Algorithm::ForwardPush
177            }
178
179            // ----------------------------------------------------------
180            // Pairwise PageRank: ForwardPush or HybridRandomWalk
181            // ----------------------------------------------------------
182            QueryType::PageRankPairwise { .. } => {
183                if profile.rows > self.config.push_graph_size_threshold {
184                    debug!(
185                        rows = profile.rows,
186                        threshold = self.config.push_graph_size_threshold,
187                        "routing to HybridRandomWalk (large graph pairwise PPR)"
188                    );
189                    Algorithm::HybridRandomWalk
190                } else {
191                    debug!(
192                        rows = profile.rows,
193                        "routing to ForwardPush (small graph pairwise PPR)"
194                    );
195                    Algorithm::ForwardPush
196                }
197            }
198
199            // ----------------------------------------------------------
200            // Spectral filter: always Neumann
201            // ----------------------------------------------------------
202            QueryType::SpectralFilter { .. } => {
203                debug!("routing to Neumann (spectral filter)");
204                Algorithm::Neumann
205            }
206
207            // ----------------------------------------------------------
208            // Batch linear system: TRUE or CG
209            // ----------------------------------------------------------
210            QueryType::BatchLinearSystem { batch_size } => {
211                if *batch_size > self.config.true_batch_threshold {
212                    debug!(
213                        batch_size,
214                        threshold = self.config.true_batch_threshold,
215                        "routing to TRUE (large batch)"
216                    );
217                    Algorithm::TRUE
218                } else {
219                    debug!(batch_size, "routing to CG (small batch)");
220                    Algorithm::CG
221                }
222            }
223        }
224    }
225
226    /// Internal routing logic for `LinearSystem` queries.
227    fn route_linear_system(&self, profile: &SparsityProfile) -> Algorithm {
228        if profile.is_diag_dominant
229            && profile.density < self.config.sparsity_sublinear_threshold
230            && profile.estimated_spectral_radius < self.config.neumann_spectral_radius_threshold
231        {
232            debug!(
233                density = profile.density,
234                spectral_radius = profile.estimated_spectral_radius,
235                "routing to Neumann (diag-dominant, sparse, low spectral radius)"
236            );
237            Algorithm::Neumann
238        } else if profile.estimated_condition < self.config.cg_condition_threshold {
239            debug!(
240                condition = profile.estimated_condition,
241                "routing to CG (well-conditioned)"
242            );
243            Algorithm::CG
244        } else {
245            debug!(
246                condition = profile.estimated_condition,
247                "routing to BMSSP (ill-conditioned)"
248            );
249            Algorithm::BMSSP
250        }
251    }
252}
253
254impl Default for SolverRouter {
255    fn default() -> Self {
256        Self::new(RouterConfig::default())
257    }
258}
259
260// ---------------------------------------------------------------------------
261// SolverOrchestrator
262// ---------------------------------------------------------------------------
263
264/// High-level solver facade that combines routing with execution.
265///
266/// Owns a [`SolverRouter`] and delegates to the appropriate solver backend.
267/// Provides a [`solve_with_fallback`](Self::solve_with_fallback) method that
268/// automatically retries with progressively more robust (but slower)
269/// algorithms when the first choice fails.
270///
271/// # Example
272///
273/// ```rust
274/// use ruvector_solver::router::{SolverOrchestrator, RouterConfig};
275/// use ruvector_solver::types::{ComputeBudget, CsrMatrix, QueryType};
276///
277/// let orchestrator = SolverOrchestrator::new(RouterConfig::default());
278///
279/// let matrix = CsrMatrix::<f64>::from_coo(3, 3, vec![
280///     (0, 0, 2.0), (0, 1, -0.5),
281///     (1, 0, -0.5), (1, 1, 2.0), (1, 2, -0.5),
282///     (2, 1, -0.5), (2, 2, 2.0),
283/// ]);
284/// let rhs = vec![1.0, 0.0, 1.0];
285/// let budget = ComputeBudget::default();
286///
287/// let result = orchestrator
288///     .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
289///     .unwrap();
290/// assert!(result.residual_norm < 1e-6);
291/// ```
292#[derive(Debug, Clone)]
293pub struct SolverOrchestrator {
294    router: SolverRouter,
295}
296
297impl SolverOrchestrator {
298    /// Create a new orchestrator with the provided routing configuration.
299    pub fn new(config: RouterConfig) -> Self {
300        Self {
301            router: SolverRouter::new(config),
302        }
303    }
304
305    /// Return a reference to the inner router.
306    pub fn router(&self) -> &SolverRouter {
307        &self.router
308    }
309
310    // -----------------------------------------------------------------------
311    // Public API
312    // -----------------------------------------------------------------------
313
314    /// Auto-select the best algorithm and solve `Ax = b`.
315    ///
316    /// Analyses the sparsity profile of `matrix`, routes to the best
317    /// algorithm via [`SolverRouter::select_algorithm`], and dispatches.
318    ///
319    /// # Errors
320    ///
321    /// Returns [`SolverError`] if the selected solver fails (e.g.
322    /// non-convergence, dimension mismatch, numerical instability).
323    pub fn solve(
324        &self,
325        matrix: &CsrMatrix<f64>,
326        rhs: &[f64],
327        query: QueryType,
328        budget: &ComputeBudget,
329    ) -> Result<SolverResult, SolverError> {
330        let profile = Self::analyze_sparsity(matrix);
331        let algorithm = self.router.select_algorithm(&profile, &query);
332
333        info!(%algorithm, rows = matrix.rows, nnz = matrix.nnz(), "solve: selected algorithm");
334
335        self.dispatch(algorithm, matrix, rhs, budget)
336    }
337
338    /// Solve with a deterministic fallback chain.
339    ///
340    /// Tries the routed algorithm first. On failure, falls back through:
341    ///
342    /// 1. **Selected algorithm** (from routing)
343    /// 2. **CG** (robust iterative)
344    /// 3. **Dense** (direct, always works for small systems)
345    ///
346    /// Each step is only attempted if the previous one returned an error.
347    ///
348    /// # Errors
349    ///
350    /// Returns the error from the *last* fallback attempt if all fail.
351    pub fn solve_with_fallback(
352        &self,
353        matrix: &CsrMatrix<f64>,
354        rhs: &[f64],
355        query: QueryType,
356        budget: &ComputeBudget,
357    ) -> Result<SolverResult, SolverError> {
358        let profile = Self::analyze_sparsity(matrix);
359        let primary = self.router.select_algorithm(&profile, &query);
360
361        let chain = Self::build_fallback_chain(primary);
362
363        info!(
364            ?chain,
365            rows = matrix.rows,
366            nnz = matrix.nnz(),
367            "solve_with_fallback: attempting chain"
368        );
369
370        let mut last_err: Option<SolverError> = None;
371
372        for (idx, &algorithm) in chain.iter().enumerate() {
373            match self.dispatch(algorithm, matrix, rhs, budget) {
374                Ok(result) => {
375                    if idx > 0 {
376                        info!(
377                            %algorithm,
378                            "fallback succeeded on attempt {}",
379                            idx + 1
380                        );
381                    }
382                    return Ok(result);
383                }
384                Err(e) => {
385                    warn!(
386                        %algorithm,
387                        error = %e,
388                        "algorithm failed, trying next in fallback chain"
389                    );
390                    last_err = Some(e);
391                }
392            }
393        }
394
395        Err(last_err
396            .unwrap_or_else(|| SolverError::BackendError("fallback chain was empty".into())))
397    }
398
399    /// Estimate the computational complexity of solving with the routed
400    /// algorithm, without actually solving.
401    ///
402    /// Useful for admission control, cost estimation, or deciding whether
403    /// to batch multiple queries.
404    pub fn estimate_complexity(
405        &self,
406        matrix: &CsrMatrix<f64>,
407        query: &QueryType,
408    ) -> ComplexityEstimate {
409        let profile = Self::analyze_sparsity(matrix);
410        let algorithm = self.router.select_algorithm(&profile, query);
411        let n = profile.rows;
412
413        let (estimated_iterations, complexity_class) = match algorithm {
414            Algorithm::Neumann => {
415                let k = if profile.estimated_spectral_radius > 0.0
416                    && profile.estimated_spectral_radius < 1.0
417                {
418                    let log_inv_eps = (1.0 / 1e-8_f64).ln();
419                    let log_inv_rho = (1.0 / profile.estimated_spectral_radius).ln();
420                    (log_inv_eps / log_inv_rho).ceil() as usize
421                } else {
422                    1000
423                };
424                (k.min(1000), ComplexityClass::SublinearNnz)
425            }
426            Algorithm::CG => {
427                let iters = (profile.estimated_condition.sqrt()).ceil() as usize;
428                (iters.min(1000), ComplexityClass::SqrtCondition)
429            }
430            Algorithm::ForwardPush | Algorithm::BackwardPush => {
431                let iters = ((n as f64).sqrt()).ceil() as usize;
432                (iters, ComplexityClass::SublinearNnz)
433            }
434            Algorithm::HybridRandomWalk => (n.min(1000), ComplexityClass::Linear),
435            Algorithm::TRUE => {
436                let iters = (profile.estimated_condition.sqrt()).ceil() as usize;
437                (iters.min(1000), ComplexityClass::SqrtCondition)
438            }
439            Algorithm::BMSSP => {
440                let iters = (profile.estimated_condition.sqrt().ln()).ceil() as usize;
441                (iters.max(1).min(1000), ComplexityClass::Linear)
442            }
443            Algorithm::Dense => (1, ComplexityClass::Cubic),
444            Algorithm::Jacobi | Algorithm::GaussSeidel => (1000, ComplexityClass::Linear),
445        };
446
447        let estimated_flops = match algorithm {
448            Algorithm::Dense => {
449                let dim = n as u64;
450                (2 * dim * dim * dim) / 3
451            }
452            _ => (estimated_iterations as u64) * (2 * profile.nnz as u64 + n as u64),
453        };
454
455        let estimated_memory_bytes = match algorithm {
456            Algorithm::Dense => n * profile.cols * std::mem::size_of::<f64>(),
457            _ => {
458                // CSR storage + 3 work vectors.
459                let csr = profile.nnz * (std::mem::size_of::<f64>() + std::mem::size_of::<usize>())
460                    + (n + 1) * std::mem::size_of::<usize>();
461                let work = 3 * n * std::mem::size_of::<f64>();
462                csr + work
463            }
464        };
465
466        ComplexityEstimate {
467            algorithm,
468            estimated_flops,
469            estimated_iterations,
470            estimated_memory_bytes,
471            complexity_class,
472        }
473    }
474
475    /// Analyse the sparsity profile of a CSR matrix.
476    ///
477    /// Performs a single O(nnz) pass over the matrix to compute structural
478    /// and numerical properties used by the router. This is intentionally
479    /// cheap so it can be called on every solve request.
480    pub fn analyze_sparsity(matrix: &CsrMatrix<f64>) -> SparsityProfile {
481        let n = matrix.rows;
482        let m = matrix.cols;
483        let nnz = matrix.nnz();
484        let total_entries = (n as f64) * (m as f64);
485        let density = if total_entries > 0.0 {
486            nnz as f64 / total_entries
487        } else {
488            0.0
489        };
490
491        let mut is_diag_dominant = true;
492        let mut max_nnz_per_row: usize = 0;
493        let mut sum_off_diag_ratio = 0.0_f64;
494        let mut diag_min = f64::INFINITY;
495        let mut diag_max = 0.0_f64;
496        let mut symmetric_mismatches: usize = 0;
497
498        // Only check symmetry for small-to-medium matrices to keep O(nnz).
499        let check_symmetry = nnz <= 100_000;
500
501        for row in 0..n {
502            let start = matrix.row_ptr[row];
503            let end = matrix.row_ptr[row + 1];
504            let row_nnz = end - start;
505            max_nnz_per_row = max_nnz_per_row.max(row_nnz);
506
507            let mut diag_val: f64 = 0.0;
508            let mut off_diag_sum: f64 = 0.0;
509
510            for idx in start..end {
511                let col = matrix.col_indices[idx];
512                let val = matrix.values[idx];
513
514                if col == row {
515                    diag_val = val.abs();
516                } else {
517                    off_diag_sum += val.abs();
518                }
519
520                // Structural symmetry check: look for (col, row) entry.
521                if check_symmetry && col != row && col < n {
522                    let col_start = matrix.row_ptr[col];
523                    let col_end = matrix.row_ptr[col + 1];
524                    let found = matrix.col_indices[col_start..col_end]
525                        .binary_search(&row)
526                        .is_ok();
527                    if !found {
528                        symmetric_mismatches += 1;
529                    }
530                }
531            }
532
533            if diag_val <= off_diag_sum {
534                is_diag_dominant = false;
535            }
536
537            if diag_val > 0.0 {
538                let ratio = off_diag_sum / diag_val;
539                sum_off_diag_ratio += ratio;
540                diag_min = diag_min.min(diag_val);
541                diag_max = diag_max.max(diag_val);
542            } else if n > 0 {
543                is_diag_dominant = false;
544                sum_off_diag_ratio += 1.0;
545            }
546        }
547
548        let avg_nnz_per_row = if n > 0 { nnz as f64 / n as f64 } else { 0.0 };
549
550        // Spectral radius of Jacobi iteration matrix D^{-1}(L+U).
551        let estimated_spectral_radius = if n > 0 {
552            sum_off_diag_ratio / n as f64
553        } else {
554            0.0
555        };
556
557        // Rough condition number from diagonal range.
558        let estimated_condition = if diag_min > 0.0 && diag_min.is_finite() {
559            diag_max / diag_min
560        } else {
561            f64::INFINITY
562        };
563
564        let is_symmetric_structure = if check_symmetry {
565            symmetric_mismatches == 0
566        } else {
567            n == m
568        };
569
570        SparsityProfile {
571            rows: n,
572            cols: m,
573            nnz,
574            density,
575            is_diag_dominant,
576            estimated_spectral_radius,
577            estimated_condition,
578            is_symmetric_structure,
579            avg_nnz_per_row,
580            max_nnz_per_row,
581        }
582    }
583
584    // -----------------------------------------------------------------------
585    // Internal helpers
586    // -----------------------------------------------------------------------
587
588    /// Build a deduplicated fallback chain: `[primary, CG, Dense]`.
589    fn build_fallback_chain(primary: Algorithm) -> Vec<Algorithm> {
590        let mut chain = Vec::with_capacity(3);
591        chain.push(primary);
592
593        if primary != Algorithm::CG {
594            chain.push(Algorithm::CG);
595        }
596        if primary != Algorithm::Dense {
597            chain.push(Algorithm::Dense);
598        }
599
600        chain
601    }
602
603    /// Dispatch a solve request to the concrete solver for `algorithm`.
604    ///
605    /// Feature-gated solvers return a `BackendError` when the feature is
606    /// not compiled in, allowing the fallback chain to proceed.
607    fn dispatch(
608        &self,
609        algorithm: Algorithm,
610        matrix: &CsrMatrix<f64>,
611        rhs: &[f64],
612        budget: &ComputeBudget,
613    ) -> Result<SolverResult, SolverError> {
614        match algorithm {
615            // ----- Neumann series ------------------------------------------
616            Algorithm::Neumann => {
617                #[cfg(feature = "neumann")]
618                {
619                    let solver =
620                        crate::neumann::NeumannSolver::new(budget.tolerance, budget.max_iterations);
621                    SolverEngine::solve(&solver, matrix, rhs, budget)
622                }
623                #[cfg(not(feature = "neumann"))]
624                {
625                    Err(SolverError::BackendError(
626                        "neumann feature is not enabled".into(),
627                    ))
628                }
629            }
630
631            // ----- Conjugate Gradient --------------------------------------
632            Algorithm::CG => {
633                #[cfg(feature = "cg")]
634                {
635                    let solver = crate::cg::ConjugateGradientSolver::new(
636                        budget.tolerance,
637                        budget.max_iterations,
638                        false,
639                    );
640                    solver.solve(matrix, rhs, budget)
641                }
642                #[cfg(not(feature = "cg"))]
643                {
644                    // Inline CG when the feature crate is not available.
645                    self.solve_cg_inline(matrix, rhs, budget)
646                }
647            }
648
649            // ----- ForwardPush ---------------------------------------------
650            Algorithm::ForwardPush => {
651                #[cfg(feature = "forward-push")]
652                {
653                    self.solve_jacobi_fallback(Algorithm::ForwardPush, matrix, rhs, budget)
654                }
655                #[cfg(not(feature = "forward-push"))]
656                {
657                    Err(SolverError::BackendError(
658                        "forward-push feature is not enabled".into(),
659                    ))
660                }
661            }
662
663            // ----- BackwardPush --------------------------------------------
664            Algorithm::BackwardPush => {
665                #[cfg(feature = "backward-push")]
666                {
667                    self.solve_jacobi_fallback(Algorithm::BackwardPush, matrix, rhs, budget)
668                }
669                #[cfg(not(feature = "backward-push"))]
670                {
671                    Err(SolverError::BackendError(
672                        "backward-push feature is not enabled".into(),
673                    ))
674                }
675            }
676
677            // ----- HybridRandomWalk ----------------------------------------
678            Algorithm::HybridRandomWalk => {
679                #[cfg(feature = "hybrid-random-walk")]
680                {
681                    self.solve_jacobi_fallback(Algorithm::HybridRandomWalk, matrix, rhs, budget)
682                }
683                #[cfg(not(feature = "hybrid-random-walk"))]
684                {
685                    Err(SolverError::BackendError(
686                        "hybrid-random-walk feature is not enabled".into(),
687                    ))
688                }
689            }
690
691            // ----- TRUE batch solver ---------------------------------------
692            Algorithm::TRUE => {
693                #[cfg(feature = "true-solver")]
694                {
695                    // TRUE for a single RHS degrades to Neumann.
696                    let solver =
697                        crate::neumann::NeumannSolver::new(budget.tolerance, budget.max_iterations);
698                    let mut result = SolverEngine::solve(&solver, matrix, rhs, budget)?;
699                    result.algorithm = Algorithm::TRUE;
700                    Ok(result)
701                }
702                #[cfg(not(feature = "true-solver"))]
703                {
704                    Err(SolverError::BackendError(
705                        "true-solver feature is not enabled".into(),
706                    ))
707                }
708            }
709
710            // ----- BMSSP ---------------------------------------------------
711            Algorithm::BMSSP => {
712                #[cfg(feature = "bmssp")]
713                {
714                    self.solve_jacobi_fallback(Algorithm::BMSSP, matrix, rhs, budget)
715                }
716                #[cfg(not(feature = "bmssp"))]
717                {
718                    Err(SolverError::BackendError(
719                        "bmssp feature is not enabled".into(),
720                    ))
721                }
722            }
723
724            // ----- Dense direct solver -------------------------------------
725            Algorithm::Dense => self.solve_dense(matrix, rhs, budget),
726
727            // ----- Legacy iterative solvers --------------------------------
728            Algorithm::Jacobi => self.solve_jacobi_fallback(Algorithm::Jacobi, matrix, rhs, budget),
729            Algorithm::GaussSeidel => {
730                self.solve_jacobi_fallback(Algorithm::GaussSeidel, matrix, rhs, budget)
731            }
732        }
733    }
734
735    /// Inline Conjugate Gradient for symmetric positive-definite systems.
736    ///
737    /// Standard unpreconditioned CG. Used when the `cg` feature crate is
738    /// not compiled in but CG is needed (e.g. as a fallback).
739    #[allow(dead_code)]
740    fn solve_cg_inline(
741        &self,
742        matrix: &CsrMatrix<f64>,
743        rhs: &[f64],
744        budget: &ComputeBudget,
745    ) -> Result<SolverResult, SolverError> {
746        let n = matrix.rows;
747        validate_square(matrix)?;
748        validate_rhs_len(matrix, rhs)?;
749
750        let max_iters = budget.max_iterations;
751        let tol = budget.tolerance;
752        let start = Instant::now();
753
754        let mut x = vec![0.0_f64; n];
755        let mut r: Vec<f64> = rhs.to_vec();
756        let mut p = r.clone();
757        let mut ap = vec![0.0_f64; n];
758        let mut convergence_history = Vec::new();
759
760        let mut r_dot_r = dot(&r, &r);
761
762        for iter in 0..max_iters {
763            let residual_norm = r_dot_r.sqrt();
764
765            convergence_history.push(ConvergenceInfo {
766                iteration: iter,
767                residual_norm,
768            });
769
770            if residual_norm.is_nan() || residual_norm.is_infinite() {
771                return Err(SolverError::NumericalInstability {
772                    iteration: iter,
773                    detail: format!("CG residual became {}", residual_norm),
774                });
775            }
776
777            if residual_norm < tol {
778                return Ok(SolverResult {
779                    solution: x.iter().map(|&v| v as f32).collect(),
780                    iterations: iter,
781                    residual_norm,
782                    wall_time: start.elapsed(),
783                    convergence_history,
784                    algorithm: Algorithm::CG,
785                });
786            }
787
788            // ap = A * p
789            matrix.spmv(&p, &mut ap);
790
791            let p_dot_ap = dot(&p, &ap);
792            if p_dot_ap.abs() < 1e-30 {
793                return Err(SolverError::NumericalInstability {
794                    iteration: iter,
795                    detail: "CG: p^T A p near zero (matrix may not be SPD)".into(),
796                });
797            }
798
799            let alpha = r_dot_r / p_dot_ap;
800
801            for i in 0..n {
802                x[i] += alpha * p[i];
803                r[i] -= alpha * ap[i];
804            }
805
806            let new_r_dot_r = dot(&r, &r);
807            let beta = new_r_dot_r / r_dot_r;
808
809            for i in 0..n {
810                p[i] = r[i] + beta * p[i];
811            }
812
813            r_dot_r = new_r_dot_r;
814
815            if start.elapsed() > budget.max_time {
816                return Err(SolverError::BudgetExhausted {
817                    reason: "wall-clock time limit exceeded".into(),
818                    elapsed: start.elapsed(),
819                });
820            }
821        }
822
823        let final_residual = convergence_history
824            .last()
825            .map(|c| c.residual_norm)
826            .unwrap_or(f64::INFINITY);
827
828        Err(SolverError::NonConvergence {
829            iterations: max_iters,
830            residual: final_residual,
831            tolerance: tol,
832        })
833    }
834
835    /// Dense direct solver via Gaussian elimination with partial pivoting.
836    ///
837    /// O(n^3) time and O(n^2) memory. Only used as a last-resort fallback.
838    fn solve_dense(
839        &self,
840        matrix: &CsrMatrix<f64>,
841        rhs: &[f64],
842        _budget: &ComputeBudget,
843    ) -> Result<SolverResult, SolverError> {
844        let n = matrix.rows;
845        validate_square(matrix)?;
846        validate_rhs_len(matrix, rhs)?;
847
848        const MAX_DENSE_DIM: usize = 4096;
849        if n > MAX_DENSE_DIM {
850            return Err(SolverError::InvalidInput(
851                crate::error::ValidationError::MatrixTooLarge {
852                    rows: n,
853                    cols: n,
854                    max_dim: MAX_DENSE_DIM,
855                },
856            ));
857        }
858
859        let start = Instant::now();
860
861        // Expand CSR to dense augmented matrix [A | b].
862        let stride = n + 1;
863        let mut aug = vec![0.0_f64; n * stride];
864        for row in 0..n {
865            let rs = matrix.row_ptr[row];
866            let re = matrix.row_ptr[row + 1];
867            for idx in rs..re {
868                let col = matrix.col_indices[idx];
869                aug[row * stride + col] = matrix.values[idx];
870            }
871            aug[row * stride + n] = rhs[row];
872        }
873
874        // Gaussian elimination with partial pivoting.
875        for col in 0..n {
876            let mut max_val = aug[col * stride + col].abs();
877            let mut max_row = col;
878            for row in (col + 1)..n {
879                let val = aug[row * stride + col].abs();
880                if val > max_val {
881                    max_val = val;
882                    max_row = row;
883                }
884            }
885
886            if max_val < 1e-12 {
887                return Err(SolverError::NumericalInstability {
888                    iteration: 0,
889                    detail: format!(
890                        "dense solver: near-zero pivot ({:.2e}) at column {}",
891                        max_val, col
892                    ),
893                });
894            }
895
896            if max_row != col {
897                for j in 0..stride {
898                    aug.swap(col * stride + j, max_row * stride + j);
899                }
900            }
901
902            let pivot = aug[col * stride + col];
903            for row in (col + 1)..n {
904                let factor = aug[row * stride + col] / pivot;
905                aug[row * stride + col] = 0.0;
906                for j in (col + 1)..stride {
907                    let above = aug[col * stride + j];
908                    aug[row * stride + j] -= factor * above;
909                }
910            }
911        }
912
913        // Back-substitution.
914        let mut solution_f64 = vec![0.0_f64; n];
915        for row in (0..n).rev() {
916            let mut sum = aug[row * stride + n];
917            for col in (row + 1)..n {
918                sum -= aug[row * stride + col] * solution_f64[col];
919            }
920            solution_f64[row] = sum / aug[row * stride + row];
921        }
922
923        // Compute residual.
924        let mut ax = vec![0.0_f64; n];
925        matrix.spmv(&solution_f64, &mut ax);
926        let residual_norm: f64 = (0..n)
927            .map(|i| {
928                let r = rhs[i] - ax[i];
929                r * r
930            })
931            .sum::<f64>()
932            .sqrt();
933
934        let solution: Vec<f32> = solution_f64.iter().map(|&v| v as f32).collect();
935
936        Ok(SolverResult {
937            solution,
938            iterations: 1,
939            residual_norm,
940            wall_time: start.elapsed(),
941            convergence_history: vec![ConvergenceInfo {
942                iteration: 0,
943                residual_norm,
944            }],
945            algorithm: Algorithm::Dense,
946        })
947    }
948
949    /// Generic Jacobi-iteration fallback for algorithms whose specialised
950    /// backends are not yet implemented.
951    ///
952    /// Tags the result with the requested `algorithm` label so callers see
953    /// the correct algorithm in the result.
954    fn solve_jacobi_fallback(
955        &self,
956        algorithm: Algorithm,
957        matrix: &CsrMatrix<f64>,
958        rhs: &[f64],
959        budget: &ComputeBudget,
960    ) -> Result<SolverResult, SolverError> {
961        let n = matrix.rows;
962        validate_square(matrix)?;
963        validate_rhs_len(matrix, rhs)?;
964
965        let max_iters = budget.max_iterations;
966        let tol = budget.tolerance;
967        let start = Instant::now();
968
969        // Extract diagonal.
970        let mut diag = vec![0.0_f64; n];
971        for row in 0..n {
972            let rs = matrix.row_ptr[row];
973            let re = matrix.row_ptr[row + 1];
974            for idx in rs..re {
975                if matrix.col_indices[idx] == row {
976                    diag[row] = matrix.values[idx];
977                    break;
978                }
979            }
980        }
981
982        for (i, &d) in diag.iter().enumerate() {
983            if d.abs() < 1e-30 {
984                return Err(SolverError::NumericalInstability {
985                    iteration: 0,
986                    detail: format!("zero or near-zero diagonal at row {} (val={:.2e})", i, d),
987                });
988            }
989        }
990
991        let mut x = vec![0.0_f64; n];
992        let mut x_new = vec![0.0_f64; n];
993        let mut temp = vec![0.0_f64; n];
994        let mut convergence_history = Vec::new();
995
996        for iter in 0..max_iters {
997            for row in 0..n {
998                let rs = matrix.row_ptr[row];
999                let re = matrix.row_ptr[row + 1];
1000                let mut sum = 0.0_f64;
1001                for idx in rs..re {
1002                    let col = matrix.col_indices[idx];
1003                    if col != row {
1004                        sum += matrix.values[idx] * x[col];
1005                    }
1006                }
1007                x_new[row] = (rhs[row] - sum) / diag[row];
1008            }
1009
1010            matrix.spmv(&x_new, &mut temp);
1011            let residual_norm: f64 = (0..n)
1012                .map(|i| {
1013                    let r = rhs[i] - temp[i];
1014                    r * r
1015                })
1016                .sum::<f64>()
1017                .sqrt();
1018
1019            convergence_history.push(ConvergenceInfo {
1020                iteration: iter,
1021                residual_norm,
1022            });
1023
1024            if residual_norm.is_nan() || residual_norm.is_infinite() {
1025                return Err(SolverError::NumericalInstability {
1026                    iteration: iter,
1027                    detail: format!("residual became {}", residual_norm),
1028                });
1029            }
1030
1031            if residual_norm < tol {
1032                return Ok(SolverResult {
1033                    solution: x_new.iter().map(|&v| v as f32).collect(),
1034                    iterations: iter + 1,
1035                    residual_norm,
1036                    wall_time: start.elapsed(),
1037                    convergence_history,
1038                    algorithm,
1039                });
1040            }
1041
1042            std::mem::swap(&mut x, &mut x_new);
1043
1044            if start.elapsed() > budget.max_time {
1045                return Err(SolverError::BudgetExhausted {
1046                    reason: "wall-clock time limit exceeded".into(),
1047                    elapsed: start.elapsed(),
1048                });
1049            }
1050        }
1051
1052        let final_residual = convergence_history
1053            .last()
1054            .map(|c| c.residual_norm)
1055            .unwrap_or(f64::INFINITY);
1056
1057        Err(SolverError::NonConvergence {
1058            iterations: max_iters,
1059            residual: final_residual,
1060            tolerance: tol,
1061        })
1062    }
1063}
1064
1065impl Default for SolverOrchestrator {
1066    fn default() -> Self {
1067        Self::new(RouterConfig::default())
1068    }
1069}
1070
1071// ---------------------------------------------------------------------------
1072// Utility functions
1073// ---------------------------------------------------------------------------
1074
1075/// Dot product of two f64 slices.
1076#[inline]
1077#[allow(dead_code)]
1078fn dot(a: &[f64], b: &[f64]) -> f64 {
1079    assert_eq!(
1080        a.len(),
1081        b.len(),
1082        "dot: length mismatch {} vs {}",
1083        a.len(),
1084        b.len()
1085    );
1086    a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
1087}
1088
1089/// Validate that a matrix is square.
1090fn validate_square(matrix: &CsrMatrix<f64>) -> Result<(), SolverError> {
1091    if matrix.rows != matrix.cols {
1092        return Err(SolverError::InvalidInput(
1093            crate::error::ValidationError::DimensionMismatch(format!(
1094                "matrix must be square, got {}x{}",
1095                matrix.rows, matrix.cols
1096            )),
1097        ));
1098    }
1099    Ok(())
1100}
1101
1102/// Validate that the RHS vector length matches the matrix dimension.
1103fn validate_rhs_len(matrix: &CsrMatrix<f64>, rhs: &[f64]) -> Result<(), SolverError> {
1104    if rhs.len() != matrix.rows {
1105        return Err(SolverError::InvalidInput(
1106            crate::error::ValidationError::DimensionMismatch(format!(
1107                "rhs length {} does not match matrix dimension {}",
1108                rhs.len(),
1109                matrix.rows
1110            )),
1111        ));
1112    }
1113    Ok(())
1114}
1115
1116// ---------------------------------------------------------------------------
1117// Tests
1118// ---------------------------------------------------------------------------
1119
1120#[cfg(test)]
1121mod tests {
1122    use super::*;
1123
1124    /// Build a 3x3 diagonally dominant SPD matrix.
1125    fn diag_dominant_3x3() -> CsrMatrix<f64> {
1126        CsrMatrix::<f64>::from_coo(
1127            3,
1128            3,
1129            vec![
1130                (0, 0, 4.0),
1131                (0, 1, -1.0),
1132                (1, 0, -1.0),
1133                (1, 1, 4.0),
1134                (1, 2, -1.0),
1135                (2, 1, -1.0),
1136                (2, 2, 4.0),
1137            ],
1138        )
1139    }
1140
1141    fn default_budget() -> ComputeBudget {
1142        ComputeBudget {
1143            tolerance: 1e-8,
1144            ..Default::default()
1145        }
1146    }
1147
1148    // -----------------------------------------------------------------------
1149    // Router tests
1150    // -----------------------------------------------------------------------
1151
1152    #[test]
1153    fn routes_diag_dominant_sparse_to_neumann() {
1154        let router = SolverRouter::new(RouterConfig::default());
1155        let profile = SparsityProfile {
1156            rows: 1000,
1157            cols: 1000,
1158            nnz: 3000,
1159            density: 0.003,
1160            is_diag_dominant: true,
1161            estimated_spectral_radius: 0.5,
1162            estimated_condition: 10.0,
1163            is_symmetric_structure: true,
1164            avg_nnz_per_row: 3.0,
1165            max_nnz_per_row: 5,
1166        };
1167
1168        assert_eq!(
1169            router.select_algorithm(&profile, &QueryType::LinearSystem),
1170            Algorithm::Neumann
1171        );
1172    }
1173
1174    #[test]
1175    fn routes_well_conditioned_non_diag_dominant_to_cg() {
1176        let router = SolverRouter::new(RouterConfig::default());
1177        let profile = SparsityProfile {
1178            rows: 1000,
1179            cols: 1000,
1180            nnz: 50_000,
1181            density: 0.05,
1182            is_diag_dominant: false,
1183            estimated_spectral_radius: 0.9,
1184            estimated_condition: 50.0,
1185            is_symmetric_structure: true,
1186            avg_nnz_per_row: 50.0,
1187            max_nnz_per_row: 80,
1188        };
1189
1190        assert_eq!(
1191            router.select_algorithm(&profile, &QueryType::LinearSystem),
1192            Algorithm::CG
1193        );
1194    }
1195
1196    #[test]
1197    fn routes_ill_conditioned_to_bmssp() {
1198        let router = SolverRouter::new(RouterConfig::default());
1199        let profile = SparsityProfile {
1200            rows: 1000,
1201            cols: 1000,
1202            nnz: 50_000,
1203            density: 0.05,
1204            is_diag_dominant: false,
1205            estimated_spectral_radius: 0.99,
1206            estimated_condition: 500.0,
1207            is_symmetric_structure: true,
1208            avg_nnz_per_row: 50.0,
1209            max_nnz_per_row: 80,
1210        };
1211
1212        assert_eq!(
1213            router.select_algorithm(&profile, &QueryType::LinearSystem),
1214            Algorithm::BMSSP
1215        );
1216    }
1217
1218    #[test]
1219    fn routes_single_pagerank_to_forward_push() {
1220        let router = SolverRouter::new(RouterConfig::default());
1221        let profile = SparsityProfile {
1222            rows: 5000,
1223            cols: 5000,
1224            nnz: 20_000,
1225            density: 0.0008,
1226            is_diag_dominant: false,
1227            estimated_spectral_radius: 0.85,
1228            estimated_condition: 100.0,
1229            is_symmetric_structure: false,
1230            avg_nnz_per_row: 4.0,
1231            max_nnz_per_row: 50,
1232        };
1233
1234        assert_eq!(
1235            router.select_algorithm(&profile, &QueryType::PageRankSingle { source: 0 }),
1236            Algorithm::ForwardPush
1237        );
1238    }
1239
1240    #[test]
1241    fn routes_large_pairwise_to_hybrid_random_walk() {
1242        let router = SolverRouter::new(RouterConfig::default());
1243        let profile = SparsityProfile {
1244            rows: 5000,
1245            cols: 5000,
1246            nnz: 20_000,
1247            density: 0.0008,
1248            is_diag_dominant: false,
1249            estimated_spectral_radius: 0.85,
1250            estimated_condition: 100.0,
1251            is_symmetric_structure: false,
1252            avg_nnz_per_row: 4.0,
1253            max_nnz_per_row: 50,
1254        };
1255
1256        assert_eq!(
1257            router.select_algorithm(
1258                &profile,
1259                &QueryType::PageRankPairwise {
1260                    source: 0,
1261                    target: 100,
1262                }
1263            ),
1264            Algorithm::HybridRandomWalk
1265        );
1266    }
1267
1268    #[test]
1269    fn routes_small_pairwise_to_forward_push() {
1270        let router = SolverRouter::new(RouterConfig::default());
1271        let profile = SparsityProfile {
1272            rows: 500,
1273            cols: 500,
1274            nnz: 2000,
1275            density: 0.008,
1276            is_diag_dominant: false,
1277            estimated_spectral_radius: 0.85,
1278            estimated_condition: 100.0,
1279            is_symmetric_structure: false,
1280            avg_nnz_per_row: 4.0,
1281            max_nnz_per_row: 10,
1282        };
1283
1284        assert_eq!(
1285            router.select_algorithm(
1286                &profile,
1287                &QueryType::PageRankPairwise {
1288                    source: 0,
1289                    target: 10,
1290                }
1291            ),
1292            Algorithm::ForwardPush
1293        );
1294    }
1295
1296    #[test]
1297    fn routes_spectral_filter_to_neumann() {
1298        let router = SolverRouter::new(RouterConfig::default());
1299        let profile = SparsityProfile {
1300            rows: 100,
1301            cols: 100,
1302            nnz: 500,
1303            density: 0.05,
1304            is_diag_dominant: true,
1305            estimated_spectral_radius: 0.3,
1306            estimated_condition: 5.0,
1307            is_symmetric_structure: true,
1308            avg_nnz_per_row: 5.0,
1309            max_nnz_per_row: 8,
1310        };
1311
1312        assert_eq!(
1313            router.select_algorithm(
1314                &profile,
1315                &QueryType::SpectralFilter {
1316                    polynomial_degree: 10,
1317                }
1318            ),
1319            Algorithm::Neumann
1320        );
1321    }
1322
1323    #[test]
1324    fn routes_large_batch_to_true() {
1325        let router = SolverRouter::new(RouterConfig::default());
1326        let profile = SparsityProfile {
1327            rows: 1000,
1328            cols: 1000,
1329            nnz: 5000,
1330            density: 0.005,
1331            is_diag_dominant: true,
1332            estimated_spectral_radius: 0.5,
1333            estimated_condition: 10.0,
1334            is_symmetric_structure: true,
1335            avg_nnz_per_row: 5.0,
1336            max_nnz_per_row: 10,
1337        };
1338
1339        assert_eq!(
1340            router.select_algorithm(&profile, &QueryType::BatchLinearSystem { batch_size: 200 }),
1341            Algorithm::TRUE
1342        );
1343    }
1344
1345    #[test]
1346    fn routes_small_batch_to_cg() {
1347        let router = SolverRouter::new(RouterConfig::default());
1348        let profile = SparsityProfile {
1349            rows: 1000,
1350            cols: 1000,
1351            nnz: 5000,
1352            density: 0.005,
1353            is_diag_dominant: true,
1354            estimated_spectral_radius: 0.5,
1355            estimated_condition: 10.0,
1356            is_symmetric_structure: true,
1357            avg_nnz_per_row: 5.0,
1358            max_nnz_per_row: 10,
1359        };
1360
1361        assert_eq!(
1362            router.select_algorithm(&profile, &QueryType::BatchLinearSystem { batch_size: 50 }),
1363            Algorithm::CG
1364        );
1365    }
1366
1367    #[test]
1368    fn custom_config_overrides_thresholds() {
1369        let config = RouterConfig {
1370            cg_condition_threshold: 10.0,
1371            ..Default::default()
1372        };
1373        let router = SolverRouter::new(config);
1374
1375        let profile = SparsityProfile {
1376            rows: 1000,
1377            cols: 1000,
1378            nnz: 50_000,
1379            density: 0.05,
1380            is_diag_dominant: false,
1381            estimated_spectral_radius: 0.9,
1382            estimated_condition: 50.0,
1383            is_symmetric_structure: true,
1384            avg_nnz_per_row: 50.0,
1385            max_nnz_per_row: 80,
1386        };
1387
1388        assert_eq!(
1389            router.select_algorithm(&profile, &QueryType::LinearSystem),
1390            Algorithm::BMSSP
1391        );
1392    }
1393
1394    #[test]
1395    fn neumann_requires_low_spectral_radius() {
1396        let router = SolverRouter::new(RouterConfig::default());
1397        let profile = SparsityProfile {
1398            rows: 1000,
1399            cols: 1000,
1400            nnz: 3000,
1401            density: 0.003,
1402            is_diag_dominant: true,
1403            estimated_spectral_radius: 0.96, // above 0.95 threshold
1404            estimated_condition: 10.0,
1405            is_symmetric_structure: true,
1406            avg_nnz_per_row: 3.0,
1407            max_nnz_per_row: 5,
1408        };
1409
1410        // Should fall through to CG, not Neumann.
1411        assert_eq!(
1412            router.select_algorithm(&profile, &QueryType::LinearSystem),
1413            Algorithm::CG
1414        );
1415    }
1416
1417    // -----------------------------------------------------------------------
1418    // SparsityProfile analysis tests
1419    // -----------------------------------------------------------------------
1420
1421    #[test]
1422    fn analyze_identity_matrix() {
1423        let matrix = CsrMatrix::<f64>::identity(5);
1424        let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1425
1426        assert_eq!(profile.rows, 5);
1427        assert_eq!(profile.cols, 5);
1428        assert_eq!(profile.nnz, 5);
1429        assert!(profile.is_diag_dominant);
1430        assert!((profile.density - 0.2).abs() < 1e-10);
1431        assert!(profile.estimated_spectral_radius.abs() < 1e-10);
1432        assert!((profile.estimated_condition - 1.0).abs() < 1e-10);
1433        assert!(profile.is_symmetric_structure);
1434        assert_eq!(profile.max_nnz_per_row, 1);
1435    }
1436
1437    #[test]
1438    fn analyze_diag_dominant() {
1439        let matrix = diag_dominant_3x3();
1440        let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1441
1442        assert!(profile.is_diag_dominant);
1443        assert!(profile.estimated_spectral_radius < 1.0);
1444        assert!(profile.is_symmetric_structure);
1445    }
1446
1447    #[test]
1448    fn analyze_empty_matrix() {
1449        let matrix = CsrMatrix::<f64> {
1450            row_ptr: vec![0],
1451            col_indices: vec![],
1452            values: vec![],
1453            rows: 0,
1454            cols: 0,
1455        };
1456        let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1457
1458        assert_eq!(profile.rows, 0);
1459        assert_eq!(profile.nnz, 0);
1460        assert_eq!(profile.density, 0.0);
1461    }
1462
1463    // -----------------------------------------------------------------------
1464    // Orchestrator solve tests
1465    // -----------------------------------------------------------------------
1466
1467    #[test]
1468    fn orchestrator_solve_identity() {
1469        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1470        let matrix = CsrMatrix::<f64>::identity(4);
1471        let rhs = vec![1.0_f64, 2.0, 3.0, 4.0];
1472        let budget = default_budget();
1473
1474        let result = orchestrator
1475            .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
1476            .unwrap();
1477
1478        for (x, b) in result.solution.iter().zip(rhs.iter()) {
1479            assert!((*x as f64 - b).abs() < 1e-4, "expected {}, got {}", b, x);
1480        }
1481    }
1482
1483    #[test]
1484    fn orchestrator_solve_diag_dominant() {
1485        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1486        let matrix = diag_dominant_3x3();
1487        let rhs = vec![1.0_f64, 0.0, 1.0];
1488        let budget = default_budget();
1489
1490        let result = orchestrator
1491            .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
1492            .unwrap();
1493
1494        assert!(result.residual_norm < 1e-6);
1495    }
1496
1497    #[test]
1498    fn orchestrator_solve_with_fallback_succeeds() {
1499        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1500        let matrix = diag_dominant_3x3();
1501        let rhs = vec![1.0_f64, 0.0, 1.0];
1502        let budget = default_budget();
1503
1504        let result = orchestrator
1505            .solve_with_fallback(&matrix, &rhs, QueryType::LinearSystem, &budget)
1506            .unwrap();
1507
1508        assert!(result.residual_norm < 1e-6);
1509    }
1510
1511    #[test]
1512    fn orchestrator_dimension_mismatch() {
1513        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1514        let matrix = CsrMatrix::<f64>::identity(3);
1515        let rhs = vec![1.0_f64, 2.0]; // wrong length
1516        let budget = default_budget();
1517
1518        let result = orchestrator.solve(&matrix, &rhs, QueryType::LinearSystem, &budget);
1519        assert!(result.is_err());
1520    }
1521
1522    #[test]
1523    fn estimate_complexity_returns_reasonable_values() {
1524        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1525        let matrix = diag_dominant_3x3();
1526
1527        let estimate = orchestrator.estimate_complexity(&matrix, &QueryType::LinearSystem);
1528
1529        assert!(estimate.estimated_flops > 0);
1530        assert!(estimate.estimated_memory_bytes > 0);
1531        assert!(estimate.estimated_iterations > 0);
1532    }
1533
1534    #[test]
1535    fn fallback_chain_deduplicates() {
1536        let chain = SolverOrchestrator::build_fallback_chain(Algorithm::CG);
1537        assert_eq!(chain, vec![Algorithm::CG, Algorithm::Dense]);
1538
1539        let chain = SolverOrchestrator::build_fallback_chain(Algorithm::Dense);
1540        assert_eq!(chain, vec![Algorithm::Dense, Algorithm::CG]);
1541
1542        let chain = SolverOrchestrator::build_fallback_chain(Algorithm::Neumann);
1543        assert_eq!(
1544            chain,
1545            vec![Algorithm::Neumann, Algorithm::CG, Algorithm::Dense]
1546        );
1547    }
1548
1549    #[test]
1550    fn cg_inline_solves_spd_system() {
1551        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1552        let matrix = diag_dominant_3x3();
1553        let rhs = vec![1.0_f64, 2.0, 3.0];
1554        let budget = default_budget();
1555
1556        let result = orchestrator
1557            .solve_cg_inline(&matrix, &rhs, &budget)
1558            .unwrap();
1559
1560        assert!(result.residual_norm < 1e-6);
1561        assert_eq!(result.algorithm, Algorithm::CG);
1562    }
1563
1564    #[test]
1565    fn dense_solves_small_system() {
1566        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1567        let matrix = diag_dominant_3x3();
1568        let rhs = vec![1.0_f64, 2.0, 3.0];
1569        let budget = default_budget();
1570
1571        let result = orchestrator.solve_dense(&matrix, &rhs, &budget).unwrap();
1572
1573        assert!(result.residual_norm < 1e-4);
1574        assert_eq!(result.algorithm, Algorithm::Dense);
1575    }
1576
1577    #[test]
1578    fn dense_rejects_non_square() {
1579        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1580        let matrix = CsrMatrix::<f64> {
1581            row_ptr: vec![0, 1, 2],
1582            col_indices: vec![0, 1],
1583            values: vec![1.0, 1.0],
1584            rows: 2,
1585            cols: 3,
1586        };
1587        let rhs = vec![1.0_f64, 1.0];
1588        let budget = default_budget();
1589
1590        assert!(orchestrator.solve_dense(&matrix, &rhs, &budget).is_err());
1591    }
1592
1593    #[test]
1594    fn cg_and_dense_agree_on_solution() {
1595        let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1596        let matrix = diag_dominant_3x3();
1597        let rhs = vec![3.0_f64, -1.0, 2.0];
1598        let budget = default_budget();
1599
1600        let cg_result = orchestrator
1601            .solve_cg_inline(&matrix, &rhs, &budget)
1602            .unwrap();
1603        let dense_result = orchestrator.solve_dense(&matrix, &rhs, &budget).unwrap();
1604
1605        for (cg_x, dense_x) in cg_result.solution.iter().zip(dense_result.solution.iter()) {
1606            assert!(
1607                (cg_x - dense_x).abs() < 1e-3,
1608                "CG={} vs Dense={}",
1609                cg_x,
1610                dense_x
1611            );
1612        }
1613    }
1614}