1use std::time::Instant;
31
32use tracing::{debug, info, warn};
33
34use crate::error::SolverError;
35use crate::traits::SolverEngine;
36use crate::types::{
37 Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
38 QueryType, SolverResult, SparsityProfile,
39};
40
41#[derive(Debug, Clone)]
61pub struct RouterConfig {
62 pub neumann_spectral_radius_threshold: f64,
69
70 pub cg_condition_threshold: f64,
77
78 pub sparsity_sublinear_threshold: f64,
85
86 pub true_batch_threshold: usize,
91
92 pub push_graph_size_threshold: usize,
97}
98
99impl Default for RouterConfig {
100 fn default() -> Self {
101 Self {
102 neumann_spectral_radius_threshold: 0.95,
103 cg_condition_threshold: 100.0,
104 sparsity_sublinear_threshold: 0.05,
105 true_batch_threshold: 100,
106 push_graph_size_threshold: 1_000,
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
145pub struct SolverRouter {
146 config: RouterConfig,
147}
148
149impl SolverRouter {
150 pub fn new(config: RouterConfig) -> Self {
152 Self { config }
153 }
154
155 pub fn config(&self) -> &RouterConfig {
157 &self.config
158 }
159
160 pub fn select_algorithm(&self, profile: &SparsityProfile, query: &QueryType) -> Algorithm {
165 match query {
166 QueryType::LinearSystem => self.route_linear_system(profile),
170
171 QueryType::PageRankSingle { .. } => {
175 debug!("routing to ForwardPush (single-source PageRank)");
176 Algorithm::ForwardPush
177 }
178
179 QueryType::PageRankPairwise { .. } => {
183 if profile.rows > self.config.push_graph_size_threshold {
184 debug!(
185 rows = profile.rows,
186 threshold = self.config.push_graph_size_threshold,
187 "routing to HybridRandomWalk (large graph pairwise PPR)"
188 );
189 Algorithm::HybridRandomWalk
190 } else {
191 debug!(
192 rows = profile.rows,
193 "routing to ForwardPush (small graph pairwise PPR)"
194 );
195 Algorithm::ForwardPush
196 }
197 }
198
199 QueryType::SpectralFilter { .. } => {
203 debug!("routing to Neumann (spectral filter)");
204 Algorithm::Neumann
205 }
206
207 QueryType::BatchLinearSystem { batch_size } => {
211 if *batch_size > self.config.true_batch_threshold {
212 debug!(
213 batch_size,
214 threshold = self.config.true_batch_threshold,
215 "routing to TRUE (large batch)"
216 );
217 Algorithm::TRUE
218 } else {
219 debug!(batch_size, "routing to CG (small batch)");
220 Algorithm::CG
221 }
222 }
223 }
224 }
225
226 fn route_linear_system(&self, profile: &SparsityProfile) -> Algorithm {
228 if profile.is_diag_dominant
229 && profile.density < self.config.sparsity_sublinear_threshold
230 && profile.estimated_spectral_radius < self.config.neumann_spectral_radius_threshold
231 {
232 debug!(
233 density = profile.density,
234 spectral_radius = profile.estimated_spectral_radius,
235 "routing to Neumann (diag-dominant, sparse, low spectral radius)"
236 );
237 Algorithm::Neumann
238 } else if profile.estimated_condition < self.config.cg_condition_threshold {
239 debug!(
240 condition = profile.estimated_condition,
241 "routing to CG (well-conditioned)"
242 );
243 Algorithm::CG
244 } else {
245 debug!(
246 condition = profile.estimated_condition,
247 "routing to BMSSP (ill-conditioned)"
248 );
249 Algorithm::BMSSP
250 }
251 }
252}
253
254impl Default for SolverRouter {
255 fn default() -> Self {
256 Self::new(RouterConfig::default())
257 }
258}
259
260#[derive(Debug, Clone)]
293pub struct SolverOrchestrator {
294 router: SolverRouter,
295}
296
297impl SolverOrchestrator {
298 pub fn new(config: RouterConfig) -> Self {
300 Self {
301 router: SolverRouter::new(config),
302 }
303 }
304
305 pub fn router(&self) -> &SolverRouter {
307 &self.router
308 }
309
310 pub fn solve(
324 &self,
325 matrix: &CsrMatrix<f64>,
326 rhs: &[f64],
327 query: QueryType,
328 budget: &ComputeBudget,
329 ) -> Result<SolverResult, SolverError> {
330 let profile = Self::analyze_sparsity(matrix);
331 let algorithm = self.router.select_algorithm(&profile, &query);
332
333 info!(%algorithm, rows = matrix.rows, nnz = matrix.nnz(), "solve: selected algorithm");
334
335 self.dispatch(algorithm, matrix, rhs, budget)
336 }
337
338 pub fn solve_with_fallback(
352 &self,
353 matrix: &CsrMatrix<f64>,
354 rhs: &[f64],
355 query: QueryType,
356 budget: &ComputeBudget,
357 ) -> Result<SolverResult, SolverError> {
358 let profile = Self::analyze_sparsity(matrix);
359 let primary = self.router.select_algorithm(&profile, &query);
360
361 let chain = Self::build_fallback_chain(primary);
362
363 info!(
364 ?chain,
365 rows = matrix.rows,
366 nnz = matrix.nnz(),
367 "solve_with_fallback: attempting chain"
368 );
369
370 let mut last_err: Option<SolverError> = None;
371
372 for (idx, &algorithm) in chain.iter().enumerate() {
373 match self.dispatch(algorithm, matrix, rhs, budget) {
374 Ok(result) => {
375 if idx > 0 {
376 info!(
377 %algorithm,
378 "fallback succeeded on attempt {}",
379 idx + 1
380 );
381 }
382 return Ok(result);
383 }
384 Err(e) => {
385 warn!(
386 %algorithm,
387 error = %e,
388 "algorithm failed, trying next in fallback chain"
389 );
390 last_err = Some(e);
391 }
392 }
393 }
394
395 Err(last_err
396 .unwrap_or_else(|| SolverError::BackendError("fallback chain was empty".into())))
397 }
398
399 pub fn estimate_complexity(
405 &self,
406 matrix: &CsrMatrix<f64>,
407 query: &QueryType,
408 ) -> ComplexityEstimate {
409 let profile = Self::analyze_sparsity(matrix);
410 let algorithm = self.router.select_algorithm(&profile, query);
411 let n = profile.rows;
412
413 let (estimated_iterations, complexity_class) = match algorithm {
414 Algorithm::Neumann => {
415 let k = if profile.estimated_spectral_radius > 0.0
416 && profile.estimated_spectral_radius < 1.0
417 {
418 let log_inv_eps = (1.0 / 1e-8_f64).ln();
419 let log_inv_rho = (1.0 / profile.estimated_spectral_radius).ln();
420 (log_inv_eps / log_inv_rho).ceil() as usize
421 } else {
422 1000
423 };
424 (k.min(1000), ComplexityClass::SublinearNnz)
425 }
426 Algorithm::CG => {
427 let iters = (profile.estimated_condition.sqrt()).ceil() as usize;
428 (iters.min(1000), ComplexityClass::SqrtCondition)
429 }
430 Algorithm::ForwardPush | Algorithm::BackwardPush => {
431 let iters = ((n as f64).sqrt()).ceil() as usize;
432 (iters, ComplexityClass::SublinearNnz)
433 }
434 Algorithm::HybridRandomWalk => (n.min(1000), ComplexityClass::Linear),
435 Algorithm::TRUE => {
436 let iters = (profile.estimated_condition.sqrt()).ceil() as usize;
437 (iters.min(1000), ComplexityClass::SqrtCondition)
438 }
439 Algorithm::BMSSP => {
440 let iters = (profile.estimated_condition.sqrt().ln()).ceil() as usize;
441 (iters.max(1).min(1000), ComplexityClass::Linear)
442 }
443 Algorithm::Dense => (1, ComplexityClass::Cubic),
444 Algorithm::Jacobi | Algorithm::GaussSeidel => (1000, ComplexityClass::Linear),
445 };
446
447 let estimated_flops = match algorithm {
448 Algorithm::Dense => {
449 let dim = n as u64;
450 (2 * dim * dim * dim) / 3
451 }
452 _ => (estimated_iterations as u64) * (2 * profile.nnz as u64 + n as u64),
453 };
454
455 let estimated_memory_bytes = match algorithm {
456 Algorithm::Dense => n * profile.cols * std::mem::size_of::<f64>(),
457 _ => {
458 let csr = profile.nnz * (std::mem::size_of::<f64>() + std::mem::size_of::<usize>())
460 + (n + 1) * std::mem::size_of::<usize>();
461 let work = 3 * n * std::mem::size_of::<f64>();
462 csr + work
463 }
464 };
465
466 ComplexityEstimate {
467 algorithm,
468 estimated_flops,
469 estimated_iterations,
470 estimated_memory_bytes,
471 complexity_class,
472 }
473 }
474
475 pub fn analyze_sparsity(matrix: &CsrMatrix<f64>) -> SparsityProfile {
481 let n = matrix.rows;
482 let m = matrix.cols;
483 let nnz = matrix.nnz();
484 let total_entries = (n as f64) * (m as f64);
485 let density = if total_entries > 0.0 {
486 nnz as f64 / total_entries
487 } else {
488 0.0
489 };
490
491 let mut is_diag_dominant = true;
492 let mut max_nnz_per_row: usize = 0;
493 let mut sum_off_diag_ratio = 0.0_f64;
494 let mut diag_min = f64::INFINITY;
495 let mut diag_max = 0.0_f64;
496 let mut symmetric_mismatches: usize = 0;
497
498 let check_symmetry = nnz <= 100_000;
500
501 for row in 0..n {
502 let start = matrix.row_ptr[row];
503 let end = matrix.row_ptr[row + 1];
504 let row_nnz = end - start;
505 max_nnz_per_row = max_nnz_per_row.max(row_nnz);
506
507 let mut diag_val: f64 = 0.0;
508 let mut off_diag_sum: f64 = 0.0;
509
510 for idx in start..end {
511 let col = matrix.col_indices[idx];
512 let val = matrix.values[idx];
513
514 if col == row {
515 diag_val = val.abs();
516 } else {
517 off_diag_sum += val.abs();
518 }
519
520 if check_symmetry && col != row && col < n {
522 let col_start = matrix.row_ptr[col];
523 let col_end = matrix.row_ptr[col + 1];
524 let found = matrix.col_indices[col_start..col_end]
525 .binary_search(&row)
526 .is_ok();
527 if !found {
528 symmetric_mismatches += 1;
529 }
530 }
531 }
532
533 if diag_val <= off_diag_sum {
534 is_diag_dominant = false;
535 }
536
537 if diag_val > 0.0 {
538 let ratio = off_diag_sum / diag_val;
539 sum_off_diag_ratio += ratio;
540 diag_min = diag_min.min(diag_val);
541 diag_max = diag_max.max(diag_val);
542 } else if n > 0 {
543 is_diag_dominant = false;
544 sum_off_diag_ratio += 1.0;
545 }
546 }
547
548 let avg_nnz_per_row = if n > 0 { nnz as f64 / n as f64 } else { 0.0 };
549
550 let estimated_spectral_radius = if n > 0 {
552 sum_off_diag_ratio / n as f64
553 } else {
554 0.0
555 };
556
557 let estimated_condition = if diag_min > 0.0 && diag_min.is_finite() {
559 diag_max / diag_min
560 } else {
561 f64::INFINITY
562 };
563
564 let is_symmetric_structure = if check_symmetry {
565 symmetric_mismatches == 0
566 } else {
567 n == m
568 };
569
570 SparsityProfile {
571 rows: n,
572 cols: m,
573 nnz,
574 density,
575 is_diag_dominant,
576 estimated_spectral_radius,
577 estimated_condition,
578 is_symmetric_structure,
579 avg_nnz_per_row,
580 max_nnz_per_row,
581 }
582 }
583
584 fn build_fallback_chain(primary: Algorithm) -> Vec<Algorithm> {
590 let mut chain = Vec::with_capacity(3);
591 chain.push(primary);
592
593 if primary != Algorithm::CG {
594 chain.push(Algorithm::CG);
595 }
596 if primary != Algorithm::Dense {
597 chain.push(Algorithm::Dense);
598 }
599
600 chain
601 }
602
603 fn dispatch(
608 &self,
609 algorithm: Algorithm,
610 matrix: &CsrMatrix<f64>,
611 rhs: &[f64],
612 budget: &ComputeBudget,
613 ) -> Result<SolverResult, SolverError> {
614 match algorithm {
615 Algorithm::Neumann => {
617 #[cfg(feature = "neumann")]
618 {
619 let solver =
620 crate::neumann::NeumannSolver::new(budget.tolerance, budget.max_iterations);
621 SolverEngine::solve(&solver, matrix, rhs, budget)
622 }
623 #[cfg(not(feature = "neumann"))]
624 {
625 Err(SolverError::BackendError(
626 "neumann feature is not enabled".into(),
627 ))
628 }
629 }
630
631 Algorithm::CG => {
633 #[cfg(feature = "cg")]
634 {
635 let solver = crate::cg::ConjugateGradientSolver::new(
636 budget.tolerance,
637 budget.max_iterations,
638 false,
639 );
640 solver.solve(matrix, rhs, budget)
641 }
642 #[cfg(not(feature = "cg"))]
643 {
644 self.solve_cg_inline(matrix, rhs, budget)
646 }
647 }
648
649 Algorithm::ForwardPush => {
651 #[cfg(feature = "forward-push")]
652 {
653 self.solve_jacobi_fallback(Algorithm::ForwardPush, matrix, rhs, budget)
654 }
655 #[cfg(not(feature = "forward-push"))]
656 {
657 Err(SolverError::BackendError(
658 "forward-push feature is not enabled".into(),
659 ))
660 }
661 }
662
663 Algorithm::BackwardPush => {
665 #[cfg(feature = "backward-push")]
666 {
667 self.solve_jacobi_fallback(Algorithm::BackwardPush, matrix, rhs, budget)
668 }
669 #[cfg(not(feature = "backward-push"))]
670 {
671 Err(SolverError::BackendError(
672 "backward-push feature is not enabled".into(),
673 ))
674 }
675 }
676
677 Algorithm::HybridRandomWalk => {
679 #[cfg(feature = "hybrid-random-walk")]
680 {
681 self.solve_jacobi_fallback(Algorithm::HybridRandomWalk, matrix, rhs, budget)
682 }
683 #[cfg(not(feature = "hybrid-random-walk"))]
684 {
685 Err(SolverError::BackendError(
686 "hybrid-random-walk feature is not enabled".into(),
687 ))
688 }
689 }
690
691 Algorithm::TRUE => {
693 #[cfg(feature = "true-solver")]
694 {
695 let solver =
697 crate::neumann::NeumannSolver::new(budget.tolerance, budget.max_iterations);
698 let mut result = SolverEngine::solve(&solver, matrix, rhs, budget)?;
699 result.algorithm = Algorithm::TRUE;
700 Ok(result)
701 }
702 #[cfg(not(feature = "true-solver"))]
703 {
704 Err(SolverError::BackendError(
705 "true-solver feature is not enabled".into(),
706 ))
707 }
708 }
709
710 Algorithm::BMSSP => {
712 #[cfg(feature = "bmssp")]
713 {
714 self.solve_jacobi_fallback(Algorithm::BMSSP, matrix, rhs, budget)
715 }
716 #[cfg(not(feature = "bmssp"))]
717 {
718 Err(SolverError::BackendError(
719 "bmssp feature is not enabled".into(),
720 ))
721 }
722 }
723
724 Algorithm::Dense => self.solve_dense(matrix, rhs, budget),
726
727 Algorithm::Jacobi => self.solve_jacobi_fallback(Algorithm::Jacobi, matrix, rhs, budget),
729 Algorithm::GaussSeidel => {
730 self.solve_jacobi_fallback(Algorithm::GaussSeidel, matrix, rhs, budget)
731 }
732 }
733 }
734
735 #[allow(dead_code)]
740 fn solve_cg_inline(
741 &self,
742 matrix: &CsrMatrix<f64>,
743 rhs: &[f64],
744 budget: &ComputeBudget,
745 ) -> Result<SolverResult, SolverError> {
746 let n = matrix.rows;
747 validate_square(matrix)?;
748 validate_rhs_len(matrix, rhs)?;
749
750 let max_iters = budget.max_iterations;
751 let tol = budget.tolerance;
752 let start = Instant::now();
753
754 let mut x = vec![0.0_f64; n];
755 let mut r: Vec<f64> = rhs.to_vec();
756 let mut p = r.clone();
757 let mut ap = vec![0.0_f64; n];
758 let mut convergence_history = Vec::new();
759
760 let mut r_dot_r = dot(&r, &r);
761
762 for iter in 0..max_iters {
763 let residual_norm = r_dot_r.sqrt();
764
765 convergence_history.push(ConvergenceInfo {
766 iteration: iter,
767 residual_norm,
768 });
769
770 if residual_norm.is_nan() || residual_norm.is_infinite() {
771 return Err(SolverError::NumericalInstability {
772 iteration: iter,
773 detail: format!("CG residual became {}", residual_norm),
774 });
775 }
776
777 if residual_norm < tol {
778 return Ok(SolverResult {
779 solution: x.iter().map(|&v| v as f32).collect(),
780 iterations: iter,
781 residual_norm,
782 wall_time: start.elapsed(),
783 convergence_history,
784 algorithm: Algorithm::CG,
785 });
786 }
787
788 matrix.spmv(&p, &mut ap);
790
791 let p_dot_ap = dot(&p, &ap);
792 if p_dot_ap.abs() < 1e-30 {
793 return Err(SolverError::NumericalInstability {
794 iteration: iter,
795 detail: "CG: p^T A p near zero (matrix may not be SPD)".into(),
796 });
797 }
798
799 let alpha = r_dot_r / p_dot_ap;
800
801 for i in 0..n {
802 x[i] += alpha * p[i];
803 r[i] -= alpha * ap[i];
804 }
805
806 let new_r_dot_r = dot(&r, &r);
807 let beta = new_r_dot_r / r_dot_r;
808
809 for i in 0..n {
810 p[i] = r[i] + beta * p[i];
811 }
812
813 r_dot_r = new_r_dot_r;
814
815 if start.elapsed() > budget.max_time {
816 return Err(SolverError::BudgetExhausted {
817 reason: "wall-clock time limit exceeded".into(),
818 elapsed: start.elapsed(),
819 });
820 }
821 }
822
823 let final_residual = convergence_history
824 .last()
825 .map(|c| c.residual_norm)
826 .unwrap_or(f64::INFINITY);
827
828 Err(SolverError::NonConvergence {
829 iterations: max_iters,
830 residual: final_residual,
831 tolerance: tol,
832 })
833 }
834
835 fn solve_dense(
839 &self,
840 matrix: &CsrMatrix<f64>,
841 rhs: &[f64],
842 _budget: &ComputeBudget,
843 ) -> Result<SolverResult, SolverError> {
844 let n = matrix.rows;
845 validate_square(matrix)?;
846 validate_rhs_len(matrix, rhs)?;
847
848 const MAX_DENSE_DIM: usize = 4096;
849 if n > MAX_DENSE_DIM {
850 return Err(SolverError::InvalidInput(
851 crate::error::ValidationError::MatrixTooLarge {
852 rows: n,
853 cols: n,
854 max_dim: MAX_DENSE_DIM,
855 },
856 ));
857 }
858
859 let start = Instant::now();
860
861 let stride = n + 1;
863 let mut aug = vec![0.0_f64; n * stride];
864 for row in 0..n {
865 let rs = matrix.row_ptr[row];
866 let re = matrix.row_ptr[row + 1];
867 for idx in rs..re {
868 let col = matrix.col_indices[idx];
869 aug[row * stride + col] = matrix.values[idx];
870 }
871 aug[row * stride + n] = rhs[row];
872 }
873
874 for col in 0..n {
876 let mut max_val = aug[col * stride + col].abs();
877 let mut max_row = col;
878 for row in (col + 1)..n {
879 let val = aug[row * stride + col].abs();
880 if val > max_val {
881 max_val = val;
882 max_row = row;
883 }
884 }
885
886 if max_val < 1e-12 {
887 return Err(SolverError::NumericalInstability {
888 iteration: 0,
889 detail: format!(
890 "dense solver: near-zero pivot ({:.2e}) at column {}",
891 max_val, col
892 ),
893 });
894 }
895
896 if max_row != col {
897 for j in 0..stride {
898 aug.swap(col * stride + j, max_row * stride + j);
899 }
900 }
901
902 let pivot = aug[col * stride + col];
903 for row in (col + 1)..n {
904 let factor = aug[row * stride + col] / pivot;
905 aug[row * stride + col] = 0.0;
906 for j in (col + 1)..stride {
907 let above = aug[col * stride + j];
908 aug[row * stride + j] -= factor * above;
909 }
910 }
911 }
912
913 let mut solution_f64 = vec![0.0_f64; n];
915 for row in (0..n).rev() {
916 let mut sum = aug[row * stride + n];
917 for col in (row + 1)..n {
918 sum -= aug[row * stride + col] * solution_f64[col];
919 }
920 solution_f64[row] = sum / aug[row * stride + row];
921 }
922
923 let mut ax = vec![0.0_f64; n];
925 matrix.spmv(&solution_f64, &mut ax);
926 let residual_norm: f64 = (0..n)
927 .map(|i| {
928 let r = rhs[i] - ax[i];
929 r * r
930 })
931 .sum::<f64>()
932 .sqrt();
933
934 let solution: Vec<f32> = solution_f64.iter().map(|&v| v as f32).collect();
935
936 Ok(SolverResult {
937 solution,
938 iterations: 1,
939 residual_norm,
940 wall_time: start.elapsed(),
941 convergence_history: vec![ConvergenceInfo {
942 iteration: 0,
943 residual_norm,
944 }],
945 algorithm: Algorithm::Dense,
946 })
947 }
948
949 fn solve_jacobi_fallback(
955 &self,
956 algorithm: Algorithm,
957 matrix: &CsrMatrix<f64>,
958 rhs: &[f64],
959 budget: &ComputeBudget,
960 ) -> Result<SolverResult, SolverError> {
961 let n = matrix.rows;
962 validate_square(matrix)?;
963 validate_rhs_len(matrix, rhs)?;
964
965 let max_iters = budget.max_iterations;
966 let tol = budget.tolerance;
967 let start = Instant::now();
968
969 let mut diag = vec![0.0_f64; n];
971 for row in 0..n {
972 let rs = matrix.row_ptr[row];
973 let re = matrix.row_ptr[row + 1];
974 for idx in rs..re {
975 if matrix.col_indices[idx] == row {
976 diag[row] = matrix.values[idx];
977 break;
978 }
979 }
980 }
981
982 for (i, &d) in diag.iter().enumerate() {
983 if d.abs() < 1e-30 {
984 return Err(SolverError::NumericalInstability {
985 iteration: 0,
986 detail: format!("zero or near-zero diagonal at row {} (val={:.2e})", i, d),
987 });
988 }
989 }
990
991 let mut x = vec![0.0_f64; n];
992 let mut x_new = vec![0.0_f64; n];
993 let mut temp = vec![0.0_f64; n];
994 let mut convergence_history = Vec::new();
995
996 for iter in 0..max_iters {
997 for row in 0..n {
998 let rs = matrix.row_ptr[row];
999 let re = matrix.row_ptr[row + 1];
1000 let mut sum = 0.0_f64;
1001 for idx in rs..re {
1002 let col = matrix.col_indices[idx];
1003 if col != row {
1004 sum += matrix.values[idx] * x[col];
1005 }
1006 }
1007 x_new[row] = (rhs[row] - sum) / diag[row];
1008 }
1009
1010 matrix.spmv(&x_new, &mut temp);
1011 let residual_norm: f64 = (0..n)
1012 .map(|i| {
1013 let r = rhs[i] - temp[i];
1014 r * r
1015 })
1016 .sum::<f64>()
1017 .sqrt();
1018
1019 convergence_history.push(ConvergenceInfo {
1020 iteration: iter,
1021 residual_norm,
1022 });
1023
1024 if residual_norm.is_nan() || residual_norm.is_infinite() {
1025 return Err(SolverError::NumericalInstability {
1026 iteration: iter,
1027 detail: format!("residual became {}", residual_norm),
1028 });
1029 }
1030
1031 if residual_norm < tol {
1032 return Ok(SolverResult {
1033 solution: x_new.iter().map(|&v| v as f32).collect(),
1034 iterations: iter + 1,
1035 residual_norm,
1036 wall_time: start.elapsed(),
1037 convergence_history,
1038 algorithm,
1039 });
1040 }
1041
1042 std::mem::swap(&mut x, &mut x_new);
1043
1044 if start.elapsed() > budget.max_time {
1045 return Err(SolverError::BudgetExhausted {
1046 reason: "wall-clock time limit exceeded".into(),
1047 elapsed: start.elapsed(),
1048 });
1049 }
1050 }
1051
1052 let final_residual = convergence_history
1053 .last()
1054 .map(|c| c.residual_norm)
1055 .unwrap_or(f64::INFINITY);
1056
1057 Err(SolverError::NonConvergence {
1058 iterations: max_iters,
1059 residual: final_residual,
1060 tolerance: tol,
1061 })
1062 }
1063}
1064
1065impl Default for SolverOrchestrator {
1066 fn default() -> Self {
1067 Self::new(RouterConfig::default())
1068 }
1069}
1070
1071#[inline]
1077#[allow(dead_code)]
1078fn dot(a: &[f64], b: &[f64]) -> f64 {
1079 assert_eq!(
1080 a.len(),
1081 b.len(),
1082 "dot: length mismatch {} vs {}",
1083 a.len(),
1084 b.len()
1085 );
1086 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
1087}
1088
1089fn validate_square(matrix: &CsrMatrix<f64>) -> Result<(), SolverError> {
1091 if matrix.rows != matrix.cols {
1092 return Err(SolverError::InvalidInput(
1093 crate::error::ValidationError::DimensionMismatch(format!(
1094 "matrix must be square, got {}x{}",
1095 matrix.rows, matrix.cols
1096 )),
1097 ));
1098 }
1099 Ok(())
1100}
1101
1102fn validate_rhs_len(matrix: &CsrMatrix<f64>, rhs: &[f64]) -> Result<(), SolverError> {
1104 if rhs.len() != matrix.rows {
1105 return Err(SolverError::InvalidInput(
1106 crate::error::ValidationError::DimensionMismatch(format!(
1107 "rhs length {} does not match matrix dimension {}",
1108 rhs.len(),
1109 matrix.rows
1110 )),
1111 ));
1112 }
1113 Ok(())
1114}
1115
1116#[cfg(test)]
1121mod tests {
1122 use super::*;
1123
1124 fn diag_dominant_3x3() -> CsrMatrix<f64> {
1126 CsrMatrix::<f64>::from_coo(
1127 3,
1128 3,
1129 vec![
1130 (0, 0, 4.0),
1131 (0, 1, -1.0),
1132 (1, 0, -1.0),
1133 (1, 1, 4.0),
1134 (1, 2, -1.0),
1135 (2, 1, -1.0),
1136 (2, 2, 4.0),
1137 ],
1138 )
1139 }
1140
1141 fn default_budget() -> ComputeBudget {
1142 ComputeBudget {
1143 tolerance: 1e-8,
1144 ..Default::default()
1145 }
1146 }
1147
1148 #[test]
1153 fn routes_diag_dominant_sparse_to_neumann() {
1154 let router = SolverRouter::new(RouterConfig::default());
1155 let profile = SparsityProfile {
1156 rows: 1000,
1157 cols: 1000,
1158 nnz: 3000,
1159 density: 0.003,
1160 is_diag_dominant: true,
1161 estimated_spectral_radius: 0.5,
1162 estimated_condition: 10.0,
1163 is_symmetric_structure: true,
1164 avg_nnz_per_row: 3.0,
1165 max_nnz_per_row: 5,
1166 };
1167
1168 assert_eq!(
1169 router.select_algorithm(&profile, &QueryType::LinearSystem),
1170 Algorithm::Neumann
1171 );
1172 }
1173
1174 #[test]
1175 fn routes_well_conditioned_non_diag_dominant_to_cg() {
1176 let router = SolverRouter::new(RouterConfig::default());
1177 let profile = SparsityProfile {
1178 rows: 1000,
1179 cols: 1000,
1180 nnz: 50_000,
1181 density: 0.05,
1182 is_diag_dominant: false,
1183 estimated_spectral_radius: 0.9,
1184 estimated_condition: 50.0,
1185 is_symmetric_structure: true,
1186 avg_nnz_per_row: 50.0,
1187 max_nnz_per_row: 80,
1188 };
1189
1190 assert_eq!(
1191 router.select_algorithm(&profile, &QueryType::LinearSystem),
1192 Algorithm::CG
1193 );
1194 }
1195
1196 #[test]
1197 fn routes_ill_conditioned_to_bmssp() {
1198 let router = SolverRouter::new(RouterConfig::default());
1199 let profile = SparsityProfile {
1200 rows: 1000,
1201 cols: 1000,
1202 nnz: 50_000,
1203 density: 0.05,
1204 is_diag_dominant: false,
1205 estimated_spectral_radius: 0.99,
1206 estimated_condition: 500.0,
1207 is_symmetric_structure: true,
1208 avg_nnz_per_row: 50.0,
1209 max_nnz_per_row: 80,
1210 };
1211
1212 assert_eq!(
1213 router.select_algorithm(&profile, &QueryType::LinearSystem),
1214 Algorithm::BMSSP
1215 );
1216 }
1217
1218 #[test]
1219 fn routes_single_pagerank_to_forward_push() {
1220 let router = SolverRouter::new(RouterConfig::default());
1221 let profile = SparsityProfile {
1222 rows: 5000,
1223 cols: 5000,
1224 nnz: 20_000,
1225 density: 0.0008,
1226 is_diag_dominant: false,
1227 estimated_spectral_radius: 0.85,
1228 estimated_condition: 100.0,
1229 is_symmetric_structure: false,
1230 avg_nnz_per_row: 4.0,
1231 max_nnz_per_row: 50,
1232 };
1233
1234 assert_eq!(
1235 router.select_algorithm(&profile, &QueryType::PageRankSingle { source: 0 }),
1236 Algorithm::ForwardPush
1237 );
1238 }
1239
1240 #[test]
1241 fn routes_large_pairwise_to_hybrid_random_walk() {
1242 let router = SolverRouter::new(RouterConfig::default());
1243 let profile = SparsityProfile {
1244 rows: 5000,
1245 cols: 5000,
1246 nnz: 20_000,
1247 density: 0.0008,
1248 is_diag_dominant: false,
1249 estimated_spectral_radius: 0.85,
1250 estimated_condition: 100.0,
1251 is_symmetric_structure: false,
1252 avg_nnz_per_row: 4.0,
1253 max_nnz_per_row: 50,
1254 };
1255
1256 assert_eq!(
1257 router.select_algorithm(
1258 &profile,
1259 &QueryType::PageRankPairwise {
1260 source: 0,
1261 target: 100,
1262 }
1263 ),
1264 Algorithm::HybridRandomWalk
1265 );
1266 }
1267
1268 #[test]
1269 fn routes_small_pairwise_to_forward_push() {
1270 let router = SolverRouter::new(RouterConfig::default());
1271 let profile = SparsityProfile {
1272 rows: 500,
1273 cols: 500,
1274 nnz: 2000,
1275 density: 0.008,
1276 is_diag_dominant: false,
1277 estimated_spectral_radius: 0.85,
1278 estimated_condition: 100.0,
1279 is_symmetric_structure: false,
1280 avg_nnz_per_row: 4.0,
1281 max_nnz_per_row: 10,
1282 };
1283
1284 assert_eq!(
1285 router.select_algorithm(
1286 &profile,
1287 &QueryType::PageRankPairwise {
1288 source: 0,
1289 target: 10,
1290 }
1291 ),
1292 Algorithm::ForwardPush
1293 );
1294 }
1295
1296 #[test]
1297 fn routes_spectral_filter_to_neumann() {
1298 let router = SolverRouter::new(RouterConfig::default());
1299 let profile = SparsityProfile {
1300 rows: 100,
1301 cols: 100,
1302 nnz: 500,
1303 density: 0.05,
1304 is_diag_dominant: true,
1305 estimated_spectral_radius: 0.3,
1306 estimated_condition: 5.0,
1307 is_symmetric_structure: true,
1308 avg_nnz_per_row: 5.0,
1309 max_nnz_per_row: 8,
1310 };
1311
1312 assert_eq!(
1313 router.select_algorithm(
1314 &profile,
1315 &QueryType::SpectralFilter {
1316 polynomial_degree: 10,
1317 }
1318 ),
1319 Algorithm::Neumann
1320 );
1321 }
1322
1323 #[test]
1324 fn routes_large_batch_to_true() {
1325 let router = SolverRouter::new(RouterConfig::default());
1326 let profile = SparsityProfile {
1327 rows: 1000,
1328 cols: 1000,
1329 nnz: 5000,
1330 density: 0.005,
1331 is_diag_dominant: true,
1332 estimated_spectral_radius: 0.5,
1333 estimated_condition: 10.0,
1334 is_symmetric_structure: true,
1335 avg_nnz_per_row: 5.0,
1336 max_nnz_per_row: 10,
1337 };
1338
1339 assert_eq!(
1340 router.select_algorithm(&profile, &QueryType::BatchLinearSystem { batch_size: 200 }),
1341 Algorithm::TRUE
1342 );
1343 }
1344
1345 #[test]
1346 fn routes_small_batch_to_cg() {
1347 let router = SolverRouter::new(RouterConfig::default());
1348 let profile = SparsityProfile {
1349 rows: 1000,
1350 cols: 1000,
1351 nnz: 5000,
1352 density: 0.005,
1353 is_diag_dominant: true,
1354 estimated_spectral_radius: 0.5,
1355 estimated_condition: 10.0,
1356 is_symmetric_structure: true,
1357 avg_nnz_per_row: 5.0,
1358 max_nnz_per_row: 10,
1359 };
1360
1361 assert_eq!(
1362 router.select_algorithm(&profile, &QueryType::BatchLinearSystem { batch_size: 50 }),
1363 Algorithm::CG
1364 );
1365 }
1366
1367 #[test]
1368 fn custom_config_overrides_thresholds() {
1369 let config = RouterConfig {
1370 cg_condition_threshold: 10.0,
1371 ..Default::default()
1372 };
1373 let router = SolverRouter::new(config);
1374
1375 let profile = SparsityProfile {
1376 rows: 1000,
1377 cols: 1000,
1378 nnz: 50_000,
1379 density: 0.05,
1380 is_diag_dominant: false,
1381 estimated_spectral_radius: 0.9,
1382 estimated_condition: 50.0,
1383 is_symmetric_structure: true,
1384 avg_nnz_per_row: 50.0,
1385 max_nnz_per_row: 80,
1386 };
1387
1388 assert_eq!(
1389 router.select_algorithm(&profile, &QueryType::LinearSystem),
1390 Algorithm::BMSSP
1391 );
1392 }
1393
1394 #[test]
1395 fn neumann_requires_low_spectral_radius() {
1396 let router = SolverRouter::new(RouterConfig::default());
1397 let profile = SparsityProfile {
1398 rows: 1000,
1399 cols: 1000,
1400 nnz: 3000,
1401 density: 0.003,
1402 is_diag_dominant: true,
1403 estimated_spectral_radius: 0.96, estimated_condition: 10.0,
1405 is_symmetric_structure: true,
1406 avg_nnz_per_row: 3.0,
1407 max_nnz_per_row: 5,
1408 };
1409
1410 assert_eq!(
1412 router.select_algorithm(&profile, &QueryType::LinearSystem),
1413 Algorithm::CG
1414 );
1415 }
1416
1417 #[test]
1422 fn analyze_identity_matrix() {
1423 let matrix = CsrMatrix::<f64>::identity(5);
1424 let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1425
1426 assert_eq!(profile.rows, 5);
1427 assert_eq!(profile.cols, 5);
1428 assert_eq!(profile.nnz, 5);
1429 assert!(profile.is_diag_dominant);
1430 assert!((profile.density - 0.2).abs() < 1e-10);
1431 assert!(profile.estimated_spectral_radius.abs() < 1e-10);
1432 assert!((profile.estimated_condition - 1.0).abs() < 1e-10);
1433 assert!(profile.is_symmetric_structure);
1434 assert_eq!(profile.max_nnz_per_row, 1);
1435 }
1436
1437 #[test]
1438 fn analyze_diag_dominant() {
1439 let matrix = diag_dominant_3x3();
1440 let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1441
1442 assert!(profile.is_diag_dominant);
1443 assert!(profile.estimated_spectral_radius < 1.0);
1444 assert!(profile.is_symmetric_structure);
1445 }
1446
1447 #[test]
1448 fn analyze_empty_matrix() {
1449 let matrix = CsrMatrix::<f64> {
1450 row_ptr: vec![0],
1451 col_indices: vec![],
1452 values: vec![],
1453 rows: 0,
1454 cols: 0,
1455 };
1456 let profile = SolverOrchestrator::analyze_sparsity(&matrix);
1457
1458 assert_eq!(profile.rows, 0);
1459 assert_eq!(profile.nnz, 0);
1460 assert_eq!(profile.density, 0.0);
1461 }
1462
1463 #[test]
1468 fn orchestrator_solve_identity() {
1469 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1470 let matrix = CsrMatrix::<f64>::identity(4);
1471 let rhs = vec![1.0_f64, 2.0, 3.0, 4.0];
1472 let budget = default_budget();
1473
1474 let result = orchestrator
1475 .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
1476 .unwrap();
1477
1478 for (x, b) in result.solution.iter().zip(rhs.iter()) {
1479 assert!((*x as f64 - b).abs() < 1e-4, "expected {}, got {}", b, x);
1480 }
1481 }
1482
1483 #[test]
1484 fn orchestrator_solve_diag_dominant() {
1485 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1486 let matrix = diag_dominant_3x3();
1487 let rhs = vec![1.0_f64, 0.0, 1.0];
1488 let budget = default_budget();
1489
1490 let result = orchestrator
1491 .solve(&matrix, &rhs, QueryType::LinearSystem, &budget)
1492 .unwrap();
1493
1494 assert!(result.residual_norm < 1e-6);
1495 }
1496
1497 #[test]
1498 fn orchestrator_solve_with_fallback_succeeds() {
1499 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1500 let matrix = diag_dominant_3x3();
1501 let rhs = vec![1.0_f64, 0.0, 1.0];
1502 let budget = default_budget();
1503
1504 let result = orchestrator
1505 .solve_with_fallback(&matrix, &rhs, QueryType::LinearSystem, &budget)
1506 .unwrap();
1507
1508 assert!(result.residual_norm < 1e-6);
1509 }
1510
1511 #[test]
1512 fn orchestrator_dimension_mismatch() {
1513 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1514 let matrix = CsrMatrix::<f64>::identity(3);
1515 let rhs = vec![1.0_f64, 2.0]; let budget = default_budget();
1517
1518 let result = orchestrator.solve(&matrix, &rhs, QueryType::LinearSystem, &budget);
1519 assert!(result.is_err());
1520 }
1521
1522 #[test]
1523 fn estimate_complexity_returns_reasonable_values() {
1524 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1525 let matrix = diag_dominant_3x3();
1526
1527 let estimate = orchestrator.estimate_complexity(&matrix, &QueryType::LinearSystem);
1528
1529 assert!(estimate.estimated_flops > 0);
1530 assert!(estimate.estimated_memory_bytes > 0);
1531 assert!(estimate.estimated_iterations > 0);
1532 }
1533
1534 #[test]
1535 fn fallback_chain_deduplicates() {
1536 let chain = SolverOrchestrator::build_fallback_chain(Algorithm::CG);
1537 assert_eq!(chain, vec![Algorithm::CG, Algorithm::Dense]);
1538
1539 let chain = SolverOrchestrator::build_fallback_chain(Algorithm::Dense);
1540 assert_eq!(chain, vec![Algorithm::Dense, Algorithm::CG]);
1541
1542 let chain = SolverOrchestrator::build_fallback_chain(Algorithm::Neumann);
1543 assert_eq!(
1544 chain,
1545 vec![Algorithm::Neumann, Algorithm::CG, Algorithm::Dense]
1546 );
1547 }
1548
1549 #[test]
1550 fn cg_inline_solves_spd_system() {
1551 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1552 let matrix = diag_dominant_3x3();
1553 let rhs = vec![1.0_f64, 2.0, 3.0];
1554 let budget = default_budget();
1555
1556 let result = orchestrator
1557 .solve_cg_inline(&matrix, &rhs, &budget)
1558 .unwrap();
1559
1560 assert!(result.residual_norm < 1e-6);
1561 assert_eq!(result.algorithm, Algorithm::CG);
1562 }
1563
1564 #[test]
1565 fn dense_solves_small_system() {
1566 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1567 let matrix = diag_dominant_3x3();
1568 let rhs = vec![1.0_f64, 2.0, 3.0];
1569 let budget = default_budget();
1570
1571 let result = orchestrator.solve_dense(&matrix, &rhs, &budget).unwrap();
1572
1573 assert!(result.residual_norm < 1e-4);
1574 assert_eq!(result.algorithm, Algorithm::Dense);
1575 }
1576
1577 #[test]
1578 fn dense_rejects_non_square() {
1579 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1580 let matrix = CsrMatrix::<f64> {
1581 row_ptr: vec![0, 1, 2],
1582 col_indices: vec![0, 1],
1583 values: vec![1.0, 1.0],
1584 rows: 2,
1585 cols: 3,
1586 };
1587 let rhs = vec![1.0_f64, 1.0];
1588 let budget = default_budget();
1589
1590 assert!(orchestrator.solve_dense(&matrix, &rhs, &budget).is_err());
1591 }
1592
1593 #[test]
1594 fn cg_and_dense_agree_on_solution() {
1595 let orchestrator = SolverOrchestrator::new(RouterConfig::default());
1596 let matrix = diag_dominant_3x3();
1597 let rhs = vec![3.0_f64, -1.0, 2.0];
1598 let budget = default_budget();
1599
1600 let cg_result = orchestrator
1601 .solve_cg_inline(&matrix, &rhs, &budget)
1602 .unwrap();
1603 let dense_result = orchestrator.solve_dense(&matrix, &rhs, &budget).unwrap();
1604
1605 for (cg_x, dense_x) in cg_result.solution.iter().zip(dense_result.solution.iter()) {
1606 assert!(
1607 (cg_x - dense_x).abs() < 1e-3,
1608 "CG={} vs Dense={}",
1609 cg_x,
1610 dense_x
1611 );
1612 }
1613 }
1614}