Skip to main content

scirs2_graph/diffusion/
influence_max.rs

1//! Influence Maximization algorithms
2//!
3//! This module provides algorithms for finding the top-k seed nodes that
4//! maximise information spread under a given diffusion model:
5//!
6//! | Function | Description |
7//! |----------|-------------|
8//! | [`greedy_influence_max`] | Greedy hill-climbing with Monte-Carlo estimates (Kempe 2003) |
9//! | [`celf_influence_max`] | CELF – lazy evaluation cuts MC calls dramatically |
10//! | [`celf_plus_plus`] | CELF++ – one additional optimisation over CELF |
11//! | [`degree_heuristic`] | Fast O(n log n) heuristic: pick highest-degree nodes |
12//! | [`pagerank_heuristic`] | PageRank-based seed selection (directed influence proxy) |
13//!
14//! # References
15//! - Kempe, Kleinberg & Tardos (2003) – *KDD 2003*
16//! - Leskovec et al. (2007) – CELF, *KDD 2007*
17//! - Goyal, Lu & Lakshmanan (2011) – CELF++, *WWW 2011*
18
19use std::collections::{BinaryHeap, HashMap};
20
21use crate::diffusion::models::{simulate_ic, AdjList};
22use crate::error::{GraphError, Result};
23
24// ---------------------------------------------------------------------------
25// Configuration & result types
26// ---------------------------------------------------------------------------
27
28/// Configuration for influence maximization algorithms.
29#[derive(Debug, Clone)]
30pub struct InfluenceMaxConfig {
31    /// Number of Monte-Carlo simulations used to estimate spread.
32    pub num_simulations: usize,
33    /// Diffusion model: `"ic"` (Independent Cascade) or `"lt"` (Linear Threshold).
34    pub model: DiffusionModel,
35}
36
37/// Selector for which diffusion model to use during IM.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum DiffusionModel {
40    /// Independent Cascade model.
41    IC,
42    /// Linear Threshold model.
43    LT,
44}
45
46impl Default for InfluenceMaxConfig {
47    fn default() -> Self {
48        InfluenceMaxConfig {
49            num_simulations: 100,
50            model: DiffusionModel::IC,
51        }
52    }
53}
54
55/// Result returned by influence maximization routines.
56#[derive(Debug, Clone)]
57pub struct InfluenceMaxResult {
58    /// Selected seed nodes in the order they were chosen.
59    pub seeds: Vec<usize>,
60    /// Estimated expected spread of the seed set.
61    pub estimated_spread: f64,
62    /// Number of oracle (Monte-Carlo) calls made during the run.
63    pub oracle_calls: usize,
64}
65
66// ---------------------------------------------------------------------------
67// Internal spread estimator
68// ---------------------------------------------------------------------------
69
70/// Estimate expected spread for a candidate seed set.
71///
72/// Uses Monte-Carlo averaging with the IC model (LT support can be added
73/// similarly).  Returns `(spread_estimate, oracle_call_count)`.
74fn estimate_spread(
75    adjacency: &AdjList,
76    num_nodes: usize,
77    seeds: &[usize],
78    config: &InfluenceMaxConfig,
79) -> Result<(f64, usize)> {
80    let n = config.num_simulations;
81    if n == 0 {
82        return Err(GraphError::InvalidParameter {
83            param: "num_simulations".to_string(),
84            value: "0".to_string(),
85            expected: ">= 1".to_string(),
86            context: "estimate_spread".to_string(),
87        });
88    }
89
90    let spread = match config.model {
91        DiffusionModel::IC => {
92            let mut total = 0.0_f64;
93            for _ in 0..n {
94                total += simulate_ic(adjacency, seeds)?.spread as f64;
95            }
96            total / n as f64
97        }
98        DiffusionModel::LT => {
99            use crate::diffusion::models::simulate_lt;
100            let mut total = 0.0_f64;
101            for _ in 0..n {
102                total += simulate_lt(adjacency, num_nodes, seeds, None)?.spread as f64;
103            }
104            total / n as f64
105        }
106    };
107
108    Ok((spread, n))
109}
110
111// ---------------------------------------------------------------------------
112// Greedy (Kempe et al. 2003)
113// ---------------------------------------------------------------------------
114
115/// Greedy influence maximization using Monte-Carlo spread estimates.
116///
117/// At each of the `k` iterations the algorithm evaluates every non-seed node
118/// as a candidate addition and picks the one with the highest *marginal gain*.
119/// This is the algorithm of Kempe, Kleinberg & Tardos (KDD 2003) with a
120/// `(1 – 1/e)`-approximation guarantee for submodular diffusion models.
121///
122/// **Complexity**: `O(k · n · num_simulations)` MC simulations.
123///
124/// # Arguments
125/// * `adjacency` — directed adjacency list with propagation probabilities.
126/// * `num_nodes` — total number of nodes.
127/// * `k` — desired seed set size.
128/// * `config` — number of MC simulations and model choice.
129///
130/// # Errors
131/// Returns an error when `k > num_nodes` or `num_simulations == 0`.
132pub fn greedy_influence_max(
133    adjacency: &AdjList,
134    num_nodes: usize,
135    k: usize,
136    config: &InfluenceMaxConfig,
137) -> Result<InfluenceMaxResult> {
138    if k == 0 {
139        return Ok(InfluenceMaxResult {
140            seeds: Vec::new(),
141            estimated_spread: 0.0,
142            oracle_calls: 0,
143        });
144    }
145    if k > num_nodes {
146        return Err(GraphError::InvalidParameter {
147            param: "k".to_string(),
148            value: k.to_string(),
149            expected: format!("<= num_nodes={num_nodes}"),
150            context: "greedy_influence_max".to_string(),
151        });
152    }
153
154    let mut seeds: Vec<usize> = Vec::with_capacity(k);
155    let mut current_spread = 0.0_f64;
156    let mut oracle_calls = 0_usize;
157    let mut selected: std::collections::HashSet<usize> = std::collections::HashSet::new();
158
159    for _round in 0..k {
160        let mut best_node = None;
161        let mut best_gain = f64::NEG_INFINITY;
162
163        for candidate in 0..num_nodes {
164            if selected.contains(&candidate) {
165                continue;
166            }
167            let mut trial_seeds = seeds.clone();
168            trial_seeds.push(candidate);
169            let (spread, calls) = estimate_spread(adjacency, num_nodes, &trial_seeds, config)?;
170            oracle_calls += calls;
171
172            let gain = spread - current_spread;
173            if gain > best_gain {
174                best_gain = gain;
175                best_node = Some((candidate, spread));
176            }
177        }
178
179        match best_node {
180            Some((node, spread)) => {
181                seeds.push(node);
182                selected.insert(node);
183                current_spread = spread;
184            }
185            None => break,
186        }
187    }
188
189    Ok(InfluenceMaxResult {
190        estimated_spread: current_spread,
191        seeds,
192        oracle_calls,
193    })
194}
195
196// ---------------------------------------------------------------------------
197// CELF (lazy evaluation)
198// ---------------------------------------------------------------------------
199
200/// CELF entry in the priority queue.
201#[derive(Debug, Clone)]
202struct CelfEntry {
203    node: usize,
204    marginal_gain: f64,
205    /// Round in which `marginal_gain` was last computed.
206    round: usize,
207    /// Flag used by CELF++ to avoid one extra re-evaluation per round.
208    prev_best: bool,
209}
210
211impl PartialEq for CelfEntry {
212    fn eq(&self, other: &Self) -> bool {
213        self.marginal_gain == other.marginal_gain && self.node == other.node
214    }
215}
216
217impl Eq for CelfEntry {}
218
219impl PartialOrd for CelfEntry {
220    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
221        Some(self.cmp(other))
222    }
223}
224
225impl Ord for CelfEntry {
226    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
227        self.marginal_gain
228            .partial_cmp(&other.marginal_gain)
229            .unwrap_or(std::cmp::Ordering::Equal)
230            .then(self.node.cmp(&other.node))
231    }
232}
233
234/// CELF influence maximization (lazy evaluation).
235///
236/// CELF exploits the *submodularity* of influence spread: the marginal gain of
237/// a node can only decrease as the seed set grows.  A node whose stored
238/// marginal gain is an upper-bound (from a previous round) does not need
239/// re-evaluation until it reaches the top of the heap.
240///
241/// **Expected complexity**: `O(k · num_simulations)` MC calls (empirically
242/// much fewer than the naive greedy).
243///
244/// # Arguments
245/// * `adjacency` — directed adjacency list.
246/// * `num_nodes` — total number of nodes.
247/// * `k` — seed set size.
248/// * `config` — MC count and model.
249pub fn celf_influence_max(
250    adjacency: &AdjList,
251    num_nodes: usize,
252    k: usize,
253    config: &InfluenceMaxConfig,
254) -> Result<InfluenceMaxResult> {
255    if k == 0 {
256        return Ok(InfluenceMaxResult {
257            seeds: Vec::new(),
258            estimated_spread: 0.0,
259            oracle_calls: 0,
260        });
261    }
262    if k > num_nodes {
263        return Err(GraphError::InvalidParameter {
264            param: "k".to_string(),
265            value: k.to_string(),
266            expected: format!("<= num_nodes={num_nodes}"),
267            context: "celf_influence_max".to_string(),
268        });
269    }
270
271    let mut oracle_calls = 0_usize;
272
273    // Initialise heap with marginal gains of singletons
274    let mut heap: BinaryHeap<CelfEntry> = BinaryHeap::new();
275    for node in 0..num_nodes {
276        let (gain, calls) = estimate_spread(adjacency, num_nodes, &[node], config)?;
277        oracle_calls += calls;
278        heap.push(CelfEntry {
279            node,
280            marginal_gain: gain,
281            round: 0,
282            prev_best: false,
283        });
284    }
285
286    let mut seeds: Vec<usize> = Vec::with_capacity(k);
287    let mut current_spread = 0.0_f64;
288    let mut selected: std::collections::HashSet<usize> = std::collections::HashSet::new();
289
290    let mut round = 0_usize;
291    while seeds.len() < k {
292        let entry = loop {
293            let top = heap.pop().ok_or_else(|| GraphError::AlgorithmFailure {
294                algorithm: "celf_influence_max".to_string(),
295                reason: "priority queue exhausted before k seeds selected".to_string(),
296                iterations: seeds.len(),
297                tolerance: 0.0,
298            })?;
299
300            if selected.contains(&top.node) {
301                continue;
302            }
303
304            if top.round == round {
305                // Already evaluated in this round — guaranteed optimal by submodularity
306                break top;
307            }
308
309            // Re-evaluate marginal gain
310            let mut trial = seeds.clone();
311            trial.push(top.node);
312            let (new_spread, calls) = estimate_spread(adjacency, num_nodes, &trial, config)?;
313            oracle_calls += calls;
314
315            let updated = CelfEntry {
316                node: top.node,
317                marginal_gain: new_spread - current_spread,
318                round,
319                prev_best: false,
320            };
321            heap.push(updated);
322        };
323
324        seeds.push(entry.node);
325        selected.insert(entry.node);
326        current_spread += entry.marginal_gain;
327        round += 1;
328    }
329
330    // Final spread estimate with the full seed set
331    let (final_spread, calls) = estimate_spread(adjacency, num_nodes, &seeds, config)?;
332    oracle_calls += calls;
333
334    Ok(InfluenceMaxResult {
335        seeds,
336        estimated_spread: final_spread,
337        oracle_calls,
338    })
339}
340
341// ---------------------------------------------------------------------------
342// CELF++ (Goyal et al. 2011)
343// ---------------------------------------------------------------------------
344
345/// CELF++ influence maximization.
346///
347/// CELF++ adds one optimisation over CELF: within each round it tracks the
348/// *second* node that was the best in the previous iteration.  If a node at
349/// the top of the heap was also the best in the previous round (flag
350/// `prev_best`), its marginal gain with the current seed set has already been
351/// computed and can be used without re-evaluation.
352///
353/// In practice this reduces oracle calls by roughly 35–55 % compared to CELF.
354pub fn celf_plus_plus(
355    adjacency: &AdjList,
356    num_nodes: usize,
357    k: usize,
358    config: &InfluenceMaxConfig,
359) -> Result<InfluenceMaxResult> {
360    if k == 0 {
361        return Ok(InfluenceMaxResult {
362            seeds: Vec::new(),
363            estimated_spread: 0.0,
364            oracle_calls: 0,
365        });
366    }
367    if k > num_nodes {
368        return Err(GraphError::InvalidParameter {
369            param: "k".to_string(),
370            value: k.to_string(),
371            expected: format!("<= num_nodes={num_nodes}"),
372            context: "celf_plus_plus".to_string(),
373        });
374    }
375
376    let mut oracle_calls = 0_usize;
377
378    // ------ initialise heap with singleton marginal gains ------
379    let mut heap: BinaryHeap<CelfEntry> = BinaryHeap::new();
380    // Also track per-node cached gains for the CELF++ prev_best optimisation
381    let mut cached_gain: HashMap<usize, f64> = HashMap::new();
382
383    for node in 0..num_nodes {
384        let (gain, calls) = estimate_spread(adjacency, num_nodes, &[node], config)?;
385        oracle_calls += calls;
386        cached_gain.insert(node, gain);
387        heap.push(CelfEntry {
388            node,
389            marginal_gain: gain,
390            round: 0,
391            prev_best: false,
392        });
393    }
394
395    let mut seeds: Vec<usize> = Vec::with_capacity(k);
396    let mut current_spread = 0.0_f64;
397    let mut selected: std::collections::HashSet<usize> = std::collections::HashSet::new();
398    let mut prev_best_node: Option<usize> = None;
399
400    let mut round = 0_usize;
401    while seeds.len() < k {
402        // ------ find best candidate for this round ------
403        let chosen = loop {
404            let top = heap.pop().ok_or_else(|| GraphError::AlgorithmFailure {
405                algorithm: "celf_plus_plus".to_string(),
406                reason: "priority queue exhausted".to_string(),
407                iterations: seeds.len(),
408                tolerance: 0.0,
409            })?;
410
411            if selected.contains(&top.node) {
412                continue;
413            }
414
415            // CELF++ optimisation: if this node was the best in the previous
416            // round AND its gain was already re-evaluated w.r.t. the *current*
417            // seed set, skip re-evaluation.
418            if top.prev_best && top.round == round {
419                break top;
420            }
421
422            if top.round == round {
423                // Already updated this round
424                break top;
425            }
426
427            // Re-evaluate marginal gain w.r.t. current seed set
428            let mut trial = seeds.clone();
429            trial.push(top.node);
430            let (new_spread, calls) = estimate_spread(adjacency, num_nodes, &trial, config)?;
431            oracle_calls += calls;
432
433            let gain = new_spread - current_spread;
434            *cached_gain.entry(top.node).or_insert(gain) = gain;
435
436            // CELF++ optimisation: also evaluate w.r.t. seeds + prev_best
437            let is_prev_best = prev_best_node.map(|pb| pb == top.node).unwrap_or(false);
438            let prev_best_flag = if let Some(pb) = prev_best_node {
439                if !selected.contains(&pb) && !is_prev_best {
440                    let mut trial2 = seeds.clone();
441                    trial2.push(pb);
442                    trial2.push(top.node);
443                    let (spread2, calls2) = estimate_spread(adjacency, num_nodes, &trial2, config)?;
444                    oracle_calls += calls2;
445                    let gain2 =
446                        spread2 - current_spread - cached_gain.get(&pb).cloned().unwrap_or(0.0);
447                    // If gain2 >= gain the node is still best even after adding prev_best
448                    gain2 >= gain
449                } else {
450                    false
451                }
452            } else {
453                false
454            };
455
456            let updated = CelfEntry {
457                node: top.node,
458                marginal_gain: gain,
459                round,
460                prev_best: prev_best_flag,
461            };
462            heap.push(updated);
463        };
464
465        prev_best_node = Some(chosen.node);
466        seeds.push(chosen.node);
467        selected.insert(chosen.node);
468        current_spread += chosen.marginal_gain;
469        round += 1;
470    }
471
472    let (final_spread, calls) = estimate_spread(adjacency, num_nodes, &seeds, config)?;
473    oracle_calls += calls;
474
475    Ok(InfluenceMaxResult {
476        seeds,
477        estimated_spread: final_spread,
478        oracle_calls,
479    })
480}
481
482// ---------------------------------------------------------------------------
483// Degree heuristic
484// ---------------------------------------------------------------------------
485
486/// High-degree seed selection heuristic.
487///
488/// Selects the `k` nodes with highest out-degree as the seed set.  This is a
489/// fast `O(n log n)` heuristic that often performs surprisingly well in
490/// practice.
491///
492/// # Arguments
493/// * `adjacency` — directed adjacency list.
494/// * `num_nodes` — total number of nodes.
495/// * `k` — seed set size.
496/// * `config` — used only to compute the spread estimate at the end.
497pub fn degree_heuristic(
498    adjacency: &AdjList,
499    num_nodes: usize,
500    k: usize,
501    config: &InfluenceMaxConfig,
502) -> Result<InfluenceMaxResult> {
503    if k == 0 {
504        return Ok(InfluenceMaxResult {
505            seeds: Vec::new(),
506            estimated_spread: 0.0,
507            oracle_calls: 0,
508        });
509    }
510    if k > num_nodes {
511        return Err(GraphError::InvalidParameter {
512            param: "k".to_string(),
513            value: k.to_string(),
514            expected: format!("<= num_nodes={num_nodes}"),
515            context: "degree_heuristic".to_string(),
516        });
517    }
518
519    // Compute out-degree for every node
520    let mut degrees: Vec<(usize, usize)> = (0..num_nodes)
521        .map(|node| {
522            let deg = adjacency.get(&node).map(|nbrs| nbrs.len()).unwrap_or(0);
523            (node, deg)
524        })
525        .collect();
526
527    // Sort descending by degree
528    degrees.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
529
530    let seeds: Vec<usize> = degrees.iter().take(k).map(|&(node, _)| node).collect();
531
532    let (estimated_spread, oracle_calls) = estimate_spread(adjacency, num_nodes, &seeds, config)?;
533
534    Ok(InfluenceMaxResult {
535        seeds,
536        estimated_spread,
537        oracle_calls,
538    })
539}
540
541// ---------------------------------------------------------------------------
542// PageRank heuristic
543// ---------------------------------------------------------------------------
544
545/// PageRank-based seed selection heuristic.
546///
547/// Runs a lightweight power-iteration PageRank on the propagation graph and
548/// selects the `k` highest-ranked nodes as seeds.  PageRank captures both
549/// degree and structural position, making it a stronger proxy for influence
550/// than raw degree.
551///
552/// # Arguments
553/// * `adjacency` — directed adjacency list.
554/// * `num_nodes` — total number of nodes.
555/// * `k` — seed set size.
556/// * `config` — MC config used for the spread estimate.
557/// * `damping` — PageRank damping factor (typically 0.85).
558/// * `max_iter` — maximum power-iteration steps.
559/// * `tol` — convergence tolerance (L1 norm of score delta).
560pub fn pagerank_heuristic(
561    adjacency: &AdjList,
562    num_nodes: usize,
563    k: usize,
564    config: &InfluenceMaxConfig,
565    damping: f64,
566    max_iter: usize,
567    tol: f64,
568) -> Result<InfluenceMaxResult> {
569    if k == 0 {
570        return Ok(InfluenceMaxResult {
571            seeds: Vec::new(),
572            estimated_spread: 0.0,
573            oracle_calls: 0,
574        });
575    }
576    if k > num_nodes {
577        return Err(GraphError::InvalidParameter {
578            param: "k".to_string(),
579            value: k.to_string(),
580            expected: format!("<= num_nodes={num_nodes}"),
581            context: "pagerank_heuristic".to_string(),
582        });
583    }
584    if !(0.0..=1.0).contains(&damping) {
585        return Err(GraphError::InvalidParameter {
586            param: "damping".to_string(),
587            value: damping.to_string(),
588            expected: "[0, 1]".to_string(),
589            context: "pagerank_heuristic".to_string(),
590        });
591    }
592
593    // ------- compute out-degree for normalisation -------
594    let out_degree: Vec<f64> = (0..num_nodes)
595        .map(|n| adjacency.get(&n).map(|v| v.len() as f64).unwrap_or(0.0))
596        .collect();
597
598    // ------- power iteration -------
599    let base_score = (1.0 - damping) / num_nodes as f64;
600    let mut scores: Vec<f64> = vec![1.0 / num_nodes as f64; num_nodes];
601
602    for _ in 0..max_iter {
603        let mut new_scores: Vec<f64> = vec![base_score; num_nodes];
604
605        // Dangling nodes contribute uniformly
606        let dangling_sum: f64 = (0..num_nodes)
607            .filter(|&n| out_degree[n] == 0.0)
608            .map(|n| scores[n])
609            .sum::<f64>()
610            * damping
611            / num_nodes as f64;
612
613        for n in 0..num_nodes {
614            new_scores[n] += dangling_sum;
615        }
616
617        // Regular contributions
618        for (src, nbrs) in adjacency {
619            let contrib = damping * scores[*src] / out_degree[*src];
620            for &(tgt, _) in nbrs {
621                if tgt < num_nodes {
622                    new_scores[tgt] += contrib;
623                }
624            }
625        }
626
627        // Convergence check
628        let delta: f64 = scores
629            .iter()
630            .zip(new_scores.iter())
631            .map(|(a, b)| (a - b).abs())
632            .sum();
633        scores = new_scores;
634        if delta < tol {
635            break;
636        }
637    }
638
639    // ------- select top-k -------
640    let mut ranked: Vec<(usize, f64)> = scores.iter().cloned().enumerate().collect();
641    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
642    let seeds: Vec<usize> = ranked.iter().take(k).map(|&(node, _)| node).collect();
643
644    let (estimated_spread, oracle_calls) = estimate_spread(adjacency, num_nodes, &seeds, config)?;
645
646    Ok(InfluenceMaxResult {
647        seeds,
648        estimated_spread,
649        oracle_calls,
650    })
651}
652
653// ---------------------------------------------------------------------------
654// Tests
655// ---------------------------------------------------------------------------
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660
661    /// Build a path graph 0→1→2→3→4 with probability `p`.
662    fn path_adj(n: usize, p: f64) -> AdjList {
663        let mut adj: AdjList = HashMap::new();
664        for i in 0..(n - 1) {
665            adj.entry(i).or_default().push((i + 1, p));
666        }
667        adj
668    }
669
670    /// Star: hub 0 with spokes to 1..n, probability `p`.
671    fn star_adj(n: usize, p: f64) -> AdjList {
672        let mut adj: AdjList = HashMap::new();
673        for i in 1..n {
674            adj.entry(0).or_default().push((i, p));
675        }
676        adj
677    }
678
679    #[test]
680    fn test_greedy_k1_selects_hub() {
681        let adj = star_adj(6, 1.0);
682        let config = InfluenceMaxConfig {
683            num_simulations: 20,
684            model: DiffusionModel::IC,
685        };
686        let result = greedy_influence_max(&adj, 6, 1, &config).expect("greedy");
687        assert_eq!(result.seeds.len(), 1);
688        // Hub (node 0) should be selected since it activates all 5 spokes
689        assert_eq!(result.seeds[0], 0);
690    }
691
692    #[test]
693    fn test_greedy_k0() {
694        let adj = star_adj(5, 1.0);
695        let config = InfluenceMaxConfig::default();
696        let result = greedy_influence_max(&adj, 5, 0, &config).expect("k=0");
697        assert!(result.seeds.is_empty());
698        assert_eq!(result.estimated_spread, 0.0);
699    }
700
701    #[test]
702    fn test_greedy_k_too_large() {
703        let adj = star_adj(3, 1.0);
704        let config = InfluenceMaxConfig::default();
705        let err = greedy_influence_max(&adj, 3, 10, &config);
706        assert!(err.is_err());
707    }
708
709    #[test]
710    fn test_celf_selects_hub() {
711        let adj = star_adj(6, 1.0);
712        let config = InfluenceMaxConfig {
713            num_simulations: 20,
714            model: DiffusionModel::IC,
715        };
716        let result = celf_influence_max(&adj, 6, 1, &config).expect("celf");
717        assert_eq!(result.seeds.len(), 1);
718        assert_eq!(result.seeds[0], 0);
719    }
720
721    #[test]
722    fn test_celf_pp_selects_hub() {
723        let adj = star_adj(6, 1.0);
724        let config = InfluenceMaxConfig {
725            num_simulations: 20,
726            model: DiffusionModel::IC,
727        };
728        let result = celf_plus_plus(&adj, 6, 1, &config).expect("celf++");
729        assert_eq!(result.seeds.len(), 1);
730        assert_eq!(result.seeds[0], 0);
731    }
732
733    #[test]
734    fn test_degree_heuristic() {
735        let adj = star_adj(6, 0.5);
736        let config = InfluenceMaxConfig::default();
737        let result = degree_heuristic(&adj, 6, 1, &config).expect("degree heuristic");
738        // Node 0 has degree 5, all others 0
739        assert_eq!(result.seeds[0], 0);
740    }
741
742    #[test]
743    fn test_pagerank_heuristic() {
744        let adj = star_adj(6, 1.0);
745        let config = InfluenceMaxConfig {
746            num_simulations: 20,
747            model: DiffusionModel::IC,
748        };
749        let result =
750            pagerank_heuristic(&adj, 6, 1, &config, 0.85, 100, 1e-6).expect("pagerank heuristic");
751        assert_eq!(result.seeds.len(), 1);
752    }
753
754    #[test]
755    fn test_degree_heuristic_k2() {
756        // Two hubs: 0 has 4 spokes, 1 has 3 spokes
757        let mut adj: AdjList = HashMap::new();
758        for i in 2..6 {
759            adj.entry(0).or_default().push((i, 0.5));
760        }
761        for i in 6..9 {
762            adj.entry(1).or_default().push((i, 0.5));
763        }
764        let config = InfluenceMaxConfig::default();
765        let result = degree_heuristic(&adj, 9, 2, &config).expect("degree k=2");
766        assert_eq!(result.seeds.len(), 2);
767        assert!(result.seeds.contains(&0));
768        assert!(result.seeds.contains(&1));
769    }
770
771    #[test]
772    fn test_greedy_path_k2() {
773        let adj = path_adj(10, 1.0);
774        let config = InfluenceMaxConfig {
775            num_simulations: 30,
776            model: DiffusionModel::IC,
777        };
778        let result = greedy_influence_max(&adj, 10, 2, &config).expect("greedy path");
779        assert_eq!(result.seeds.len(), 2);
780        // With prob 1.0, node 0 activates entire chain; node 0 should be chosen first
781        assert!(result.seeds.contains(&0));
782    }
783}