Skip to main content

scirs2_graph/hypergraph/
algorithms.rs

1//! Hypergraph algorithms.
2//!
3//! Implements:
4//! * **Spectral clustering** via the normalised Laplacian (Zhou et al. 2006).
5//! * **Hyperedge cuts**: raw cut, ratio cut, normalised cut.
6//! * **Generalised random walk** (Markov chain on nodes through hyperedges).
7//! * **Hypergraph betweenness centrality** based on shortest paths through
8//!   the clique-expansion graph.
9//! * **s-walks and s-paths**: walks/paths between hyperedges that share ≥ s nodes.
10
11use super::core::{clique_expansion, hyperedge_centrality, IndexedHypergraph};
12use crate::error::{GraphError, Result};
13use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::{Rng, RngExt, SeedableRng};
15use std::cmp::Ordering;
16use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
17
18// ============================================================================
19// Internal helpers
20// ============================================================================
21
22/// Returns the normalised Laplacian Θ = I − D_v^{-1/2} B W D_e^{-1} B^T D_v^{-1/2}.
23///
24/// Shape: (n_nodes × n_nodes).  Used for spectral clustering.
25fn normalised_laplacian(hg: &IndexedHypergraph) -> Array2<f64> {
26    let n = hg.n_nodes();
27    let m = hg.n_hyperedges();
28
29    if n == 0 || m == 0 {
30        return Array2::eye(n);
31    }
32
33    // B: n × m incidence (binary, so each entry 0/1)
34    let b = hg.incidence_matrix_binary();
35
36    // D_v[i] = weighted degree of node i
37    let dv: Vec<f64> = (0..n).map(|i| hg.weighted_degree(i)).collect();
38    // D_e[e] = size of hyperedge e (cardinality)
39    let de: Vec<f64> = hg
40        .hyperedges()
41        .iter()
42        .map(|he| he.nodes.len() as f64)
43        .collect();
44    // W[e] = weight of hyperedge e
45    let w: Vec<f64> = hg.hyperedges().iter().map(|he| he.weight).collect();
46
47    // Compute Θ = I − D_v^{-1/2} B W D_e^{-1} B^T D_v^{-1/2}
48    // We build the (n × n) matrix Ω = D_v^{-1/2} B W D_e^{-1} B^T D_v^{-1/2} directly.
49    let mut omega = Array2::<f64>::zeros((n, n));
50    for e in 0..m {
51        if de[e] == 0.0 {
52            continue;
53        }
54        let scale = w[e] / de[e];
55        // Find which nodes belong to this hyperedge
56        let members: Vec<usize> = (0..n)
57            .filter(|&i| (b[[i, e]] - 1.0).abs() < 1e-10)
58            .collect();
59        for &i in &members {
60            let dvi = if dv[i] > 0.0 { dv[i].sqrt() } else { 0.0 };
61            if dvi == 0.0 {
62                continue;
63            }
64            for &j in &members {
65                let dvj = if dv[j] > 0.0 { dv[j].sqrt() } else { 0.0 };
66                if dvj == 0.0 {
67                    continue;
68                }
69                omega[[i, j]] += scale / (dvi * dvj);
70            }
71        }
72    }
73
74    // Θ = I − Ω
75    let mut theta = Array2::<f64>::eye(n);
76    for i in 0..n {
77        for j in 0..n {
78            theta[[i, j]] -= omega[[i, j]];
79        }
80    }
81    theta
82}
83
84// ============================================================================
85// Spectral clustering
86// ============================================================================
87
88/// Result of hypergraph spectral clustering.
89#[derive(Debug, Clone)]
90pub struct SpectralClusteringResult {
91    /// Cluster label for each node (`labels[i] ∈ 0..k`).
92    pub labels: Vec<usize>,
93    /// The first `k` eigenvectors stacked as columns (shape `n × k`).
94    pub embedding: Array2<f64>,
95    /// Number of iterations used by the power-iteration eigensolver.
96    pub eigenvalue_iterations: usize,
97}
98
99/// Perform **spectral clustering** on a hypergraph using the normalised
100/// Laplacian (Zhou et al. NeurIPS 2006).
101///
102/// # Arguments
103/// * `hg`   – the hypergraph
104/// * `k`    – number of clusters
105/// * `seed` – RNG seed for k-means initialisation
106///
107/// # Algorithm
108/// 1. Form the `n × n` normalised Laplacian Θ.
109/// 2. Extract the `k` eigenvectors corresponding to the **smallest** `k`
110///    eigenvalues using deflated power iteration.
111/// 3. Run k-means on the resulting `n × k` embedding.
112///
113/// # Errors
114/// Returns `GraphError::InvalidGraph` when `k > n_nodes` or the hypergraph is
115/// empty.
116pub fn spectral_clustering(
117    hg: &IndexedHypergraph,
118    k: usize,
119    seed: u64,
120) -> Result<SpectralClusteringResult> {
121    use scirs2_core::random::ChaCha20Rng;
122    let n = hg.n_nodes();
123    if n == 0 {
124        return Err(GraphError::InvalidGraph(
125            "hypergraph has no nodes".to_string(),
126        ));
127    }
128    if k == 0 || k > n {
129        return Err(GraphError::InvalidGraph(format!(
130            "k = {k} must be in 1..={n}"
131        )));
132    }
133
134    let theta = normalised_laplacian(hg);
135
136    // Compute the k smallest eigenvectors via deflated power iteration on
137    // (sigma*I - Theta).  We use sigma = 2 so the operator is positive definite.
138    let sigma = 2.0_f64;
139    // shifted: A = sigma*I - Theta  → largest eigenvectors of A = smallest of Theta
140    let mut a = theta.clone();
141    for i in 0..n {
142        a[[i, i]] = sigma - theta[[i, i]];
143        for j in 0..n {
144            if i != j {
145                a[[i, j]] = -theta[[i, j]];
146            }
147        }
148    }
149
150    let mut rng = ChaCha20Rng::seed_from_u64(seed);
151    let mut eigenvecs: Vec<Vec<f64>> = Vec::with_capacity(k);
152    let mut total_iters = 0usize;
153
154    for _ki in 0..k {
155        // Random starting vector
156        let mut v: Vec<f64> = (0..n).map(|_| rng.random::<f64>() - 0.5).collect();
157        // Orthogonalise against already-found eigenvectors
158        for prev in &eigenvecs {
159            let dot: f64 = v.iter().zip(prev.iter()).map(|(a, b)| a * b).sum();
160            for (vi, pi) in v.iter_mut().zip(prev.iter()) {
161                *vi -= dot * pi;
162            }
163        }
164        // Normalise
165        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
166        if norm < 1e-12 {
167            v = vec![0.0; n];
168            if _ki < n {
169                v[_ki] = 1.0;
170            }
171        } else {
172            for vi in &mut v {
173                *vi /= norm;
174            }
175        }
176
177        let max_iter = 2000;
178        let tol = 1e-10;
179        let mut iters = 0usize;
180        for _ in 0..max_iter {
181            iters += 1;
182            // Multiply by A
183            let mut nv: Vec<f64> = vec![0.0; n];
184            for i in 0..n {
185                for j in 0..n {
186                    nv[i] += a[[i, j]] * v[j];
187                }
188            }
189            // Deflation: remove components along already-found eigenvectors
190            for prev in &eigenvecs {
191                let dot: f64 = nv.iter().zip(prev.iter()).map(|(a, b)| a * b).sum();
192                for (nvi, pi) in nv.iter_mut().zip(prev.iter()) {
193                    *nvi -= dot * pi;
194                }
195            }
196            // Normalise
197            let norm: f64 = nv.iter().map(|x| x * x).sum::<f64>().sqrt();
198            if norm < 1e-15 {
199                break;
200            }
201            for vi in &mut nv {
202                *vi /= norm;
203            }
204            // Convergence check
205            let diff: f64 = nv
206                .iter()
207                .zip(v.iter())
208                .map(|(a, b)| (a - b).abs() + (a + b).abs())
209                .fold(0.0_f64, f64::min);
210            // Check if either v==nv or v==-nv
211            let diff_pos: f64 = nv
212                .iter()
213                .zip(v.iter())
214                .map(|(a, b)| (a - b).powi(2))
215                .sum::<f64>()
216                .sqrt();
217            let diff_neg: f64 = nv
218                .iter()
219                .zip(v.iter())
220                .map(|(a, b)| (a + b).powi(2))
221                .sum::<f64>()
222                .sqrt();
223            let _ = diff; // suppress unused warning
224            v = nv;
225            if diff_pos < tol || diff_neg < tol {
226                break;
227            }
228        }
229        total_iters += iters;
230        eigenvecs.push(v);
231    }
232
233    // Build n×k embedding matrix
234    let mut embedding = Array2::<f64>::zeros((n, k));
235    for (ki, ev) in eigenvecs.iter().enumerate() {
236        for (i, &val) in ev.iter().enumerate() {
237            embedding[[i, ki]] = val;
238        }
239    }
240
241    // k-means on embedding rows
242    let labels = kmeans(&embedding, k, seed + 1, 300);
243
244    Ok(SpectralClusteringResult {
245        labels,
246        embedding,
247        eigenvalue_iterations: total_iters,
248    })
249}
250
251/// Simple k-means clustering on the rows of a matrix.
252fn kmeans(data: &Array2<f64>, k: usize, seed: u64, max_iter: usize) -> Vec<usize> {
253    use scirs2_core::random::ChaCha20Rng;
254    let n = data.nrows();
255    let d = data.ncols();
256    if k == 0 || n == 0 {
257        return vec![0; n];
258    }
259    let k = k.min(n);
260
261    let mut rng = ChaCha20Rng::seed_from_u64(seed);
262
263    // k-means++ initialisation
264    let mut centers: Vec<Vec<f64>> = Vec::with_capacity(k);
265    let first = rng.random_range(0..n);
266    centers.push(data.row(first).to_vec());
267
268    for _ in 1..k {
269        // Probability proportional to distance squared to nearest centre
270        let dists: Vec<f64> = (0..n)
271            .map(|i| {
272                centers
273                    .iter()
274                    .map(|c| {
275                        data.row(i)
276                            .iter()
277                            .zip(c.iter())
278                            .map(|(a, b)| (a - b).powi(2))
279                            .sum::<f64>()
280                    })
281                    .fold(f64::INFINITY, f64::min)
282            })
283            .collect();
284        let total: f64 = dists.iter().sum();
285        if total < 1e-15 {
286            break;
287        }
288        let threshold = rng.random::<f64>() * total;
289        let mut acc = 0.0;
290        let mut chosen = n - 1;
291        for (i, &d) in dists.iter().enumerate() {
292            acc += d;
293            if acc >= threshold {
294                chosen = i;
295                break;
296            }
297        }
298        centers.push(data.row(chosen).to_vec());
299    }
300
301    let mut labels = vec![0usize; n];
302    for _iter in 0..max_iter {
303        // Assignment
304        let mut changed = false;
305        for i in 0..n {
306            let best = (0..centers.len())
307                .min_by(|&a, &b| {
308                    let da: f64 = data
309                        .row(i)
310                        .iter()
311                        .zip(centers[a].iter())
312                        .map(|(x, c)| (x - c).powi(2))
313                        .sum();
314                    let db: f64 = data
315                        .row(i)
316                        .iter()
317                        .zip(centers[b].iter())
318                        .map(|(x, c)| (x - c).powi(2))
319                        .sum();
320                    da.partial_cmp(&db).unwrap_or(Ordering::Equal)
321                })
322                .unwrap_or(0);
323            if labels[i] != best {
324                labels[i] = best;
325                changed = true;
326            }
327        }
328        if !changed {
329            break;
330        }
331        // Update centres
332        let mut sums = vec![vec![0.0f64; d]; centers.len()];
333        let mut counts = vec![0usize; centers.len()];
334        for i in 0..n {
335            let c = labels[i];
336            counts[c] += 1;
337            for j in 0..d {
338                sums[c][j] += data[[i, j]];
339            }
340        }
341        for c in 0..centers.len() {
342            if counts[c] > 0 {
343                for j in 0..d {
344                    centers[c][j] = sums[c][j] / counts[c] as f64;
345                }
346            }
347        }
348    }
349    labels
350}
351
352// ============================================================================
353// Hyperedge cuts
354// ============================================================================
355
356/// Result of a hyperedge cut computation.
357#[derive(Debug, Clone)]
358pub struct CutResult {
359    /// Raw hyperedge cut (number of hyperedges crossing the partition).
360    pub cut: usize,
361    /// Ratio cut: `cut / min(|S|, |V\S|)`.
362    pub ratio_cut: f64,
363    /// Normalised cut: `cut/vol(S) + cut/vol(V\S)` where `vol(X) = Σ_v deg(v)`.
364    pub normalised_cut: f64,
365}
366
367/// Compute the **hyperedge cut**, **ratio cut**, and **normalised cut** for a
368/// binary partition of the node set.
369///
370/// # Arguments
371/// * `hg`        – the hypergraph
372/// * `partition` – boolean slice of length `n_nodes`; `true` → side A, `false` → side B
373///
374/// # Errors
375/// Returns an error if `partition.len() != hg.n_nodes()`.
376pub fn hyperedge_cut(hg: &IndexedHypergraph, partition: &[bool]) -> Result<CutResult> {
377    if partition.len() != hg.n_nodes() {
378        return Err(GraphError::InvalidGraph(format!(
379            "partition length {} != n_nodes {}",
380            partition.len(),
381            hg.n_nodes()
382        )));
383    }
384    let mut cut = 0usize;
385    for he in hg.hyperedges() {
386        let has_true = he.nodes.iter().any(|&n| partition[n]);
387        let has_false = he.nodes.iter().any(|&n| !partition[n]);
388        if has_true && has_false {
389            cut += 1;
390        }
391    }
392
393    let size_a = partition.iter().filter(|&&b| b).count();
394    let size_b = hg.n_nodes() - size_a;
395    let min_side = size_a.min(size_b);
396
397    let ratio_cut = if min_side > 0 {
398        cut as f64 / min_side as f64
399    } else {
400        f64::INFINITY
401    };
402
403    // Volumes
404    let vol_a: f64 = (0..hg.n_nodes())
405        .filter(|&i| partition[i])
406        .map(|i| hg.weighted_degree(i))
407        .sum();
408    let vol_b: f64 = (0..hg.n_nodes())
409        .filter(|&i| !partition[i])
410        .map(|i| hg.weighted_degree(i))
411        .sum();
412
413    let normalised_cut = if vol_a > 0.0 && vol_b > 0.0 {
414        cut as f64 / vol_a + cut as f64 / vol_b
415    } else {
416        f64::INFINITY
417    };
418
419    Ok(CutResult {
420        cut,
421        ratio_cut,
422        normalised_cut,
423    })
424}
425
426// ============================================================================
427// Generalised hypergraph random walk (stationary distribution)
428// ============================================================================
429
430/// Compute the **stationary distribution** of the generalised random walk on a
431/// hypergraph, using power iteration on the transition matrix.
432///
433/// The transition probability P(u→v) follows the Chung–Zhou formulation:
434///
435/// ```text
436/// P(u, v) = Σ_e ∈ E(u) [ w_e / (deg_w(u) * |e|) ]   for v ∈ e
437/// ```
438///
439/// Returns a vector of length `n_nodes` summing to 1.
440///
441/// # Errors
442/// Returns an error if the hypergraph has no nodes or all nodes are isolated.
443pub fn stationary_distribution(hg: &IndexedHypergraph) -> Result<Array1<f64>> {
444    let n = hg.n_nodes();
445    if n == 0 {
446        return Err(GraphError::InvalidGraph(
447            "hypergraph has no nodes".to_string(),
448        ));
449    }
450
451    // Check that at least one node has non-zero weighted degree
452    let any_connected = (0..n).any(|i| hg.weighted_degree(i) > 0.0);
453    if !any_connected {
454        // Uniform distribution over all nodes
455        return Ok(Array1::from_elem(n, 1.0 / n as f64));
456    }
457
458    // Build transition matrix P (n × n)
459    let mut p = Array2::<f64>::zeros((n, n));
460    for he in hg.hyperedges() {
461        let size = he.nodes.len() as f64;
462        if size == 0.0 {
463            continue;
464        }
465        for &u in &he.nodes {
466            let deg_u = hg.weighted_degree(u);
467            if deg_u == 0.0 {
468                continue;
469            }
470            for &v in &he.nodes {
471                p[[u, v]] += he.weight / (deg_u * size);
472            }
473        }
474    }
475
476    // Power iteration: pi * P = pi
477    let mut pi = Array1::from_elem(n, 1.0 / n as f64);
478    let max_iter = 5000;
479    let tol = 1e-10;
480
481    for _ in 0..max_iter {
482        // pi_new = pi . P
483        let mut pi_new = Array1::<f64>::zeros(n);
484        for i in 0..n {
485            for j in 0..n {
486                pi_new[j] += pi[i] * p[[i, j]];
487            }
488        }
489        // Normalise
490        let s: f64 = pi_new.iter().sum();
491        if s > 0.0 {
492            pi_new.mapv_inplace(|x| x / s);
493        }
494        // Convergence
495        let diff: f64 = pi_new
496            .iter()
497            .zip(pi.iter())
498            .map(|(a, b)| (a - b).abs())
499            .sum();
500        pi = pi_new;
501        if diff < tol {
502            break;
503        }
504    }
505    Ok(pi)
506}
507
508// ============================================================================
509// Hypergraph betweenness centrality
510// ============================================================================
511
512/// Compute **hypergraph betweenness centrality** for every node.
513///
514/// We compute betweenness on the **clique-expansion graph** (2-section), where
515/// edge weights are used as distances in Dijkstra's algorithm.  The betweenness
516/// of node `v` is the fraction of shortest paths between all pairs `(s, t)`
517/// (s ≠ v ≠ t) that pass through `v`.
518///
519/// Returns a vector of length `n_nodes`.
520pub fn betweenness_centrality(hg: &IndexedHypergraph) -> Vec<f64> {
521    let n = hg.n_nodes();
522    let mut bc = vec![0.0f64; n];
523    if n < 3 {
524        return bc;
525    }
526
527    let g = clique_expansion(hg);
528
529    // Build adjacency list from graph
530    // We use a simple BFS-based SSSP (unweighted) for correctness, since
531    // the clique-expansion already captures hyperedge topology.
532    let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
533    // Pull edges from the clique expansion
534    for he in hg.hyperedges() {
535        let k = he.nodes.len();
536        for i in 0..k {
537            for j in (i + 1)..k {
538                let u = he.nodes[i];
539                let v = he.nodes[j];
540                if !adj[u].contains(&v) {
541                    adj[u].push(v);
542                }
543                if !adj[v].contains(&u) {
544                    adj[v].push(u);
545                }
546            }
547        }
548    }
549    let _ = g; // g used via adj
550
551    // Brandes' algorithm (unweighted BFS version)
552    for s in 0..n {
553        let mut stack: Vec<usize> = Vec::new();
554        let mut pred: Vec<Vec<usize>> = vec![Vec::new(); n];
555        let mut sigma = vec![0.0f64; n];
556        sigma[s] = 1.0;
557        let mut dist = vec![-1i64; n];
558        dist[s] = 0;
559        let mut queue = VecDeque::new();
560        queue.push_back(s);
561
562        while let Some(v) = queue.pop_front() {
563            stack.push(v);
564            for &w in &adj[v] {
565                if dist[w] < 0 {
566                    queue.push_back(w);
567                    dist[w] = dist[v] + 1;
568                }
569                if dist[w] == dist[v] + 1 {
570                    sigma[w] += sigma[v];
571                    pred[w].push(v);
572                }
573            }
574        }
575
576        let mut delta = vec![0.0f64; n];
577        while let Some(w) = stack.pop() {
578            for &v in &pred[w] {
579                if sigma[w] > 0.0 {
580                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
581                }
582            }
583            if w != s {
584                bc[w] += delta[w];
585            }
586        }
587    }
588
589    // Normalise by (n-1)(n-2) for undirected graphs
590    let factor = if n > 2 {
591        1.0 / ((n - 1) as f64 * (n - 2) as f64)
592    } else {
593        1.0
594    };
595    for v in &mut bc {
596        *v *= factor;
597    }
598    bc
599}
600
601// ============================================================================
602// s-walks and s-paths
603// ============================================================================
604
605/// Compute the **s-distance** between two hyperedges: the length of the
606/// shortest s-path (sequence of hyperedges each sharing ≥ `s` nodes with the
607/// next).
608///
609/// Returns `None` if the hyperedges are not s-connected.
610pub fn s_distance(hg: &IndexedHypergraph, e1: usize, e2: usize, s: usize) -> Option<usize> {
611    let m = hg.n_hyperedges();
612    if e1 >= m || e2 >= m {
613        return None;
614    }
615    if e1 == e2 {
616        return Some(0);
617    }
618
619    // BFS on hyperedge-level s-adjacency graph
620    let mut dist = vec![usize::MAX; m];
621    dist[e1] = 0;
622    let mut queue = VecDeque::new();
623    queue.push_back(e1);
624
625    while let Some(cur) = queue.pop_front() {
626        let cur_dist = dist[cur];
627        for next in 0..m {
628            if next == cur {
629                continue;
630            }
631            if dist[next] == usize::MAX {
632                let shared = hg.hyperedges()[cur].intersection_size(&hg.hyperedges()[next]);
633                if shared >= s {
634                    dist[next] = cur_dist + 1;
635                    if next == e2 {
636                        return Some(dist[next]);
637                    }
638                    queue.push_back(next);
639                }
640            }
641        }
642    }
643    None
644}
645
646/// Compute the **s-diameter** of the hypergraph: the maximum s-distance over
647/// all pairs of hyperedges in the same s-connected component.
648///
649/// Returns `0` if there are fewer than 2 hyperedges or the hypergraph is
650/// s-disconnected everywhere.
651pub fn s_diameter(hg: &IndexedHypergraph, s: usize) -> usize {
652    let m = hg.n_hyperedges();
653    let mut max_dist = 0usize;
654    for e1 in 0..m {
655        for e2 in (e1 + 1)..m {
656            if let Some(d) = s_distance(hg, e1, e2, s) {
657                max_dist = max_dist.max(d);
658            }
659        }
660    }
661    max_dist
662}
663
664/// Enumerate all **s-paths** of length ≤ `max_len` starting from hyperedge
665/// `start` as BFS layers.
666///
667/// Returns a `HashMap<usize, usize>` mapping each reachable hyperedge to its
668/// s-distance from `start`.
669pub fn s_reachability(
670    hg: &IndexedHypergraph,
671    start: usize,
672    s: usize,
673    max_len: usize,
674) -> HashMap<usize, usize> {
675    let m = hg.n_hyperedges();
676    let mut dists: HashMap<usize, usize> = HashMap::new();
677    if start >= m {
678        return dists;
679    }
680    dists.insert(start, 0);
681    let mut queue = VecDeque::new();
682    queue.push_back(start);
683
684    while let Some(cur) = queue.pop_front() {
685        let cur_dist = *dists.get(&cur).unwrap_or(&0);
686        if cur_dist >= max_len {
687            continue;
688        }
689        for next in 0..m {
690            if next == cur || dists.contains_key(&next) {
691                continue;
692            }
693            let shared = hg.hyperedges()[cur].intersection_size(&hg.hyperedges()[next]);
694            if shared >= s {
695                dists.insert(next, cur_dist + 1);
696                queue.push_back(next);
697            }
698        }
699    }
700    dists
701}
702
703/// Hyperedge betweenness centrality in the **s-line graph**.
704///
705/// Returns a vector of length `n_hyperedges` where each entry is the fraction
706/// of shortest s-paths (in the s-adjacency graph) passing through that
707/// hyperedge.
708pub fn s_betweenness_centrality(hg: &IndexedHypergraph, s: usize) -> Vec<f64> {
709    let m = hg.n_hyperedges();
710    let mut bc = vec![0.0f64; m];
711    if m < 3 {
712        return bc;
713    }
714
715    // Build s-adjacency list
716    let mut s_adj: Vec<Vec<usize>> = vec![Vec::new(); m];
717    for i in 0..m {
718        for j in (i + 1)..m {
719            let shared = hg.hyperedges()[i].intersection_size(&hg.hyperedges()[j]);
720            if shared >= s {
721                s_adj[i].push(j);
722                s_adj[j].push(i);
723            }
724        }
725    }
726
727    // Brandes on s-adjacency graph
728    for src in 0..m {
729        let mut stack: Vec<usize> = Vec::new();
730        let mut pred: Vec<Vec<usize>> = vec![Vec::new(); m];
731        let mut sigma = vec![0.0f64; m];
732        sigma[src] = 1.0;
733        let mut dist = vec![-1i64; m];
734        dist[src] = 0;
735        let mut queue = VecDeque::new();
736        queue.push_back(src);
737
738        while let Some(v) = queue.pop_front() {
739            stack.push(v);
740            for &w in &s_adj[v] {
741                if dist[w] < 0 {
742                    queue.push_back(w);
743                    dist[w] = dist[v] + 1;
744                }
745                if dist[w] == dist[v] + 1 {
746                    sigma[w] += sigma[v];
747                    pred[w].push(v);
748                }
749            }
750        }
751        let mut delta = vec![0.0f64; m];
752        while let Some(w) = stack.pop() {
753            for &v in &pred[w] {
754                if sigma[w] > 0.0 {
755                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
756                }
757            }
758            if w != src {
759                bc[w] += delta[w];
760            }
761        }
762    }
763
764    let factor = if m > 2 {
765        1.0 / ((m - 1) as f64 * (m - 2) as f64)
766    } else {
767        1.0
768    };
769    for v in &mut bc {
770        *v *= factor;
771    }
772    bc
773}
774
775// ============================================================================
776// Tests
777// ============================================================================
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782    use approx::assert_relative_eq;
783
784    fn make_hg() -> IndexedHypergraph {
785        // 5 nodes, 3 hyperedges
786        let mut hg = IndexedHypergraph::new(5);
787        hg.add_hyperedge(vec![0, 1, 2], 1.0).expect("ok");
788        hg.add_hyperedge(vec![2, 3, 4], 1.0).expect("ok");
789        hg.add_hyperedge(vec![0, 3], 1.0).expect("ok");
790        hg
791    }
792
793    #[test]
794    fn test_spectral_clustering_labels() {
795        let hg = make_hg();
796        let res = spectral_clustering(&hg, 2, 42).expect("cluster ok");
797        assert_eq!(res.labels.len(), 5);
798        // All labels must be in 0..2
799        for &l in &res.labels {
800            assert!(l < 2);
801        }
802    }
803
804    #[test]
805    fn test_spectral_clustering_invalid_k() {
806        let hg = make_hg();
807        assert!(spectral_clustering(&hg, 0, 0).is_err());
808        assert!(spectral_clustering(&hg, 100, 0).is_err());
809    }
810
811    #[test]
812    fn test_hyperedge_cut_partition() {
813        let mut hg = IndexedHypergraph::new(4);
814        hg.add_hyperedge(vec![0, 1], 1.0).expect("ok");
815        hg.add_hyperedge(vec![2, 3], 1.0).expect("ok");
816        hg.add_hyperedge(vec![1, 2], 1.0).expect("ok"); // crosses partition
817                                                        // Partition: {0,1} vs {2,3}
818        let part = vec![true, true, false, false];
819        let res = hyperedge_cut(&hg, &part).expect("cut ok");
820        assert_eq!(res.cut, 1);
821    }
822
823    #[test]
824    fn test_hyperedge_cut_all_same_side() {
825        let mut hg = IndexedHypergraph::new(4);
826        hg.add_hyperedge(vec![0, 1, 2, 3], 1.0).expect("ok");
827        let part = vec![true, true, true, true];
828        let res = hyperedge_cut(&hg, &part).expect("cut ok");
829        assert_eq!(res.cut, 0);
830    }
831
832    #[test]
833    fn test_hyperedge_cut_wrong_len() {
834        let hg = make_hg();
835        assert!(hyperedge_cut(&hg, &[true, false]).is_err());
836    }
837
838    #[test]
839    fn test_stationary_distribution_sums_to_one() {
840        let hg = make_hg();
841        let pi = stationary_distribution(&hg).expect("ok");
842        let s: f64 = pi.iter().sum();
843        assert_relative_eq!(s, 1.0, epsilon = 1e-6);
844    }
845
846    #[test]
847    fn test_stationary_empty() {
848        let hg = IndexedHypergraph::new(0);
849        assert!(stationary_distribution(&hg).is_err());
850    }
851
852    #[test]
853    fn test_betweenness_centrality_len() {
854        let hg = make_hg();
855        let bc = betweenness_centrality(&hg);
856        assert_eq!(bc.len(), 5);
857        for &v in &bc {
858            assert!(v >= 0.0);
859        }
860    }
861
862    #[test]
863    fn test_s_distance_same_edge() {
864        let hg = make_hg();
865        assert_eq!(s_distance(&hg, 0, 0, 1), Some(0));
866    }
867
868    #[test]
869    fn test_s_distance_adjacent() {
870        // Edges 0 and 1 share node 2 → s=1 distance is 1
871        let hg = make_hg();
872        assert_eq!(s_distance(&hg, 0, 1, 1), Some(1));
873    }
874
875    #[test]
876    fn test_s_distance_disconnected() {
877        // Edges {0,1} and {3,4} share no nodes → s=1 distance might still be connected via 3rd edge
878        let mut hg = IndexedHypergraph::new(5);
879        hg.add_hyperedge(vec![0, 1], 1.0).expect("ok");
880        hg.add_hyperedge(vec![3, 4], 1.0).expect("ok");
881        // Completely disjoint → no s=1 path
882        assert_eq!(s_distance(&hg, 0, 1, 1), None);
883    }
884
885    #[test]
886    fn test_s_reachability() {
887        let hg = make_hg();
888        let reach = s_reachability(&hg, 0, 1, 5);
889        assert!(reach.contains_key(&0));
890        // Edge 0 shares node 2 with edge 1 → reachable
891        assert!(reach.contains_key(&1));
892    }
893
894    #[test]
895    fn test_s_betweenness_len() {
896        let hg = make_hg();
897        let sbc = s_betweenness_centrality(&hg, 1);
898        assert_eq!(sbc.len(), hg.n_hyperedges());
899    }
900
901    #[test]
902    fn test_s_diameter() {
903        let hg = make_hg();
904        let d = s_diameter(&hg, 1);
905        // Should be finite (all edges connected at s=1 through shared nodes)
906        assert!(d <= hg.n_hyperedges());
907    }
908}