1use std::cmp::Reverse;
17use std::collections::{HashMap, HashSet, VecDeque};
18
19use scirs2_core::random::{Rng, RngExt};
20
21use crate::base::{EdgeWeight, Graph, Node};
22use crate::error::{GraphError, Result};
23
24pub type NodeId = usize;
26
27type SpreadFn = Box<dyn Fn(&[NodeId]) -> f64>;
29
30#[derive(Debug, Clone, PartialEq, Default)]
36pub enum CascadeModel {
37 #[default]
40 IndependentCascade,
41 LinearThreshold,
45}
46
47fn estimate_spread_ic(
56 adj: &HashMap<NodeId, Vec<(NodeId, f64)>>,
57 seeds: &[NodeId],
58 num_simulations: usize,
59) -> f64 {
60 let mut rng = scirs2_core::random::rng();
61 let mut total = 0.0f64;
62
63 for _ in 0..num_simulations {
64 let mut active: HashSet<NodeId> = seeds.iter().cloned().collect();
65 let mut queue: VecDeque<NodeId> = seeds.iter().cloned().collect();
66
67 while let Some(node) = queue.pop_front() {
68 if let Some(neighbors) = adj.get(&node) {
69 for &(nbr, prob) in neighbors {
70 if !active.contains(&nbr) && rng.random::<f64>() < prob {
71 active.insert(nbr);
72 queue.push_back(nbr);
73 }
74 }
75 }
76 }
77 total += active.len() as f64;
78 }
79
80 total / num_simulations as f64
81}
82
83fn estimate_spread_lt(
85 adj: &HashMap<NodeId, Vec<(NodeId, f64)>>,
86 n_nodes: usize,
87 seeds: &[NodeId],
88 num_simulations: usize,
89) -> f64 {
90 let mut rng = scirs2_core::random::rng();
91 let mut total = 0.0f64;
92
93 for _ in 0..num_simulations {
94 let thresholds: Vec<f64> = (0..n_nodes).map(|_| rng.random::<f64>()).collect();
96 let mut active: HashSet<NodeId> = seeds.iter().cloned().collect();
97 let mut changed = true;
98
99 while changed {
100 changed = false;
101 for node in 0..n_nodes {
102 if active.contains(&node) {
103 continue;
104 }
105 let influence: f64 = adj
107 .get(&node)
108 .map(|nbrs| {
109 nbrs.iter()
110 .filter(|&&(nbr, _)| active.contains(&nbr))
111 .map(|&(_, w)| w)
112 .sum::<f64>()
113 })
114 .unwrap_or(0.0);
115
116 if influence >= thresholds[node] {
117 active.insert(node);
118 changed = true;
119 }
120 }
121 }
122 total += active.len() as f64;
123 }
124
125 total / num_simulations as f64
126}
127
128#[derive(Debug, Clone)]
130pub struct InfluenceConfig {
131 pub model: CascadeModel,
133 pub num_simulations: usize,
135 pub default_prob: f64,
137}
138
139impl Default for InfluenceConfig {
140 fn default() -> Self {
141 InfluenceConfig {
142 model: CascadeModel::IndependentCascade,
143 num_simulations: 100,
144 default_prob: 0.1,
145 }
146 }
147}
148
149pub fn influence_maximization<N, E, Ix>(
163 graph: &Graph<N, E, Ix>,
164 k: usize,
165 config: &InfluenceConfig,
166) -> Result<Vec<NodeId>>
167where
168 N: Node + Clone + std::fmt::Debug,
169 E: EdgeWeight + Clone + Into<f64>,
170 Ix: petgraph::graph::IndexType,
171{
172 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
173 let n = nodes.len();
174
175 if k == 0 {
176 return Ok(Vec::new());
177 }
178 if k > n {
179 return Err(GraphError::InvalidParameter {
180 param: "k".to_string(),
181 value: k.to_string(),
182 expected: format!("<= n_nodes ({})", n),
183 context: "influence_maximization".to_string(),
184 });
185 }
186
187 let node_to_idx: HashMap<N, NodeId> = nodes
189 .iter()
190 .enumerate()
191 .map(|(i, nd)| (nd.clone(), i))
192 .collect();
193
194 let mut adj: HashMap<NodeId, Vec<(NodeId, f64)>> = HashMap::new();
195 for edge in graph.edges() {
196 let si = *node_to_idx
197 .get(&edge.source)
198 .ok_or_else(|| GraphError::node_not_found("source node"))?;
199 let ti = *node_to_idx
200 .get(&edge.target)
201 .ok_or_else(|| GraphError::node_not_found("target node"))?;
202 let w: f64 = edge.weight.clone().into();
203 let prob = if w > 0.0 && w <= 1.0 {
204 w
205 } else {
206 config.default_prob
207 };
208 adj.entry(si).or_default().push((ti, prob));
209 adj.entry(ti).or_default().push((si, prob)); }
211
212 let spread_fn: SpreadFn = match &config.model {
213 CascadeModel::IndependentCascade => {
214 let adj_ref = adj.clone();
215 let sims = config.num_simulations;
216 Box::new(move |seeds| estimate_spread_ic(&adj_ref, seeds, sims))
217 }
218 CascadeModel::LinearThreshold => {
219 let adj_ref = adj.clone();
220 let sims = config.num_simulations;
221 Box::new(move |seeds| estimate_spread_lt(&adj_ref, n, seeds, sims))
222 }
223 };
224
225 let mut seeds: Vec<NodeId> = Vec::with_capacity(k);
226 let mut current_spread = 0.0f64;
227
228 for _ in 0..k {
229 let mut best_node = None;
230 let mut best_gain = f64::NEG_INFINITY;
231
232 for candidate in 0..n {
233 if seeds.contains(&candidate) {
234 continue;
235 }
236 let mut trial_seeds = seeds.clone();
237 trial_seeds.push(candidate);
238 let spread = spread_fn(&trial_seeds);
239 let gain = spread - current_spread;
240
241 if gain > best_gain {
242 best_gain = gain;
243 best_node = Some(candidate);
244 }
245 }
246
247 if let Some(node) = best_node {
248 seeds.push(node);
249 current_spread += best_gain;
250 }
251 }
252
253 Ok(seeds)
254}
255
256#[derive(Debug, Clone, PartialEq, Eq, Hash)]
262pub enum RoleType {
263 Hub,
265 Peripheral,
267 Bridge,
269 Member,
271 Isolated,
273}
274
275impl std::fmt::Display for RoleType {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 match self {
278 RoleType::Hub => write!(f, "Hub"),
279 RoleType::Peripheral => write!(f, "Peripheral"),
280 RoleType::Bridge => write!(f, "Bridge"),
281 RoleType::Member => write!(f, "Member"),
282 RoleType::Isolated => write!(f, "Isolated"),
283 }
284 }
285}
286
287pub fn role_detection<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<HashMap<NodeId, RoleType>>
302where
303 N: Node + Clone + std::fmt::Debug,
304 E: EdgeWeight + Clone + Into<f64>,
305 Ix: petgraph::graph::IndexType,
306{
307 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
308 let n = nodes.len();
309
310 if n == 0 {
311 return Ok(HashMap::new());
312 }
313
314 let degrees: Vec<f64> = nodes.iter().map(|nd| graph.degree(nd) as f64).collect();
316 let mean_deg = degrees.iter().sum::<f64>() / n as f64;
317 let var_deg = degrees.iter().map(|d| (d - mean_deg).powi(2)).sum::<f64>() / n as f64;
318 let std_deg = var_deg.sqrt();
319
320 let clustering: Vec<f64> = nodes
322 .iter()
323 .map(|nd| local_clustering_coefficient(graph, nd))
324 .collect();
325
326 let mean_clustering = if n > 0 {
327 clustering.iter().sum::<f64>() / n as f64
328 } else {
329 0.0
330 };
331
332 let mut roles = HashMap::with_capacity(n);
333
334 for (i, _node) in nodes.iter().enumerate() {
335 let deg = degrees[i];
336 let clust = clustering[i];
337
338 let role = if deg == 0.0 {
339 RoleType::Isolated
340 } else if deg > mean_deg + std_deg {
341 RoleType::Hub
342 } else if deg < (mean_deg - 0.5 * std_deg).max(1.0) && clust < mean_clustering * 0.5 {
343 RoleType::Peripheral
344 } else if clust < mean_clustering * 0.4 && deg >= 2.0 {
345 RoleType::Bridge
346 } else {
347 RoleType::Member
348 };
349
350 roles.insert(i, role);
351 }
352
353 Ok(roles)
354}
355
356fn local_clustering_coefficient<N, E, Ix>(graph: &Graph<N, E, Ix>, node: &N) -> f64
358where
359 N: Node + Clone + std::fmt::Debug,
360 E: EdgeWeight + Clone + Into<f64>,
361 Ix: petgraph::graph::IndexType,
362{
363 let neighbors: Vec<N> = match graph.neighbors(node) {
364 Ok(nbrs) => nbrs,
365 Err(_) => return 0.0,
366 };
367 let k = neighbors.len();
368 if k < 2 {
369 return 0.0;
370 }
371
372 let mut triangles = 0usize;
373 for i in 0..k {
374 for j in i + 1..k {
375 if graph.has_edge(&neighbors[i], &neighbors[j]) {
376 triangles += 1;
377 }
378 }
379 }
380
381 let max_possible = k * (k - 1) / 2;
382 if max_possible == 0 {
383 0.0
384 } else {
385 triangles as f64 / max_possible as f64
386 }
387}
388
389pub fn echo_chamber_detection<N, E, Ix>(
407 graph: &Graph<N, E, Ix>,
408 features: &[Vec<f64>],
409) -> Result<Vec<Vec<NodeId>>>
410where
411 N: Node + Clone + std::fmt::Debug,
412 E: EdgeWeight + Clone + Into<f64>,
413 Ix: petgraph::graph::IndexType,
414{
415 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
416 let n = nodes.len();
417
418 if n == 0 {
419 return Ok(Vec::new());
420 }
421 if features.len() != n {
422 return Err(GraphError::InvalidParameter {
423 param: "features".to_string(),
424 value: format!("{} rows", features.len()),
425 expected: format!("{} rows (one per node)", n),
426 context: "echo_chamber_detection".to_string(),
427 });
428 }
429
430 let node_to_idx: HashMap<N, NodeId> = nodes
432 .iter()
433 .enumerate()
434 .map(|(i, nd)| (nd.clone(), i))
435 .collect();
436
437 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
438 for edge in graph.edges() {
439 if let (Some(&si), Some(&ti)) =
440 (node_to_idx.get(&edge.source), node_to_idx.get(&edge.target))
441 {
442 adj[si].push(ti);
443 adj[ti].push(si);
444 }
445 }
446
447 let mut labels: Vec<NodeId> = (0..n).collect();
450
451 for _round in 0..20 {
453 let mut changed = false;
454
455 let mut order: Vec<usize> = (0..n).collect();
457 for i in (1..n).rev() {
459 let j = i
460 .wrapping_mul(6364136223846793005)
461 .wrapping_add(1442695040888963407)
462 % (i + 1);
463 order.swap(i, j);
464 }
465
466 for &node in &order {
467 let nbrs = &adj[node];
468 if nbrs.is_empty() {
469 continue;
470 }
471
472 let mut label_scores: HashMap<NodeId, f64> = HashMap::new();
474 for &nbr in nbrs {
475 let lbl = labels[nbr];
476 let sim = feature_similarity(&features[node], &features[nbr]);
477 *label_scores.entry(lbl).or_default() += 1.0 + sim;
478 }
479
480 if let Some((&best_label, _)) = label_scores
481 .iter()
482 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
483 {
484 if best_label != labels[node] {
485 labels[node] = best_label;
486 changed = true;
487 }
488 }
489 }
490
491 if !changed {
492 break;
493 }
494 }
495
496 let mut chambers: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
498 for (node, &lbl) in labels.iter().enumerate() {
499 chambers.entry(lbl).or_default().push(node);
500 }
501
502 let mut result: Vec<Vec<NodeId>> = chambers.into_values().collect();
503 result.sort_by_key(|b| Reverse(b.len())); Ok(result)
505}
506
507fn feature_similarity(a: &[f64], b: &[f64]) -> f64 {
509 if a.is_empty() || b.is_empty() || a.len() != b.len() {
510 return 0.0;
511 }
512 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
513 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
514 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
515 dot / (norm_a * norm_b)
516}
517
518pub fn polarization_index<N, E, Ix>(graph: &Graph<N, E, Ix>, features: &[Vec<f64>]) -> Result<f64>
539where
540 N: Node + Clone + std::fmt::Debug,
541 E: EdgeWeight + Clone + Into<f64>,
542 Ix: petgraph::graph::IndexType,
543{
544 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
545 let n = nodes.len();
546
547 if n < 2 {
548 return Ok(0.0);
549 }
550
551 let feat_len = features.first().map(|f| f.len()).unwrap_or(0);
552 let feature_pad: Vec<Vec<f64>>;
553 let features_ref: &[Vec<f64>] = if features.len() == n {
554 features
555 } else {
556 feature_pad = vec![vec![0.0; feat_len.max(1)]; n];
557 &feature_pad
558 };
559
560 let chambers = echo_chamber_detection(graph, features_ref)?;
562 let num_chambers = chambers.len();
563
564 if num_chambers <= 1 {
565 return Ok(0.0);
566 }
567
568 let mut node_chamber: Vec<usize> = vec![0; n];
570 for (cid, chamber) in chambers.iter().enumerate() {
571 for &node in chamber {
572 if node < n {
573 node_chamber[node] = cid;
574 }
575 }
576 }
577
578 let node_to_idx: HashMap<N, NodeId> = nodes
579 .iter()
580 .enumerate()
581 .map(|(i, nd)| (nd.clone(), i))
582 .collect();
583
584 let edges = graph.edges();
585 let total_edges = edges.len() as f64;
586
587 if total_edges == 0.0 {
588 return Ok(0.0);
589 }
590
591 let mut intra = 0.0f64;
593 let mut cross = 0.0f64;
594 let mut intra_sim = 0.0f64;
595 let mut cross_sim = 0.0f64;
596
597 for edge in &edges {
598 if let (Some(&si), Some(&ti)) =
599 (node_to_idx.get(&edge.source), node_to_idx.get(&edge.target))
600 {
601 let sim = feature_similarity(
602 features_ref.get(si).map(|v| v.as_slice()).unwrap_or(&[]),
603 features_ref.get(ti).map(|v| v.as_slice()).unwrap_or(&[]),
604 );
605 if node_chamber[si] == node_chamber[ti] {
606 intra += 1.0;
607 intra_sim += sim;
608 } else {
609 cross += 1.0;
610 cross_sim += sim;
611 }
612 }
613 }
614
615 let modularity_component = intra / total_edges;
617
618 let homophily_component = if feat_len > 0 && (intra + cross) > 0.0 {
620 let avg_intra_sim = if intra > 0.0 { intra_sim / intra } else { 0.0 };
621 let avg_cross_sim = if cross > 0.0 { cross_sim / cross } else { 0.0 };
622 ((avg_intra_sim - avg_cross_sim + 2.0) / 4.0).clamp(0.0, 1.0)
624 } else {
625 0.5 };
627
628 let polarization = 0.6 * modularity_component + 0.4 * homophily_component;
630 Ok(polarization.clamp(0.0, 1.0))
631}
632
633pub fn simulate_spread<N, E, Ix>(
642 graph: &Graph<N, E, Ix>,
643 seeds: &[NodeId],
644 config: &InfluenceConfig,
645) -> Result<HashSet<NodeId>>
646where
647 N: Node + Clone + std::fmt::Debug,
648 E: EdgeWeight + Clone + Into<f64>,
649 Ix: petgraph::graph::IndexType,
650{
651 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
652 let n = nodes.len();
653 let node_to_idx: HashMap<N, NodeId> = nodes
654 .iter()
655 .enumerate()
656 .map(|(i, nd)| (nd.clone(), i))
657 .collect();
658
659 let mut adj: HashMap<NodeId, Vec<(NodeId, f64)>> = HashMap::new();
660 for edge in graph.edges() {
661 let si = *node_to_idx
662 .get(&edge.source)
663 .ok_or_else(|| GraphError::node_not_found("source"))?;
664 let ti = *node_to_idx
665 .get(&edge.target)
666 .ok_or_else(|| GraphError::node_not_found("target"))?;
667 let w: f64 = edge.weight.clone().into();
668 let prob = if w > 0.0 && w <= 1.0 {
669 w
670 } else {
671 config.default_prob
672 };
673 adj.entry(si).or_default().push((ti, prob));
674 adj.entry(ti).or_default().push((si, prob));
675 }
676
677 let active_count = match &config.model {
678 CascadeModel::IndependentCascade => {
679 let mut rng = scirs2_core::random::rng();
681 let mut active: HashSet<NodeId> = seeds.iter().cloned().collect();
682 let mut queue: VecDeque<NodeId> = seeds.iter().cloned().collect();
683 while let Some(node) = queue.pop_front() {
684 if let Some(neighbors) = adj.get(&node) {
685 for &(nbr, prob) in neighbors {
686 if !active.contains(&nbr) && rng.random::<f64>() < prob {
687 active.insert(nbr);
688 queue.push_back(nbr);
689 }
690 }
691 }
692 }
693 active
694 }
695 CascadeModel::LinearThreshold => {
696 let mut rng = scirs2_core::random::rng();
697 let thresholds: Vec<f64> = (0..n).map(|_| rng.random::<f64>()).collect();
698 let mut active: HashSet<NodeId> = seeds.iter().cloned().collect();
699 let mut changed = true;
700 while changed {
701 changed = false;
702 for node in 0..n {
703 if active.contains(&node) {
704 continue;
705 }
706 let influence: f64 = adj
707 .get(&node)
708 .map(|nbrs| {
709 nbrs.iter()
710 .filter(|&&(nbr, _)| active.contains(&nbr))
711 .map(|&(_, w)| w)
712 .sum::<f64>()
713 })
714 .unwrap_or(0.0);
715 if influence >= thresholds[node] {
716 active.insert(node);
717 changed = true;
718 }
719 }
720 }
721 active
722 }
723 };
724
725 Ok(active_count)
726}
727
728#[cfg(test)]
733mod tests {
734 use super::*;
735 use crate::base::Graph;
736
737 fn make_social_graph() -> Graph<usize, f64> {
738 let mut g: Graph<usize, f64> = Graph::new();
739 for i in 0..4 {
741 for j in i + 1..4 {
742 let _ = g.add_edge(i, j, 0.3);
743 }
744 }
745 for i in 5..9 {
746 for j in i + 1..9 {
747 let _ = g.add_edge(i, j, 0.3);
748 }
749 }
750 let _ = g.add_edge(3, 4, 0.1);
752 let _ = g.add_edge(4, 5, 0.1);
753 g
754 }
755
756 #[test]
757 fn test_influence_maximization_returns_k_seeds() {
758 let g = make_social_graph();
759 let config = InfluenceConfig {
760 model: CascadeModel::IndependentCascade,
761 num_simulations: 20,
762 default_prob: 0.3,
763 };
764 let seeds = influence_maximization(&g, 3, &config).expect("IM failed");
765 assert_eq!(seeds.len(), 3, "Should return exactly k seeds");
766 let unique: HashSet<_> = seeds.iter().cloned().collect();
768 assert_eq!(unique.len(), 3, "Seeds should be unique");
769 }
770
771 #[test]
772 fn test_influence_maximization_linear_threshold() {
773 let g = make_social_graph();
774 let config = InfluenceConfig {
775 model: CascadeModel::LinearThreshold,
776 num_simulations: 20,
777 default_prob: 0.3,
778 };
779 let seeds = influence_maximization(&g, 2, &config).expect("IM LT failed");
780 assert_eq!(seeds.len(), 2);
781 }
782
783 #[test]
784 fn test_influence_maximization_k_zero() {
785 let g = make_social_graph();
786 let config = InfluenceConfig::default();
787 let seeds = influence_maximization(&g, 0, &config).expect("IM k=0");
788 assert!(seeds.is_empty());
789 }
790
791 #[test]
792 fn test_role_detection_identifies_hub() {
793 let g = make_social_graph();
794 let roles = role_detection(&g).expect("Role detection failed");
795 assert!(roles.contains_key(&4), "Node 4 should have a role");
797 let hubs: Vec<_> = roles.values().filter(|r| **r == RoleType::Hub).collect();
799 assert!(!hubs.is_empty(), "Should detect at least one hub");
800 }
801
802 #[test]
803 fn test_role_detection_isolated() {
804 let mut g: Graph<usize, f64> = Graph::new();
805 g.add_node(0);
806 g.add_node(1);
807 let _ = g.add_edge(0, 1, 1.0);
808 g.add_node(2); let roles = role_detection(&g).expect("Roles failed");
810 assert_eq!(roles.get(&2), Some(&RoleType::Isolated));
811 }
812
813 #[test]
814 fn test_echo_chamber_detection_two_groups() {
815 let g = make_social_graph();
816 let features: Vec<Vec<f64>> = (0..9)
818 .map(|i| vec![if i < 4 { 0.1 } else { 0.9 }])
819 .collect();
820 let chambers = echo_chamber_detection(&g, &features).expect("Echo chamber failed");
821 assert!(!chambers.is_empty(), "Should detect at least one chamber");
822 let total: usize = chambers.iter().map(|c| c.len()).sum();
824 assert_eq!(total, 9, "All nodes must be assigned to a chamber");
825 }
826
827 #[test]
828 fn test_echo_chamber_feature_size_mismatch() {
829 let g = make_social_graph();
830 let features: Vec<Vec<f64>> = vec![vec![0.5]; 3]; let result = echo_chamber_detection(&g, &features);
832 assert!(
833 result.is_err(),
834 "Should return error for mismatched features"
835 );
836 }
837
838 #[test]
839 fn test_polarization_index_range() {
840 let g = make_social_graph();
841 let features: Vec<Vec<f64>> = (0..9)
842 .map(|i| vec![if i < 4 { 0.0 } else { 1.0 }])
843 .collect();
844 let pi = polarization_index(&g, &features).expect("Polarization failed");
845 assert!(
846 (0.0..=1.0).contains(&pi),
847 "Polarization index must be in [0,1], got {}",
848 pi
849 );
850 }
851
852 #[test]
853 fn test_polarization_index_no_features() {
854 let g = make_social_graph();
855 let features: Vec<Vec<f64>> = vec![vec![0.0; 0]; 9];
856 let pi = polarization_index(&g, &features).expect("Polarization (no feat)");
857 assert!((0.0..=1.0).contains(&pi));
858 }
859
860 #[test]
861 fn test_simulate_spread_ic() {
862 let g = make_social_graph();
863 let config = InfluenceConfig {
864 model: CascadeModel::IndependentCascade,
865 num_simulations: 10,
866 default_prob: 0.3,
867 };
868 let activated = simulate_spread(&g, &[0], &config).expect("Spread failed");
869 assert!(activated.contains(&0), "Seed must be in activated set");
871 }
872}