ruvector_sparsifier/
audit.rs1use rand::prelude::*;
8use serde::{Deserialize, Serialize};
9
10use crate::graph::SparseGraph;
11use crate::types::AuditResult;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SpectralAuditor {
21 pub n_probes: usize,
23 pub threshold: f64,
25}
26
27impl SpectralAuditor {
28 pub fn new(n_probes: usize, threshold: f64) -> Self {
30 Self {
31 n_probes,
32 threshold,
33 }
34 }
35
36 pub fn audit_quadratic_form(
41 &self,
42 g_full: &SparseGraph,
43 g_spec: &SparseGraph,
44 n_probes: usize,
45 ) -> AuditResult {
46 let n = g_full.num_vertices();
47 if n == 0 {
48 return AuditResult::trivial_pass(self.threshold);
49 }
50
51 let n_spec = g_spec.num_vertices();
52 let dim = n.max(n_spec);
53
54 let mut rng = rand::thread_rng();
55 let mut max_error = 0.0f64;
56 let mut sum_error = 0.0f64;
57 let probes = if n_probes > 0 { n_probes } else { self.n_probes };
58
59 for _ in 0..probes {
60 let x: Vec<f64> = (0..dim).map(|_| rng.gen::<f64>() * 2.0 - 1.0).collect();
62
63 let val_full = g_full.laplacian_quadratic_form(&x[..n.max(1)]);
64 let val_spec = if n_spec > 0 {
65 g_spec.laplacian_quadratic_form(&x[..n_spec.max(1)])
66 } else {
67 0.0
68 };
69
70 let denom = val_full.abs().max(1e-15);
71 let rel_error = (val_full - val_spec).abs() / denom;
72
73 max_error = max_error.max(rel_error);
74 sum_error += rel_error;
75 }
76
77 let avg_error = if probes > 0 {
78 sum_error / probes as f64
79 } else {
80 0.0
81 };
82
83 AuditResult {
84 max_error,
85 avg_error,
86 passed: max_error <= self.threshold,
87 n_probes: probes,
88 threshold: self.threshold,
89 }
90 }
91
92 pub fn audit_cuts(
98 &self,
99 g_full: &SparseGraph,
100 g_spec: &SparseGraph,
101 n_cuts: usize,
102 ) -> AuditResult {
103 let n = g_full.num_vertices();
104 if n == 0 {
105 return AuditResult::trivial_pass(self.threshold);
106 }
107
108 let mut rng = rand::thread_rng();
109 let mut max_error = 0.0f64;
110 let mut sum_error = 0.0f64;
111
112 for _ in 0..n_cuts {
113 let x: Vec<f64> = (0..n)
115 .map(|_| if rng.gen::<bool>() { 1.0 } else { -1.0 })
116 .collect();
117
118 let cut_full = g_full.laplacian_quadratic_form(&x);
120 let cut_spec = if g_spec.num_vertices() >= n {
121 g_spec.laplacian_quadratic_form(&x)
122 } else {
123 let mut x_padded = x.clone();
125 x_padded.resize(g_spec.num_vertices().max(n), 0.0);
126 g_spec.laplacian_quadratic_form(&x_padded)
127 };
128
129 let denom = cut_full.abs().max(1e-15);
130 let rel_error = (cut_full - cut_spec).abs() / denom;
131 max_error = max_error.max(rel_error);
132 sum_error += rel_error;
133 }
134
135 let avg_error = if n_cuts > 0 {
136 sum_error / n_cuts as f64
137 } else {
138 0.0
139 };
140
141 AuditResult {
142 max_error,
143 avg_error,
144 passed: max_error <= self.threshold,
145 n_probes: n_cuts,
146 threshold: self.threshold,
147 }
148 }
149
150 pub fn audit_conductance(
155 &self,
156 g_full: &SparseGraph,
157 g_spec: &SparseGraph,
158 k_clusters: usize,
159 ) -> AuditResult {
160 let n = g_full.num_vertices();
161 if n == 0 {
162 return AuditResult::trivial_pass(self.threshold);
163 }
164
165 let mut rng = rand::thread_rng();
166 let mut max_error = 0.0f64;
167 let mut sum_error = 0.0f64;
168
169 for _ in 0..k_clusters {
170 let cluster_id: Vec<usize> = (0..n).map(|_| rng.gen_range(0..k_clusters.max(2))).collect();
172
173 for c in 0..k_clusters.max(2) {
175 let x: Vec<f64> = cluster_id
176 .iter()
177 .map(|&cid| if cid == c { 1.0 } else { 0.0 })
178 .collect();
179
180 let val_full = g_full.laplacian_quadratic_form(&x);
181 let val_spec = if g_spec.num_vertices() >= n {
182 g_spec.laplacian_quadratic_form(&x)
183 } else {
184 let mut x_padded = x.clone();
185 x_padded.resize(g_spec.num_vertices().max(n), 0.0);
186 g_spec.laplacian_quadratic_form(&x_padded)
187 };
188
189 let denom = val_full.abs().max(1e-15);
190 let rel_error = (val_full - val_spec).abs() / denom;
191 max_error = max_error.max(rel_error);
192 sum_error += rel_error;
193 }
194 }
195
196 let total_probes = k_clusters * k_clusters.max(2);
197 let avg_error = if total_probes > 0 {
198 sum_error / total_probes as f64
199 } else {
200 0.0
201 };
202
203 AuditResult {
204 max_error,
205 avg_error,
206 passed: max_error <= self.threshold,
207 n_probes: total_probes,
208 threshold: self.threshold,
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_audit_identical_graphs() {
219 let g = SparseGraph::from_edges(&[
220 (0, 1, 1.0),
221 (1, 2, 1.0),
222 (2, 0, 1.0),
223 ]);
224 let auditor = SpectralAuditor::new(20, 0.01);
225 let result = auditor.audit_quadratic_form(&g, &g, 20);
226 assert!(result.passed);
227 assert!(result.max_error < 1e-10);
228 }
229
230 #[test]
231 fn test_audit_empty_graph() {
232 let g = SparseGraph::new();
233 let auditor = SpectralAuditor::new(10, 0.2);
234 let result = auditor.audit_quadratic_form(&g, &g, 10);
235 assert!(result.passed);
236 }
237
238 #[test]
239 fn test_audit_cuts_identical() {
240 let g = SparseGraph::from_edges(&[
241 (0, 1, 2.0),
242 (1, 2, 3.0),
243 ]);
244 let auditor = SpectralAuditor::new(10, 0.01);
245 let result = auditor.audit_cuts(&g, &g, 10);
246 assert!(result.passed);
247 }
248
249 #[test]
250 fn test_audit_conductance_identical() {
251 let g = SparseGraph::from_edges(&[
252 (0, 1, 1.0),
253 (1, 2, 1.0),
254 (2, 3, 1.0),
255 ]);
256 let auditor = SpectralAuditor::new(10, 0.01);
257 let result = auditor.audit_conductance(&g, &g, 3);
258 assert!(result.passed);
259 }
260}