1use std::collections::{BinaryHeap, HashMap};
20
21use crate::diffusion::models::{simulate_ic, AdjList};
22use crate::error::{GraphError, Result};
23
24#[derive(Debug, Clone)]
30pub struct InfluenceMaxConfig {
31 pub num_simulations: usize,
33 pub model: DiffusionModel,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum DiffusionModel {
40 IC,
42 LT,
44}
45
46impl Default for InfluenceMaxConfig {
47 fn default() -> Self {
48 InfluenceMaxConfig {
49 num_simulations: 100,
50 model: DiffusionModel::IC,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct InfluenceMaxResult {
58 pub seeds: Vec<usize>,
60 pub estimated_spread: f64,
62 pub oracle_calls: usize,
64}
65
66fn estimate_spread(
75 adjacency: &AdjList,
76 num_nodes: usize,
77 seeds: &[usize],
78 config: &InfluenceMaxConfig,
79) -> Result<(f64, usize)> {
80 let n = config.num_simulations;
81 if n == 0 {
82 return Err(GraphError::InvalidParameter {
83 param: "num_simulations".to_string(),
84 value: "0".to_string(),
85 expected: ">= 1".to_string(),
86 context: "estimate_spread".to_string(),
87 });
88 }
89
90 let spread = match config.model {
91 DiffusionModel::IC => {
92 let mut total = 0.0_f64;
93 for _ in 0..n {
94 total += simulate_ic(adjacency, seeds)?.spread as f64;
95 }
96 total / n as f64
97 }
98 DiffusionModel::LT => {
99 use crate::diffusion::models::simulate_lt;
100 let mut total = 0.0_f64;
101 for _ in 0..n {
102 total += simulate_lt(adjacency, num_nodes, seeds, None)?.spread as f64;
103 }
104 total / n as f64
105 }
106 };
107
108 Ok((spread, n))
109}
110
111pub fn greedy_influence_max(
133 adjacency: &AdjList,
134 num_nodes: usize,
135 k: usize,
136 config: &InfluenceMaxConfig,
137) -> Result<InfluenceMaxResult> {
138 if k == 0 {
139 return Ok(InfluenceMaxResult {
140 seeds: Vec::new(),
141 estimated_spread: 0.0,
142 oracle_calls: 0,
143 });
144 }
145 if k > num_nodes {
146 return Err(GraphError::InvalidParameter {
147 param: "k".to_string(),
148 value: k.to_string(),
149 expected: format!("<= num_nodes={num_nodes}"),
150 context: "greedy_influence_max".to_string(),
151 });
152 }
153
154 let mut seeds: Vec<usize> = Vec::with_capacity(k);
155 let mut current_spread = 0.0_f64;
156 let mut oracle_calls = 0_usize;
157 let mut selected: std::collections::HashSet<usize> = std::collections::HashSet::new();
158
159 for _round in 0..k {
160 let mut best_node = None;
161 let mut best_gain = f64::NEG_INFINITY;
162
163 for candidate in 0..num_nodes {
164 if selected.contains(&candidate) {
165 continue;
166 }
167 let mut trial_seeds = seeds.clone();
168 trial_seeds.push(candidate);
169 let (spread, calls) = estimate_spread(adjacency, num_nodes, &trial_seeds, config)?;
170 oracle_calls += calls;
171
172 let gain = spread - current_spread;
173 if gain > best_gain {
174 best_gain = gain;
175 best_node = Some((candidate, spread));
176 }
177 }
178
179 match best_node {
180 Some((node, spread)) => {
181 seeds.push(node);
182 selected.insert(node);
183 current_spread = spread;
184 }
185 None => break,
186 }
187 }
188
189 Ok(InfluenceMaxResult {
190 estimated_spread: current_spread,
191 seeds,
192 oracle_calls,
193 })
194}
195
196#[derive(Debug, Clone)]
202struct CelfEntry {
203 node: usize,
204 marginal_gain: f64,
205 round: usize,
207 prev_best: bool,
209}
210
211impl PartialEq for CelfEntry {
212 fn eq(&self, other: &Self) -> bool {
213 self.marginal_gain == other.marginal_gain && self.node == other.node
214 }
215}
216
217impl Eq for CelfEntry {}
218
219impl PartialOrd for CelfEntry {
220 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
221 Some(self.cmp(other))
222 }
223}
224
225impl Ord for CelfEntry {
226 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
227 self.marginal_gain
228 .partial_cmp(&other.marginal_gain)
229 .unwrap_or(std::cmp::Ordering::Equal)
230 .then(self.node.cmp(&other.node))
231 }
232}
233
234pub fn celf_influence_max(
250 adjacency: &AdjList,
251 num_nodes: usize,
252 k: usize,
253 config: &InfluenceMaxConfig,
254) -> Result<InfluenceMaxResult> {
255 if k == 0 {
256 return Ok(InfluenceMaxResult {
257 seeds: Vec::new(),
258 estimated_spread: 0.0,
259 oracle_calls: 0,
260 });
261 }
262 if k > num_nodes {
263 return Err(GraphError::InvalidParameter {
264 param: "k".to_string(),
265 value: k.to_string(),
266 expected: format!("<= num_nodes={num_nodes}"),
267 context: "celf_influence_max".to_string(),
268 });
269 }
270
271 let mut oracle_calls = 0_usize;
272
273 let mut heap: BinaryHeap<CelfEntry> = BinaryHeap::new();
275 for node in 0..num_nodes {
276 let (gain, calls) = estimate_spread(adjacency, num_nodes, &[node], config)?;
277 oracle_calls += calls;
278 heap.push(CelfEntry {
279 node,
280 marginal_gain: gain,
281 round: 0,
282 prev_best: false,
283 });
284 }
285
286 let mut seeds: Vec<usize> = Vec::with_capacity(k);
287 let mut current_spread = 0.0_f64;
288 let mut selected: std::collections::HashSet<usize> = std::collections::HashSet::new();
289
290 let mut round = 0_usize;
291 while seeds.len() < k {
292 let entry = loop {
293 let top = heap.pop().ok_or_else(|| GraphError::AlgorithmFailure {
294 algorithm: "celf_influence_max".to_string(),
295 reason: "priority queue exhausted before k seeds selected".to_string(),
296 iterations: seeds.len(),
297 tolerance: 0.0,
298 })?;
299
300 if selected.contains(&top.node) {
301 continue;
302 }
303
304 if top.round == round {
305 break top;
307 }
308
309 let mut trial = seeds.clone();
311 trial.push(top.node);
312 let (new_spread, calls) = estimate_spread(adjacency, num_nodes, &trial, config)?;
313 oracle_calls += calls;
314
315 let updated = CelfEntry {
316 node: top.node,
317 marginal_gain: new_spread - current_spread,
318 round,
319 prev_best: false,
320 };
321 heap.push(updated);
322 };
323
324 seeds.push(entry.node);
325 selected.insert(entry.node);
326 current_spread += entry.marginal_gain;
327 round += 1;
328 }
329
330 let (final_spread, calls) = estimate_spread(adjacency, num_nodes, &seeds, config)?;
332 oracle_calls += calls;
333
334 Ok(InfluenceMaxResult {
335 seeds,
336 estimated_spread: final_spread,
337 oracle_calls,
338 })
339}
340
341pub fn celf_plus_plus(
355 adjacency: &AdjList,
356 num_nodes: usize,
357 k: usize,
358 config: &InfluenceMaxConfig,
359) -> Result<InfluenceMaxResult> {
360 if k == 0 {
361 return Ok(InfluenceMaxResult {
362 seeds: Vec::new(),
363 estimated_spread: 0.0,
364 oracle_calls: 0,
365 });
366 }
367 if k > num_nodes {
368 return Err(GraphError::InvalidParameter {
369 param: "k".to_string(),
370 value: k.to_string(),
371 expected: format!("<= num_nodes={num_nodes}"),
372 context: "celf_plus_plus".to_string(),
373 });
374 }
375
376 let mut oracle_calls = 0_usize;
377
378 let mut heap: BinaryHeap<CelfEntry> = BinaryHeap::new();
380 let mut cached_gain: HashMap<usize, f64> = HashMap::new();
382
383 for node in 0..num_nodes {
384 let (gain, calls) = estimate_spread(adjacency, num_nodes, &[node], config)?;
385 oracle_calls += calls;
386 cached_gain.insert(node, gain);
387 heap.push(CelfEntry {
388 node,
389 marginal_gain: gain,
390 round: 0,
391 prev_best: false,
392 });
393 }
394
395 let mut seeds: Vec<usize> = Vec::with_capacity(k);
396 let mut current_spread = 0.0_f64;
397 let mut selected: std::collections::HashSet<usize> = std::collections::HashSet::new();
398 let mut prev_best_node: Option<usize> = None;
399
400 let mut round = 0_usize;
401 while seeds.len() < k {
402 let chosen = loop {
404 let top = heap.pop().ok_or_else(|| GraphError::AlgorithmFailure {
405 algorithm: "celf_plus_plus".to_string(),
406 reason: "priority queue exhausted".to_string(),
407 iterations: seeds.len(),
408 tolerance: 0.0,
409 })?;
410
411 if selected.contains(&top.node) {
412 continue;
413 }
414
415 if top.prev_best && top.round == round {
419 break top;
420 }
421
422 if top.round == round {
423 break top;
425 }
426
427 let mut trial = seeds.clone();
429 trial.push(top.node);
430 let (new_spread, calls) = estimate_spread(adjacency, num_nodes, &trial, config)?;
431 oracle_calls += calls;
432
433 let gain = new_spread - current_spread;
434 *cached_gain.entry(top.node).or_insert(gain) = gain;
435
436 let is_prev_best = prev_best_node.map(|pb| pb == top.node).unwrap_or(false);
438 let prev_best_flag = if let Some(pb) = prev_best_node {
439 if !selected.contains(&pb) && !is_prev_best {
440 let mut trial2 = seeds.clone();
441 trial2.push(pb);
442 trial2.push(top.node);
443 let (spread2, calls2) = estimate_spread(adjacency, num_nodes, &trial2, config)?;
444 oracle_calls += calls2;
445 let gain2 =
446 spread2 - current_spread - cached_gain.get(&pb).cloned().unwrap_or(0.0);
447 gain2 >= gain
449 } else {
450 false
451 }
452 } else {
453 false
454 };
455
456 let updated = CelfEntry {
457 node: top.node,
458 marginal_gain: gain,
459 round,
460 prev_best: prev_best_flag,
461 };
462 heap.push(updated);
463 };
464
465 prev_best_node = Some(chosen.node);
466 seeds.push(chosen.node);
467 selected.insert(chosen.node);
468 current_spread += chosen.marginal_gain;
469 round += 1;
470 }
471
472 let (final_spread, calls) = estimate_spread(adjacency, num_nodes, &seeds, config)?;
473 oracle_calls += calls;
474
475 Ok(InfluenceMaxResult {
476 seeds,
477 estimated_spread: final_spread,
478 oracle_calls,
479 })
480}
481
482pub fn degree_heuristic(
498 adjacency: &AdjList,
499 num_nodes: usize,
500 k: usize,
501 config: &InfluenceMaxConfig,
502) -> Result<InfluenceMaxResult> {
503 if k == 0 {
504 return Ok(InfluenceMaxResult {
505 seeds: Vec::new(),
506 estimated_spread: 0.0,
507 oracle_calls: 0,
508 });
509 }
510 if k > num_nodes {
511 return Err(GraphError::InvalidParameter {
512 param: "k".to_string(),
513 value: k.to_string(),
514 expected: format!("<= num_nodes={num_nodes}"),
515 context: "degree_heuristic".to_string(),
516 });
517 }
518
519 let mut degrees: Vec<(usize, usize)> = (0..num_nodes)
521 .map(|node| {
522 let deg = adjacency.get(&node).map(|nbrs| nbrs.len()).unwrap_or(0);
523 (node, deg)
524 })
525 .collect();
526
527 degrees.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
529
530 let seeds: Vec<usize> = degrees.iter().take(k).map(|&(node, _)| node).collect();
531
532 let (estimated_spread, oracle_calls) = estimate_spread(adjacency, num_nodes, &seeds, config)?;
533
534 Ok(InfluenceMaxResult {
535 seeds,
536 estimated_spread,
537 oracle_calls,
538 })
539}
540
541pub fn pagerank_heuristic(
561 adjacency: &AdjList,
562 num_nodes: usize,
563 k: usize,
564 config: &InfluenceMaxConfig,
565 damping: f64,
566 max_iter: usize,
567 tol: f64,
568) -> Result<InfluenceMaxResult> {
569 if k == 0 {
570 return Ok(InfluenceMaxResult {
571 seeds: Vec::new(),
572 estimated_spread: 0.0,
573 oracle_calls: 0,
574 });
575 }
576 if k > num_nodes {
577 return Err(GraphError::InvalidParameter {
578 param: "k".to_string(),
579 value: k.to_string(),
580 expected: format!("<= num_nodes={num_nodes}"),
581 context: "pagerank_heuristic".to_string(),
582 });
583 }
584 if !(0.0..=1.0).contains(&damping) {
585 return Err(GraphError::InvalidParameter {
586 param: "damping".to_string(),
587 value: damping.to_string(),
588 expected: "[0, 1]".to_string(),
589 context: "pagerank_heuristic".to_string(),
590 });
591 }
592
593 let out_degree: Vec<f64> = (0..num_nodes)
595 .map(|n| adjacency.get(&n).map(|v| v.len() as f64).unwrap_or(0.0))
596 .collect();
597
598 let base_score = (1.0 - damping) / num_nodes as f64;
600 let mut scores: Vec<f64> = vec![1.0 / num_nodes as f64; num_nodes];
601
602 for _ in 0..max_iter {
603 let mut new_scores: Vec<f64> = vec![base_score; num_nodes];
604
605 let dangling_sum: f64 = (0..num_nodes)
607 .filter(|&n| out_degree[n] == 0.0)
608 .map(|n| scores[n])
609 .sum::<f64>()
610 * damping
611 / num_nodes as f64;
612
613 for n in 0..num_nodes {
614 new_scores[n] += dangling_sum;
615 }
616
617 for (src, nbrs) in adjacency {
619 let contrib = damping * scores[*src] / out_degree[*src];
620 for &(tgt, _) in nbrs {
621 if tgt < num_nodes {
622 new_scores[tgt] += contrib;
623 }
624 }
625 }
626
627 let delta: f64 = scores
629 .iter()
630 .zip(new_scores.iter())
631 .map(|(a, b)| (a - b).abs())
632 .sum();
633 scores = new_scores;
634 if delta < tol {
635 break;
636 }
637 }
638
639 let mut ranked: Vec<(usize, f64)> = scores.iter().cloned().enumerate().collect();
641 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
642 let seeds: Vec<usize> = ranked.iter().take(k).map(|&(node, _)| node).collect();
643
644 let (estimated_spread, oracle_calls) = estimate_spread(adjacency, num_nodes, &seeds, config)?;
645
646 Ok(InfluenceMaxResult {
647 seeds,
648 estimated_spread,
649 oracle_calls,
650 })
651}
652
653#[cfg(test)]
658mod tests {
659 use super::*;
660
661 fn path_adj(n: usize, p: f64) -> AdjList {
663 let mut adj: AdjList = HashMap::new();
664 for i in 0..(n - 1) {
665 adj.entry(i).or_default().push((i + 1, p));
666 }
667 adj
668 }
669
670 fn star_adj(n: usize, p: f64) -> AdjList {
672 let mut adj: AdjList = HashMap::new();
673 for i in 1..n {
674 adj.entry(0).or_default().push((i, p));
675 }
676 adj
677 }
678
679 #[test]
680 fn test_greedy_k1_selects_hub() {
681 let adj = star_adj(6, 1.0);
682 let config = InfluenceMaxConfig {
683 num_simulations: 20,
684 model: DiffusionModel::IC,
685 };
686 let result = greedy_influence_max(&adj, 6, 1, &config).expect("greedy");
687 assert_eq!(result.seeds.len(), 1);
688 assert_eq!(result.seeds[0], 0);
690 }
691
692 #[test]
693 fn test_greedy_k0() {
694 let adj = star_adj(5, 1.0);
695 let config = InfluenceMaxConfig::default();
696 let result = greedy_influence_max(&adj, 5, 0, &config).expect("k=0");
697 assert!(result.seeds.is_empty());
698 assert_eq!(result.estimated_spread, 0.0);
699 }
700
701 #[test]
702 fn test_greedy_k_too_large() {
703 let adj = star_adj(3, 1.0);
704 let config = InfluenceMaxConfig::default();
705 let err = greedy_influence_max(&adj, 3, 10, &config);
706 assert!(err.is_err());
707 }
708
709 #[test]
710 fn test_celf_selects_hub() {
711 let adj = star_adj(6, 1.0);
712 let config = InfluenceMaxConfig {
713 num_simulations: 20,
714 model: DiffusionModel::IC,
715 };
716 let result = celf_influence_max(&adj, 6, 1, &config).expect("celf");
717 assert_eq!(result.seeds.len(), 1);
718 assert_eq!(result.seeds[0], 0);
719 }
720
721 #[test]
722 fn test_celf_pp_selects_hub() {
723 let adj = star_adj(6, 1.0);
724 let config = InfluenceMaxConfig {
725 num_simulations: 20,
726 model: DiffusionModel::IC,
727 };
728 let result = celf_plus_plus(&adj, 6, 1, &config).expect("celf++");
729 assert_eq!(result.seeds.len(), 1);
730 assert_eq!(result.seeds[0], 0);
731 }
732
733 #[test]
734 fn test_degree_heuristic() {
735 let adj = star_adj(6, 0.5);
736 let config = InfluenceMaxConfig::default();
737 let result = degree_heuristic(&adj, 6, 1, &config).expect("degree heuristic");
738 assert_eq!(result.seeds[0], 0);
740 }
741
742 #[test]
743 fn test_pagerank_heuristic() {
744 let adj = star_adj(6, 1.0);
745 let config = InfluenceMaxConfig {
746 num_simulations: 20,
747 model: DiffusionModel::IC,
748 };
749 let result =
750 pagerank_heuristic(&adj, 6, 1, &config, 0.85, 100, 1e-6).expect("pagerank heuristic");
751 assert_eq!(result.seeds.len(), 1);
752 }
753
754 #[test]
755 fn test_degree_heuristic_k2() {
756 let mut adj: AdjList = HashMap::new();
758 for i in 2..6 {
759 adj.entry(0).or_default().push((i, 0.5));
760 }
761 for i in 6..9 {
762 adj.entry(1).or_default().push((i, 0.5));
763 }
764 let config = InfluenceMaxConfig::default();
765 let result = degree_heuristic(&adj, 9, 2, &config).expect("degree k=2");
766 assert_eq!(result.seeds.len(), 2);
767 assert!(result.seeds.contains(&0));
768 assert!(result.seeds.contains(&1));
769 }
770
771 #[test]
772 fn test_greedy_path_k2() {
773 let adj = path_adj(10, 1.0);
774 let config = InfluenceMaxConfig {
775 num_simulations: 30,
776 model: DiffusionModel::IC,
777 };
778 let result = greedy_influence_max(&adj, 10, 2, &config).expect("greedy path");
779 assert_eq!(result.seeds.len(), 2);
780 assert!(result.seeds.contains(&0));
782 }
783}