1use scirs2_core::ndarray::ArrayD;
11use std::collections::HashMap;
12
13use crate::error::{PgmError, Result};
14use crate::factor::Factor;
15use crate::graph::FactorGraph;
16
17pub struct BayesianNetwork {
22 graph: FactorGraph,
23 structure: HashMap<String, Vec<String>>, }
25
26impl BayesianNetwork {
27 pub fn new() -> Self {
29 Self {
30 graph: FactorGraph::new(),
31 structure: HashMap::new(),
32 }
33 }
34
35 pub fn add_variable(&mut self, name: String, cardinality: usize) -> &mut Self {
37 self.graph
38 .add_variable_with_card(name.clone(), "Discrete".to_string(), cardinality);
39 self.structure.insert(name, Vec::new());
40 self
41 }
42
43 pub fn add_cpd(
50 &mut self,
51 child: String,
52 parents: Vec<String>,
53 cpd: ArrayD<f64>,
54 ) -> Result<&mut Self> {
55 if self.graph.get_variable(&child).is_none() {
57 return Err(PgmError::VariableNotFound(child));
58 }
59
60 for parent in &parents {
62 if self.graph.get_variable(parent).is_none() {
63 return Err(PgmError::VariableNotFound(parent.clone()));
64 }
65 }
66
67 self.structure.insert(child.clone(), parents.clone());
69
70 let mut factor_vars = parents.clone();
72 factor_vars.push(child.clone());
73
74 let factor = Factor::new(format!("P({}|{:?})", child, parents), factor_vars, cpd)?;
75
76 self.graph.add_factor(factor)?;
77 Ok(self)
78 }
79
80 pub fn add_prior(&mut self, variable: String, prior: ArrayD<f64>) -> Result<&mut Self> {
82 let factor = Factor::new(format!("P({})", variable), vec![variable.clone()], prior)?;
83 self.graph.add_factor(factor)?;
84 self.structure.insert(variable, Vec::new());
85 Ok(self)
86 }
87
88 pub fn graph(&self) -> &FactorGraph {
90 &self.graph
91 }
92
93 pub fn is_acyclic(&self) -> bool {
95 let mut visited = HashMap::new();
97 let mut rec_stack = HashMap::new();
98
99 for node in self.structure.keys() {
100 if !visited.contains_key(node) && self.has_cycle(node, &mut visited, &mut rec_stack) {
101 return false;
102 }
103 }
104
105 true
106 }
107
108 fn has_cycle(
109 &self,
110 node: &str,
111 visited: &mut HashMap<String, bool>,
112 rec_stack: &mut HashMap<String, bool>,
113 ) -> bool {
114 visited.insert(node.to_string(), true);
115 rec_stack.insert(node.to_string(), true);
116
117 if let Some(parents) = self.structure.get(node) {
118 for parent in parents {
119 if !visited.contains_key(parent) {
120 if self.has_cycle(parent, visited, rec_stack) {
121 return true;
122 }
123 } else if rec_stack.get(parent) == Some(&true) {
124 return true;
125 }
126 }
127 }
128
129 rec_stack.insert(node.to_string(), false);
130 false
131 }
132
133 pub fn topological_order(&self) -> Result<Vec<String>> {
135 if !self.is_acyclic() {
136 return Err(PgmError::InvalidGraph(
137 "Network contains cycles".to_string(),
138 ));
139 }
140
141 let mut in_degree: HashMap<String, usize> = HashMap::new();
142 let mut children: HashMap<String, Vec<String>> = HashMap::new();
143
144 for (child, parents) in &self.structure {
146 in_degree.insert(child.clone(), parents.len());
147 for parent in parents {
148 children
149 .entry(parent.clone())
150 .or_default()
151 .push(child.clone());
152 }
153 }
154
155 let mut queue: Vec<String> = in_degree
157 .iter()
158 .filter(|(_, °)| deg == 0)
159 .map(|(v, _)| v.clone())
160 .collect();
161
162 let mut result = Vec::new();
163
164 while let Some(node) = queue.pop() {
165 result.push(node.clone());
166
167 if let Some(child_nodes) = children.get(&node) {
168 for child in child_nodes {
169 if let Some(deg) = in_degree.get_mut(child) {
170 *deg -= 1;
171 if *deg == 0 {
172 queue.push(child.clone());
173 }
174 }
175 }
176 }
177 }
178
179 if result.len() != self.structure.len() {
180 return Err(PgmError::InvalidGraph(
181 "Could not compute topological order".to_string(),
182 ));
183 }
184
185 Ok(result)
186 }
187}
188
189impl Default for BayesianNetwork {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195pub struct HiddenMarkovModel {
199 graph: FactorGraph,
200 #[allow(dead_code)]
201 num_states: usize,
202 #[allow(dead_code)]
203 num_observations: usize,
204 time_steps: usize,
205}
206
207impl HiddenMarkovModel {
208 pub fn new(num_states: usize, num_observations: usize, time_steps: usize) -> Self {
215 let mut graph = FactorGraph::new();
216
217 for t in 0..time_steps {
219 graph.add_variable_with_card(
220 format!("state_{}", t),
221 "HiddenState".to_string(),
222 num_states,
223 );
224 }
225
226 for t in 0..time_steps {
228 graph.add_variable_with_card(
229 format!("obs_{}", t),
230 "Observation".to_string(),
231 num_observations,
232 );
233 }
234
235 Self {
236 graph,
237 num_states,
238 num_observations,
239 time_steps,
240 }
241 }
242
243 pub fn set_initial_distribution(&mut self, initial: ArrayD<f64>) -> Result<&mut Self> {
245 let factor = Factor::new(
246 "P(state_0)".to_string(),
247 vec!["state_0".to_string()],
248 initial,
249 )?;
250 self.graph.add_factor(factor)?;
251 Ok(self)
252 }
253
254 pub fn set_transition_matrix(&mut self, transition: ArrayD<f64>) -> Result<&mut Self> {
259 for t in 1..self.time_steps {
261 let factor = Factor::new(
262 format!("P(state_{}|state_{})", t, t - 1),
263 vec![format!("state_{}", t - 1), format!("state_{}", t)],
264 transition.clone(),
265 )?;
266 self.graph.add_factor(factor)?;
267 }
268 Ok(self)
269 }
270
271 pub fn set_emission_matrix(&mut self, emission: ArrayD<f64>) -> Result<&mut Self> {
276 for t in 0..self.time_steps {
278 let factor = Factor::new(
279 format!("P(obs_{}|state_{})", t, t),
280 vec![format!("state_{}", t), format!("obs_{}", t)],
281 emission.clone(),
282 )?;
283 self.graph.add_factor(factor)?;
284 }
285 Ok(self)
286 }
287
288 pub fn graph(&self) -> &FactorGraph {
290 &self.graph
291 }
292
293 pub fn filter(&self, observations: &[usize], t: usize) -> Result<ArrayD<f64>> {
298 if t >= self.time_steps {
299 return Err(PgmError::InvalidDistribution(format!(
300 "Time step {} exceeds sequence length {}",
301 t, self.time_steps
302 )));
303 }
304
305 if t >= observations.len() {
306 return Err(PgmError::InvalidDistribution(format!(
307 "Not enough observations: need {} but got {}",
308 t + 1,
309 observations.len()
310 )));
311 }
312
313 let mut evidence_graph = self.graph.clone();
315
316 for (time, &obs_value) in observations.iter().enumerate().take(t + 1) {
318 let obs_var = format!("obs_{}", time);
319
320 let mut evidence_values = vec![0.0; self.num_observations];
322 evidence_values[obs_value] = 1.0;
323 let evidence_factor = Factor::new(
324 format!("evidence_{}", time),
325 vec![obs_var.clone()],
326 ArrayD::from_shape_vec(vec![self.num_observations], evidence_values)?,
327 )?;
328 evidence_graph.add_factor(evidence_factor)?;
329 }
330
331 use crate::variable_elimination::VariableElimination;
333 let ve = VariableElimination::new();
334 let state_var = format!("state_{}", t);
335 ve.marginalize(&evidence_graph, &state_var)
336 }
337
338 pub fn smooth(&self, observations: &[usize], t: usize) -> Result<ArrayD<f64>> {
343 if t >= self.time_steps {
344 return Err(PgmError::InvalidDistribution(format!(
345 "Time step {} exceeds sequence length {}",
346 t, self.time_steps
347 )));
348 }
349
350 if observations.len() != self.time_steps {
351 return Err(PgmError::InvalidDistribution(format!(
352 "Expected {} observations but got {}",
353 self.time_steps,
354 observations.len()
355 )));
356 }
357
358 let mut evidence_graph = self.graph.clone();
360
361 for (time, &obs_value) in observations.iter().enumerate().take(self.time_steps) {
363 let obs_var = format!("obs_{}", time);
364
365 let mut evidence_values = vec![0.0; self.num_observations];
367 evidence_values[obs_value] = 1.0;
368 let evidence_factor = Factor::new(
369 format!("evidence_{}", time),
370 vec![obs_var.clone()],
371 ArrayD::from_shape_vec(vec![self.num_observations], evidence_values)?,
372 )?;
373 evidence_graph.add_factor(evidence_factor)?;
374 }
375
376 use crate::variable_elimination::VariableElimination;
378 let ve = VariableElimination::new();
379 let state_var = format!("state_{}", t);
380 ve.marginalize(&evidence_graph, &state_var)
381 }
382
383 pub fn viterbi(&self, observations: &[usize]) -> Result<Vec<usize>> {
388 if observations.len() != self.time_steps {
389 return Err(PgmError::InvalidDistribution(format!(
390 "Observations length {} does not match time steps {}",
391 observations.len(),
392 self.time_steps
393 )));
394 }
395
396 let mut evidence_graph = self.graph.clone();
398
399 for (time, &obs_value) in observations.iter().enumerate().take(self.time_steps) {
401 let obs_var = format!("obs_{}", time);
402
403 let mut evidence_values = vec![0.0; self.num_observations];
404 evidence_values[obs_value] = 1.0;
405 let evidence_factor = Factor::new(
406 format!("evidence_{}", time),
407 vec![obs_var.clone()],
408 ArrayD::from_shape_vec(vec![self.num_observations], evidence_values)?,
409 )?;
410 evidence_graph.add_factor(evidence_factor)?;
411 }
412
413 use crate::variable_elimination::VariableElimination;
415 let ve = VariableElimination::new();
416 let assignment = ve.map(&evidence_graph)?;
417
418 let mut sequence = Vec::new();
420 for t in 0..self.time_steps {
421 let state_var = format!("state_{}", t);
422 if let Some(&state) = assignment.get(&state_var) {
423 sequence.push(state);
424 } else {
425 return Err(PgmError::VariableNotFound(state_var));
426 }
427 }
428
429 Ok(sequence)
430 }
431}
432
433pub struct MarkovRandomField {
435 graph: FactorGraph,
436}
437
438impl MarkovRandomField {
439 pub fn new() -> Self {
441 Self {
442 graph: FactorGraph::new(),
443 }
444 }
445
446 pub fn add_variable(&mut self, name: String, cardinality: usize) -> &mut Self {
448 self.graph
449 .add_variable_with_card(name, "Discrete".to_string(), cardinality);
450 self
451 }
452
453 pub fn add_pairwise_potential(
455 &mut self,
456 var1: String,
457 var2: String,
458 potential: ArrayD<f64>,
459 ) -> Result<&mut Self> {
460 let factor = Factor::new(
461 format!("φ({},{})", var1, var2),
462 vec![var1.clone(), var2.clone()],
463 potential,
464 )?;
465 self.graph.add_factor(factor)?;
466 Ok(self)
467 }
468
469 pub fn add_unary_potential(
471 &mut self,
472 var: String,
473 potential: ArrayD<f64>,
474 ) -> Result<&mut Self> {
475 let factor = Factor::new(format!("φ({})", var), vec![var.clone()], potential)?;
476 self.graph.add_factor(factor)?;
477 Ok(self)
478 }
479
480 pub fn graph(&self) -> &FactorGraph {
482 &self.graph
483 }
484}
485
486impl Default for MarkovRandomField {
487 fn default() -> Self {
488 Self::new()
489 }
490}
491
492pub struct ConditionalRandomField {
494 graph: FactorGraph,
495 input_vars: Vec<String>,
496 output_vars: Vec<String>,
497}
498
499impl ConditionalRandomField {
500 pub fn new() -> Self {
502 Self {
503 graph: FactorGraph::new(),
504 input_vars: Vec::new(),
505 output_vars: Vec::new(),
506 }
507 }
508
509 pub fn add_input_variable(&mut self, name: String, cardinality: usize) -> &mut Self {
511 self.graph
512 .add_variable_with_card(name.clone(), "Input".to_string(), cardinality);
513 self.input_vars.push(name);
514 self
515 }
516
517 pub fn add_output_variable(&mut self, name: String, cardinality: usize) -> &mut Self {
519 self.graph
520 .add_variable_with_card(name.clone(), "Output".to_string(), cardinality);
521 self.output_vars.push(name);
522 self
523 }
524
525 pub fn add_feature(
527 &mut self,
528 name: String,
529 variables: Vec<String>,
530 potential: ArrayD<f64>,
531 ) -> Result<&mut Self> {
532 let factor = Factor::new(format!("feature_{}", name), variables, potential)?;
533 self.graph.add_factor(factor)?;
534 Ok(self)
535 }
536
537 pub fn graph(&self) -> &FactorGraph {
539 &self.graph
540 }
541}
542
543impl Default for ConditionalRandomField {
544 fn default() -> Self {
545 Self::new()
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use scirs2_core::ndarray::Array;
553
554 #[test]
555 fn test_bayesian_network_creation() {
556 let mut bn = BayesianNetwork::new();
557 bn.add_variable("x".to_string(), 2);
558 bn.add_variable("y".to_string(), 2);
559
560 assert!(bn.graph().get_variable("x").is_some());
561 assert!(bn.graph().get_variable("y").is_some());
562 }
563
564 #[test]
565 fn test_bayesian_network_cpd() {
566 let mut bn = BayesianNetwork::new();
567 bn.add_variable("x".to_string(), 2);
568 bn.add_variable("y".to_string(), 2);
569
570 let prior = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
571 .unwrap()
572 .into_dyn();
573 bn.add_prior("x".to_string(), prior).unwrap();
574
575 let cpd = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
576 .unwrap()
577 .into_dyn();
578 bn.add_cpd("y".to_string(), vec!["x".to_string()], cpd)
579 .unwrap();
580
581 assert_eq!(bn.graph().num_factors(), 2);
582 }
583
584 #[test]
585 fn test_bayesian_network_acyclic() {
586 let mut bn = BayesianNetwork::new();
587 bn.add_variable("x".to_string(), 2);
588 bn.add_variable("y".to_string(), 2);
589
590 let cpd = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
591 .unwrap()
592 .into_dyn();
593 bn.add_cpd("y".to_string(), vec!["x".to_string()], cpd)
594 .unwrap();
595
596 assert!(bn.is_acyclic());
597 }
598
599 #[test]
600 fn test_bayesian_network_topological_order() {
601 let mut bn = BayesianNetwork::new();
602 bn.add_variable("x".to_string(), 2);
603 bn.add_variable("y".to_string(), 2);
604 bn.add_variable("z".to_string(), 2);
605
606 let cpd_y = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
607 .unwrap()
608 .into_dyn();
609 bn.add_cpd("y".to_string(), vec!["x".to_string()], cpd_y)
610 .unwrap();
611
612 let cpd_z = Array::from_shape_vec(vec![2, 2], vec![0.8, 0.2, 0.3, 0.7])
613 .unwrap()
614 .into_dyn();
615 bn.add_cpd("z".to_string(), vec!["y".to_string()], cpd_z)
616 .unwrap();
617
618 let order = bn.topological_order().unwrap();
619 assert_eq!(order.len(), 3);
620 }
621
622 #[test]
623 fn test_hmm_creation() {
624 let hmm = HiddenMarkovModel::new(3, 2, 5);
625 assert_eq!(hmm.graph().num_variables(), 10); }
627
628 #[test]
629 fn test_hmm_parameters() {
630 let mut hmm = HiddenMarkovModel::new(2, 2, 3);
631
632 let initial = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
633 .unwrap()
634 .into_dyn();
635 hmm.set_initial_distribution(initial).unwrap();
636
637 let transition = Array::from_shape_vec(vec![2, 2], vec![0.7, 0.3, 0.4, 0.6])
638 .unwrap()
639 .into_dyn();
640 hmm.set_transition_matrix(transition).unwrap();
641
642 let emission = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
643 .unwrap()
644 .into_dyn();
645 hmm.set_emission_matrix(emission).unwrap();
646
647 assert!(hmm.graph().num_factors() > 0);
648 }
649
650 #[test]
651 fn test_hmm_filtering() {
652 let mut hmm = HiddenMarkovModel::new(2, 2, 3);
653
654 let initial = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
655 .unwrap()
656 .into_dyn();
657 hmm.set_initial_distribution(initial).unwrap();
658
659 let transition = Array::from_shape_vec(vec![2, 2], vec![0.7, 0.3, 0.4, 0.6])
660 .unwrap()
661 .into_dyn();
662 hmm.set_transition_matrix(transition).unwrap();
663
664 let emission = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
665 .unwrap()
666 .into_dyn();
667 hmm.set_emission_matrix(emission).unwrap();
668
669 let observations = vec![0, 1, 0];
671 let result = hmm.filter(&observations, 1);
672 assert!(result.is_ok());
673
674 let marginal = result.unwrap();
675 assert_eq!(marginal.len(), 2);
676 let sum: f64 = marginal.iter().sum();
678 assert!((sum - 1.0).abs() < 1e-6);
679 }
680
681 #[test]
682 fn test_hmm_smoothing() {
683 let mut hmm = HiddenMarkovModel::new(2, 2, 3);
684
685 let initial = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
686 .unwrap()
687 .into_dyn();
688 hmm.set_initial_distribution(initial).unwrap();
689
690 let transition = Array::from_shape_vec(vec![2, 2], vec![0.7, 0.3, 0.4, 0.6])
691 .unwrap()
692 .into_dyn();
693 hmm.set_transition_matrix(transition).unwrap();
694
695 let emission = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
696 .unwrap()
697 .into_dyn();
698 hmm.set_emission_matrix(emission).unwrap();
699
700 let observations = vec![0, 1, 0];
702 let result = hmm.smooth(&observations, 1);
703 assert!(result.is_ok());
704
705 let marginal = result.unwrap();
706 assert_eq!(marginal.len(), 2);
707 let sum: f64 = marginal.iter().sum();
708 assert!((sum - 1.0).abs() < 1e-6);
709 }
710
711 #[test]
712 fn test_hmm_viterbi() {
713 let mut hmm = HiddenMarkovModel::new(2, 2, 3);
714
715 let initial = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
716 .unwrap()
717 .into_dyn();
718 hmm.set_initial_distribution(initial).unwrap();
719
720 let transition = Array::from_shape_vec(vec![2, 2], vec![0.7, 0.3, 0.4, 0.6])
721 .unwrap()
722 .into_dyn();
723 hmm.set_transition_matrix(transition).unwrap();
724
725 let emission = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
726 .unwrap()
727 .into_dyn();
728 hmm.set_emission_matrix(emission).unwrap();
729
730 let observations = vec![0, 1, 0];
732 let result = hmm.viterbi(&observations);
733 assert!(result.is_ok());
734
735 let sequence = result.unwrap();
736 assert_eq!(sequence.len(), 3);
737 for state in sequence {
739 assert!(state < 2);
740 }
741 }
742
743 #[test]
744 fn test_mrf_creation() {
745 let mut mrf = MarkovRandomField::new();
746 mrf.add_variable("x".to_string(), 2);
747 mrf.add_variable("y".to_string(), 2);
748
749 let potential = Array::from_shape_vec(vec![2, 2], vec![1.0, 0.5, 0.5, 1.0])
750 .unwrap()
751 .into_dyn();
752 mrf.add_pairwise_potential("x".to_string(), "y".to_string(), potential)
753 .unwrap();
754
755 assert_eq!(mrf.graph().num_factors(), 1);
756 }
757
758 #[test]
759 fn test_crf_creation() {
760 let mut crf = ConditionalRandomField::new();
761 crf.add_input_variable("x".to_string(), 3);
762 crf.add_output_variable("y".to_string(), 2);
763
764 let feature = Array::from_shape_vec(vec![3, 2], vec![1.0, 0.5, 0.8, 0.2, 0.6, 0.4])
765 .unwrap()
766 .into_dyn();
767 crf.add_feature(
768 "f1".to_string(),
769 vec!["x".to_string(), "y".to_string()],
770 feature,
771 )
772 .unwrap();
773
774 assert_eq!(crf.graph().num_factors(), 1);
775 }
776}