1use std::collections::HashMap;
8
9use ordered_float::OrderedFloat;
10use serde::{Deserialize, Serialize};
11
12use crate::error::{Result, SparsifierError};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct SparseGraph {
31 adj: Vec<HashMap<usize, f64>>,
33 num_edges: usize,
35 total_weight: f64,
37}
38
39impl Default for SparseGraph {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparseGraph {
46 pub fn new() -> Self {
50 Self {
51 adj: Vec::new(),
52 num_edges: 0,
53 total_weight: 0.0,
54 }
55 }
56
57 pub fn with_capacity(n: usize) -> Self {
59 Self {
61 adj: (0..n).map(|_| HashMap::with_capacity(4)).collect(),
62 num_edges: 0,
63 total_weight: 0.0,
64 }
65 }
66
67 pub fn from_edges(edges: &[(usize, usize, f64)]) -> Self {
71 let n = edges
72 .iter()
73 .map(|&(u, v, _)| u.max(v) + 1)
74 .max()
75 .unwrap_or(0);
76 let mut g = Self::with_capacity(n);
77 for &(u, v, w) in edges {
78 let _ = g.insert_edge(u, v, w);
80 }
81 g
82 }
83
84 pub fn ensure_capacity(&mut self, n: usize) {
88 if n > self.adj.len() {
89 self.adj.resize_with(n, HashMap::new);
90 }
91 }
92
93 #[inline]
97 pub fn num_vertices(&self) -> usize {
98 self.adj.len()
99 }
100
101 #[inline]
103 pub fn num_edges(&self) -> usize {
104 self.num_edges
105 }
106
107 #[inline]
109 pub fn total_weight(&self) -> f64 {
110 self.total_weight
111 }
112
113 #[inline]
117 pub fn degree(&self, u: usize) -> usize {
118 self.adj.get(u).map_or(0, |m| m.len())
119 }
120
121 #[inline]
126 pub fn weighted_degree(&self, u: usize) -> f64 {
127 self.adj
128 .get(u)
129 .map_or(0.0, |m| m.values().copied().sum())
130 }
131
132 pub fn neighbors(&self, u: usize) -> impl Iterator<Item = (usize, f64)> + '_ {
138 self.adj[u].iter().map(|(&v, &w)| (v, w))
139 }
140
141 #[inline]
143 pub fn edge_weight(&self, u: usize, v: usize) -> Option<f64> {
144 self.adj.get(u).and_then(|m| m.get(&v).copied())
145 }
146
147 #[inline]
149 pub fn has_edge(&self, u: usize, v: usize) -> bool {
150 self.adj
151 .get(u)
152 .is_some_and(|m| m.contains_key(&v))
153 }
154
155 pub fn edges(&self) -> impl Iterator<Item = (usize, usize, f64)> + '_ {
157 self.adj.iter().enumerate().flat_map(|(u, nbrs)| {
158 nbrs.iter()
159 .filter(move |(&v, _)| v > u)
160 .map(move |(&v, &w)| (u, v, w))
161 })
162 }
163
164 pub fn insert_edge(&mut self, u: usize, v: usize, weight: f64) -> Result<()> {
175 if !weight.is_finite() || weight <= 0.0 {
176 return Err(SparsifierError::InvalidWeight(weight));
177 }
178 let n = u.max(v) + 1;
179 self.ensure_capacity(n);
180
181 if self.adj[u].contains_key(&v) {
182 return Err(SparsifierError::EdgeAlreadyExists(u, v));
183 }
184
185 self.adj[u].insert(v, weight);
186 if u != v {
187 self.adj[v].insert(u, weight);
188 }
189 self.num_edges += 1;
190 self.total_weight += weight;
191 Ok(())
192 }
193
194 pub fn insert_or_update_edge(
198 &mut self,
199 u: usize,
200 v: usize,
201 weight: f64,
202 ) -> Result<Option<f64>> {
203 if !weight.is_finite() || weight <= 0.0 {
204 return Err(SparsifierError::InvalidWeight(weight));
205 }
206 let n = u.max(v) + 1;
207 self.ensure_capacity(n);
208
209 let old = self.adj[u].insert(v, weight);
210 if u != v {
211 self.adj[v].insert(u, weight);
212 }
213
214 if let Some(old_w) = old {
215 self.total_weight += weight - old_w;
216 Ok(Some(old_w))
217 } else {
218 self.num_edges += 1;
219 self.total_weight += weight;
220 Ok(None)
221 }
222 }
223
224 pub fn delete_edge(&mut self, u: usize, v: usize) -> Result<f64> {
230 let w = self
231 .adj
232 .get_mut(u)
233 .and_then(|m| m.remove(&v))
234 .ok_or(SparsifierError::EdgeNotFound(u, v))?;
235 if u != v {
236 self.adj[v].remove(&u);
237 }
238 self.num_edges -= 1;
239 self.total_weight -= w;
240 Ok(w)
241 }
242
243 pub fn update_weight(&mut self, u: usize, v: usize, new_weight: f64) -> Result<f64> {
249 if !new_weight.is_finite() || new_weight <= 0.0 {
250 return Err(SparsifierError::InvalidWeight(new_weight));
251 }
252 let old = self
253 .adj
254 .get_mut(u)
255 .and_then(|m| m.get_mut(&v))
256 .ok_or(SparsifierError::EdgeNotFound(u, v))?;
257 let old_w = *old;
258 *old = new_weight;
259 if u != v {
260 if let Some(entry) = self.adj[v].get_mut(&u) {
261 *entry = new_weight;
262 }
263 }
264 self.total_weight += new_weight - old_w;
265 Ok(old_w)
266 }
267
268 pub fn clear(&mut self) {
270 self.adj.clear();
271 self.num_edges = 0;
272 self.total_weight = 0.0;
273 }
274
275 pub fn laplacian_quadratic_form(&self, x: &[f64]) -> f64 {
285 assert!(
286 x.len() >= self.num_vertices(),
287 "x.len()={} < num_vertices={}",
288 x.len(),
289 self.num_vertices()
290 );
291 let mut sum = 0.0;
292 for (u, nbrs) in self.adj.iter().enumerate() {
293 for (&v, &w) in nbrs {
294 if v > u {
295 let diff = x[u] - x[v];
296 sum += w * diff * diff;
297 }
298 }
299 }
300 sum
301 }
302
303 pub fn to_csr(&self) -> (Vec<usize>, Vec<usize>, Vec<f64>, usize) {
311 let n = self.num_vertices();
312 let mut row_ptr = Vec::with_capacity(n + 1);
313 let mut col_indices = Vec::new();
314 let mut values = Vec::new();
315
316 row_ptr.push(0);
317 for u in 0..n {
318 let mut entries: Vec<(usize, f64)> = self.adj[u]
320 .iter()
321 .map(|(&v, &w)| (v, w))
322 .collect();
323 entries.sort_by_key(|&(v, w)| (v, OrderedFloat(w)));
324 for (v, w) in entries {
325 col_indices.push(v);
326 values.push(w);
327 }
328 row_ptr.push(col_indices.len());
329 }
330
331 (row_ptr, col_indices, values, n)
332 }
333
334 pub fn from_csr(
339 row_ptr: &[usize],
340 col_indices: &[usize],
341 values: &[f64],
342 n: usize,
343 ) -> Self {
344 let mut g = Self::with_capacity(n);
345 for u in 0..n {
346 let start = row_ptr[u];
347 let end = row_ptr[u + 1];
348 for idx in start..end {
349 let v = col_indices[idx];
350 let w = values[idx];
351 if v >= u && !g.has_edge(u, v) {
352 let _ = g.insert_edge(u, v, w);
353 }
354 }
355 }
356 g
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_insert_and_query() {
366 let mut g = SparseGraph::new();
367 g.insert_edge(0, 1, 2.0).unwrap();
368 g.insert_edge(1, 2, 3.0).unwrap();
369
370 assert_eq!(g.num_vertices(), 3);
371 assert_eq!(g.num_edges(), 2);
372 assert!((g.total_weight() - 5.0).abs() < 1e-12);
373 assert_eq!(g.degree(0), 1);
374 assert_eq!(g.degree(1), 2);
375 assert_eq!(g.edge_weight(0, 1), Some(2.0));
376 assert_eq!(g.edge_weight(1, 0), Some(2.0));
377 assert_eq!(g.edge_weight(0, 2), None);
378 }
379
380 #[test]
381 fn test_delete_edge() {
382 let mut g = SparseGraph::from_edges(&[(0, 1, 1.0), (1, 2, 2.0)]);
383 assert_eq!(g.num_edges(), 2);
384
385 let w = g.delete_edge(0, 1).unwrap();
386 assert!((w - 1.0).abs() < 1e-12);
387 assert_eq!(g.num_edges(), 1);
388 assert!(!g.has_edge(0, 1));
389 assert!(!g.has_edge(1, 0));
390 }
391
392 #[test]
393 fn test_update_weight() {
394 let mut g = SparseGraph::from_edges(&[(0, 1, 1.0)]);
395 let old = g.update_weight(0, 1, 5.0).unwrap();
396 assert!((old - 1.0).abs() < 1e-12);
397 assert_eq!(g.edge_weight(0, 1), Some(5.0));
398 assert_eq!(g.edge_weight(1, 0), Some(5.0));
399 assert!((g.total_weight() - 5.0).abs() < 1e-12);
400 }
401
402 #[test]
403 fn test_laplacian_quadratic_form() {
404 let g = SparseGraph::from_edges(&[(0, 1, 1.0), (1, 2, 1.0), (0, 2, 1.0)]);
406 let x = vec![1.0, 0.0, 0.0];
408 let val = g.laplacian_quadratic_form(&x);
409 assert!((val - 2.0).abs() < 1e-12);
410 }
411
412 #[test]
413 fn test_csr_roundtrip() {
414 let g = SparseGraph::from_edges(&[(0, 1, 1.5), (1, 2, 2.5), (0, 2, 3.5)]);
415 let (rp, ci, vals, n) = g.to_csr();
416 let g2 = SparseGraph::from_csr(&rp, &ci, &vals, n);
417
418 assert_eq!(g2.num_vertices(), g.num_vertices());
419 assert_eq!(g2.num_edges(), g.num_edges());
420 assert!((g2.total_weight() - g.total_weight()).abs() < 1e-12);
421 }
422
423 #[test]
424 fn test_edges_iterator() {
425 let g = SparseGraph::from_edges(&[(0, 1, 1.0), (1, 2, 2.0), (0, 2, 3.0)]);
426 let edges: Vec<_> = g.edges().collect();
427 assert_eq!(edges.len(), 3);
428 }
429
430 #[test]
431 fn test_invalid_weight() {
432 let mut g = SparseGraph::new();
433 assert!(g.insert_edge(0, 1, -1.0).is_err());
434 assert!(g.insert_edge(0, 1, 0.0).is_err());
435 assert!(g.insert_edge(0, 1, f64::NAN).is_err());
436 assert!(g.insert_edge(0, 1, f64::INFINITY).is_err());
437 }
438
439 #[test]
440 fn test_duplicate_edge() {
441 let mut g = SparseGraph::new();
442 g.insert_edge(0, 1, 1.0).unwrap();
443 assert!(g.insert_edge(0, 1, 2.0).is_err());
444 }
445}