Skip to main content

tensorlogic_quantrs_hooks/
dbn.rs

1//! Dynamic Bayesian Networks (DBN) for temporal probabilistic models.
2//!
3//! This module provides support for Dynamic Bayesian Networks, which are
4//! generalizations of Hidden Markov Models to handle multiple interacting
5//! variables over time.
6//!
7//! # Structure
8//!
9//! A DBN consists of:
10//! - Initial (prior) network at t=0
11//! - Two-time-slice transition defining state evolution
12//! - Interface variables connecting adjacent time slices
13
14use scirs2_core::ndarray::{ArrayD, IxDyn};
15use std::collections::{HashMap, HashSet};
16
17use crate::error::{PgmError, Result};
18use crate::message_passing::MessagePassingAlgorithm;
19use crate::{Factor, FactorGraph, SumProductAlgorithm, VariableElimination};
20
21/// Dynamic Bayesian Network.
22///
23/// Represents a temporal probabilistic model with variables that evolve over time.
24///
25/// # Example
26///
27/// ```
28/// use tensorlogic_quantrs_hooks::DynamicBayesianNetwork;
29/// use scirs2_core::ndarray::{Array, ArrayD, IxDyn};
30///
31/// // Create a simple DBN with one state variable
32/// let dbn = DynamicBayesianNetwork::new(
33///     vec![("state".to_string(), 2)],  // state variables with cardinality
34///     vec![],  // no observation variables
35/// );
36/// ```
37#[derive(Debug, Clone)]
38pub struct DynamicBayesianNetwork {
39    /// State variables with cardinalities
40    pub state_vars: Vec<(String, usize)>,
41    /// Observation variables with cardinalities
42    pub observation_vars: Vec<(String, usize)>,
43    /// Initial distribution P(X_0) for each state variable
44    pub initial_dists: HashMap<String, ArrayD<f64>>,
45    /// Transition distributions P(X_t | X_{t-1})
46    pub transition_dists: HashMap<String, ArrayD<f64>>,
47    /// Emission distributions P(Y_t | X_t)
48    pub emission_dists: HashMap<String, ArrayD<f64>>,
49}
50
51/// Temporal variable representing a variable at a specific time step.
52#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53pub struct TemporalVar {
54    /// Base variable name
55    pub name: String,
56    /// Time step
57    pub time: usize,
58}
59
60impl std::fmt::Display for TemporalVar {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "{}_{}", self.name, self.time)
63    }
64}
65
66impl DynamicBayesianNetwork {
67    /// Create a new DBN with state and observation variables.
68    pub fn new(state_vars: Vec<(String, usize)>, observation_vars: Vec<(String, usize)>) -> Self {
69        Self {
70            state_vars,
71            observation_vars,
72            initial_dists: HashMap::new(),
73            transition_dists: HashMap::new(),
74            emission_dists: HashMap::new(),
75        }
76    }
77
78    /// Set initial distribution for a state variable.
79    pub fn set_initial(&mut self, var: &str, dist: ArrayD<f64>) -> Result<&mut Self> {
80        if !self.state_vars.iter().any(|(name, _)| name == var) {
81            return Err(PgmError::VariableNotFound(var.to_string()));
82        }
83        self.initial_dists.insert(var.to_string(), dist);
84        Ok(self)
85    }
86
87    /// Set transition distribution P(X_t | X_{t-1}) for a state variable.
88    pub fn set_transition(&mut self, var: &str, dist: ArrayD<f64>) -> Result<&mut Self> {
89        if !self.state_vars.iter().any(|(name, _)| name == var) {
90            return Err(PgmError::VariableNotFound(var.to_string()));
91        }
92        self.transition_dists.insert(var.to_string(), dist);
93        Ok(self)
94    }
95
96    /// Set emission distribution P(Y | X) for an observation variable.
97    pub fn set_emission(&mut self, obs_var: &str, dist: ArrayD<f64>) -> Result<&mut Self> {
98        if !self
99            .observation_vars
100            .iter()
101            .any(|(name, _)| name == obs_var)
102        {
103            return Err(PgmError::VariableNotFound(obs_var.to_string()));
104        }
105        self.emission_dists.insert(obs_var.to_string(), dist);
106        Ok(self)
107    }
108
109    /// Unroll the DBN for a fixed number of time steps.
110    ///
111    /// Returns a FactorGraph representing the unrolled DBN.
112    pub fn unroll(&self, num_steps: usize) -> Result<FactorGraph> {
113        if num_steps == 0 {
114            return Err(PgmError::InvalidDistribution(
115                "Number of steps must be positive".to_string(),
116            ));
117        }
118
119        let mut graph = FactorGraph::new();
120
121        // Add all variables for all time steps
122        for t in 0..num_steps {
123            // State variables
124            for (var, card) in &self.state_vars {
125                let temporal_name = format!("{}_{}", var, t);
126                graph.add_variable_with_card(temporal_name, "State".to_string(), *card);
127            }
128
129            // Observation variables
130            for (var, card) in &self.observation_vars {
131                let temporal_name = format!("{}_{}", var, t);
132                graph.add_variable_with_card(temporal_name, "Observation".to_string(), *card);
133            }
134        }
135
136        // Add initial factors (t=0)
137        for (var, card) in &self.state_vars {
138            let temporal_name = format!("{}_{}", var, 0);
139            let dist = self.initial_dists.get(var).cloned().unwrap_or_else(|| {
140                // Default to uniform
141                ArrayD::from_elem(IxDyn(&[*card]), 1.0 / *card as f64)
142            });
143
144            let factor = Factor::new(format!("P0_{}", var), vec![temporal_name], dist)?;
145            graph.add_factor(factor)?;
146        }
147
148        // Add transition factors (t=1 to num_steps-1)
149        for t in 1..num_steps {
150            for (var, card) in &self.state_vars {
151                let prev_name = format!("{}_{}", var, t - 1);
152                let curr_name = format!("{}_{}", var, t);
153
154                let dist = self.transition_dists.get(var).cloned().unwrap_or_else(|| {
155                    // Default to identity transition
156                    let mut identity = ArrayD::zeros(IxDyn(&[*card, *card]));
157                    for i in 0..*card {
158                        identity[[i, i]] = 1.0;
159                    }
160                    identity
161                });
162
163                let factor =
164                    Factor::new(format!("T{}_{}", t, var), vec![prev_name, curr_name], dist)?;
165                graph.add_factor(factor)?;
166            }
167        }
168
169        // Add emission factors (all time steps)
170        for t in 0..num_steps {
171            for (obs_var, _) in &self.observation_vars {
172                if let Some(dist) = self.emission_dists.get(obs_var) {
173                    // Emission depends on state variables
174                    let mut factor_vars: Vec<String> = self
175                        .state_vars
176                        .iter()
177                        .map(|(v, _)| format!("{}_{}", v, t))
178                        .collect();
179                    factor_vars.push(format!("{}_{}", obs_var, t));
180
181                    let factor =
182                        Factor::new(format!("E{}_{}", t, obs_var), factor_vars, dist.clone())?;
183                    graph.add_factor(factor)?;
184                }
185            }
186        }
187
188        Ok(graph)
189    }
190
191    /// Perform filtering to compute P(X_t | y_{1:t}).
192    ///
193    /// Returns marginal distributions for state variables at each time step.
194    pub fn filter(
195        &self,
196        observations: &[HashMap<String, usize>],
197    ) -> Result<Vec<HashMap<String, ArrayD<f64>>>> {
198        let num_steps = observations.len();
199        if num_steps == 0 {
200            return Ok(Vec::new());
201        }
202
203        // Unroll the DBN
204        let graph = self.unroll(num_steps)?;
205
206        // Set evidence
207        let mut evidence: HashMap<String, usize> = HashMap::new();
208        for (t, obs) in observations.iter().enumerate() {
209            for (var, &value) in obs {
210                let temporal_name = format!("{}_{}", var, t);
211                evidence.insert(temporal_name, value);
212            }
213        }
214
215        // Run inference for each time step
216        let ve = VariableElimination::default();
217        let mut results = Vec::new();
218
219        for t in 0..num_steps {
220            let mut marginals = HashMap::new();
221
222            for (var, _) in &self.state_vars {
223                let temporal_name = format!("{}_{}", var, t);
224                if let Ok(marginal) = ve.marginalize(&graph, &temporal_name) {
225                    marginals.insert(var.clone(), marginal);
226                }
227            }
228
229            results.push(marginals);
230        }
231
232        Ok(results)
233    }
234
235    /// Perform smoothing to compute P(X_t | y_{1:T}) for all t.
236    ///
237    /// Uses variable elimination on the unrolled DBN.
238    pub fn smooth(
239        &self,
240        observations: &[HashMap<String, usize>],
241    ) -> Result<Vec<HashMap<String, ArrayD<f64>>>> {
242        // For smoothing, we need all evidence before computing marginals
243        // The implementation is the same as filter for exact inference
244        self.filter(observations)
245    }
246
247    /// Compute most likely sequence using Viterbi algorithm on unrolled DBN.
248    pub fn viterbi(
249        &self,
250        observations: &[HashMap<String, usize>],
251    ) -> Result<Vec<HashMap<String, usize>>> {
252        let num_steps = observations.len();
253        if num_steps == 0 {
254            return Ok(Vec::new());
255        }
256
257        // Unroll the DBN
258        let graph = self.unroll(num_steps)?;
259
260        // Set evidence
261        let mut evidence: HashMap<String, usize> = HashMap::new();
262        for (t, obs) in observations.iter().enumerate() {
263            for (var, &value) in obs {
264                let temporal_name = format!("{}_{}", var, t);
265                evidence.insert(temporal_name, value);
266            }
267        }
268
269        // Run MAP inference using marginalization
270        let ve = VariableElimination::default();
271
272        let mut results = Vec::new();
273
274        for t in 0..num_steps {
275            let mut state = HashMap::new();
276
277            for (var, _) in &self.state_vars {
278                let temporal_name = format!("{}_{}", var, t);
279                if let Ok(marginal) = ve.marginalize(&graph, &temporal_name) {
280                    // Get argmax
281                    let max_idx = marginal
282                        .iter()
283                        .enumerate()
284                        .max_by(|(_, a), (_, b)| {
285                            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
286                        })
287                        .map(|(idx, _)| idx)
288                        .unwrap_or(0);
289                    state.insert(var.clone(), max_idx);
290                }
291            }
292
293            results.push(state);
294        }
295
296        Ok(results)
297    }
298
299    /// Get interface (state) variable cardinalities.
300    pub fn state_cardinalities(&self) -> HashMap<String, usize> {
301        self.state_vars.iter().cloned().collect()
302    }
303
304    /// Get observation variable cardinalities.
305    pub fn observation_cardinalities(&self) -> HashMap<String, usize> {
306        self.observation_vars.iter().cloned().collect()
307    }
308
309    /// Get all variables in the DBN.
310    pub fn all_variables(&self) -> HashSet<String> {
311        let mut vars = HashSet::new();
312
313        for (var, _) in &self.state_vars {
314            vars.insert(var.clone());
315        }
316
317        for (var, _) in &self.observation_vars {
318            vars.insert(var.clone());
319        }
320
321        vars
322    }
323
324    /// Run belief propagation on the unrolled DBN.
325    pub fn run_belief_propagation(
326        &self,
327        num_steps: usize,
328        evidence: &HashMap<String, usize>,
329    ) -> Result<HashMap<String, ArrayD<f64>>> {
330        let graph = self.unroll(num_steps)?;
331
332        // Convert evidence to temporal format
333        let mut temporal_evidence: HashMap<String, usize> = HashMap::new();
334        for (var, &value) in evidence {
335            // Assume evidence is for the last time step if no time suffix
336            if var.contains('_') {
337                temporal_evidence.insert(var.clone(), value);
338            } else {
339                temporal_evidence.insert(format!("{}_{}", var, num_steps - 1), value);
340            }
341        }
342
343        // Run sum-product
344        let algorithm = SumProductAlgorithm::new(100, 1e-6, 0.0);
345        algorithm.run(&graph)
346    }
347}
348
349/// Builder for creating DBNs with fluent API.
350pub struct DBNBuilder {
351    state_vars: Vec<(String, usize)>,
352    obs_vars: Vec<(String, usize)>,
353    initial: HashMap<String, ArrayD<f64>>,
354    transitions: HashMap<String, ArrayD<f64>>,
355    emissions: HashMap<String, ArrayD<f64>>,
356}
357
358impl Default for DBNBuilder {
359    fn default() -> Self {
360        Self::new()
361    }
362}
363
364impl DBNBuilder {
365    /// Create a new DBN builder.
366    pub fn new() -> Self {
367        Self {
368            state_vars: Vec::new(),
369            obs_vars: Vec::new(),
370            initial: HashMap::new(),
371            transitions: HashMap::new(),
372            emissions: HashMap::new(),
373        }
374    }
375
376    /// Add a state variable.
377    pub fn add_state_var(mut self, name: String, cardinality: usize) -> Self {
378        self.state_vars.push((name, cardinality));
379        self
380    }
381
382    /// Add an observation variable.
383    pub fn add_observation_var(mut self, name: String, cardinality: usize) -> Self {
384        self.obs_vars.push((name, cardinality));
385        self
386    }
387
388    /// Set initial distribution for a state variable.
389    pub fn set_initial(mut self, var: &str, dist: ArrayD<f64>) -> Self {
390        self.initial.insert(var.to_string(), dist);
391        self
392    }
393
394    /// Set transition distribution P(X_t | X_{t-1}).
395    pub fn set_transition(mut self, var: &str, dist: ArrayD<f64>) -> Self {
396        self.transitions.insert(var.to_string(), dist);
397        self
398    }
399
400    /// Set emission distribution P(Y_t | X_t).
401    pub fn set_emission(mut self, obs_var: &str, dist: ArrayD<f64>) -> Self {
402        self.emissions.insert(obs_var.to_string(), dist);
403        self
404    }
405
406    /// Build the DBN.
407    pub fn build(self) -> Result<DynamicBayesianNetwork> {
408        let mut dbn = DynamicBayesianNetwork::new(self.state_vars, self.obs_vars);
409
410        for (var, dist) in self.initial {
411            dbn.set_initial(&var, dist)?;
412        }
413
414        for (var, dist) in self.transitions {
415            dbn.set_transition(&var, dist)?;
416        }
417
418        for (var, dist) in self.emissions {
419            dbn.set_emission(&var, dist)?;
420        }
421
422        Ok(dbn)
423    }
424}
425
426/// Coupled DBN with multiple interacting processes.
427#[derive(Debug, Clone)]
428pub struct CoupledDBN {
429    /// Individual DBN processes
430    pub processes: Vec<DynamicBayesianNetwork>,
431    /// Coupling factors between processes
432    pub couplings: Vec<CouplingFactor>,
433}
434
435/// Coupling factor between DBN processes.
436#[derive(Debug, Clone)]
437pub struct CouplingFactor {
438    /// Process indices involved
439    pub process_indices: Vec<usize>,
440    /// Variables involved
441    pub variables: Vec<String>,
442    /// Coupling potential
443    pub potential: ArrayD<f64>,
444}
445
446impl CoupledDBN {
447    /// Create a new coupled DBN.
448    pub fn new(processes: Vec<DynamicBayesianNetwork>) -> Self {
449        Self {
450            processes,
451            couplings: Vec::new(),
452        }
453    }
454
455    /// Add a coupling factor.
456    pub fn add_coupling(&mut self, coupling: CouplingFactor) {
457        self.couplings.push(coupling);
458    }
459
460    /// Unroll the coupled DBN.
461    pub fn unroll(&self, num_steps: usize) -> Result<FactorGraph> {
462        let mut graph = FactorGraph::new();
463
464        // Unroll each process
465        for (i, process) in self.processes.iter().enumerate() {
466            let process_graph = process.unroll(num_steps)?;
467
468            // Add variables with process prefix
469            for var_name in process_graph.variable_names() {
470                let full_name = format!("p{}_{}", i, var_name);
471                if let Some(var) = process_graph.get_variable(var_name) {
472                    graph.add_variable_with_card(full_name, var.domain.clone(), var.cardinality);
473                }
474            }
475
476            // Add factors with process prefix
477            for factor_id in process_graph.factor_ids() {
478                if let Some(factor) = process_graph.get_factor(factor_id) {
479                    let new_vars: Vec<String> = factor
480                        .variables
481                        .iter()
482                        .map(|v| format!("p{}_{}", i, v))
483                        .collect();
484
485                    let new_factor = Factor::new(
486                        format!("p{}_{}", i, factor.name),
487                        new_vars,
488                        factor.values.clone(),
489                    )?;
490
491                    graph.add_factor(new_factor)?;
492                }
493            }
494        }
495
496        // Add coupling factors
497        for (i, coupling) in self.couplings.iter().enumerate() {
498            let coupled_vars: Vec<String> = coupling
499                .variables
500                .iter()
501                .enumerate()
502                .map(|(j, v)| {
503                    if j < coupling.process_indices.len() {
504                        format!("p{}_{}", coupling.process_indices[j], v)
505                    } else {
506                        v.clone()
507                    }
508                })
509                .collect();
510
511            let coupling_factor = Factor::new(
512                format!("coupling_{}", i),
513                coupled_vars,
514                coupling.potential.clone(),
515            )?;
516
517            graph.add_factor(coupling_factor)?;
518        }
519
520        Ok(graph)
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use scirs2_core::ndarray::Array;
528
529    #[test]
530    fn test_dbn_creation() {
531        let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
532
533        assert_eq!(dbn.state_vars.len(), 1);
534        assert_eq!(dbn.observation_vars.len(), 0);
535    }
536
537    #[test]
538    fn test_dbn_set_distributions() {
539        let mut dbn = DynamicBayesianNetwork::new(
540            vec![("state".to_string(), 2)],
541            vec![("obs".to_string(), 3)],
542        );
543
544        let initial = Array::from_vec(vec![0.6, 0.4]).into_dyn();
545        dbn.set_initial("state", initial).unwrap();
546
547        let transition = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![0.7, 0.3, 0.4, 0.6]).unwrap();
548        dbn.set_transition("state", transition).unwrap();
549
550        assert!(dbn.initial_dists.contains_key("state"));
551        assert!(dbn.transition_dists.contains_key("state"));
552    }
553
554    #[test]
555    fn test_dbn_unroll() {
556        let mut dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
557
558        let initial = Array::from_vec(vec![0.6, 0.4]).into_dyn();
559        dbn.set_initial("state", initial).unwrap();
560
561        let graph = dbn.unroll(3).unwrap();
562
563        // Should have 3 time steps
564        assert!(graph.get_variable("state_0").is_some());
565        assert!(graph.get_variable("state_1").is_some());
566        assert!(graph.get_variable("state_2").is_some());
567    }
568
569    #[test]
570    fn test_dbn_builder() {
571        let dbn = DBNBuilder::new()
572            .add_state_var("weather".to_string(), 2)
573            .add_observation_var("umbrella".to_string(), 2)
574            .set_initial("weather", Array::from_vec(vec![0.5, 0.5]).into_dyn())
575            .set_transition(
576                "weather",
577                ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![0.7, 0.3, 0.3, 0.7]).unwrap(),
578            )
579            .build()
580            .unwrap();
581
582        assert_eq!(dbn.state_vars.len(), 1);
583        assert_eq!(dbn.observation_vars.len(), 1);
584    }
585
586    #[test]
587    fn test_dbn_state_cardinalities() {
588        let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 3)], vec![]);
589
590        let cards = dbn.state_cardinalities();
591        assert_eq!(cards.get("state"), Some(&3));
592    }
593
594    #[test]
595    fn test_dbn_all_variables() {
596        let dbn = DynamicBayesianNetwork::new(
597            vec![("x".to_string(), 2), ("y".to_string(), 2)],
598            vec![("obs".to_string(), 3)],
599        );
600
601        let vars = dbn.all_variables();
602        assert!(vars.contains("x"));
603        assert!(vars.contains("y"));
604        assert!(vars.contains("obs"));
605    }
606
607    #[test]
608    fn test_coupled_dbn() {
609        let dbn1 = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
610
611        let dbn2 = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
612
613        let coupled = CoupledDBN::new(vec![dbn1, dbn2]);
614
615        assert_eq!(coupled.processes.len(), 2);
616    }
617
618    #[test]
619    fn test_temporal_var_display() {
620        let tv = TemporalVar {
621            name: "state".to_string(),
622            time: 3,
623        };
624
625        assert_eq!(format!("{}", tv), "state_3");
626    }
627
628    #[test]
629    fn test_dbn_filter_empty() {
630        let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
631
632        let results = dbn.filter(&[]).unwrap();
633        assert!(results.is_empty());
634    }
635
636    #[test]
637    fn test_dbn_viterbi_empty() {
638        let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
639
640        let results = dbn.viterbi(&[]).unwrap();
641        assert!(results.is_empty());
642    }
643
644    #[test]
645    fn test_dbn_unroll_zero_steps() {
646        let dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
647
648        let result = dbn.unroll(0);
649        assert!(result.is_err());
650    }
651
652    #[test]
653    fn test_dbn_set_invalid_var() {
654        let mut dbn = DynamicBayesianNetwork::new(vec![("state".to_string(), 2)], vec![]);
655
656        let dist = Array::from_vec(vec![0.5, 0.5]).into_dyn();
657        let result = dbn.set_initial("invalid", dist);
658        assert!(result.is_err());
659    }
660}