1use rand::prelude::*;
8use serde::{Deserialize, Serialize};
9
10use crate::graph::SparseGraph;
11use crate::traits::ImportanceScorer;
12use crate::types::EdgeImportance;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct EffectiveResistanceEstimator {
25 pub max_walk_length: usize,
27 pub num_walks: usize,
29}
30
31impl Default for EffectiveResistanceEstimator {
32 fn default() -> Self {
33 Self {
34 max_walk_length: 100,
35 num_walks: 10,
36 }
37 }
38}
39
40impl EffectiveResistanceEstimator {
41 pub fn new(max_walk_length: usize, num_walks: usize) -> Self {
43 Self {
44 max_walk_length,
45 num_walks,
46 }
47 }
48
49 pub fn estimate(&self, graph: &SparseGraph, u: usize, v: usize) -> f64 {
54 if u == v {
55 return 0.0;
56 }
57 let total_w = graph.total_weight();
58 if total_w <= 0.0 {
59 return f64::MAX;
60 }
61 if graph.degree(u) == 0 || graph.degree(v) == 0 {
62 return f64::MAX;
63 }
64
65 let mut rng = rand::thread_rng();
66 let mut total_steps = 0u64;
67
68 for _ in 0..self.num_walks {
69 total_steps += self.walk_to_target(graph, u, v, &mut rng) as u64;
71 total_steps += self.walk_to_target(graph, v, u, &mut rng) as u64;
73 }
74
75 let avg_commute = total_steps as f64 / self.num_walks as f64;
76 avg_commute / (2.0 * total_w)
78 }
79
80 fn walk_to_target<R: Rng>(
82 &self,
83 graph: &SparseGraph,
84 start: usize,
85 target: usize,
86 rng: &mut R,
87 ) -> usize {
88 let mut current = start;
89 for step in 1..=self.max_walk_length {
90 current = self.random_neighbor(graph, current, rng);
91 if current == target {
92 return step;
93 }
94 }
95 self.max_walk_length
96 }
97
98 fn random_neighbor<R: Rng>(&self, graph: &SparseGraph, u: usize, rng: &mut R) -> usize {
100 let w_deg = graph.weighted_degree(u);
101 if w_deg <= 0.0 {
102 return u; }
104 let threshold = rng.gen::<f64>() * w_deg;
105 let mut cumulative = 0.0;
106 for (v, w) in graph.neighbors(u) {
107 cumulative += w;
108 if cumulative >= threshold {
109 return v;
110 }
111 }
112 u
114 }
115}
116
117#[derive(Debug, Clone, Default, Serialize, Deserialize)]
127pub struct LocalImportanceScorer {
128 pub estimator: EffectiveResistanceEstimator,
130}
131
132impl LocalImportanceScorer {
133 pub fn new(walk_length: usize, num_walks: usize) -> Self {
135 Self {
136 estimator: EffectiveResistanceEstimator::new(walk_length, num_walks),
137 }
138 }
139
140 pub fn importance_score(&self, graph: &SparseGraph, u: usize, v: usize, weight: f64) -> f64 {
142 let r_eff = self.estimator.estimate(graph, u, v);
143 weight * r_eff
144 }
145}
146
147impl ImportanceScorer for LocalImportanceScorer {
148 fn score(
149 &self,
150 graph: &SparseGraph,
151 u: usize,
152 v: usize,
153 weight: f64,
154 ) -> EdgeImportance {
155 let r_eff = self.estimator.estimate(graph, u, v);
156 EdgeImportance::new(u, v, weight, r_eff)
157 }
158
159 fn score_all(&self, graph: &SparseGraph) -> Vec<EdgeImportance> {
160 let edges: Vec<(usize, usize, f64)> = graph.edges().collect();
162
163 if edges.len() > 100 {
164 use rayon::prelude::*;
165 edges
166 .par_iter()
167 .map(|&(u, v, w)| {
168 let r_eff = self.estimator.estimate(graph, u, v);
169 EdgeImportance::new(u, v, w, r_eff)
170 })
171 .collect()
172 } else {
173 edges
174 .iter()
175 .map(|&(u, v, w)| self.score(graph, u, v, w))
176 .collect()
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_self_loop_resistance() {
187 let g = SparseGraph::from_edges(&[(0, 1, 1.0)]);
188 let est = EffectiveResistanceEstimator::new(50, 5);
189 let r = est.estimate(&g, 0, 0);
190 assert!((r - 0.0).abs() < 1e-12);
191 }
192
193 #[test]
194 fn test_resistance_positive() {
195 let g = SparseGraph::from_edges(&[
196 (0, 1, 1.0),
197 (1, 2, 1.0),
198 (2, 3, 1.0),
199 (3, 0, 1.0),
200 ]);
201 let est = EffectiveResistanceEstimator::new(200, 20);
202 let r = est.estimate(&g, 0, 2);
203 assert!(r > 0.0);
204 }
205
206 #[test]
207 fn test_scorer_all() {
208 let g = SparseGraph::from_edges(&[(0, 1, 1.0), (1, 2, 2.0)]);
209 let scorer = LocalImportanceScorer::new(50, 5);
210 let scores = scorer.score_all(&g);
211 assert_eq!(scores.len(), 2);
212 for s in &scores {
213 assert!(s.score > 0.0);
214 }
215 }
216}