ruvector_sparsifier/
sampler.rs1use rand::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::graph::SparseGraph;
12use crate::types::EdgeImportance;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct SpectralSampler {
31 pub epsilon: f64,
33}
34
35impl SpectralSampler {
36 pub fn new(epsilon: f64) -> Self {
38 Self { epsilon }
39 }
40
41 pub fn sample_edges(
47 &self,
48 scores: &[EdgeImportance],
49 budget: usize,
50 backbone_edges: &std::collections::HashSet<(usize, usize)>,
51 ) -> SparseGraph {
52 if scores.is_empty() {
53 return SparseGraph::new();
54 }
55
56 let mut rng = rand::thread_rng();
57 let n_vertices = scores
58 .iter()
59 .map(|s| s.u.max(s.v) + 1)
60 .max()
61 .unwrap_or(0);
62 let log_n = (n_vertices as f64).ln().max(1.0);
63
64 let total_importance: f64 = scores.iter().map(|s| s.score).sum();
66 if total_importance <= 0.0 {
67 return self.backbone_only_graph(scores, backbone_edges);
69 }
70
71 let backbone_count = scores
73 .iter()
74 .filter(|s| {
75 let key = Self::edge_key(s.u, s.v);
76 backbone_edges.contains(&key)
77 })
78 .count();
79 let sample_budget = budget.saturating_sub(backbone_count);
80
81 let c = if total_importance > 0.0 {
83 sample_budget as f64 / (total_importance * log_n / (self.epsilon * self.epsilon))
84 } else {
85 1.0
86 };
87
88 let mut g = SparseGraph::with_capacity(n_vertices);
89
90 for s in scores {
91 let key = Self::edge_key(s.u, s.v);
92 let is_backbone = backbone_edges.contains(&key);
93
94 if is_backbone {
95 let _ = g.insert_or_update_edge(s.u, s.v, s.weight);
97 continue;
98 }
99
100 let p = (c * s.score * log_n / (self.epsilon * self.epsilon)).min(1.0);
102
103 if p >= 1.0 || rng.gen::<f64>() < p {
104 let reweighted = if p > 0.0 { s.weight / p } else { s.weight };
106 let _ = g.insert_or_update_edge(s.u, s.v, reweighted);
107 }
108 }
109
110 g
111 }
112
113 pub fn sample_single_edge(
116 &self,
117 importance: &EdgeImportance,
118 n_vertices: usize,
119 total_importance: f64,
120 budget: usize,
121 ) -> Option<(usize, usize, f64)> {
122 let log_n = (n_vertices as f64).ln().max(1.0);
123 let c = if total_importance > 0.0 {
124 budget as f64 / (total_importance * log_n / (self.epsilon * self.epsilon))
125 } else {
126 1.0
127 };
128 let p = (c * importance.score * log_n / (self.epsilon * self.epsilon)).min(1.0);
129
130 let mut rng = rand::thread_rng();
131 if p >= 1.0 || rng.gen::<f64>() < p {
132 let reweighted = if p > 0.0 {
133 importance.weight / p
134 } else {
135 importance.weight
136 };
137 Some((importance.u, importance.v, reweighted))
138 } else {
139 None
140 }
141 }
142
143 fn edge_key(u: usize, v: usize) -> (usize, usize) {
146 if u <= v { (u, v) } else { (v, u) }
147 }
148
149 fn backbone_only_graph(
150 &self,
151 scores: &[EdgeImportance],
152 backbone_edges: &std::collections::HashSet<(usize, usize)>,
153 ) -> SparseGraph {
154 let n = scores
155 .iter()
156 .map(|s| s.u.max(s.v) + 1)
157 .max()
158 .unwrap_or(0);
159 let mut g = SparseGraph::with_capacity(n);
160 for s in scores {
161 let key = Self::edge_key(s.u, s.v);
162 if backbone_edges.contains(&key) {
163 let _ = g.insert_or_update_edge(s.u, s.v, s.weight);
164 }
165 }
166 g
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_sample_with_backbone() {
176 let scores = vec![
177 EdgeImportance::new(0, 1, 1.0, 1.0),
178 EdgeImportance::new(1, 2, 1.0, 1.0),
179 EdgeImportance::new(0, 2, 1.0, 1.0),
180 ];
181 let mut backbone = std::collections::HashSet::new();
182 backbone.insert((0, 1));
183
184 let sampler = SpectralSampler::new(0.2);
185 let g = sampler.sample_edges(&scores, 10, &backbone);
186
187 assert!(g.has_edge(0, 1));
189 }
190
191 #[test]
192 fn test_sample_empty() {
193 let sampler = SpectralSampler::new(0.2);
194 let g = sampler.sample_edges(&[], 10, &Default::default());
195 assert_eq!(g.num_edges(), 0);
196 }
197
198 #[test]
199 fn test_high_budget_keeps_all() {
200 let scores = vec![
201 EdgeImportance::new(0, 1, 1.0, 10.0),
202 EdgeImportance::new(1, 2, 1.0, 10.0),
203 ];
204 let sampler = SpectralSampler::new(0.01); let g = sampler.sample_edges(&scores, 1000, &Default::default());
207 assert!(g.num_edges() >= 1); }
209}