Skip to main content

scirs2_cluster/community/
sbm.rs

1//! Stochastic Block Model (SBM) for community detection.
2//!
3//! The SBM is a generative model for networks with community structure.
4//! Given K blocks (communities), the probability of an edge between nodes i
5//! and j depends only on their block assignments z_i and z_j through a
6//! K x K probability matrix B.
7//!
8//! This module provides:
9//!
10//! - **Fitting via variational EM**: infer block assignments and the B matrix
11//! - **Degree-corrected SBM**: accounts for degree heterogeneity within blocks
12//! - **Model selection**: choose K via Integrated Classification Likelihood (ICL)
13//! - **Network generation**: sample random graphs from a fitted or specified SBM
14
15use std::collections::HashMap;
16
17use serde::{Deserialize, Serialize};
18
19use super::{AdjacencyGraph, CommunityResult};
20use crate::error::{ClusteringError, Result};
21
22// ---------------------------------------------------------------------------
23// Configuration
24// ---------------------------------------------------------------------------
25
26/// Configuration for SBM fitting.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct StochasticBlockModelConfig {
29    /// Number of blocks (communities). If `None`, model selection is used.
30    pub num_blocks: Option<usize>,
31    /// Range of K values to try when `num_blocks` is `None`.
32    pub k_range: (usize, usize),
33    /// Maximum EM iterations.
34    pub max_iterations: usize,
35    /// Convergence tolerance on log-likelihood change.
36    pub convergence_threshold: f64,
37    /// Whether to use degree-corrected SBM.
38    pub degree_corrected: bool,
39    /// Random seed.
40    pub seed: u64,
41}
42
43impl Default for StochasticBlockModelConfig {
44    fn default() -> Self {
45        Self {
46            num_blocks: None,
47            k_range: (2, 8),
48            max_iterations: 100,
49            convergence_threshold: 1e-6,
50            degree_corrected: false,
51            seed: 42,
52        }
53    }
54}
55
56/// Result of SBM fitting.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SBMResult {
59    /// Community detection result (labels, num_communities, quality_score).
60    pub community: CommunityResult,
61    /// Block probability matrix B (K x K), row-major.
62    pub block_matrix: Vec<f64>,
63    /// Number of blocks K.
64    pub k: usize,
65    /// Log-likelihood of the fitted model.
66    pub log_likelihood: f64,
67    /// ICL score (for model comparison).
68    pub icl_score: f64,
69    /// Degree correction factors (only for degree-corrected SBM).
70    pub degree_corrections: Option<Vec<f64>>,
71}
72
73// ---------------------------------------------------------------------------
74// PRNG
75// ---------------------------------------------------------------------------
76
77struct Xorshift64(u64);
78
79impl Xorshift64 {
80    fn new(seed: u64) -> Self {
81        Self(if seed == 0 { 1 } else { seed })
82    }
83    fn next_u64(&mut self) -> u64 {
84        let mut x = self.0;
85        x ^= x << 13;
86        x ^= x >> 7;
87        x ^= x << 17;
88        self.0 = x;
89        x
90    }
91    /// Uniform in [0, 1).
92    fn next_f64(&mut self) -> f64 {
93        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
94    }
95}
96
97// ---------------------------------------------------------------------------
98// SBM Core
99// ---------------------------------------------------------------------------
100
101/// The Stochastic Block Model.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct StochasticBlockModel {
104    /// Configuration.
105    pub config: StochasticBlockModelConfig,
106}
107
108impl StochasticBlockModel {
109    /// Create a new SBM with the given configuration.
110    pub fn new(config: StochasticBlockModelConfig) -> Self {
111        Self { config }
112    }
113
114    /// Fit the SBM to the observed adjacency graph.
115    ///
116    /// If `config.num_blocks` is set, fits with that K.
117    /// Otherwise, tries each K in `config.k_range` and picks the one
118    /// with the best ICL score.
119    pub fn fit(&self, graph: &AdjacencyGraph) -> Result<SBMResult> {
120        let n = graph.n_nodes;
121        if n == 0 {
122            return Err(ClusteringError::InvalidInput(
123                "Graph has zero nodes".to_string(),
124            ));
125        }
126
127        if let Some(k) = self.config.num_blocks {
128            if k == 0 || k > n {
129                return Err(ClusteringError::InvalidInput(format!(
130                    "num_blocks ({}) must be in [1, {}]",
131                    k, n
132                )));
133            }
134            self.fit_k(graph, k)
135        } else {
136            let k_min = self.config.k_range.0.max(1);
137            let k_max = self.config.k_range.1.min(n);
138            if k_min > k_max {
139                return Err(ClusteringError::InvalidInput("Invalid k_range".to_string()));
140            }
141
142            let mut best: Option<SBMResult> = None;
143            for k in k_min..=k_max {
144                let result = self.fit_k(graph, k)?;
145                let is_better = best
146                    .as_ref()
147                    .map(|b| result.icl_score > b.icl_score)
148                    .unwrap_or(true);
149                if is_better {
150                    best = Some(result);
151                }
152            }
153
154            best.ok_or_else(|| ClusteringError::ComputationError("No valid K found".to_string()))
155        }
156    }
157
158    /// Fit SBM with a specific K.
159    fn fit_k(&self, graph: &AdjacencyGraph, k: usize) -> Result<SBMResult> {
160        let n = graph.n_nodes;
161        let mut rng = Xorshift64::new(self.config.seed.wrapping_add(k as u64));
162
163        // Build dense adjacency for fast lookup.
164        let mut adj_matrix = vec![0.0_f64; n * n];
165        for i in 0..n {
166            for &(j, w) in &graph.adjacency[i] {
167                adj_matrix[i * n + j] = w;
168            }
169        }
170
171        // Initialise tau (soft assignment matrix): n x k, row-major.
172        // Start with a noisy K-means-like initialisation.
173        let mut tau = vec![0.0_f64; n * k];
174        for i in 0..n {
175            let assigned = (i * k / n) % k; // spread nodes across blocks
176            for r in 0..k {
177                tau[i * k + r] = if r == assigned {
178                    0.8
179                } else {
180                    0.2 / ((k - 1).max(1) as f64)
181                };
182            }
183            // Add noise.
184            let noise_sum: f64 = (0..k).map(|_| rng.next_f64() * 0.1).sum();
185            for r in 0..k {
186                tau[i * k + r] += rng.next_f64() * 0.1;
187            }
188            // Normalise.
189            let row_sum: f64 = (0..k).map(|r| tau[i * k + r]).sum();
190            if row_sum > 0.0 {
191                for r in 0..k {
192                    tau[i * k + r] /= row_sum;
193                }
194            }
195            let _ = noise_sum; // suppress warning
196        }
197
198        // Block probability matrix B: k x k, row-major.
199        let mut b_matrix = vec![0.0_f64; k * k];
200        // Degree corrections.
201        let mut theta = vec![1.0_f64; n];
202
203        let mut prev_ll = f64::NEG_INFINITY;
204
205        for _iter in 0..self.config.max_iterations {
206            // --- M-step: update B and optionally theta ---
207            self.m_step(graph, &adj_matrix, &tau, &mut b_matrix, &mut theta, n, k);
208
209            // --- E-step: update tau ---
210            self.e_step(graph, &adj_matrix, &b_matrix, &theta, &mut tau, n, k);
211
212            // --- Compute log-likelihood ---
213            let ll = self.log_likelihood(&adj_matrix, &b_matrix, &theta, &tau, n, k);
214
215            if (ll - prev_ll).abs() < self.config.convergence_threshold {
216                break;
217            }
218            prev_ll = ll;
219        }
220
221        // Hard assignment: argmax of tau.
222        let mut labels = vec![0usize; n];
223        for i in 0..n {
224            let mut best_r = 0;
225            let mut best_val = f64::NEG_INFINITY;
226            for r in 0..k {
227                if tau[i * k + r] > best_val {
228                    best_val = tau[i * k + r];
229                    best_r = r;
230                }
231            }
232            labels[i] = best_r;
233        }
234
235        // Compact labels (some blocks may be empty).
236        let mut mapping: HashMap<usize, usize> = HashMap::new();
237        let mut next_id = 0usize;
238        for lbl in &labels {
239            if !mapping.contains_key(lbl) {
240                mapping.insert(*lbl, next_id);
241                next_id += 1;
242            }
243        }
244        let compacted: Vec<usize> = labels
245            .iter()
246            .map(|l| mapping.get(l).copied().unwrap_or(0))
247            .collect();
248        let num_communities = next_id;
249
250        let ll = self.log_likelihood(&adj_matrix, &b_matrix, &theta, &tau, n, k);
251        let icl = self.compute_icl(ll, &compacted, n, k);
252        let quality = graph.modularity(&compacted);
253
254        let degree_corrections = if self.config.degree_corrected {
255            Some(theta)
256        } else {
257            None
258        };
259
260        Ok(SBMResult {
261            community: CommunityResult {
262                labels: compacted,
263                num_communities,
264                quality_score: Some(quality),
265            },
266            block_matrix: b_matrix,
267            k,
268            log_likelihood: ll,
269            icl_score: icl,
270            degree_corrections,
271        })
272    }
273
274    /// M-step: update B matrix and degree corrections.
275    fn m_step(
276        &self,
277        _graph: &AdjacencyGraph,
278        adj_matrix: &[f64],
279        tau: &[f64],
280        b_matrix: &mut [f64],
281        theta: &mut [f64],
282        n: usize,
283        k: usize,
284    ) {
285        // B_{rs} = sum_{i,j} tau_{ir} * A_{ij} * tau_{js} / sum_{i,j} tau_{ir} * tau_{js}
286        for r in 0..k {
287            for s in 0..k {
288                let mut numerator = 0.0;
289                let mut denominator = 0.0;
290                for i in 0..n {
291                    let tau_ir = tau[i * k + r];
292                    if tau_ir < 1e-15 {
293                        continue;
294                    }
295                    for j in 0..n {
296                        if i == j {
297                            continue;
298                        }
299                        let tau_js = tau[j * k + s];
300                        if tau_js < 1e-15 {
301                            continue;
302                        }
303                        numerator += tau_ir * adj_matrix[i * n + j] * tau_js;
304                        denominator += tau_ir * tau_js;
305                    }
306                }
307                // Clamp to [epsilon, 1 - epsilon] for numerical stability.
308                let val = if denominator > 1e-15 {
309                    numerator / denominator
310                } else {
311                    0.5
312                };
313                b_matrix[r * k + s] = val.clamp(1e-10, 1.0 - 1e-10);
314            }
315        }
316
317        // Degree corrections (degree-corrected SBM).
318        if self.config.degree_corrected {
319            // theta_i = (actual degree of i) / (expected degree under SBM)
320            for i in 0..n {
321                let actual_deg: f64 = (0..n)
322                    .filter(|&j| j != i)
323                    .map(|j| adj_matrix[i * n + j])
324                    .sum();
325
326                let mut expected = 0.0;
327                for j in 0..n {
328                    if j == i {
329                        continue;
330                    }
331                    for r in 0..k {
332                        for s in 0..k {
333                            expected += tau[i * k + r] * tau[j * k + s] * b_matrix[r * k + s];
334                        }
335                    }
336                }
337                theta[i] = if expected > 1e-15 {
338                    (actual_deg / expected).max(1e-10)
339                } else {
340                    1.0
341                };
342            }
343        }
344    }
345
346    /// E-step: update posterior block probabilities tau.
347    fn e_step(
348        &self,
349        _graph: &AdjacencyGraph,
350        adj_matrix: &[f64],
351        b_matrix: &[f64],
352        theta: &[f64],
353        tau: &mut [f64],
354        n: usize,
355        k: usize,
356    ) {
357        // pi_r = proportion of nodes in block r.
358        let mut pi = vec![0.0_f64; k];
359        for i in 0..n {
360            for r in 0..k {
361                pi[r] += tau[i * k + r];
362            }
363        }
364        let pi_sum: f64 = pi.iter().sum();
365        if pi_sum > 0.0 {
366            for r in 0..k {
367                pi[r] = (pi[r] / pi_sum).max(1e-10);
368            }
369        }
370
371        for i in 0..n {
372            let mut log_probs = vec![0.0_f64; k];
373            for r in 0..k {
374                log_probs[r] = pi[r].ln();
375
376                for j in 0..n {
377                    if j == i {
378                        continue;
379                    }
380                    // Use the current hard assignment of j approximation
381                    // for efficiency, or sum over s.
382                    for s in 0..k {
383                        let tau_js = tau[j * k + s];
384                        if tau_js < 1e-15 {
385                            continue;
386                        }
387
388                        let mut p_rs = b_matrix[r * k + s];
389                        if self.config.degree_corrected {
390                            p_rs *= theta[i] * theta[j];
391                        }
392                        p_rs = p_rs.clamp(1e-15, 1.0 - 1e-15);
393
394                        let a_ij = adj_matrix[i * n + j];
395                        if a_ij > 0.0 {
396                            log_probs[r] += tau_js * (a_ij * p_rs.ln());
397                        } else {
398                            log_probs[r] += tau_js * ((1.0 - p_rs).ln());
399                        }
400                    }
401                }
402            }
403
404            // Log-sum-exp normalisation.
405            let max_lp = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
406            let mut sum_exp = 0.0;
407            for r in 0..k {
408                log_probs[r] = (log_probs[r] - max_lp).exp();
409                sum_exp += log_probs[r];
410            }
411            if sum_exp > 0.0 {
412                for r in 0..k {
413                    tau[i * k + r] = (log_probs[r] / sum_exp).max(1e-15);
414                }
415            }
416        }
417    }
418
419    /// Compute the log-likelihood of the model.
420    fn log_likelihood(
421        &self,
422        adj_matrix: &[f64],
423        b_matrix: &[f64],
424        theta: &[f64],
425        tau: &[f64],
426        n: usize,
427        k: usize,
428    ) -> f64 {
429        let mut ll = 0.0;
430        for i in 0..n {
431            for j in (i + 1)..n {
432                let a_ij = adj_matrix[i * n + j];
433                for r in 0..k {
434                    let tau_ir = tau[i * k + r];
435                    if tau_ir < 1e-15 {
436                        continue;
437                    }
438                    for s in 0..k {
439                        let tau_js = tau[j * k + s];
440                        if tau_js < 1e-15 {
441                            continue;
442                        }
443                        let mut p = b_matrix[r * k + s];
444                        if self.config.degree_corrected {
445                            p *= theta[i] * theta[j];
446                        }
447                        p = p.clamp(1e-15, 1.0 - 1e-15);
448
449                        if a_ij > 0.0 {
450                            ll += tau_ir * tau_js * a_ij * p.ln();
451                        } else {
452                            ll += tau_ir * tau_js * (1.0 - p).ln();
453                        }
454                    }
455                }
456            }
457        }
458        ll
459    }
460
461    /// Compute the Integrated Classification Likelihood (ICL) score.
462    ///
463    /// ICL = LL - penalty
464    /// penalty = (K*(K+1)/2) * ln(n*(n-1)/2) / 2 + (K-1) * ln(n) / 2
465    fn compute_icl(&self, ll: f64, labels: &[usize], n: usize, k: usize) -> f64 {
466        let n_f = n as f64;
467        let k_f = k as f64;
468        // Number of B matrix parameters.
469        let n_b_params = k_f * (k_f + 1.0) / 2.0;
470        // Number of possible edges.
471        let n_pairs = n_f * (n_f - 1.0) / 2.0;
472        let penalty =
473            n_b_params * n_pairs.max(1.0).ln() / 2.0 + (k_f - 1.0) * n_f.max(1.0).ln() / 2.0;
474
475        // Entropy of tau (classification entropy).
476        // Since we use hard labels here, the classification entropy contribution
477        // is captured by the block sizes.
478        let mut block_sizes = vec![0usize; k];
479        for &l in labels {
480            if l < k {
481                block_sizes[l] += 1;
482            }
483        }
484        let entropy_correction: f64 = block_sizes
485            .iter()
486            .filter(|&&s| s > 0)
487            .map(|&s| {
488                let p = s as f64 / n_f;
489                -(s as f64) * p.ln()
490            })
491            .sum();
492
493        ll - penalty - entropy_correction
494    }
495
496    /// Predict block assignments for a given graph using a fitted model.
497    ///
498    /// This re-runs the E-step with the given B matrix to assign labels.
499    pub fn predict(
500        &self,
501        graph: &AdjacencyGraph,
502        b_matrix: &[f64],
503        k: usize,
504    ) -> Result<Vec<usize>> {
505        let n = graph.n_nodes;
506        if n == 0 {
507            return Err(ClusteringError::InvalidInput(
508                "Graph has zero nodes".to_string(),
509            ));
510        }
511        if b_matrix.len() != k * k {
512            return Err(ClusteringError::InvalidInput(
513                "B matrix size mismatch".to_string(),
514            ));
515        }
516
517        // Build dense adjacency.
518        let mut adj_matrix = vec![0.0_f64; n * n];
519        for i in 0..n {
520            for &(j, w) in &graph.adjacency[i] {
521                adj_matrix[i * n + j] = w;
522            }
523        }
524
525        // Uniform initialisation.
526        let uniform = 1.0 / k as f64;
527        let mut tau = vec![uniform; n * k];
528        let theta = vec![1.0_f64; n];
529
530        for _iter in 0..self.config.max_iterations {
531            self.e_step(graph, &adj_matrix, b_matrix, &theta, &mut tau, n, k);
532        }
533
534        // Hard assignment.
535        let mut labels = vec![0usize; n];
536        for i in 0..n {
537            let mut best_r = 0;
538            let mut best_val = f64::NEG_INFINITY;
539            for r in 0..k {
540                if tau[i * k + r] > best_val {
541                    best_val = tau[i * k + r];
542                    best_r = r;
543                }
544            }
545            labels[i] = best_r;
546        }
547
548        Ok(labels)
549    }
550
551    /// Generate a random graph from SBM parameters.
552    ///
553    /// - `n`: number of nodes
554    /// - `k`: number of blocks
555    /// - `b_matrix`: K x K probability matrix (row-major)
556    /// - `block_sizes`: sizes of each block (must sum to n)
557    pub fn generate(
558        n: usize,
559        k: usize,
560        b_matrix: &[f64],
561        block_sizes: &[usize],
562        seed: u64,
563    ) -> Result<(AdjacencyGraph, Vec<usize>)> {
564        if b_matrix.len() != k * k {
565            return Err(ClusteringError::InvalidInput(
566                "B matrix size must be k*k".to_string(),
567            ));
568        }
569        if block_sizes.len() != k {
570            return Err(ClusteringError::InvalidInput(
571                "block_sizes length must equal k".to_string(),
572            ));
573        }
574        let total: usize = block_sizes.iter().sum();
575        if total != n {
576            return Err(ClusteringError::InvalidInput(format!(
577                "block_sizes sum ({}) must equal n ({})",
578                total, n
579            )));
580        }
581
582        let mut rng = Xorshift64::new(seed);
583
584        // Assign nodes to blocks.
585        let mut labels = Vec::with_capacity(n);
586        for (block, &size) in block_sizes.iter().enumerate() {
587            for _ in 0..size {
588                labels.push(block);
589            }
590        }
591
592        // Generate edges.
593        let mut graph = AdjacencyGraph::new(n);
594        for i in 0..n {
595            for j in (i + 1)..n {
596                let r = labels[i];
597                let s = labels[j];
598                let p = b_matrix[r * k + s];
599                if rng.next_f64() < p {
600                    let _ = graph.add_edge(i, j, 1.0);
601                }
602            }
603        }
604
605        Ok((graph, labels))
606    }
607}
608
609// ---------------------------------------------------------------------------
610// Tests
611// ---------------------------------------------------------------------------
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    /// Generate a planted-partition graph via SBM and fit -> recover blocks.
618    #[test]
619    fn test_sbm_generate_and_fit() {
620        let k = 2;
621        let n = 20;
622        // High intra-block probability, low inter-block.
623        let b_matrix = vec![0.8, 0.05, 0.05, 0.8];
624        let block_sizes = vec![10, 10];
625        let (graph, true_labels) =
626            StochasticBlockModel::generate(n, k, &b_matrix, &block_sizes, 123)
627                .expect("generate should succeed");
628
629        let config = StochasticBlockModelConfig {
630            num_blocks: Some(2),
631            max_iterations: 50,
632            seed: 42,
633            ..Default::default()
634        };
635        let sbm = StochasticBlockModel::new(config);
636        let result = sbm.fit(&graph).expect("fit should succeed");
637
638        assert_eq!(result.community.num_communities, 2);
639        assert_eq!(result.community.labels.len(), n);
640
641        // Check accuracy: at least 70% of nodes should be correctly assigned
642        // (up to label permutation).
643        let accuracy = compute_accuracy(&true_labels, &result.community.labels, k);
644        assert!(accuracy >= 0.7, "Accuracy {} is too low", accuracy);
645    }
646
647    /// Degree-corrected SBM should handle heterogeneous degrees.
648    #[test]
649    fn test_sbm_degree_corrected() {
650        let k = 2;
651        let n = 20;
652        let b_matrix = vec![0.7, 0.1, 0.1, 0.7];
653        let block_sizes = vec![10, 10];
654        let (graph, _) = StochasticBlockModel::generate(n, k, &b_matrix, &block_sizes, 456)
655            .expect("generate should succeed");
656
657        let config = StochasticBlockModelConfig {
658            num_blocks: Some(2),
659            degree_corrected: true,
660            max_iterations: 30,
661            seed: 789,
662            ..Default::default()
663        };
664        let sbm = StochasticBlockModel::new(config);
665        let result = sbm.fit(&graph).expect("fit should succeed");
666
667        assert!(result.degree_corrections.is_some());
668        let dc = result
669            .degree_corrections
670            .as_ref()
671            .expect("should have degree corrections");
672        assert_eq!(dc.len(), n);
673        // All corrections should be positive.
674        for &d in dc {
675            assert!(d > 0.0);
676        }
677    }
678
679    /// Model selection should pick roughly the right K.
680    #[test]
681    fn test_sbm_model_selection() {
682        let k = 2;
683        let n = 30;
684        let b_matrix = vec![0.9, 0.05, 0.05, 0.9];
685        let block_sizes = vec![15, 15];
686        let (graph, _) = StochasticBlockModel::generate(n, k, &b_matrix, &block_sizes, 111)
687            .expect("generate should succeed");
688
689        let config = StochasticBlockModelConfig {
690            num_blocks: None,
691            k_range: (2, 5),
692            max_iterations: 30,
693            seed: 222,
694            ..Default::default()
695        };
696        let sbm = StochasticBlockModel::new(config);
697        let result = sbm.fit(&graph).expect("fit should succeed");
698
699        // The selected K should be 2 or 3 (small range is fine).
700        assert!(
701            result.k >= 2 && result.k <= 3,
702            "Selected K={} seems wrong",
703            result.k
704        );
705    }
706
707    /// Predict with a known B matrix.
708    #[test]
709    fn test_sbm_predict() {
710        let k = 2;
711        let n = 20;
712        let b_matrix = vec![0.8, 0.05, 0.05, 0.8];
713        let block_sizes = vec![10, 10];
714        let (graph, true_labels) =
715            StochasticBlockModel::generate(n, k, &b_matrix, &block_sizes, 333)
716                .expect("generate should succeed");
717
718        let config = StochasticBlockModelConfig {
719            max_iterations: 30,
720            seed: 444,
721            ..Default::default()
722        };
723        let sbm = StochasticBlockModel::new(config);
724        let predicted = sbm
725            .predict(&graph, &b_matrix, k)
726            .expect("predict should succeed");
727
728        assert_eq!(predicted.len(), n);
729        let accuracy = compute_accuracy(&true_labels, &predicted, k);
730        assert!(accuracy >= 0.6, "Predict accuracy {} is too low", accuracy);
731    }
732
733    /// Generate with invalid parameters should error.
734    #[test]
735    fn test_sbm_generate_invalid() {
736        // block_sizes don't sum to n.
737        let result = StochasticBlockModel::generate(10, 2, &[0.5, 0.1, 0.1, 0.5], &[4, 4], 0);
738        assert!(result.is_err());
739    }
740
741    /// Empty graph should error.
742    #[test]
743    fn test_sbm_empty_graph() {
744        let g = AdjacencyGraph::new(0);
745        let config = StochasticBlockModelConfig {
746            num_blocks: Some(2),
747            ..Default::default()
748        };
749        let sbm = StochasticBlockModel::new(config);
750        assert!(sbm.fit(&g).is_err());
751    }
752
753    /// Single block.
754    #[test]
755    fn test_sbm_single_block() {
756        let n = 10;
757        let mut g = AdjacencyGraph::new(n);
758        for i in 0..n {
759            for j in (i + 1)..n {
760                let _ = g.add_edge(i, j, 1.0);
761            }
762        }
763        let config = StochasticBlockModelConfig {
764            num_blocks: Some(1),
765            max_iterations: 20,
766            seed: 555,
767            ..Default::default()
768        };
769        let sbm = StochasticBlockModel::new(config);
770        let result = sbm.fit(&g).expect("fit should succeed");
771        assert_eq!(result.community.num_communities, 1);
772    }
773
774    // -----------------------------------------------------------------------
775    // Helpers
776    // -----------------------------------------------------------------------
777
778    /// Compute assignment accuracy accounting for label permutations.
779    fn compute_accuracy(true_labels: &[usize], pred_labels: &[usize], k: usize) -> f64 {
780        let n = true_labels.len();
781        if n == 0 {
782            return 1.0;
783        }
784
785        // Try all permutations (only feasible for small k).
786        // For k <= 8 this is fine.
787        let perms = generate_permutations(k);
788        let mut best_correct = 0usize;
789        for perm in &perms {
790            let correct = (0..n)
791                .filter(|&i| {
792                    let mapped = if pred_labels[i] < perm.len() {
793                        perm[pred_labels[i]]
794                    } else {
795                        pred_labels[i]
796                    };
797                    mapped == true_labels[i]
798                })
799                .count();
800            if correct > best_correct {
801                best_correct = correct;
802            }
803        }
804        best_correct as f64 / n as f64
805    }
806
807    fn generate_permutations(k: usize) -> Vec<Vec<usize>> {
808        if k == 0 {
809            return vec![vec![]];
810        }
811        if k == 1 {
812            return vec![vec![0]];
813        }
814        let mut result = Vec::new();
815        let sub = generate_permutations(k - 1);
816        for perm in sub {
817            for pos in 0..k {
818                let mut new_perm = perm.clone();
819                new_perm.insert(pos, k - 1);
820                result.push(new_perm);
821            }
822        }
823        result
824    }
825}