1use std::collections::VecDeque;
18
19use crate::error::SolverError;
20use crate::traits::{SolverEngine, SublinearPageRank};
21use crate::types::{
22 Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, CsrMatrix, SolverResult,
23 SparsityProfile,
24};
25
26#[derive(Debug, Clone)]
47pub struct ForwardPushSolver {
48 pub alpha: f64,
50 pub epsilon: f64,
52}
53
54impl ForwardPushSolver {
55 pub fn new(alpha: f64, epsilon: f64) -> Self {
60 Self { alpha, epsilon }
61 }
62
63 fn validate_params(&self) -> Result<(), SolverError> {
70 if self.alpha <= 0.0 || self.alpha >= 1.0 {
71 return Err(SolverError::InvalidInput(
72 crate::error::ValidationError::ParameterOutOfRange {
73 name: "alpha".into(),
74 value: self.alpha.to_string(),
75 expected: "(0.0, 1.0) exclusive".into(),
76 },
77 ));
78 }
79 if self.epsilon <= 0.0 {
80 return Err(SolverError::InvalidInput(
81 crate::error::ValidationError::ParameterOutOfRange {
82 name: "epsilon".into(),
83 value: self.epsilon.to_string(),
84 expected: "> 0.0".into(),
85 },
86 ));
87 }
88 Ok(())
89 }
90
91 pub fn default_params() -> Self {
94 Self {
95 alpha: 0.85,
96 epsilon: 1e-6,
97 }
98 }
99
100 pub fn ppr_from_source(
109 &self,
110 graph: &CsrMatrix<f64>,
111 source: usize,
112 ) -> Result<Vec<(usize, f64)>, SolverError> {
113 self.validate_params()?;
114 validate_vertex(graph, source, "source")?;
115 self.forward_push_core(graph, &[(source, 1.0)])
116 }
117
118 pub fn top_k(
122 &self,
123 graph: &CsrMatrix<f64>,
124 source: usize,
125 k: usize,
126 ) -> Result<Vec<(usize, f64)>, SolverError> {
127 let mut result = self.ppr_from_source(graph, source)?;
128 result.truncate(k);
129 Ok(result)
130 }
131
132 const MAX_GRAPH_NODES: usize = 100_000_000;
143
144 fn forward_push_core(
145 &self,
146 graph: &CsrMatrix<f64>,
147 seeds: &[(usize, f64)],
148 ) -> Result<Vec<(usize, f64)>, SolverError> {
149 self.validate_params()?;
150
151 let n = graph.rows;
152 if n > Self::MAX_GRAPH_NODES {
153 return Err(SolverError::InvalidInput(
154 crate::error::ValidationError::MatrixTooLarge {
155 rows: n,
156 cols: graph.cols,
157 max_dim: Self::MAX_GRAPH_NODES,
158 },
159 ));
160 }
161
162 let mut estimate = vec![0.0f64; n];
163 let mut residual = vec![0.0f64; n];
164
165 let mut in_queue = vec![false; n];
167 let mut queue: VecDeque<usize> = VecDeque::new();
168
169 for &(v, mass) in seeds {
171 residual[v] += mass;
172 if !in_queue[v] && should_push(residual[v], graph.row_degree(v), self.epsilon) {
173 queue.push_back(v);
174 in_queue[v] = true;
175 }
176 }
177
178 while let Some(u) = queue.pop_front() {
180 in_queue[u] = false;
181
182 let r_u = residual[u];
183
184 if !should_push(r_u, graph.row_degree(u), self.epsilon) {
186 continue;
187 }
188
189 estimate[u] += self.alpha * r_u;
191
192 let degree = graph.row_degree(u);
193 if degree > 0 {
194 let push_amount = (1.0 - self.alpha) * r_u / degree as f64;
195
196 residual[u] = 0.0;
201
202 for (v, _weight) in graph.row_entries(u) {
203 residual[v] += push_amount;
204
205 if !in_queue[v] && should_push(residual[v], graph.row_degree(v), self.epsilon) {
206 queue.push_back(v);
207 in_queue[v] = true;
208 }
209 }
210 } else {
211 let leftover = (1.0 - self.alpha) * r_u;
218 residual[u] = leftover;
219
220 if !in_queue[u] && should_push(leftover, 0, self.epsilon) {
221 queue.push_back(u);
222 in_queue[u] = true;
223 }
224 }
225 }
226
227 let total_seed_mass: f64 = seeds.iter().map(|(_, m)| *m).sum();
230 check_mass_invariant(&estimate, &residual, total_seed_mass)?;
231
232 let mut result: Vec<(usize, f64)> = estimate
234 .iter()
235 .enumerate()
236 .filter(|(_, val)| **val > 0.0)
237 .map(|(i, val)| (i, *val))
238 .collect();
239
240 result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
241
242 Ok(result)
243 }
244}
245
246pub fn forward_push_with_residuals(
251 matrix: &CsrMatrix<f64>,
252 source: usize,
253 alpha: f64,
254 epsilon: f64,
255) -> Result<(Vec<f64>, Vec<f64>), SolverError> {
256 validate_vertex(matrix, source, "source")?;
257
258 let n = matrix.rows;
259 let mut estimate = vec![0.0f64; n];
260 let mut residual = vec![0.0f64; n];
261
262 residual[source] = 1.0;
263
264 let mut in_queue = vec![false; n];
265 let mut queue: VecDeque<usize> = VecDeque::new();
266
267 if should_push(1.0, matrix.row_degree(source), epsilon) {
268 queue.push_back(source);
269 in_queue[source] = true;
270 }
271
272 while let Some(u) = queue.pop_front() {
273 in_queue[u] = false;
274 let r_u = residual[u];
275
276 if !should_push(r_u, matrix.row_degree(u), epsilon) {
277 continue;
278 }
279
280 estimate[u] += alpha * r_u;
281
282 let degree = matrix.row_degree(u);
283 if degree > 0 {
284 let push_amount = (1.0 - alpha) * r_u / degree as f64;
285 residual[u] = 0.0;
287 for (v, _) in matrix.row_entries(u) {
288 residual[v] += push_amount;
289 if !in_queue[v] && should_push(residual[v], matrix.row_degree(v), epsilon) {
290 queue.push_back(v);
291 in_queue[v] = true;
292 }
293 }
294 } else {
295 let leftover = (1.0 - alpha) * r_u;
297 residual[u] = leftover;
298 if !in_queue[u] && should_push(leftover, 0, epsilon) {
299 queue.push_back(u);
300 in_queue[u] = true;
301 }
302 }
303 }
304
305 Ok((estimate, residual))
306}
307
308#[inline]
317fn should_push(residual: f64, degree: usize, epsilon: f64) -> bool {
318 if degree == 0 {
319 residual > epsilon
320 } else {
321 residual > epsilon * degree as f64
322 }
323}
324
325fn validate_vertex(graph: &CsrMatrix<f64>, vertex: usize, name: &str) -> Result<(), SolverError> {
327 if vertex >= graph.rows {
328 return Err(SolverError::InvalidInput(
329 crate::error::ValidationError::ParameterOutOfRange {
330 name: name.into(),
331 value: vertex.to_string(),
332 expected: format!("0..{}", graph.rows),
333 },
334 ));
335 }
336 Ok(())
337}
338
339fn check_mass_invariant(
341 estimate: &[f64],
342 residual: &[f64],
343 expected_mass: f64,
344) -> Result<(), SolverError> {
345 let mass: f64 = estimate.iter().sum::<f64>() + residual.iter().sum::<f64>();
346 if (mass - expected_mass).abs() > 1e-6 {
347 return Err(SolverError::NumericalInstability {
348 iteration: 0,
349 detail: format!(
350 "mass invariant violated: sum(estimate)+sum(residual) = {mass:.10}, \
351 expected {expected_mass:.10}",
352 ),
353 });
354 }
355 Ok(())
356}
357
358impl SolverEngine for ForwardPushSolver {
363 fn solve(
370 &self,
371 matrix: &CsrMatrix<f64>,
372 rhs: &[f64],
373 _budget: &ComputeBudget,
374 ) -> Result<SolverResult, SolverError> {
375 let start = std::time::Instant::now();
376
377 let source = rhs.iter().position(|&v| v != 0.0).unwrap_or(0);
378 let sparse_result = self.ppr_from_source(matrix, source)?;
379
380 let n = matrix.rows;
381 let mut solution = vec![0.0f32; n];
382 for &(idx, score) in &sparse_result {
383 solution[idx] = score as f32;
384 }
385
386 Ok(SolverResult {
387 solution,
388 iterations: sparse_result.len(),
389 residual_norm: 0.0,
390 wall_time: start.elapsed(),
391 convergence_history: Vec::new(),
392 algorithm: Algorithm::ForwardPush,
393 })
394 }
395
396 fn estimate_complexity(&self, _profile: &SparsityProfile, _n: usize) -> ComplexityEstimate {
397 let est_ops = (1.0 / self.epsilon).min(usize::MAX as f64) as usize;
398 ComplexityEstimate {
399 algorithm: Algorithm::ForwardPush,
400 estimated_flops: est_ops as u64 * 10,
401 estimated_iterations: est_ops,
402 estimated_memory_bytes: est_ops * 16,
403 complexity_class: ComplexityClass::SublinearNnz,
404 }
405 }
406
407 fn algorithm(&self) -> Algorithm {
408 Algorithm::ForwardPush
409 }
410}
411
412impl SublinearPageRank for ForwardPushSolver {
417 fn ppr(
418 &self,
419 matrix: &CsrMatrix<f64>,
420 source: usize,
421 alpha: f64,
422 epsilon: f64,
423 ) -> Result<Vec<(usize, f64)>, SolverError> {
424 let solver = ForwardPushSolver::new(alpha, epsilon);
425 solver.ppr_from_source(matrix, source)
426 }
427
428 fn ppr_multi_seed(
429 &self,
430 matrix: &CsrMatrix<f64>,
431 seeds: &[(usize, f64)],
432 alpha: f64,
433 epsilon: f64,
434 ) -> Result<Vec<(usize, f64)>, SolverError> {
435 for &(v, _) in seeds {
436 validate_vertex(matrix, v, "seed vertex")?;
437 }
438 let solver = ForwardPushSolver::new(alpha, epsilon);
439 solver.forward_push_core(matrix, seeds)
440 }
441}
442
443#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[derive(Debug, Clone, Copy)]
453 struct KahanAccumulator {
454 sum: f64,
455 compensation: f64,
456 }
457
458 impl KahanAccumulator {
459 #[inline]
460 const fn new() -> Self {
461 Self {
462 sum: 0.0,
463 compensation: 0.0,
464 }
465 }
466
467 #[inline]
468 fn add(&mut self, value: f64) {
469 let y = value - self.compensation;
470 let t = self.sum + y;
471 self.compensation = (t - self.sum) - y;
472 self.sum = t;
473 }
474
475 #[inline]
476 fn value(&self) -> f64 {
477 self.sum
478 }
479 }
480
481 fn triangle_graph() -> CsrMatrix<f64> {
484 CsrMatrix::<f64>::from_coo(
485 4,
486 4,
487 vec![
488 (0, 1, 1.0f64),
489 (0, 2, 1.0f64),
490 (1, 0, 1.0f64),
491 (1, 2, 1.0f64),
492 (1, 3, 1.0f64),
493 (2, 0, 1.0f64),
494 (2, 1, 1.0f64),
495 (3, 1, 1.0f64),
496 ],
497 )
498 }
499
500 fn path_graph() -> CsrMatrix<f64> {
502 CsrMatrix::<f64>::from_coo(4, 4, vec![(0, 1, 1.0f64), (1, 2, 1.0f64), (2, 3, 1.0f64)])
503 }
504
505 fn star_graph() -> CsrMatrix<f64> {
507 let n = 6;
508 let mut entries = Vec::new();
509 for leaf in 1..n {
510 entries.push((0, leaf, 1.0f64));
511 entries.push((leaf, 0, 1.0f64));
512 }
513 CsrMatrix::<f64>::from_coo(n, n, entries)
514 }
515
516 #[test]
517 fn basic_ppr_triangle() {
518 let graph = triangle_graph();
519 let solver = ForwardPushSolver::default_params();
520 let result = solver.ppr_from_source(&graph, 0).unwrap();
521
522 assert!(!result.is_empty());
523 assert_eq!(result[0].0, 0, "source should be top-ranked");
524 assert!(result[0].1 > 0.0);
525
526 for &(_, score) in &result {
527 assert!(score > 0.0);
528 }
529
530 for w in result.windows(2) {
531 assert!(w[0].1 >= w[1].1, "results should be sorted descending");
532 }
533 }
534
535 #[test]
536 fn ppr_path_graph_monotone_decay() {
537 let graph = path_graph();
538 let solver = ForwardPushSolver::new(0.85, 1e-8);
539 let result = solver.ppr_from_source(&graph, 0).unwrap();
540
541 let mut scores = vec![0.0f64; 4];
542 for &(v, s) in &result {
543 scores[v] = s;
544 }
545 assert!(scores[0] > scores[1], "score[0] > score[1]");
546 assert!(scores[1] > scores[2], "score[1] > score[2]");
547 assert!(scores[2] > scores[3], "score[2] > score[3]");
548 }
549
550 #[test]
551 fn ppr_star_symmetry() {
552 let graph = star_graph();
553 let solver = ForwardPushSolver::new(0.85, 1e-8);
554 let result = solver.ppr_from_source(&graph, 0).unwrap();
555
556 let leaf_scores: Vec<f64> = result
557 .iter()
558 .filter(|(v, _)| *v != 0)
559 .map(|(_, s)| *s)
560 .collect();
561 assert_eq!(leaf_scores.len(), 5);
562
563 let mean = leaf_scores.iter().sum::<f64>() / leaf_scores.len() as f64;
564 for &s in &leaf_scores {
565 assert!(
566 (s - mean).abs() < 1e-6,
567 "leaf scores should be equal: got {s} vs mean {mean}",
568 );
569 }
570 }
571
572 #[test]
573 fn top_k_truncates() {
574 let graph = triangle_graph();
575 let solver = ForwardPushSolver::default_params();
576 let result = solver.top_k(&graph, 0, 2).unwrap();
577
578 assert!(result.len() <= 2);
579 assert_eq!(result[0].0, 0);
580 }
581
582 #[test]
583 fn mass_invariant_holds() {
584 let graph = triangle_graph();
585 let solver = ForwardPushSolver::default_params();
586 assert!(solver.ppr_from_source(&graph, 0).is_ok());
587 }
588
589 #[test]
590 fn invalid_source_errors() {
591 let graph = triangle_graph();
592 let solver = ForwardPushSolver::default_params();
593 assert!(solver.ppr_from_source(&graph, 100).is_err());
594 }
595
596 #[test]
597 fn isolated_vertex_receives_zero() {
598 let graph = CsrMatrix::<f64>::from_coo(
600 4,
601 4,
602 vec![
603 (0, 1, 1.0f64),
604 (1, 0, 1.0f64),
605 (1, 2, 1.0f64),
606 (2, 1, 1.0f64),
607 ],
608 );
609 let solver = ForwardPushSolver::default_params();
610 let result = solver.ppr_from_source(&graph, 0).unwrap();
611
612 let v3_score = result.iter().find(|(v, _)| *v == 3).map_or(0.0, |p| p.1);
613 assert!(
614 v3_score.abs() < 1e-10,
615 "isolated vertex should have ~zero PPR",
616 );
617 }
618
619 #[test]
620 fn isolated_source_converges_to_one() {
621 let graph = CsrMatrix::<f64>::from_coo(
625 4,
626 4,
627 vec![
628 (0, 1, 1.0f64),
629 (1, 0, 1.0f64),
630 (1, 2, 1.0f64),
631 (2, 1, 1.0f64),
632 ],
633 );
634 let solver = ForwardPushSolver::default_params();
635 let result = solver.ppr_from_source(&graph, 3).unwrap();
636
637 assert_eq!(result.len(), 1);
638 assert_eq!(result[0].0, 3);
639 assert!(
642 (result[0].1 - 1.0).abs() < 1e-4,
643 "isolated source estimate should converge near 1.0: got {}",
644 result[0].1,
645 );
646 }
647
648 #[test]
649 fn single_vertex_graph() {
650 let graph = CsrMatrix::<f64>::from_coo(1, 1, Vec::<(usize, usize, f64)>::new());
651 let solver = ForwardPushSolver::default_params();
652 let result = solver.ppr_from_source(&graph, 0).unwrap();
653
654 assert_eq!(result.len(), 1);
655 assert_eq!(result[0].0, 0);
656 assert!(
659 (result[0].1 - 1.0).abs() < 1e-4,
660 "single vertex PPR should converge near 1.0: got {}",
661 result[0].1,
662 );
663 }
664
665 #[test]
666 fn solver_engine_trait() {
667 let graph = triangle_graph();
668 let solver = ForwardPushSolver::default_params();
669
670 let mut rhs = vec![0.0f64; 4];
671 rhs[1] = 1.0;
672 let budget = ComputeBudget::default();
673
674 let result = solver.solve(&graph, &rhs, &budget).unwrap();
675 assert_eq!(result.algorithm, Algorithm::ForwardPush);
676 assert_eq!(result.solution.len(), 4);
677
678 let max_idx = result
679 .solution
680 .iter()
681 .enumerate()
682 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
683 .unwrap()
684 .0;
685 assert_eq!(max_idx, 1);
686 }
687
688 #[test]
689 fn sublinear_ppr_trait() {
690 let graph = triangle_graph();
691 let solver = ForwardPushSolver::default_params();
692 let result = solver.ppr(&graph, 0, 0.85, 1e-6).unwrap();
693
694 assert!(!result.is_empty());
695 assert_eq!(result[0].0, 0, "source should rank first via ppr trait");
696 }
697
698 #[test]
699 fn multi_seed_ppr() {
700 let graph = triangle_graph();
701 let solver = ForwardPushSolver::default_params();
702
703 let seeds = vec![(0, 0.5), (1, 0.5)];
704 let result = solver.ppr_multi_seed(&graph, &seeds, 0.85, 1e-6).unwrap();
705
706 assert!(!result.is_empty());
707 let has_0 = result.iter().any(|(v, _)| *v == 0);
708 let has_1 = result.iter().any(|(v, _)| *v == 1);
709 assert!(has_0 && has_1, "both seeds should appear in output");
710 }
711
712 #[test]
713 fn forward_push_with_residuals_mass_conservation() {
714 let graph = triangle_graph();
715 let (p, r) = forward_push_with_residuals(&graph, 0, 0.85, 1e-6).unwrap();
716
717 let total: f64 = p.iter().sum::<f64>() + r.iter().sum::<f64>();
718 assert!(
719 (total - 1.0).abs() < 1e-6,
720 "mass should be conserved: got {total}",
721 );
722 }
723
724 #[test]
725 fn kahan_accuracy() {
726 let mut acc = KahanAccumulator::new();
727 let n = 1_000_000;
728 let small = 1e-10;
729 for _ in 0..n {
730 acc.add(small);
731 }
732 let expected = n as f64 * small;
733 let relative_error = (acc.value() - expected).abs() / expected;
734 assert!(
735 relative_error < 1e-10,
736 "Kahan relative error {relative_error} should be tiny",
737 );
738 }
739
740 #[test]
741 fn self_loop_graph() {
742 let graph = CsrMatrix::<f64>::from_coo(
743 3,
744 3,
745 vec![
746 (0, 0, 1.0f64),
747 (0, 1, 1.0f64),
748 (1, 1, 1.0f64),
749 (1, 2, 1.0f64),
750 (2, 2, 1.0f64),
751 (2, 0, 1.0f64),
752 ],
753 );
754 let solver = ForwardPushSolver::default_params();
755 let result = solver.ppr_from_source(&graph, 0);
756 assert!(result.is_ok(), "self-loop graph failed: {:?}", result.err());
757 }
758
759 #[test]
760 fn complete_graph_symmetry() {
761 let n = 4;
762 let mut entries = Vec::new();
763 for i in 0..n {
764 for j in 0..n {
765 if i != j {
766 entries.push((i, j, 1.0f64));
767 }
768 }
769 }
770 let graph = CsrMatrix::<f64>::from_coo(n, n, entries);
771 let solver = ForwardPushSolver::new(0.85, 1e-8);
772 let result = solver.ppr_from_source(&graph, 0).unwrap();
773
774 assert_eq!(result[0].0, 0);
775
776 let other_scores: Vec<f64> = result
777 .iter()
778 .filter(|(v, _)| *v != 0)
779 .map(|(_, s)| *s)
780 .collect();
781 assert_eq!(other_scores.len(), 3);
782 let mean = other_scores.iter().sum::<f64>() / 3.0;
783 for &s in &other_scores {
784 assert!((s - mean).abs() < 1e-6);
785 }
786 }
787
788 #[test]
789 fn estimate_complexity_sublinear() {
790 let solver = ForwardPushSolver::new(0.85, 1e-4);
791 let profile = SparsityProfile {
792 rows: 1000,
793 cols: 1000,
794 nnz: 5000,
795 density: 0.005,
796 is_diag_dominant: false,
797 estimated_spectral_radius: 0.9,
798 estimated_condition: 10.0,
799 is_symmetric_structure: true,
800 avg_nnz_per_row: 5.0,
801 max_nnz_per_row: 10,
802 };
803 let est = solver.estimate_complexity(&profile, 1000);
804 assert_eq!(est.algorithm, Algorithm::ForwardPush);
805 assert_eq!(est.complexity_class, ComplexityClass::SublinearNnz);
806 assert!(est.estimated_iterations > 0);
807 }
808}