1use std::collections::{HashMap, HashSet, VecDeque};
13
14use scirs2_core::random::{Rng, RngExt};
15
16use crate::error::{GraphError, Result};
17
18pub type AdjList = HashMap<usize, Vec<(usize, f64)>>;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum SirState {
32 Susceptible,
34 Infected,
36 Recovered,
38}
39
40#[derive(Debug, Clone)]
46pub struct SimulationResult {
47 pub activated: HashSet<usize>,
49 pub time_series: Vec<(usize, usize, usize)>,
52 pub spread: usize,
54}
55
56#[derive(Debug, Clone)]
66pub struct IndependentCascade {
67 pub adjacency: AdjList,
69 pub num_nodes: usize,
71}
72
73impl IndependentCascade {
74 pub fn new(adjacency: AdjList, num_nodes: usize) -> Self {
80 IndependentCascade {
81 adjacency,
82 num_nodes,
83 }
84 }
85
86 pub fn from_edges(edges: &[(usize, usize, f64)], num_nodes: usize) -> Self {
88 let mut adjacency: AdjList = HashMap::new();
89 for &(src, tgt, prob) in edges {
90 adjacency.entry(src).or_default().push((tgt, prob));
91 }
92 IndependentCascade::new(adjacency, num_nodes)
93 }
94
95 pub fn simulate(&self, seeds: &[usize]) -> Result<SimulationResult> {
97 simulate_ic(&self.adjacency, seeds)
98 }
99
100 pub fn expected_spread(&self, seeds: &[usize], num_simulations: usize) -> Result<f64> {
102 expected_spread_ic(&self.adjacency, seeds, num_simulations)
103 }
104}
105
106#[derive(Debug, Clone)]
117pub struct LinearThreshold {
118 pub adjacency: AdjList,
123 pub num_nodes: usize,
125 pub thresholds: Option<Vec<f64>>,
127}
128
129impl LinearThreshold {
130 pub fn new(adjacency: AdjList, num_nodes: usize) -> Self {
132 LinearThreshold {
133 adjacency,
134 num_nodes,
135 thresholds: None,
136 }
137 }
138
139 pub fn with_thresholds(adjacency: AdjList, thresholds: Vec<f64>) -> Result<Self> {
141 let num_nodes = thresholds.len();
142 for (i, &t) in thresholds.iter().enumerate() {
143 if !(0.0..=1.0).contains(&t) {
144 return Err(GraphError::InvalidParameter {
145 param: format!("thresholds[{i}]"),
146 value: t.to_string(),
147 expected: "value in [0, 1]".to_string(),
148 context: "LinearThreshold::with_thresholds".to_string(),
149 });
150 }
151 }
152 Ok(LinearThreshold {
153 adjacency,
154 num_nodes,
155 thresholds: Some(thresholds),
156 })
157 }
158
159 pub fn from_edges(edges: &[(usize, usize, f64)], num_nodes: usize) -> Self {
161 let mut adjacency: AdjList = HashMap::new();
162 for &(src, tgt, w) in edges {
163 adjacency.entry(src).or_default().push((tgt, w));
164 }
165 LinearThreshold::new(adjacency, num_nodes)
166 }
167
168 pub fn simulate(&self, seeds: &[usize]) -> Result<SimulationResult> {
170 simulate_lt(
171 &self.adjacency,
172 self.num_nodes,
173 seeds,
174 self.thresholds.as_deref(),
175 )
176 }
177
178 pub fn expected_spread(&self, seeds: &[usize], num_simulations: usize) -> Result<f64> {
180 expected_spread_lt(
181 &self.adjacency,
182 self.num_nodes,
183 seeds,
184 self.thresholds.as_deref(),
185 num_simulations,
186 )
187 }
188}
189
190#[derive(Debug, Clone)]
200pub struct SIRModel {
201 pub adjacency: AdjList,
203 pub beta: f64,
205 pub gamma: f64,
207 pub num_nodes: usize,
209}
210
211impl SIRModel {
212 pub fn new(adjacency: AdjList, num_nodes: usize, beta: f64, gamma: f64) -> Result<Self> {
217 if !(0.0..=1.0).contains(&beta) {
218 return Err(GraphError::InvalidParameter {
219 param: "beta".to_string(),
220 value: beta.to_string(),
221 expected: "[0, 1]".to_string(),
222 context: "SIRModel::new".to_string(),
223 });
224 }
225 if !(0.0..=1.0).contains(&gamma) {
226 return Err(GraphError::InvalidParameter {
227 param: "gamma".to_string(),
228 value: gamma.to_string(),
229 expected: "[0, 1]".to_string(),
230 context: "SIRModel::new".to_string(),
231 });
232 }
233 Ok(SIRModel {
234 adjacency,
235 beta,
236 gamma,
237 num_nodes,
238 })
239 }
240
241 pub fn from_edges(
243 edges: &[(usize, usize)],
244 num_nodes: usize,
245 beta: f64,
246 gamma: f64,
247 ) -> Result<Self> {
248 let mut adjacency: AdjList = HashMap::new();
249 for &(src, tgt) in edges {
250 adjacency.entry(src).or_default().push((tgt, 1.0));
251 adjacency.entry(tgt).or_default().push((src, 1.0));
252 }
253 SIRModel::new(adjacency, num_nodes, beta, gamma)
254 }
255
256 pub fn simulate(&self, initial_infected: &[usize]) -> Result<SimulationResult> {
258 simulate_sir(
259 &self.adjacency,
260 self.num_nodes,
261 initial_infected,
262 self.beta,
263 self.gamma,
264 )
265 }
266}
267
268#[derive(Debug, Clone)]
278pub struct SISModel {
279 pub adjacency: AdjList,
281 pub beta: f64,
283 pub gamma: f64,
285 pub num_nodes: usize,
287 pub max_steps: usize,
289}
290
291impl SISModel {
292 pub fn new(
297 adjacency: AdjList,
298 num_nodes: usize,
299 beta: f64,
300 gamma: f64,
301 max_steps: usize,
302 ) -> Result<Self> {
303 if !(0.0..=1.0).contains(&beta) {
304 return Err(GraphError::InvalidParameter {
305 param: "beta".to_string(),
306 value: beta.to_string(),
307 expected: "[0, 1]".to_string(),
308 context: "SISModel::new".to_string(),
309 });
310 }
311 if !(0.0..=1.0).contains(&gamma) {
312 return Err(GraphError::InvalidParameter {
313 param: "gamma".to_string(),
314 value: gamma.to_string(),
315 expected: "[0, 1]".to_string(),
316 context: "SISModel::new".to_string(),
317 });
318 }
319 Ok(SISModel {
320 adjacency,
321 beta,
322 gamma,
323 num_nodes,
324 max_steps,
325 })
326 }
327
328 pub fn from_edges(
330 edges: &[(usize, usize)],
331 num_nodes: usize,
332 beta: f64,
333 gamma: f64,
334 max_steps: usize,
335 ) -> Result<Self> {
336 let mut adjacency: AdjList = HashMap::new();
337 for &(src, tgt) in edges {
338 adjacency.entry(src).or_default().push((tgt, 1.0));
339 adjacency.entry(tgt).or_default().push((src, 1.0));
340 }
341 SISModel::new(adjacency, num_nodes, beta, gamma, max_steps)
342 }
343
344 pub fn simulate(&self, initial_infected: &[usize]) -> Result<SimulationResult> {
346 simulate_sis(
347 &self.adjacency,
348 self.num_nodes,
349 initial_infected,
350 self.beta,
351 self.gamma,
352 self.max_steps,
353 )
354 }
355}
356
357pub fn simulate_ic(adjacency: &AdjList, seeds: &[usize]) -> Result<SimulationResult> {
371 let mut rng = scirs2_core::random::rng();
372 let mut active: HashSet<usize> = seeds.iter().cloned().collect();
373 let mut queue: VecDeque<usize> = seeds.iter().cloned().collect();
374
375 while let Some(node) = queue.pop_front() {
376 if let Some(neighbors) = adjacency.get(&node) {
377 for &(nbr, prob) in neighbors {
378 if !active.contains(&nbr) && rng.random::<f64>() < prob {
379 active.insert(nbr);
380 queue.push_back(nbr);
381 }
382 }
383 }
384 }
385
386 let spread = active.len();
387 Ok(SimulationResult {
388 activated: active,
389 time_series: Vec::new(),
390 spread,
391 })
392}
393
394pub fn expected_spread(
402 adjacency: &AdjList,
403 seeds: &[usize],
404 num_simulations: usize,
405) -> Result<f64> {
406 expected_spread_ic(adjacency, seeds, num_simulations)
407}
408
409fn expected_spread_ic(adjacency: &AdjList, seeds: &[usize], num_simulations: usize) -> Result<f64> {
410 if num_simulations == 0 {
411 return Err(GraphError::InvalidParameter {
412 param: "num_simulations".to_string(),
413 value: "0".to_string(),
414 expected: ">= 1".to_string(),
415 context: "expected_spread_ic".to_string(),
416 });
417 }
418 let mut total = 0.0_f64;
419 for _ in 0..num_simulations {
420 let result = simulate_ic(adjacency, seeds)?;
421 total += result.spread as f64;
422 }
423 Ok(total / num_simulations as f64)
424}
425
426pub fn simulate_lt(
440 adjacency: &AdjList,
441 num_nodes: usize,
442 seeds: &[usize],
443 fixed_thresholds: Option<&[f64]>,
444) -> Result<SimulationResult> {
445 let mut reverse: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
447 for (&src, nbrs) in adjacency {
448 for &(tgt, w) in nbrs {
449 reverse.entry(tgt).or_default().push((src, w));
450 }
451 }
452
453 let mut rng = scirs2_core::random::rng();
454
455 let thresholds: Vec<f64> = match fixed_thresholds {
457 Some(t) => {
458 if t.len() < num_nodes {
459 return Err(GraphError::InvalidParameter {
460 param: "fixed_thresholds".to_string(),
461 value: format!("len={}", t.len()),
462 expected: format!(">= num_nodes={num_nodes}"),
463 context: "simulate_lt".to_string(),
464 });
465 }
466 t.to_vec()
467 }
468 None => (0..num_nodes).map(|_| rng.random::<f64>()).collect(),
469 };
470
471 let mut active: HashSet<usize> = seeds.iter().cloned().collect();
472 let mut changed = true;
473
474 while changed {
476 changed = false;
477 let candidates: Vec<usize> = reverse
479 .keys()
480 .filter(|&&node| !active.contains(&node))
481 .cloned()
482 .collect();
483
484 for node in candidates {
485 let weight_sum: f64 = reverse
486 .get(&node)
487 .map(|in_nbrs| {
488 in_nbrs
489 .iter()
490 .filter(|(src, _)| active.contains(src))
491 .map(|(_, w)| w)
492 .sum()
493 })
494 .unwrap_or(0.0);
495
496 let threshold = if node < thresholds.len() {
497 thresholds[node]
498 } else {
499 1.0
500 };
501
502 if weight_sum >= threshold {
503 active.insert(node);
504 changed = true;
505 }
506 }
507 }
508
509 let spread = active.len();
510 Ok(SimulationResult {
511 activated: active,
512 time_series: Vec::new(),
513 spread,
514 })
515}
516
517fn expected_spread_lt(
518 adjacency: &AdjList,
519 num_nodes: usize,
520 seeds: &[usize],
521 fixed_thresholds: Option<&[f64]>,
522 num_simulations: usize,
523) -> Result<f64> {
524 if num_simulations == 0 {
525 return Err(GraphError::InvalidParameter {
526 param: "num_simulations".to_string(),
527 value: "0".to_string(),
528 expected: ">= 1".to_string(),
529 context: "expected_spread_lt".to_string(),
530 });
531 }
532 let mut total = 0.0_f64;
533 for _ in 0..num_simulations {
534 let result = simulate_lt(adjacency, num_nodes, seeds, fixed_thresholds)?;
535 total += result.spread as f64;
536 }
537 Ok(total / num_simulations as f64)
538}
539
540pub fn simulate_sir(
554 adjacency: &AdjList,
555 num_nodes: usize,
556 initial_infected: &[usize],
557 beta: f64,
558 gamma: f64,
559) -> Result<SimulationResult> {
560 if !(0.0..=1.0).contains(&beta) || !(0.0..=1.0).contains(&gamma) {
561 return Err(GraphError::InvalidParameter {
562 param: "beta/gamma".to_string(),
563 value: format!("beta={beta}, gamma={gamma}"),
564 expected: "both in [0, 1]".to_string(),
565 context: "simulate_sir".to_string(),
566 });
567 }
568
569 let mut rng = scirs2_core::random::rng();
570 let mut states: Vec<SirState> = vec![SirState::Susceptible; num_nodes];
571 for &node in initial_infected {
572 if node < num_nodes {
573 states[node] = SirState::Infected;
574 }
575 }
576
577 let mut time_series: Vec<(usize, usize, usize)> = Vec::new();
578 let mut ever_infected: HashSet<usize> = initial_infected.iter().cloned().collect();
579
580 loop {
581 let n_infected = states.iter().filter(|&&s| s == SirState::Infected).count();
582 let n_recovered = states.iter().filter(|&&s| s == SirState::Recovered).count();
583 let n_susceptible = num_nodes - n_infected - n_recovered;
584 time_series.push((n_susceptible, n_infected, n_recovered));
585
586 if n_infected == 0 {
587 break;
588 }
589
590 let mut next_states = states.clone();
591
592 for node in 0..num_nodes {
594 if states[node] == SirState::Infected {
595 if let Some(neighbors) = adjacency.get(&node) {
596 for &(nbr, _) in neighbors {
597 if nbr < num_nodes
598 && states[nbr] == SirState::Susceptible
599 && rng.random::<f64>() < beta
600 {
601 next_states[nbr] = SirState::Infected;
602 ever_infected.insert(nbr);
603 }
604 }
605 }
606 }
607 }
608
609 for node in 0..num_nodes {
611 if states[node] == SirState::Infected && rng.random::<f64>() < gamma {
612 next_states[node] = SirState::Recovered;
613 }
614 }
615
616 states = next_states;
617 }
618
619 Ok(SimulationResult {
620 activated: ever_infected,
621 time_series,
622 spread: states.iter().filter(|&&s| s == SirState::Recovered).count()
623 + states.iter().filter(|&&s| s == SirState::Infected).count(),
624 })
625}
626
627pub fn simulate_sis(
637 adjacency: &AdjList,
638 num_nodes: usize,
639 initial_infected: &[usize],
640 beta: f64,
641 gamma: f64,
642 max_steps: usize,
643) -> Result<SimulationResult> {
644 if !(0.0..=1.0).contains(&beta) || !(0.0..=1.0).contains(&gamma) {
645 return Err(GraphError::InvalidParameter {
646 param: "beta/gamma".to_string(),
647 value: format!("beta={beta}, gamma={gamma}"),
648 expected: "both in [0, 1]".to_string(),
649 context: "simulate_sis".to_string(),
650 });
651 }
652
653 let mut rng = scirs2_core::random::rng();
654 let mut infected: HashSet<usize> = initial_infected.iter().cloned().collect();
655 let mut ever_infected = infected.clone();
656 let mut time_series: Vec<(usize, usize, usize)> = Vec::new();
657
658 for _step in 0..max_steps {
659 let n_infected = infected.len();
660 time_series.push((num_nodes - n_infected, n_infected, 0));
661
662 if n_infected == 0 {
663 break;
664 }
665
666 let mut new_infections: HashSet<usize> = HashSet::new();
667 let mut new_recoveries: HashSet<usize> = HashSet::new();
668
669 for &node in &infected {
670 if let Some(neighbors) = adjacency.get(&node) {
672 for &(nbr, _) in neighbors {
673 if nbr < num_nodes && !infected.contains(&nbr) && rng.random::<f64>() < beta {
674 new_infections.insert(nbr);
675 ever_infected.insert(nbr);
676 }
677 }
678 }
679 if rng.random::<f64>() < gamma {
681 new_recoveries.insert(node);
682 }
683 }
684
685 for node in new_recoveries {
686 infected.remove(&node);
687 }
688 for node in new_infections {
689 infected.insert(node);
690 }
691 }
692
693 let spread = ever_infected.len();
694 Ok(SimulationResult {
695 activated: ever_infected,
696 time_series,
697 spread,
698 })
699}
700
701#[cfg(test)]
706mod tests {
707 use super::*;
708
709 fn star_adjacency(n: usize, prob: f64) -> AdjList {
710 let mut adj: AdjList = HashMap::new();
712 for i in 1..n {
713 adj.entry(0).or_default().push((i, prob));
714 }
715 adj
716 }
717
718 #[test]
719 fn test_simulate_ic_full_spread() {
720 let adj = star_adjacency(6, 1.0);
722 let result = simulate_ic(&adj, &[0]).expect("ic simulation");
723 assert_eq!(result.spread, 6);
724 }
725
726 #[test]
727 fn test_simulate_ic_no_spread() {
728 let adj = star_adjacency(6, 0.0);
730 let result = simulate_ic(&adj, &[0]).expect("ic simulation");
731 assert_eq!(result.spread, 1);
732 }
733
734 #[test]
735 fn test_simulate_lt_deterministic_threshold() {
736 let mut adj: AdjList = HashMap::new();
738 for i in 1..4_usize {
740 adj.entry(0).or_default().push((i, 1.0));
741 }
742 let thresholds = vec![0.5_f64; 4];
743 let result = simulate_lt(&adj, 4, &[0], Some(&thresholds)).expect("lt simulation");
744 assert!(result.spread >= 1);
745 }
746
747 #[test]
748 fn test_simulate_sir_terminates() {
749 let mut adj: AdjList = HashMap::new();
751 for i in 0..4_usize {
752 adj.entry(i).or_default().push((i + 1, 1.0));
753 adj.entry(i + 1).or_default().push((i, 1.0));
754 }
755 let result = simulate_sir(&adj, 5, &[0], 0.8, 0.5).expect("sir");
756 assert!(result.spread >= 1);
757 assert!(!result.time_series.is_empty());
758 }
759
760 #[test]
761 fn test_simulate_sis_terminates() {
762 let mut adj: AdjList = HashMap::new();
763 for i in 0..4_usize {
764 adj.entry(i).or_default().push((i + 1, 1.0));
765 adj.entry(i + 1).or_default().push((i, 1.0));
766 }
767 let result = simulate_sis(&adj, 5, &[0], 0.5, 0.9, 1000).expect("sis");
768 assert!(result.spread >= 1);
769 }
770
771 #[test]
772 fn test_expected_spread_ic() {
773 let adj = star_adjacency(5, 1.0);
774 let spread = expected_spread(&adj, &[0], 50).expect("expected spread");
775 assert!((spread - 5.0).abs() < 0.01);
777 }
778
779 #[test]
780 fn test_sir_bad_params() {
781 let adj: AdjList = HashMap::new();
782 let err = simulate_sir(&adj, 1, &[], 2.0, 0.5);
783 assert!(err.is_err());
784 }
785
786 #[test]
787 fn test_ic_struct() {
788 let edges = vec![(0_usize, 1_usize, 1.0_f64), (0, 2, 1.0), (0, 3, 1.0)];
789 let ic = IndependentCascade::from_edges(&edges, 4);
790 let res = ic.simulate(&[0]).expect("simulate");
791 assert_eq!(res.spread, 4);
792 }
793
794 #[test]
795 fn test_lt_bad_threshold() {
796 let adj: AdjList = HashMap::new();
797 let err = LinearThreshold::with_thresholds(adj, vec![0.5, 1.5]);
798 assert!(err.is_err());
799 }
800}