1use std::time::Instant;
31
32use tracing::{debug, info, warn};
33
34use crate::error::SolverError;
35use crate::traits::SolverEngine;
36use crate::types::{
37 Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
38 QueryType, SolverResult, SparsityProfile,
39};
40
41#[derive(Debug, Clone)]
61pub struct RouterConfig {
62 pub neumann_spectral_radius_threshold: f64,
69
70 pub cg_condition_threshold: f64,
77
78 pub sparsity_sublinear_threshold: f64,
85
86 pub true_batch_threshold: usize,
91
92 pub push_graph_size_threshold: usize,
97}
98
99impl Default for RouterConfig {
100 fn default() -> Self {
101 Self {
102 neumann_spectral_radius_threshold: 0.95,
103 cg_condition_threshold: 100.0,
104 sparsity_sublinear_threshold: 0.05,
105 true_batch_threshold: 100,
106 push_graph_size_threshold: 1_000,
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
145pub struct SolverRouter {
146 config: RouterConfig,
147}
148
149impl SolverRouter {
150 pub fn new(config: RouterConfig) -> Self {
152 Self { config }
153 }
154
155 pub fn config(&self) -> &RouterConfig {
157 &self.config
158 }
159
160 pub fn select_algorithm(
165 &self,
166 profile: &SparsityProfile,
167 query: &QueryType,
168 ) -> Algorithm {
169 match query {
170 QueryType::LinearSystem => self.route_linear_system(profile),
174
175 QueryType::PageRankSingle { .. } => {
179 debug!("routing to ForwardPush (single-source PageRank)");
180 Algorithm::ForwardPush
181 }
182
183 QueryType::PageRankPairwise { .. } => {
187 if profile.rows > self.config.push_graph_size_threshold {
188 debug!(
189 rows = profile.rows,
190 threshold = self.config.push_graph_size_threshold,
191 "routing to HybridRandomWalk (large graph pairwise PPR)"
192 );
193 Algorithm::HybridRandomWalk
194 } else {
195 debug!(
196 rows = profile.rows,
197 "routing to ForwardPush (small graph pairwise PPR)"
198 );
199 Algorithm::ForwardPush
200 }
201 }
202
203 QueryType::SpectralFilter { .. } => {
207 debug!("routing to Neumann (spectral filter)");
208 Algorithm::Neumann
209 }
210
211 QueryType::BatchLinearSystem { batch_size } => {
215 if *batch_size > self.config.true_batch_threshold {
216 debug!(
217 batch_size,
218 threshold = self.config.true_batch_threshold,
219 "routing to TRUE (large batch)"
220 );
221 Algorithm::TRUE
222 } else {
223 debug!(batch_size, "routing to CG (small batch)");
224 Algorithm::CG
225 }
226 }
227 }
228 }
229
230 fn route_linear_system(&self, profile: &SparsityProfile) -> Algorithm {
232 if profile.is_diag_dominant
233 && profile.density < self.config.sparsity_sublinear_threshold
234 && profile.estimated_spectral_radius
235 < self.config.neumann_spectral_radius_threshold
236 {
237 debug!(
238 density = profile.density,
239 spectral_radius = profile.estimated_spectral_radius,
240 "routing to Neumann (diag-dominant, sparse, low spectral radius)"
241 );
242 Algorithm::Neumann
243 } else if profile.estimated_condition < self.config.cg_condition_threshold {
244 debug!(
245 condition = profile.estimated_condition,
246 "routing to CG (well-conditioned)"
247 );
248 Algorithm::CG
249 } else {
250 debug!(
251 condition = profile.estimated_condition,
252 "routing to BMSSP (ill-conditioned)"
253 );
254 Algorithm::BMSSP
255 }
256 }
257}
258
259impl Default for SolverRouter {
260 fn default() -> Self {
261 Self::new(RouterConfig::default())
262 }
263}
264
265#[derive(Debug, Clone)]
298pub struct SolverOrchestrator {
299 router: SolverRouter,
300}
301
302impl SolverOrchestrator {
303 pub fn new(config: RouterConfig) -> Self {
305 Self {
306 router: SolverRouter::new(config),
307 }
308 }
309
310 pub fn router(&self) -> &SolverRouter {
312 &self.router
313 }
314
315 pub fn solve(
329 &self,
330 matrix: &CsrMatrix<f64>,
331 rhs: &[f64],
332 query: QueryType,
333 budget: &ComputeBudget,
334 ) -> Result<SolverResult, SolverError> {
335 let profile = Self::analyze_sparsity(matrix);
336 let algorithm = self.router.select_algorithm(&profile, &query);
337
338 info!(%algorithm, rows = matrix.rows, nnz = matrix.nnz(), "solve: selected algorithm");
339
340 self.dispatch(algorithm, matrix, rhs, budget)
341 }
342
343 pub fn solve_with_fallback(
357 &self,
358 matrix: &CsrMatrix<f64>,
359 rhs: &[f64],
360 query: QueryType,
361 budget: &ComputeBudget,
362 ) -> Result<SolverResult, SolverError> {
363 let profile = Self::analyze_sparsity(matrix);
364 let primary = self.router.select_algorithm(&profile, &query);
365
366 let chain = Self::build_fallback_chain(primary);
367
368 info!(
369 ?chain,
370 rows = matrix.rows,
371 nnz = matrix.nnz(),
372 "solve_with_fallback: attempting chain"
373 );
374
375 let mut last_err: Option<SolverError> = None;
376
377 for (idx, &algorithm) in chain.iter().enumerate() {
378 match self.dispatch(algorithm, matrix, rhs, budget) {
379 Ok(result) => {
380 if idx > 0 {
381 info!(
382 %algorithm,
383 "fallback succeeded on attempt {}",
384 idx + 1
385 );
386 }
387 return Ok(result);
388 }
389 Err(e) => {
390 warn!(
391 %algorithm,
392 error = %e,
393 "algorithm failed, trying next in fallback chain"
394 );
395 last_err = Some(e);
396 }
397 }
398 }
399
400 Err(last_err.unwrap_or_else(|| {
401 SolverError::BackendError("fallback chain was empty".into())
402 }))
403 }
404
405 pub fn estimate_complexity(
411 &self,
412 matrix: &CsrMatrix<f64>,
413 query: &QueryType,
414 ) -> ComplexityEstimate {
415 let profile = Self::analyze_sparsity(matrix);
416 let algorithm = self.router.select_algorithm(&profile, query);
417 let n = profile.rows;
418
419 let (estimated_iterations, complexity_class) = match algorithm {
420 Algorithm::Neumann => {
421 let k = if profile.estimated_spectral_radius > 0.0
422 && profile.estimated_spectral_radius < 1.0
423 {
424 let log_inv_eps = (1.0 / 1e-8_f64).ln();
425 let log_inv_rho =
426 (1.0 / profile.estimated_spectral_radius).ln();
427 (log_inv_eps / log_inv_rho).ceil() as usize
428 } else {
429 1000
430 };
431 (k.min(1000), ComplexityClass::SublinearNnz)
432 }
433 Algorithm::CG => {
434 let iters = (profile.estimated_condition.sqrt()).ceil() as usize;
435 (iters.min(1000), ComplexityClass::SqrtCondition)
436 }
437 Algorithm::ForwardPush | Algorithm::BackwardPush => {
438 let iters = ((n as f64).sqrt()).ceil() as usize;
439 (iters, ComplexityClass::SublinearNnz)
440 }
441 Algorithm::HybridRandomWalk => {
442 (n.min(1000), ComplexityClass::Linear)
443 }
444 Algorithm::TRUE => {
445 let iters = (profile.estimated_condition.sqrt()).ceil() as usize;
446 (iters.min(1000), ComplexityClass::SqrtCondition)
447 }
448 Algorithm::BMSSP => {
449 let iters = (profile.estimated_condition.sqrt().ln())
450 .ceil() as usize;
451 (iters.max(1).min(1000), ComplexityClass::Linear)
452 }
453 Algorithm::Dense => (1, ComplexityClass::Cubic),
454 Algorithm::Jacobi | Algorithm::GaussSeidel => {
455 (1000, ComplexityClass::Linear)
456 }
457 };
458
459 let estimated_flops = match algorithm {
460 Algorithm::Dense => {
461 let dim = n as u64;
462 (2 * dim * dim * dim) / 3
463 }
464 _ => {
465 (estimated_iterations as u64)
466 * (2 * profile.nnz as u64 + n as u64)
467 }
468 };
469
470 let estimated_memory_bytes = match algorithm {
471 Algorithm::Dense => {
472 n * profile.cols * std::mem::size_of::<f64>()
473 }
474 _ => {
475 let csr = profile.nnz
477 * (std::mem::size_of::<f64>() + std::mem::size_of::<usize>())
478 + (n + 1) * std::mem::size_of::<usize>();
479 let work = 3 * n * std::mem::size_of::<f64>();
480 csr + work
481 }
482 };
483
484 ComplexityEstimate {
485 algorithm,
486 estimated_flops,
487 estimated_iterations,
488 estimated_memory_bytes,
489 complexity_class,
490 }
491 }
492
493 pub fn analyze_sparsity(matrix: &CsrMatrix<f64>) -> SparsityProfile {
499 let n = matrix.rows;
500 let m = matrix.cols;
501 let nnz = matrix.nnz();
502 let total_entries = (n as f64) * (m as f64);
503 let density = if total_entries > 0.0 {
504 nnz as f64 / total_entries
505 } else {
506 0.0
507 };
508
509 let mut is_diag_dominant = true;
510 let mut max_nnz_per_row: usize = 0;
511 let mut sum_off_diag_ratio = 0.0_f64;
512 let mut diag_min = f64::INFINITY;
513 let mut diag_max = 0.0_f64;
514 let mut symmetric_mismatches: usize = 0;
515
516 let check_symmetry = nnz <= 100_000;
518
519 for row in 0..n {
520 let start = matrix.row_ptr[row];
521 let end = matrix.row_ptr[row + 1];
522 let row_nnz = end - start;
523 max_nnz_per_row = max_nnz_per_row.max(row_nnz);
524
525 let mut diag_val: f64 = 0.0;
526 let mut off_diag_sum: f64 = 0.0;
527
528 for idx in start..end {
529 let col = matrix.col_indices[idx];
530 let val = matrix.values[idx];
531
532 if col == row {
533 diag_val = val.abs();
534 } else {
535 off_diag_sum += val.abs();
536 }
537
538 if check_symmetry && col != row && col < n {
540 let col_start = matrix.row_ptr[col];
541 let col_end = matrix.row_ptr[col + 1];
542 let found = matrix.col_indices[col_start..col_end]
543 .binary_search(&row)
544 .is_ok();
545 if !found {
546 symmetric_mismatches += 1;
547 }
548 }
549 }
550
551 if diag_val <= off_diag_sum {
552 is_diag_dominant = false;
553 }
554
555 if diag_val > 0.0 {
556 let ratio = off_diag_sum / diag_val;
557 sum_off_diag_ratio += ratio;
558 diag_min = diag_min.min(diag_val);
559 diag_max = diag_max.max(diag_val);
560 } else if n > 0 {
561 is_diag_dominant = false;
562 sum_off_diag_ratio += 1.0;
563 }
564 }
565
566 let avg_nnz_per_row = if n > 0 {
567 nnz as f64 / n as f64
568 } else {
569 0.0
570 };
571
572 let estimated_spectral_radius = if n > 0 {
574 sum_off_diag_ratio / n as f64
575 } else {
576 0.0
577 };
578
579 let estimated_condition = if diag_min > 0.0 && diag_min.is_finite() {
581 diag_max / diag_min
582 } else {
583 f64::INFINITY
584 };
585
586 let is_symmetric_structure = if check_symmetry {
587 symmetric_mismatches == 0
588 } else {
589 n == m
590 };
591
592 SparsityProfile {
593 rows: n,
594 cols: m,
595 nnz,
596 density,
597 is_diag_dominant,
598 estimated_spectral_radius,
599 estimated_condition,
600 is_symmetric_structure,
601 avg_nnz_per_row,
602 max_nnz_per_row,
603 }
604 }
605
606 fn build_fallback_chain(primary: Algorithm) -> Vec<Algorithm> {
612 let mut chain = Vec::with_capacity(3);
613 chain.push(primary);
614
615 if primary != Algorithm::CG {
616 chain.push(Algorithm::CG);
617 }
618 if primary != Algorithm::Dense {
619 chain.push(Algorithm::Dense);
620 }
621
622 chain
623 }
624
625 fn dispatch(
630 &self,
631 algorithm: Algorithm,
632 matrix: &CsrMatrix<f64>,
633 rhs: &[f64],
634 budget: &ComputeBudget,
635 ) -> Result<SolverResult, SolverError> {
636 match algorithm {
637 Algorithm::Neumann => {
639 #[cfg(feature = "neumann")]
640 {
641 let solver = crate::neumann::NeumannSolver::new(
642 budget.tolerance,
643 budget.max_iterations,
644 );
645 SolverEngine::solve(&solver, matrix, rhs, budget)
646 }
647 #[cfg(not(feature = "neumann"))]
648 {
649 Err(SolverError::BackendError(
650 "neumann feature is not enabled".into(),
651 ))
652 }
653 }
654
655 Algorithm::CG => {
657 #[cfg(feature = "cg")]
658 {
659 let solver =
660 crate::cg::ConjugateGradientSolver::new(budget.tolerance, budget.max_iterations, false);
661 solver.solve(matrix, rhs, budget)
662 }
663 #[cfg(not(feature = "cg"))]
664 {
665 self.solve_cg_inline(matrix, rhs, budget)
667 }
668 }
669
670 Algorithm::ForwardPush => {
672 #[cfg(feature = "forward-push")]
673 {
674 self.solve_jacobi_fallback(
675 Algorithm::ForwardPush,
676 matrix,
677 rhs,
678 budget,
679 )
680 }
681 #[cfg(not(feature = "forward-push"))]
682 {
683 Err(SolverError::BackendError(
684 "forward-push feature is not enabled".into(),
685 ))
686 }
687 }
688
689 Algorithm::BackwardPush => {
691 #[cfg(feature = "backward-push")]
692 {
693 self.solve_jacobi_fallback(
694 Algorithm::BackwardPush,
695 matrix,
696 rhs,
697 budget,
698 )
699 }
700 #[cfg(not(feature = "backward-push"))]
701 {
702 Err(SolverError::BackendError(
703 "backward-push feature is not enabled".into(),
704 ))
705 }
706 }
707
708 Algorithm::HybridRandomWalk => {
710 #[cfg(feature = "hybrid-random-walk")]
711 {
712 self.solve_jacobi_fallback(
713 Algorithm::HybridRandomWalk,
714 matrix,
715 rhs,
716 budget,
717 )
718 }
719 #[cfg(not(feature = "hybrid-random-walk"))]
720 {
721 Err(SolverError::BackendError(
722 "hybrid-random-walk feature is not enabled".into(),
723 ))
724 }
725 }
726
727 Algorithm::TRUE => {
729 #[cfg(feature = "true-solver")]
730 {
731 let solver = crate::neumann::NeumannSolver::new(
733 budget.tolerance,
734 budget.max_iterations,
735 );
736 let mut result = SolverEngine::solve(&solver, matrix, rhs, budget)?;
737 result.algorithm = Algorithm::TRUE;
738 Ok(result)
739 }
740 #[cfg(not(feature = "true-solver"))]
741 {
742 Err(SolverError::BackendError(
743 "true-solver feature is not enabled".into(),
744 ))
745 }
746 }
747
748 Algorithm::BMSSP => {
750 #[cfg(feature = "bmssp")]
751 {
752 self.solve_jacobi_fallback(
753 Algorithm::BMSSP,
754 matrix,
755 rhs,
756 budget,
757 )
758 }
759 #[cfg(not(feature = "bmssp"))]
760 {
761 Err(SolverError::BackendError(
762 "bmssp feature is not enabled".into(),
763 ))
764 }
765 }
766
767 Algorithm::Dense => self.solve_dense(matrix, rhs, budget),
769
770 Algorithm::Jacobi => {
772 self.solve_jacobi_fallback(Algorithm::Jacobi, matrix, rhs, budget)
773 }
774 Algorithm::GaussSeidel => {
775 self.solve_jacobi_fallback(
776 Algorithm::GaussSeidel,
777 matrix,
778 rhs,
779 budget,
780 )
781 }
782 }
783 }
784
785 #[allow(dead_code)]
790 fn solve_cg_inline(
791 &self,
792 matrix: &CsrMatrix<f64>,
793 rhs: &[f64],
794 budget: &ComputeBudget,
795 ) -> Result<SolverResult, SolverError> {
796 let n = matrix.rows;
797 validate_square(matrix)?;
798 validate_rhs_len(matrix, rhs)?;
799
800 let max_iters = budget.max_iterations;
801 let tol = budget.tolerance;
802 let start = Instant::now();
803
804 let mut x = vec![0.0_f64; n];
805 let mut r: Vec<f64> = rhs.to_vec();
806 let mut p = r.clone();
807 let mut ap = vec![0.0_f64; n];
808 let mut convergence_history = Vec::new();
809
810 let mut r_dot_r = dot(&r, &r);
811
812 for iter in 0..max_iters {
813 let residual_norm = r_dot_r.sqrt();
814
815 convergence_history.push(ConvergenceInfo {
816 iteration: iter,
817 residual_norm,
818 });
819
820 if residual_norm.is_nan() || residual_norm.is_infinite() {
821 return Err(SolverError::NumericalInstability {
822 iteration: iter,
823 detail: format!("CG residual became {}", residual_norm),
824 });
825 }
826
827 if residual_norm < tol {
828 return Ok(SolverResult {
829 solution: x.iter().map(|&v| v as f32).collect(),
830 iterations: iter,
831 residual_norm,
832 wall_time: start.elapsed(),
833 convergence_history,
834 algorithm: Algorithm::CG,
835 });
836 }
837
838 matrix.spmv(&p, &mut ap);
840
841 let p_dot_ap = dot(&p, &ap);
842 if p_dot_ap.abs() < 1e-30 {
843 return Err(SolverError::NumericalInstability {
844 iteration: iter,
845 detail: "CG: p^T A p near zero (matrix may not be SPD)"
846 .into(),
847 });
848 }
849
850 let alpha = r_dot_r / p_dot_ap;
851
852 for i in 0..n {
853 x[i] += alpha * p[i];
854 r[i] -= alpha * ap[i];
855 }
856
857 let new_r_dot_r = dot(&r, &r);
858 let beta = new_r_dot_r / r_dot_r;
859
860 for i in 0..n {
861 p[i] = r[i] + beta * p[i];
862 }
863
864 r_dot_r = new_r_dot_r;
865
866 if start.elapsed() > budget.max_time {
867 return Err(SolverError::BudgetExhausted {
868 reason: "wall-clock time limit exceeded".into(),
869 elapsed: start.elapsed(),
870 });
871 }
872 }
873
874 let final_residual = convergence_history
875 .last()
876 .map(|c| c.residual_norm)
877 .unwrap_or(f64::INFINITY);
878
879 Err(SolverError::NonConvergence {
880 iterations: max_iters,
881 residual: final_residual,
882 tolerance: tol,
883 })
884 }
885
886 fn solve_dense(
890 &self,
891 matrix: &CsrMatrix<f64>,
892 rhs: &[f64],
893 _budget: &ComputeBudget,
894 ) -> Result<SolverResult, SolverError> {
895 let n = matrix.rows;
896 validate_square(matrix)?;
897 validate_rhs_len(matrix, rhs)?;
898
899 const MAX_DENSE_DIM: usize = 4096;
900 if n > MAX_DENSE_DIM {
901 return Err(SolverError::InvalidInput(
902 crate::error::ValidationError::MatrixTooLarge {
903 rows: n,
904 cols: n,
905 max_dim: MAX_DENSE_DIM,
906 },
907 ));
908 }
909
910 let start = Instant::now();
911
912 let stride = n + 1;
914 let mut aug = vec![0.0_f64; n * stride];
915 for row in 0..n {
916 let rs = matrix.row_ptr[row];
917 let re = matrix.row_ptr[row + 1];
918 for idx in rs..re {
919 let col = matrix.col_indices[idx];
920 aug[row * stride + col] = matrix.values[idx];
921 }
922 aug[row * stride + n] = rhs[row];
923 }
924
925 for col in 0..n {
927 let mut max_val = aug[col * stride + col].abs();
928 let mut max_row = col;
929 for row in (col + 1)..n {
930 let val = aug[row * stride + col].abs();
931 if val > max_val {
932 max_val = val;
933 max_row = row;
934 }
935 }
936
937 if max_val < 1e-12 {
938 return Err(SolverError::NumericalInstability {
939 iteration: 0,
940 detail: format!(
941 "dense solver: near-zero pivot ({:.2e}) at column {}",
942 max_val, col
943 ),
944 });
945 }
946
947 if max_row != col {
948 for j in 0..stride {
949 aug.swap(col * stride + j, max_row * stride + j);
950 }
951 }
952
953 let pivot = aug[col * stride + col];
954 for row in (col + 1)..n {
955 let factor = aug[row * stride + col] / pivot;
956 aug[row * stride + col] = 0.0;
957 for j in (col + 1)..stride {
958 let above = aug[col * stride + j];
959 aug[row * stride + j] -= factor * above;
960 }
961 }
962 }
963
964 let mut solution_f64 = vec![0.0_f64; n];
966 for row in (0..n).rev() {
967 let mut sum = aug[row * stride + n];
968 for col in (row + 1)..n {
969 sum -= aug[row * stride + col] * solution_f64[col];
970 }
971 solution_f64[row] = sum / aug[row * stride + row];
972 }
973
974 let mut ax = vec![0.0_f64; n];
976 matrix.spmv(&solution_f64, &mut ax);
977 let residual_norm: f64 = (0..n)
978 .map(|i| {
979 let r = rhs[i] - ax[i];
980 r * r
981 })
982 .sum::<f64>()
983 .sqrt();
984
985 let solution: Vec<f32> = solution_f64.iter().map(|&v| v as f32).collect();
986
987 Ok(SolverResult {
988 solution,
989 iterations: 1,
990 residual_norm,
991 wall_time: start.elapsed(),
992 convergence_history: vec![ConvergenceInfo {
993 iteration: 0,
994 residual_norm,
995 }],
996 algorithm: Algorithm::Dense,
997 })
998 }
999
1000 fn solve_jacobi_fallback(
1006 &self,
1007 algorithm: Algorithm,
1008 matrix: &CsrMatrix<f64>,
1009 rhs: &[f64],
1010 budget: &ComputeBudget,
1011 ) -> Result<SolverResult, SolverError> {
1012 let n = matrix.rows;
1013 validate_square(matrix)?;
1014 validate_rhs_len(matrix, rhs)?;
1015
1016 let max_iters = budget.max_iterations;
1017 let tol = budget.tolerance;
1018 let start = Instant::now();
1019
1020 let mut diag = vec![0.0_f64; n];
1022 for row in 0..n {
1023 let rs = matrix.row_ptr[row];
1024 let re = matrix.row_ptr[row + 1];
1025 for idx in rs..re {
1026 if matrix.col_indices[idx] == row {
1027 diag[row] = matrix.values[idx];
1028 break;
1029 }
1030 }
1031 }
1032
1033 for (i, &d) in diag.iter().enumerate() {
1034 if d.abs() < 1e-30 {
1035 return Err(SolverError::NumericalInstability {
1036 iteration: 0,
1037 detail: format!(
1038 "zero or near-zero diagonal at row {} (val={:.2e})",
1039 i, d
1040 ),
1041 });
1042 }
1043 }
1044
1045 let mut x = vec![0.0_f64; n];
1046 let mut x_new = vec![0.0_f64; n];
1047 let mut temp = vec![0.0_f64; n];
1048 let mut convergence_history = Vec::new();
1049
1050 for iter in 0..max_iters {
1051 for row in 0..n {
1052 let rs = matrix.row_ptr[row];
1053 let re = matrix.row_ptr[row + 1];
1054 let mut sum = 0.0_f64;
1055 for idx in rs..re {
1056 let col = matrix.col_indices[idx];
1057 if col != row {
1058 sum += matrix.values[idx] * x[col];
1059 }
1060 }
1061 x_new[row] = (rhs[row] - sum) / diag[row];
1062 }
1063
1064 matrix.spmv(&x_new, &mut temp);
1065 let residual_norm: f64 = (0..n)
1066 .map(|i| {
1067 let r = rhs[i] - temp[i];
1068 r * r
1069 })
1070 .sum::<f64>()
1071 .sqrt();
1072
1073 convergence_history.push(ConvergenceInfo {
1074 iteration: iter,
1075 residual_norm,
1076 });
1077
1078 if residual_norm.is_nan() || residual_norm.is_infinite() {
1079 return Err(SolverError::NumericalInstability {
1080 iteration: iter,
1081 detail: format!("residual became {}", residual_norm),
1082 });
1083 }
1084
1085 if residual_norm < tol {
1086 return Ok(SolverResult {
1087 solution: x_new.iter().map(|&v| v as f32).collect(),
1088 iterations: iter + 1,
1089 residual_norm,
1090 wall_time: start.elapsed(),
1091 convergence_history,
1092 algorithm,
1093 });
1094 }
1095
1096 std::mem::swap(&mut x, &mut x_new);
1097
1098 if start.elapsed() > budget.max_time {
1099 return Err(SolverError::BudgetExhausted {
1100 reason: "wall-clock time limit exceeded".into(),
1101 elapsed: start.elapsed(),
1102 });
1103 }
1104 }
1105
1106 let final_residual = convergence_history
1107 .last()
1108 .map(|c| c.residual_norm)
1109 .unwrap_or(f64::INFINITY);
1110
1111 Err(SolverError::NonConvergence {
1112 iterations: max_iters,
1113 residual: final_residual,
1114 tolerance: tol,
1115 })
1116 }
1117}
1118
1119impl Default for SolverOrchestrator {
1120 fn default() -> Self {
1121 Self::new(RouterConfig::default())
1122 }
1123}
1124
1125#[inline]
1131#[allow(dead_code)]
1132fn dot(a: &[f64], b: &[f64]) -> f64 {
1133 assert_eq!(a.len(), b.len(), "dot: length mismatch {} vs {}", a.len(), b.len());
1134 a.iter()
1135 .zip(b.iter())
1136 .map(|(&ai, &bi)| ai * bi)
1137 .sum()
1138}
1139
1140fn validate_square(matrix: &CsrMatrix<f64>) -> Result<(), SolverError> {
1142 if matrix.rows != matrix.cols {
1143 return Err(SolverError::InvalidInput(
1144 crate::error::ValidationError::DimensionMismatch(format!(
1145 "matrix must be square, got {}x{}",
1146 matrix.rows, matrix.cols
1147 )),
1148 ));
1149 }
1150 Ok(())
1151}
1152
1153fn validate_rhs_len(
1155 matrix: &CsrMatrix<f64>,
1156 rhs: &[f64],
1157) -> Result<(), SolverError> {
1158 if rhs.len() != matrix.rows {
1159 return Err(SolverError::InvalidInput(
1160 crate::error::ValidationError::DimensionMismatch(format!(
1161 "rhs length {} does not match matrix dimension {}",
1162 rhs.len(),
1163 matrix.rows
1164 )),
1165 ));
1166 }
1167 Ok(())
1168}
1169
1170#[cfg(test)]
1175mod tests {
1176 use super::*;
1177
1178 fn diag_dominant_3x3() -> CsrMatrix<f64> {
1180 CsrMatrix::<f64>::from_coo(
1181 3,
1182 3,
1183 vec![
1184 (0, 0, 4.0),
1185 (0, 1, -1.0),
1186 (1, 0, -1.0),
1187 (1, 1, 4.0),
1188 (1, 2, -1.0),
1189 (2, 1, -1.0),
1190 (2, 2, 4.0),
1191 ],
1192 )
1193 }
1194
1195 fn default_budget() -> ComputeBudget {
1196 ComputeBudget {
1197 tolerance: 1e-8,
1198 ..Default::default()
1199 }
1200 }
1201
1202 #[test]
1207 fn routes_diag_dominant_sparse_to_neumann() {
1208 let router = SolverRouter::new(RouterConfig::default());
1209 let profile = SparsityProfile {
1210 rows: 1000,
1211 cols: 1000,
1212 nnz: 3000,
1213 density: 0.003,
1214 is_diag_dominant: true,
1215 estimated_spectral_radius: 0.5,
1216 estimated_condition: 10.0,
1217 is_symmetric_structure: true,
1218 avg_nnz_per_row: 3.0,
1219 max_nnz_per_row: 5,
1220 };
1221
1222 assert_eq!(
1223 router.select_algorithm(&profile, &QueryType::LinearSystem),
1224 Algorithm::Neumann
1225 );
1226 }
1227
1228 #[test]
1229 fn routes_well_conditioned_non_diag_dominant_to_cg() {
1230 let router = SolverRouter::new(RouterConfig::default());
1231 let profile = SparsityProfile {
1232 rows: 1000,
1233 cols: 1000,
1234 nnz: 50_000,
1235 density: 0.05,
1236 is_diag_dominant: false,
1237 estimated_spectral_radius: 0.9,
1238 estimated_condition: 50.0,
1239 is_symmetric_structure: true,
1240 avg_nnz_per_row: 50.0,
1241 max_nnz_per_row: 80,
1242 };
1243
1244 assert_eq!(
1245 router.select_algorithm(&profile, &QueryType::LinearSystem),
1246 Algorithm::CG
1247 );
1248 }
1249
1250 #[test]
1251 fn routes_ill_conditioned_to_bmssp() {
1252 let router = SolverRouter::new(RouterConfig::default());
1253 let profile = SparsityProfile {
1254 rows: 1000,
1255 cols: 1000,
1256 nnz: 50_000,
1257 density: 0.05,
1258 is_diag_dominant: false,
1259 estimated_spectral_radius: 0.99,
1260 estimated_condition: 500.0,
1261 is_symmetric_structure: true,
1262 avg_nnz_per_row: 50.0,
1263 max_nnz_per_row: 80,
1264 };
1265
1266 assert_eq!(
1267 router.select_algorithm(&profile, &QueryType::LinearSystem),
1268 Algorithm::BMSSP
1269 );
1270 }
1271
1272 #[test]
1273 fn routes_single_pagerank_to_forward_push() {
1274 let router = SolverRouter::new(RouterConfig::default());
1275 let profile = SparsityProfile {
1276 rows: 5000,
1277 cols: 5000,
1278 nnz: 20_000,
1279 density: 0.0008,
1280 is_diag_dominant: false,
1281 estimated_spectral_radius: 0.85,
1282 estimated_condition: 100.0,
1283 is_symmetric_structure: false,
1284 avg_nnz_per_row: 4.0,
1285 max_nnz_per_row: 50,
1286 };
1287
1288 assert_eq!(
1289 router.select_algorithm(
1290 &profile,
1291 &QueryType::PageRankSingle { source: 0 }
1292 ),
1293 Algorithm::ForwardPush
1294 );
1295 }
1296
1297 #[test]
1298 fn routes_large_pairwise_to_hybrid_random_walk() {
1299 let router = SolverRouter::new(RouterConfig::default());
1300 let profile = SparsityProfile {
1301 rows: 5000,
1302 cols: 5000,
1303 nnz: 20_000,
1304 density: 0.0008,
1305 is_diag_dominant: false,
1306 estimated_spectral_radius: 0.85,
1307 estimated_condition: 100.0,
1308 is_symmetric_structure: false,
1309 avg_nnz_per_row: 4.0,
1310 max_nnz_per_row: 50,
1311 };
1312
1313 assert_eq!(
1314 router.select_algorithm(
1315 &profile,
1316 &QueryType::PageRankPairwise {
1317 source: 0,
1318 target: 100,
1319 }
1320 ),
1321 Algorithm::HybridRandomWalk
1322 );
1323 }
1324
1325 #[test]
1326 fn routes_small_pairwise_to_forward_push() {
1327 let router = SolverRouter::new(RouterConfig::default());
1328 let profile = SparsityProfile {
1329 rows: 500,
1330 cols: 500,
1331 nnz: 2000,
1332 density: 0.008,
1333 is_diag_dominant: false,
1334 estimated_spectral_radius: 0.85,
1335 estimated_condition: 100.0,
1336 is_symmetric_structure: false,
1337 avg_nnz_per_row: 4.0,
1338 max_nnz_per_row: 10,
1339 };
1340
1341 assert_eq!(
1342 router.select_algorithm(
1343 &profile,
1344 &QueryType::PageRankPairwise {
1345 source: 0,
1346 target: 10,
1347 }
1348 ),
1349 Algorithm::ForwardPush
1350 );
1351 }
1352
1353 #[test]
1354 fn routes_spectral_filter_to_neumann() {
1355 let router = SolverRouter::new(RouterConfig::default());
1356 let profile = SparsityProfile {
1357 rows: 100,
1358 cols: 100,
1359 nnz: 500,
1360 density: 0.05,
1361 is_diag_dominant: true,
1362 estimated_spectral_radius: 0.3,
1363 estimated_condition: 5.0,
1364 is_symmetric_structure: true,
1365 avg_nnz_per_row: 5.0,
1366 max_nnz_per_row: 8,
1367 };
1368
1369 assert_eq!(
1370 router.select_algorithm(
1371 &profile,
1372 &QueryType::SpectralFilter {
1373 polynomial_degree: 10,
1374 }
1375 ),
1376 Algorithm::Neumann
1377 );
1378 }
1379
1380 #[test]
1381 fn routes_large_batch_to_true() {
1382 let router = SolverRouter::new(RouterConfig::default());
1383 let profile = SparsityProfile {
1384 rows: 1000,
1385 cols: 1000,
1386 nnz: 5000,
1387 density: 0.005,
1388 is_diag_dominant: true,
1389 estimated_spectral_radius: 0.5,
1390 estimated_condition: 10.0,
1391 is_symmetric_structure: true,
1392 avg_nnz_per_row: 5.0,
1393 max_nnz_per_row: 10,
1394 };
1395
1396 assert_eq!(
1397 router.select_algorithm(
1398 &profile,
1399 &QueryType::BatchLinearSystem { batch_size: 200 }
1400 ),
1401 Algorithm::TRUE
1402 );
1403 }
1404
1405 #[test]
1406 fn routes_small_batch_to_cg() {
1407 let router = SolverRouter::new(RouterConfig::default());
1408 let profile = SparsityProfile {
1409 rows: 1000,
1410 cols: 1000,
1411 nnz: 5000,
1412 density: 0.005,
1413 is_diag_dominant: true,
1414 estimated_spectral_radius: 0.5,
1415 estimated_condition: 10.0,
1416 is_symmetric_structure: true,
1417 avg_nnz_per_row: 5.0,
1418 max_nnz_per_row: 10,
1419 };
1420
1421 assert_eq!(
1422 router.select_algorithm(
1423 &profile,
1424 &QueryType::BatchLinearSystem { batch_size: 50 }
1425 ),
1426 Algorithm::CG
1427 );
1428 }
1429
1430 #[test]
1431 fn custom_config_overrides_thresholds() {
1432 let config = RouterConfig {
1433 cg_condition_threshold: 10.0,
1434 ..Default::default()
1435 };
1436 let router = SolverRouter::new(config);
1437
1438 let profile = SparsityProfile {
1439 rows: 1000,
1440 cols: 1000,
1441 nnz: 50_000,
1442 density: 0.05,
1443 is_diag_dominant: false,
1444 estimated_spectral_radius: 0.9,
1445 estimated_condition: 50.0,
1446 is_symmetric_structure: true,
1447 avg_nnz_per_row: 50.0,
1448 max_nnz_per_row: 80,
1449 };
1450
1451 assert_eq!(
1452 router.select_algorithm(&profile, &QueryType::LinearSystem),
1453 Algorithm::BMSSP
1454 );
1455 }
1456
1457 #[test]
1458 fn neumann_requires_low_spectral_radius() {
1459 let router = SolverRouter::new(RouterConfig::default());
1460 let profile = SparsityProfile {
1461 rows: 1000,
1462 cols: 1000,
1463 nnz: 3000,
1464 density: 0.003,
1465 is_diag_dominant: true,
1466 estimated_spectral_radius: 0.96, estimated_condition: 10.0,
1468 is_symmetric_structure: true,
1469 avg_nnz_per_row: 3.0,
1470 max_nnz_per_row: 5,
1471 };
1472
1473 assert_eq!(
1475 router.select_algorithm(&profile, &QueryType::LinearSystem),
1476 Algorithm::CG
1477 );
1478 }
1479
1480 #[test]
1485 fn analyze_identity_matrix() {
1486 let matrix = CsrMatrix::<f64>::identity(5);
1487 let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1488
1489 assert_eq!(profile.rows, 5);
1490 assert_eq!(profile.cols, 5);
1491 assert_eq!(profile.nnz, 5);
1492 assert!(profile.is_diag_dominant);
1493 assert!((profile.density - 0.2).abs() < 1e-10);
1494 assert!(profile.estimated_spectral_radius.abs() < 1e-10);
1495 assert!((profile.estimated_condition - 1.0).abs() < 1e-10);
1496 assert!(profile.is_symmetric_structure);
1497 assert_eq!(profile.max_nnz_per_row, 1);
1498 }
1499
1500 #[test]
1501 fn analyze_diag_dominant() {
1502 let matrix = diag_dominant_3x3();
1503 let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1504
1505 assert!(profile.is_diag_dominant);
1506 assert!(profile.estimated_spectral_radius < 1.0);
1507 assert!(profile.is_symmetric_structure);
1508 }
1509
1510 #[test]
1511 fn analyze_empty_matrix() {
1512 let matrix = CsrMatrix::<f64> {
1513 row_ptr: vec![0],
1514 col_indices: vec![],
1515 values: vec![],
1516 rows: 0,
1517 cols: 0,
1518 };
1519 let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1520
1521 assert_eq!(profile.rows, 0);
1522 assert_eq!(profile.nnz, 0);
1523 assert_eq!(profile.density, 0.0);
1524 }
1525
1526 #[test]
1531 fn orchestrator_solve_identity() {
1532 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1533 let matrix = CsrMatrix::<f64>::identity(4);
1534 let rhs = vec![1.0_f64, 2.0, 3.0, 4.0];
1535 let budget = default_budget();
1536
1537 let result = orchestrator
1538 .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
1539 .unwrap();
1540
1541 for (x, b) in result.solution.iter().zip(rhs.iter()) {
1542 assert!(
1543 (*x as f64 - b).abs() < 1e-4,
1544 "expected {}, got {}",
1545 b,
1546 x
1547 );
1548 }
1549 }
1550
1551 #[test]
1552 fn orchestrator_solve_diag_dominant() {
1553 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1554 let matrix = diag_dominant_3x3();
1555 let rhs = vec![1.0_f64, 0.0, 1.0];
1556 let budget = default_budget();
1557
1558 let result = orchestrator
1559 .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
1560 .unwrap();
1561
1562 assert!(result.residual_norm < 1e-6);
1563 }
1564
1565 #[test]
1566 fn orchestrator_solve_with_fallback_succeeds() {
1567 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1568 let matrix = diag_dominant_3x3();
1569 let rhs = vec![1.0_f64, 0.0, 1.0];
1570 let budget = default_budget();
1571
1572 let result = orchestrator
1573 .solve_with_fallback(
1574 &matrix,
1575 &rhs,
1576 QueryType::LinearSystem,
1577 &budget,
1578 )
1579 .unwrap();
1580
1581 assert!(result.residual_norm < 1e-6);
1582 }
1583
1584 #[test]
1585 fn orchestrator_dimension_mismatch() {
1586 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1587 let matrix = CsrMatrix::<f64>::identity(3);
1588 let rhs = vec![1.0_f64, 2.0]; let budget = default_budget();
1590
1591 let result = orchestrator.solve(
1592 &matrix,
1593 &rhs,
1594 QueryType::LinearSystem,
1595 &budget,
1596 );
1597 assert!(result.is_err());
1598 }
1599
1600 #[test]
1601 fn estimate_complexity_returns_reasonable_values() {
1602 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1603 let matrix = diag_dominant_3x3();
1604
1605 let estimate =
1606 orchestrator.estimate_complexity(&matrix, &QueryType::LinearSystem);
1607
1608 assert!(estimate.estimated_flops > 0);
1609 assert!(estimate.estimated_memory_bytes > 0);
1610 assert!(estimate.estimated_iterations > 0);
1611 }
1612
1613 #[test]
1614 fn fallback_chain_deduplicates() {
1615 let chain = SolverOrchestrator::build_fallback_chain(Algorithm::CG);
1616 assert_eq!(chain, vec![Algorithm::CG, Algorithm::Dense]);
1617
1618 let chain = SolverOrchestrator::build_fallback_chain(Algorithm::Dense);
1619 assert_eq!(chain, vec![Algorithm::Dense, Algorithm::CG]);
1620
1621 let chain =
1622 SolverOrchestrator::build_fallback_chain(Algorithm::Neumann);
1623 assert_eq!(
1624 chain,
1625 vec![Algorithm::Neumann, Algorithm::CG, Algorithm::Dense]
1626 );
1627 }
1628
1629 #[test]
1630 fn cg_inline_solves_spd_system() {
1631 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1632 let matrix = diag_dominant_3x3();
1633 let rhs = vec![1.0_f64, 2.0, 3.0];
1634 let budget = default_budget();
1635
1636 let result = orchestrator
1637 .solve_cg_inline(&matrix, &rhs, &budget)
1638 .unwrap();
1639
1640 assert!(result.residual_norm < 1e-6);
1641 assert_eq!(result.algorithm, Algorithm::CG);
1642 }
1643
1644 #[test]
1645 fn dense_solves_small_system() {
1646 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1647 let matrix = diag_dominant_3x3();
1648 let rhs = vec![1.0_f64, 2.0, 3.0];
1649 let budget = default_budget();
1650
1651 let result = orchestrator
1652 .solve_dense(&matrix, &rhs, &budget)
1653 .unwrap();
1654
1655 assert!(result.residual_norm < 1e-4);
1656 assert_eq!(result.algorithm, Algorithm::Dense);
1657 }
1658
1659 #[test]
1660 fn dense_rejects_non_square() {
1661 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1662 let matrix = CsrMatrix::<f64> {
1663 row_ptr: vec![0, 1, 2],
1664 col_indices: vec![0, 1],
1665 values: vec![1.0, 1.0],
1666 rows: 2,
1667 cols: 3,
1668 };
1669 let rhs = vec![1.0_f64, 1.0];
1670 let budget = default_budget();
1671
1672 assert!(orchestrator.solve_dense(&matrix, &rhs, &budget).is_err());
1673 }
1674
1675 #[test]
1676 fn cg_and_dense_agree_on_solution() {
1677 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1678 let matrix = diag_dominant_3x3();
1679 let rhs = vec![3.0_f64, -1.0, 2.0];
1680 let budget = default_budget();
1681
1682 let cg_result = orchestrator
1683 .solve_cg_inline(&matrix, &rhs, &budget)
1684 .unwrap();
1685 let dense_result = orchestrator
1686 .solve_dense(&matrix, &rhs, &budget)
1687 .unwrap();
1688
1689 for (cg_x, dense_x) in cg_result
1690 .solution
1691 .iter()
1692 .zip(dense_result.solution.iter())
1693 {
1694 assert!(
1695 (cg_x - dense_x).abs() < 1e-3,
1696 "CG={} vs Dense={}",
1697 cg_x,
1698 dense_x
1699 );
1700 }
1701 }
1702}