Skip to main content

ruvector_sparsifier/
importance.rs

1//! Edge importance scoring via random walks.
2//!
3//! Estimates effective resistance using short random walks (a practical
4//! approximation to the Johnson-Lindenstrauss projection approach). The
5//! importance score `w(e) * R_eff(e)` determines sampling probability.
6
7use rand::prelude::*;
8use serde::{Deserialize, Serialize};
9
10use crate::graph::SparseGraph;
11use crate::traits::ImportanceScorer;
12use crate::types::EdgeImportance;
13
14// ---------------------------------------------------------------------------
15// EffectiveResistanceEstimator
16// ---------------------------------------------------------------------------
17
18/// Estimates effective resistance between two vertices via random walks.
19///
20/// Uses the commute-time identity: `R_eff(u,v) = commute_time(u,v) / (2m)`
21/// where `m` is the total edge weight. The commute time is estimated by
22/// running random walks from `u` until they hit `v` and back.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct EffectiveResistanceEstimator {
25    /// Maximum walk length before giving up.
26    pub max_walk_length: usize,
27    /// Number of walks to average over.
28    pub num_walks: usize,
29}
30
31impl Default for EffectiveResistanceEstimator {
32    fn default() -> Self {
33        Self {
34            max_walk_length: 100,
35            num_walks: 10,
36        }
37    }
38}
39
40impl EffectiveResistanceEstimator {
41    /// Create a new estimator with the given parameters.
42    pub fn new(max_walk_length: usize, num_walks: usize) -> Self {
43        Self {
44            max_walk_length,
45            num_walks,
46        }
47    }
48
49    /// Estimate the effective resistance between `u` and `v` in `graph`.
50    ///
51    /// Returns a value in `[0, +inf)`. For disconnected pairs the
52    /// estimate may be very large (capped at `max_walk_length / total_weight`).
53    pub fn estimate(&self, graph: &SparseGraph, u: usize, v: usize) -> f64 {
54        if u == v {
55            return 0.0;
56        }
57        let total_w = graph.total_weight();
58        if total_w <= 0.0 {
59            return f64::MAX;
60        }
61        if graph.degree(u) == 0 || graph.degree(v) == 0 {
62            return f64::MAX;
63        }
64
65        let mut rng = rand::thread_rng();
66        let mut total_steps = 0u64;
67
68        for _ in 0..self.num_walks {
69            // Walk from u to v.
70            total_steps += self.walk_to_target(graph, u, v, &mut rng) as u64;
71            // Walk from v to u.
72            total_steps += self.walk_to_target(graph, v, u, &mut rng) as u64;
73        }
74
75        let avg_commute = total_steps as f64 / self.num_walks as f64;
76        // R_eff ~ commute_time / (2 * total_weight)
77        avg_commute / (2.0 * total_w)
78    }
79
80    /// Random walk from `start` to `target`, returning the number of steps.
81    fn walk_to_target<R: Rng>(
82        &self,
83        graph: &SparseGraph,
84        start: usize,
85        target: usize,
86        rng: &mut R,
87    ) -> usize {
88        let mut current = start;
89        for step in 1..=self.max_walk_length {
90            current = self.random_neighbor(graph, current, rng);
91            if current == target {
92                return step;
93            }
94        }
95        self.max_walk_length
96    }
97
98    /// Pick a random neighbour of `u` with probability proportional to weight.
99    fn random_neighbor<R: Rng>(&self, graph: &SparseGraph, u: usize, rng: &mut R) -> usize {
100        let w_deg = graph.weighted_degree(u);
101        if w_deg <= 0.0 {
102            return u; // isolated vertex stays put
103        }
104        let threshold = rng.gen::<f64>() * w_deg;
105        let mut cumulative = 0.0;
106        for (v, w) in graph.neighbors(u) {
107            cumulative += w;
108            if cumulative >= threshold {
109                return v;
110            }
111        }
112        // Fallback (numerical edge case).
113        u
114    }
115}
116
117// ---------------------------------------------------------------------------
118// LocalImportanceScorer
119// ---------------------------------------------------------------------------
120
121/// Scores edge importance using localized random walks.
122///
123/// For each edge `(u, v, w)`, the score is `w * R_eff_estimate(u, v)`.
124/// High-importance edges (bridges, cut edges) get high scores and are
125/// more likely to be kept in the sparsifier.
126#[derive(Debug, Clone, Default, Serialize, Deserialize)]
127pub struct LocalImportanceScorer {
128    /// The underlying resistance estimator.
129    pub estimator: EffectiveResistanceEstimator,
130}
131
132impl LocalImportanceScorer {
133    /// Create a scorer with custom walk parameters.
134    pub fn new(walk_length: usize, num_walks: usize) -> Self {
135        Self {
136            estimator: EffectiveResistanceEstimator::new(walk_length, num_walks),
137        }
138    }
139
140    /// Compute the importance score for a single edge.
141    pub fn importance_score(&self, graph: &SparseGraph, u: usize, v: usize, weight: f64) -> f64 {
142        let r_eff = self.estimator.estimate(graph, u, v);
143        weight * r_eff
144    }
145}
146
147impl ImportanceScorer for LocalImportanceScorer {
148    fn score(
149        &self,
150        graph: &SparseGraph,
151        u: usize,
152        v: usize,
153        weight: f64,
154    ) -> EdgeImportance {
155        let r_eff = self.estimator.estimate(graph, u, v);
156        EdgeImportance::new(u, v, weight, r_eff)
157    }
158
159    fn score_all(&self, graph: &SparseGraph) -> Vec<EdgeImportance> {
160        // Collect edges first, then parallel-score them for large graphs.
161        let edges: Vec<(usize, usize, f64)> = graph.edges().collect();
162
163        if edges.len() > 100 {
164            use rayon::prelude::*;
165            edges
166                .par_iter()
167                .map(|&(u, v, w)| {
168                    let r_eff = self.estimator.estimate(graph, u, v);
169                    EdgeImportance::new(u, v, w, r_eff)
170                })
171                .collect()
172        } else {
173            edges
174                .iter()
175                .map(|&(u, v, w)| self.score(graph, u, v, w))
176                .collect()
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_self_loop_resistance() {
187        let g = SparseGraph::from_edges(&[(0, 1, 1.0)]);
188        let est = EffectiveResistanceEstimator::new(50, 5);
189        let r = est.estimate(&g, 0, 0);
190        assert!((r - 0.0).abs() < 1e-12);
191    }
192
193    #[test]
194    fn test_resistance_positive() {
195        let g = SparseGraph::from_edges(&[
196            (0, 1, 1.0),
197            (1, 2, 1.0),
198            (2, 3, 1.0),
199            (3, 0, 1.0),
200        ]);
201        let est = EffectiveResistanceEstimator::new(200, 20);
202        let r = est.estimate(&g, 0, 2);
203        assert!(r > 0.0);
204    }
205
206    #[test]
207    fn test_scorer_all() {
208        let g = SparseGraph::from_edges(&[(0, 1, 1.0), (1, 2, 2.0)]);
209        let scorer = LocalImportanceScorer::new(50, 5);
210        let scores = scorer.score_all(&g);
211        assert_eq!(scores.len(), 2);
212        for s in &scores {
213            assert!(s.score > 0.0);
214        }
215    }
216}