1use scirs2_core::ndarray::{ArrayD, IxDyn};
15use std::collections::{HashMap, HashSet};
16
17use crate::error::{PgmError, Result};
18use crate::message_passing::MessagePassingAlgorithm;
19use crate::{Factor, FactorGraph, SumProductAlgorithm, VariableElimination};
20
21#[derive(Debug, Clone)]
38pub struct DynamicBayesianNetwork {
39 pub state_vars: Vec<(String, usize)>,
41 pub observation_vars: Vec<(String, usize)>,
43 pub initial_dists: HashMap<String, ArrayD<f64>>,
45 pub transition_dists: HashMap<String, ArrayD<f64>>,
47 pub emission_dists: HashMap<String, ArrayD<f64>>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53pub struct TemporalVar {
54 pub name: String,
56 pub time: usize,
58}
59
60impl std::fmt::Display for TemporalVar {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}_{}", self.name, self.time)
63 }
64}
65
66impl DynamicBayesianNetwork {
67 pub fn new(state_vars: Vec<(String, usize)>, observation_vars: Vec<(String, usize)>) -> Self {
69 Self {
70 state_vars,
71 observation_vars,
72 initial_dists: HashMap::new(),
73 transition_dists: HashMap::new(),
74 emission_dists: HashMap::new(),
75 }
76 }
77
78 pub fn set_initial(&mut self, var: &str, dist: ArrayD<f64>) -> Result<&mut Self> {
80 if !self.state_vars.iter().any(|(name, _)| name == var) {
81 return Err(PgmError::VariableNotFound(var.to_string()));
82 }
83 self.initial_dists.insert(var.to_string(), dist);
84 Ok(self)
85 }
86
87 pub fn set_transition(&mut self, var: &str, dist: ArrayD<f64>) -> Result<&mut Self> {
89 if !self.state_vars.iter().any(|(name, _)| name == var) {
90 return Err(PgmError::VariableNotFound(var.to_string()));
91 }
92 self.transition_dists.insert(var.to_string(), dist);
93 Ok(self)
94 }
95
96 pub fn set_emission(&mut self, obs_var: &str, dist: ArrayD<f64>) -> Result<&mut Self> {
98 if !self
99 .observation_vars
100 .iter()
101 .any(|(name, _)| name == obs_var)
102 {
103 return Err(PgmError::VariableNotFound(obs_var.to_string()));
104 }
105 self.emission_dists.insert(obs_var.to_string(), dist);
106 Ok(self)
107 }
108
109 pub fn unroll(&self, num_steps: usize) -> Result<FactorGraph> {
113 if num_steps == 0 {
114 return Err(PgmError::InvalidDistribution(
115 "Number of steps must be positive".to_string(),
116 ));
117 }
118
119 let mut graph = FactorGraph::new();
120
121 for t in 0..num_steps {
123 for (var, card) in &self.state_vars {
125 let temporal_name = format!("{}_{}", var, t);
126 graph.add_variable_with_card(temporal_name, "State".to_string(), *card);
127 }
128
129 for (var, card) in &self.observation_vars {
131 let temporal_name = format!("{}_{}", var, t);
132 graph.add_variable_with_card(temporal_name, "Observation".to_string(), *card);
133 }
134 }
135
136 for (var, card) in &self.state_vars {
138 let temporal_name = format!("{}_{}", var, 0);
139 let dist = self.initial_dists.get(var).cloned().unwrap_or_else(|| {
140 ArrayD::from_elem(IxDyn(&[*card]), 1.0 / *card as f64)
142 });
143
144 let factor = Factor::new(format!("P0_{}", var), vec![temporal_name], dist)?;
145 graph.add_factor(factor)?;
146 }
147
148 for t in 1..num_steps {
150 for (var, card) in &self.state_vars {
151 let prev_name = format!("{}_{}", var, t - 1);
152 let curr_name = format!("{}_{}", var, t);
153
154 let dist = self.transition_dists.get(var).cloned().unwrap_or_else(|| {
155 let mut identity = ArrayD::zeros(IxDyn(&[*card, *card]));
157 for i in 0..*card {
158 identity[[i, i]] = 1.0;
159 }
160 identity
161 });
162
163 let factor =
164 Factor::new(format!("T{}_{}", t, var), vec![prev_name, curr_name], dist)?;
165 graph.add_factor(factor)?;
166 }
167 }
168
169 for t in 0..num_steps {
171 for (obs_var, _) in &self.observation_vars {
172 if let Some(dist) = self.emission_dists.get(obs_var) {
173 let mut factor_vars: Vec<String> = self
175 .state_vars
176 .iter()
177 .map(|(v, _)| format!("{}_{}", v, t))
178 .collect();
179 factor_vars.push(format!("{}_{}", obs_var, t));
180
181 let factor =
182 Factor::new(format!("E{}_{}", t, obs_var), factor_vars, dist.clone())?;
183 graph.add_factor(factor)?;
184 }
185 }
186 }
187
188 Ok(graph)
189 }
190
191 pub fn filter(
195 &self,
196 observations: &[HashMap<String, usize>],
197 ) -> Result<Vec<HashMap<String, ArrayD<f64>>>> {
198 let num_steps = observations.len();
199 if num_steps == 0 {
200 return Ok(Vec::new());
201 }
202
203 let graph = self.unroll(num_steps)?;
205
206 let mut evidence: HashMap<String, usize> = HashMap::new();
208 for (t, obs) in observations.iter().enumerate() {
209 for (var, &value) in obs {
210 let temporal_name = format!("{}_{}", var, t);
211 evidence.insert(temporal_name, value);
212 }
213 }
214
215 let ve = VariableElimination::default();
217 let mut results = Vec::new();
218
219 for t in 0..num_steps {
220 let mut marginals = HashMap::new();
221
222 for (var, _) in &self.state_vars {
223 let temporal_name = format!("{}_{}", var, t);
224 if let Ok(marginal) = ve.marginalize(&graph, &temporal_name) {
225 marginals.insert(var.clone(), marginal);
226 }
227 }
228
229 results.push(marginals);
230 }
231
232 Ok(results)
233 }
234
235 pub fn smooth(
239 &self,
240 observations: &[HashMap<String, usize>],
241 ) -> Result<Vec<HashMap<String, ArrayD<f64>>>> {
242 self.filter(observations)
245 }
246
247 pub fn viterbi(
249 &self,
250 observations: &[HashMap<String, usize>],
251 ) -> Result<Vec<HashMap<String, usize>>> {
252 let num_steps = observations.len();
253 if num_steps == 0 {
254 return Ok(Vec::new());
255 }
256
257 let graph = self.unroll(num_steps)?;
259
260 let mut evidence: HashMap<String, usize> = HashMap::new();
262 for (t, obs) in observations.iter().enumerate() {
263 for (var, &value) in obs {
264 let temporal_name = format!("{}_{}", var, t);
265 evidence.insert(temporal_name, value);
266 }
267 }
268
269 let ve = VariableElimination::default();
271
272 let mut results = Vec::new();
273
274 for t in 0..num_steps {
275 let mut state = HashMap::new();
276
277 for (var, _) in &self.state_vars {
278 let temporal_name = format!("{}_{}", var, t);
279 if let Ok(marginal) = ve.marginalize(&graph, &temporal_name) {
280 let max_idx = marginal
282 .iter()
283 .enumerate()
284 .max_by(|(_, a), (_, b)| {
285 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
286 })
287 .map(|(idx, _)| idx)
288 .unwrap_or(0);
289 state.insert(var.clone(), max_idx);
290 }
291 }
292
293 results.push(state);
294 }
295
296 Ok(results)
297 }
298
299 pub fn state_cardinalities(&self) -> HashMap<String, usize> {
301 self.state_vars.iter().cloned().collect()
302 }
303
304 pub fn observation_cardinalities(&self) -> HashMap<String, usize> {
306 self.observation_vars.iter().cloned().collect()
307 }
308
309 pub fn all_variables(&self) -> HashSet<String> {
311 let mut vars = HashSet::new();
312
313 for (var, _) in &self.state_vars {
314 vars.insert(var.clone());
315 }
316
317 for (var, _) in &self.observation_vars {
318 vars.insert(var.clone());
319 }
320
321 vars
322 }
323
324 pub fn run_belief_propagation(
326 &self,
327 num_steps: usize,
328 evidence: &HashMap<String, usize>,
329 ) -> Result<HashMap<String, ArrayD<f64>>> {
330 let graph = self.unroll(num_steps)?;
331
332 let mut temporal_evidence: HashMap<String, usize> = HashMap::new();
334 for (var, &value) in evidence {
335 if var.contains('_') {
337 temporal_evidence.insert(var.clone(), value);
338 } else {
339 temporal_evidence.insert(format!("{}_{}", var, num_steps - 1), value);
340 }
341 }
342
343 let algorithm = SumProductAlgorithm::new(100, 1e-6, 0.0);
345 algorithm.run(&graph)
346 }
347}
348
349pub struct DBNBuilder {
351 state_vars: Vec<(String, usize)>,
352 obs_vars: Vec<(String, usize)>,
353 initial: HashMap<String, ArrayD<f64>>,
354 transitions: HashMap<String, ArrayD<f64>>,
355 emissions: HashMap<String, ArrayD<f64>>,
356}
357
358impl Default for DBNBuilder {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364impl DBNBuilder {
365 pub fn new() -> Self {
367 Self {
368 state_vars: Vec::new(),
369 obs_vars: Vec::new(),
370 initial: HashMap::new(),
371 transitions: HashMap::new(),
372 emissions: HashMap::new(),
373 }
374 }
375
376 pub fn add_state_var(mut self, name: String, cardinality: usize) -> Self {
378 self.state_vars.push((name, cardinality));
379 self
380 }
381
382 pub fn add_observation_var(mut self, name: String, cardinality: usize) -> Self {
384 self.obs_vars.push((name, cardinality));
385 self
386 }
387
388 pub fn set_initial(mut self, var: &str, dist: ArrayD<f64>) -> Self {
390 self.initial.insert(var.to_string(), dist);
391 self
392 }
393
394 pub fn set_transition(mut self, var: &str, dist: ArrayD<f64>) -> Self {
396 self.transitions.insert(var.to_string(), dist);
397 self
398 }
399
400 pub fn set_emission(mut self, obs_var: &str, dist: ArrayD<f64>) -> Self {
402 self.emissions.insert(obs_var.to_string(), dist);
403 self
404 }
405
406 pub fn build(self) -> Result<DynamicBayesianNetwork> {
408 let mut dbn = DynamicBayesianNetwork::new(self.state_vars, self.obs_vars);
409
410 for (var, dist) in self.initial {
411 dbn.set_initial(&var, dist)?;
412 }
413
414 for (var, dist) in self.transitions {
415 dbn.set_transition(&var, dist)?;
416 }
417
418 for (var, dist) in self.emissions {
419 dbn.set_emission(&var, dist)?;
420 }
421
422 Ok(dbn)
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct CoupledDBN {
429 pub processes: Vec<DynamicBayesianNetwork>,
431 pub couplings: Vec<CouplingFactor>,
433}
434
435#[derive(Debug, Clone)]
437pub struct CouplingFactor {
438 pub process_indices: Vec<usize>,
440 pub variables: Vec<String>,
442 pub potential: ArrayD<f64>,
444}
445
446impl CoupledDBN {
447 pub fn new(processes: Vec<DynamicBayesianNetwork>) -> Self {
449 Self {
450 processes,
451 couplings: Vec::new(),
452 }
453 }
454
455 pub fn add_coupling(&mut self, coupling: CouplingFactor) {
457 self.couplings.push(coupling);
458 }
459
460 pub fn unroll(&self, num_steps: usize) -> Result<FactorGraph> {
462 let mut graph = FactorGraph::new();
463
464 for (i, process) in self.processes.iter().enumerate() {
466 let process_graph = process.unroll(num_steps)?;
467
468 for var_name in process_graph.variable_names() {
470 let full_name = format!("p{}_{}", i, var_name);
471 if let Some(var) = process_graph.get_variable(var_name) {
472 graph.add_variable_with_card(full_name, var.domain.clone(), var.cardinality);
473 }
474 }
475
476 for factor_id in process_graph.factor_ids() {
478 if let Some(factor) = process_graph.get_factor(factor_id) {
479 let new_vars: Vec<String> = factor
480 .variables
481 .iter()
482 .map(|v| format!("p{}_{}", i, v))
483 .collect();
484
485 let new_factor = Factor::new(
486 format!("p{}_{}", i, factor.name),
487 new_vars,
488 factor.values.clone(),
489 )?;
490
491 graph.add_factor(new_factor)?;
492 }
493 }
494 }
495
496 for (i, coupling) in self.couplings.iter().enumerate() {
498 let coupled_vars: Vec<String> = coupling
499 .variables
500 .iter()
501 .enumerate()
502 .map(|(j, v)| {
503 if j < coupling.process_indices.len() {
504 format!("p{}_{}", coupling.process_indices[j], v)
505 } else {
506 v.clone()
507 }
508 })
509 .collect();
510
511 let coupling_factor = Factor::new(
512 format!("coupling_{}", i),
513 coupled_vars,
514 coupling.potential.clone(),
515 )?;
516
517 graph.add_factor(coupling_factor)?;
518 }
519
520 Ok(graph)
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use scirs2_core::ndarray::Array;
528
529 #[test]
530 fn test_dbn_creation() {
531 let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
532
533 assert_eq!(dbn.state_vars.len(), 1);
534 assert_eq!(dbn.observation_vars.len(), 0);
535 }
536
537 #[test]
538 fn test_dbn_set_distributions() {
539 let mut dbn = DynamicBayesianNetwork::new(
540 vec![("state".to_string(), 2)],
541 vec![("obs".to_string(), 3)],
542 );
543
544 let initial = Array::from_vec(vec![0.6, 0.4]).into_dyn();
545 dbn.set_initial("state", initial).unwrap();
546
547 let transition = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![0.7, 0.3, 0.4, 0.6]).unwrap();
548 dbn.set_transition("state", transition).unwrap();
549
550 assert!(dbn.initial_dists.contains_key("state"));
551 assert!(dbn.transition_dists.contains_key("state"));
552 }
553
554 #[test]
555 fn test_dbn_unroll() {
556 let mut dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
557
558 let initial = Array::from_vec(vec![0.6, 0.4]).into_dyn();
559 dbn.set_initial("state", initial).unwrap();
560
561 let graph = dbn.unroll(3).unwrap();
562
563 assert!(graph.get_variable("state_0").is_some());
565 assert!(graph.get_variable("state_1").is_some());
566 assert!(graph.get_variable("state_2").is_some());
567 }
568
569 #[test]
570 fn test_dbn_builder() {
571 let dbn = DBNBuilder::new()
572 .add_state_var("weather".to_string(), 2)
573 .add_observation_var("umbrella".to_string(), 2)
574 .set_initial("weather", Array::from_vec(vec![0.5, 0.5]).into_dyn())
575 .set_transition(
576 "weather",
577 ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![0.7, 0.3, 0.3, 0.7]).unwrap(),
578 )
579 .build()
580 .unwrap();
581
582 assert_eq!(dbn.state_vars.len(), 1);
583 assert_eq!(dbn.observation_vars.len(), 1);
584 }
585
586 #[test]
587 fn test_dbn_state_cardinalities() {
588 let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 3)], vec![]);
589
590 let cards = dbn.state_cardinalities();
591 assert_eq!(cards.get("state"), Some(&3));
592 }
593
594 #[test]
595 fn test_dbn_all_variables() {
596 let dbn = DynamicBayesianNetwork::new(
597 vec![("x".to_string(), 2), ("y".to_string(), 2)],
598 vec![("obs".to_string(), 3)],
599 );
600
601 let vars = dbn.all_variables();
602 assert!(vars.contains("x"));
603 assert!(vars.contains("y"));
604 assert!(vars.contains("obs"));
605 }
606
607 #[test]
608 fn test_coupled_dbn() {
609 let dbn1 = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
610
611 let dbn2 = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
612
613 let coupled = CoupledDBN::new(vec![dbn1, dbn2]);
614
615 assert_eq!(coupled.processes.len(), 2);
616 }
617
618 #[test]
619 fn test_temporal_var_display() {
620 let tv = TemporalVar {
621 name: "state".to_string(),
622 time: 3,
623 };
624
625 assert_eq!(format!("{}", tv), "state_3");
626 }
627
628 #[test]
629 fn test_dbn_filter_empty() {
630 let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
631
632 let results = dbn.filter(&[]).unwrap();
633 assert!(results.is_empty());
634 }
635
636 #[test]
637 fn test_dbn_viterbi_empty() {
638 let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
639
640 let results = dbn.viterbi(&[]).unwrap();
641 assert!(results.is_empty());
642 }
643
644 #[test]
645 fn test_dbn_unroll_zero_steps() {
646 let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
647
648 let result = dbn.unroll(0);
649 assert!(result.is_err());
650 }
651
652 #[test]
653 fn test_dbn_set_invalid_var() {
654 let mut dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
655
656 let dist = Array::from_vec(vec![0.5, 0.5]).into_dyn();
657 let result = dbn.set_initial("invalid", dist);
658 assert!(result.is_err());
659 }
660}