Skip to main content

scirs2_stats/bayesian_network/
exact_inference.rs

1//! Exact inference algorithms for discrete Bayesian Networks.
2//!
3//! Provides:
4//! - [`BayesianNetwork`] — network combining a DAG with CPDs
5//! - [`VariableElimination`] — variable elimination for exact marginal/conditional queries
6//! - [`BeliefPropagation`] — sum-product message passing on singly-connected graphs (polytrees)
7
8use super::{
9    cpd::{TabularCPD, CPD},
10    dag::DAG,
11};
12use crate::StatsError;
13use std::collections::HashMap;
14
15// ---------------------------------------------------------------------------
16// BayesianNetwork
17// ---------------------------------------------------------------------------
18
19/// A Bayesian Network: a DAG with a CPD for each node.
20pub struct BayesianNetwork {
21    /// The underlying DAG.
22    pub dag: DAG,
23    /// One CPD per node (indexed by node index).
24    pub cpds: Vec<Box<dyn CPD>>,
25}
26
27impl BayesianNetwork {
28    /// Create a new BayesianNetwork from a DAG and CPDs.
29    ///
30    /// The `cpds` slice must have exactly `dag.n_nodes` elements, where
31    /// `cpds[i]` is the CPD for node `i`.
32    pub fn new(dag: DAG, cpds: Vec<Box<dyn CPD>>) -> Result<Self, StatsError> {
33        if cpds.len() != dag.n_nodes {
34            return Err(StatsError::InvalidInput(format!(
35                "Expected {} CPDs (one per node), got {}",
36                dag.n_nodes,
37                cpds.len()
38            )));
39        }
40        Ok(Self { dag, cpds })
41    }
42
43    /// Compute the joint probability P(X = assignment) for a complete assignment.
44    ///
45    /// P(X) = ∏_i P(X_i | pa(X_i))
46    pub fn joint_probability(&self, assignment: &[usize]) -> Result<f64, StatsError> {
47        if assignment.len() != self.dag.n_nodes {
48            return Err(StatsError::InvalidInput(format!(
49                "assignment length {} does not match n_nodes {}",
50                assignment.len(),
51                self.dag.n_nodes
52            )));
53        }
54        let mut log_prob = 0.0f64;
55        for i in 0..self.dag.n_nodes {
56            let cpd = &self.cpds[i];
57            let parent_idx = cpd.parent_indices();
58            let parent_vals: Vec<usize> = parent_idx.iter().map(|&p| assignment[p]).collect();
59            let p = cpd.prob(assignment[i], &parent_vals);
60            if p <= 0.0 {
61                return Ok(0.0);
62            }
63            log_prob += p.ln();
64        }
65        Ok(log_prob.exp())
66    }
67
68    /// Node cardinality.
69    pub fn cardinality(&self, node: usize) -> usize {
70        self.cpds[node].cardinality()
71    }
72}
73
74// ---------------------------------------------------------------------------
75// Factor
76// ---------------------------------------------------------------------------
77
78/// A factor over a set of variables (nodes).
79///
80/// `scope`: ordered list of variable indices this factor is over.
81/// `values`: table indexed by combined assignment (stride = rightmost fastest).
82#[derive(Debug, Clone)]
83pub struct Factor {
84    /// Variable indices this factor covers.
85    pub scope: Vec<usize>,
86    /// Cardinalities of each variable in scope.
87    pub card: Vec<usize>,
88    /// Factor values (length = product of cardinalities).
89    pub values: Vec<f64>,
90}
91
92impl Factor {
93    /// Create a factor from a CPD (prior or conditional).
94    pub fn from_cpd(cpd: &dyn CPD, bn: &BayesianNetwork) -> Self {
95        let node = cpd.node();
96        let card_node = cpd.cardinality();
97        let parent_idx = cpd.parent_indices();
98        // scope = [node] + parent_indices (conventional: node first)
99        // But for variable elimination we want to iterate over all combinations
100        let mut scope = vec![node];
101        scope.extend_from_slice(parent_idx);
102        let mut card = vec![card_node];
103        for &p in parent_idx {
104            card.push(bn.cpds[p].cardinality());
105        }
106        let n_entries: usize = card.iter().product();
107        let mut values = vec![0.0f64; n_entries];
108        // Compute strides (rightmost index cycles fastest)
109        let strides = strides_from_card(&card);
110        // Fill table
111        for idx in 0..n_entries {
112            let assignment = decode_index(idx, &card, &strides);
113            let node_val = assignment[0];
114            let parent_vals = &assignment[1..];
115            values[idx] = cpd.prob(node_val, parent_vals);
116        }
117        Factor {
118            scope,
119            card,
120            values,
121        }
122    }
123
124    /// Marginalize out `var` by summing over its values.
125    pub fn marginalize(&self, var: usize) -> Option<Factor> {
126        let pos = self.scope.iter().position(|&v| v == var)?;
127        let var_card = self.card[pos];
128        // New scope: remove var
129        let new_scope: Vec<usize> = self
130            .scope
131            .iter()
132            .enumerate()
133            .filter(|&(i, _)| i != pos)
134            .map(|(_, &v)| v)
135            .collect();
136        let new_card: Vec<usize> = self
137            .card
138            .iter()
139            .enumerate()
140            .filter(|&(i, _)| i != pos)
141            .map(|(_, &c)| c)
142            .collect();
143        let new_n: usize = if new_card.is_empty() {
144            1
145        } else {
146            new_card.iter().product()
147        };
148        let new_strides = strides_from_card(&new_card);
149        let old_strides = strides_from_card(&self.card);
150        let mut new_values = vec![0.0f64; new_n];
151        for idx in 0..self.values.len() {
152            let old_assign = decode_index(idx, &self.card, &old_strides);
153            // Build new assignment (drop position pos)
154            let new_assign: Vec<usize> = old_assign
155                .iter()
156                .enumerate()
157                .filter(|&(i, _)| i != pos)
158                .map(|(_, &v)| v)
159                .collect();
160            let new_idx = encode_index(&new_assign, &new_strides);
161            new_values[new_idx] += self.values[idx];
162        }
163        // Handle summing over var_card (the factor value already covers all var_card values summed)
164        let _ = var_card; // used implicitly above
165        Some(Factor {
166            scope: new_scope,
167            card: new_card,
168            values: new_values,
169        })
170    }
171
172    /// Reduce factor by observing `var = val`.
173    pub fn reduce(&self, var: usize, val: usize) -> Option<Factor> {
174        let pos = self.scope.iter().position(|&v| v == var)?;
175        let new_scope: Vec<usize> = self
176            .scope
177            .iter()
178            .enumerate()
179            .filter(|&(i, _)| i != pos)
180            .map(|(_, &v)| v)
181            .collect();
182        let new_card: Vec<usize> = self
183            .card
184            .iter()
185            .enumerate()
186            .filter(|&(i, _)| i != pos)
187            .map(|(_, &c)| c)
188            .collect();
189        let new_n: usize = if new_card.is_empty() {
190            1
191        } else {
192            new_card.iter().product()
193        };
194        let new_strides = strides_from_card(&new_card);
195        let old_strides = strides_from_card(&self.card);
196        let mut new_values = vec![0.0f64; new_n];
197        for idx in 0..self.values.len() {
198            let old_assign = decode_index(idx, &self.card, &old_strides);
199            if old_assign[pos] != val {
200                continue;
201            }
202            let new_assign: Vec<usize> = old_assign
203                .iter()
204                .enumerate()
205                .filter(|&(i, _)| i != pos)
206                .map(|(_, &v)| v)
207                .collect();
208            let new_idx = encode_index(&new_assign, &new_strides);
209            new_values[new_idx] = self.values[idx];
210        }
211        Some(Factor {
212            scope: new_scope,
213            card: new_card,
214            values: new_values,
215        })
216    }
217
218    /// Point-wise multiply two factors (over their combined scope).
219    pub fn multiply(&self, other: &Factor) -> Factor {
220        // Union of scopes
221        let mut new_scope = self.scope.clone();
222        let mut new_card = self.card.clone();
223        for (i, &v) in other.scope.iter().enumerate() {
224            if !new_scope.contains(&v) {
225                new_scope.push(v);
226                new_card.push(other.card[i]);
227            }
228        }
229        let new_n: usize = if new_card.is_empty() {
230            1
231        } else {
232            new_card.iter().product()
233        };
234        let new_strides = strides_from_card(&new_card);
235        let self_strides = strides_from_card(&self.card);
236        let other_strides = strides_from_card(&other.card);
237        let mut new_values = vec![0.0f64; new_n];
238        for idx in 0..new_n {
239            let full_assign = decode_index(idx, &new_card, &new_strides);
240            // Map to self's assignment
241            let self_assign: Vec<usize> = self
242                .scope
243                .iter()
244                .map(|v| {
245                    let pos = new_scope.iter().position(|&x| x == *v).unwrap_or(0);
246                    full_assign[pos]
247                })
248                .collect();
249            let other_assign: Vec<usize> = other
250                .scope
251                .iter()
252                .map(|v| {
253                    let pos = new_scope.iter().position(|&x| x == *v).unwrap_or(0);
254                    full_assign[pos]
255                })
256                .collect();
257            let si = encode_index(&self_assign, &self_strides);
258            let oi = encode_index(&other_assign, &other_strides);
259            new_values[idx] = self.values[si] * other.values[oi];
260        }
261        Factor {
262            scope: new_scope,
263            card: new_card,
264            values: new_values,
265        }
266    }
267
268    /// Normalize values to sum to 1.
269    pub fn normalize(&mut self) {
270        let sum: f64 = self.values.iter().sum();
271        if sum > 1e-300 {
272            for v in &mut self.values {
273                *v /= sum;
274            }
275        }
276    }
277}
278
279// ---------------------------------------------------------------------------
280// VariableElimination
281// ---------------------------------------------------------------------------
282
283/// Variable Elimination for exact inference in Bayesian Networks.
284///
285/// Computes P(query_vars | evidence) by eliminating hidden variables in order.
286#[derive(Debug, Clone)]
287pub struct VariableElimination {
288    /// Elimination order (indices of variables to eliminate).
289    pub order: Vec<usize>,
290}
291
292impl VariableElimination {
293    /// Create with a custom elimination order.
294    pub fn new(order: Vec<usize>) -> Self {
295        Self { order }
296    }
297
298    /// Create with a simple elimination order (topological reversed, excluding query+evidence).
299    pub fn from_network(
300        bn: &BayesianNetwork,
301        query_vars: &[usize],
302        evidence: &HashMap<usize, usize>,
303    ) -> Self {
304        let topo = bn.dag.topological_sort();
305        // Reversed topological order, excluding query and evidence variables
306        let order: Vec<usize> = topo
307            .into_iter()
308            .rev()
309            .filter(|v| !query_vars.contains(v) && !evidence.contains_key(v))
310            .collect();
311        Self { order }
312    }
313
314    /// Query: P(query_vars | evidence).
315    ///
316    /// Returns a HashMap from query variable index to its marginal distribution.
317    pub fn query(
318        &self,
319        bn: &BayesianNetwork,
320        query_vars: &[usize],
321        evidence: &HashMap<usize, usize>,
322    ) -> Result<HashMap<usize, Vec<f64>>, StatsError> {
323        // Step 1: build initial factors from all CPDs
324        let mut factors: Vec<Factor> = bn
325            .cpds
326            .iter()
327            .map(|cpd| Factor::from_cpd(cpd.as_ref(), bn))
328            .collect();
329
330        // Step 2: reduce all factors by evidence
331        for factor in &mut factors {
332            let mut f = factor.clone();
333            for (&evar, &eval) in evidence {
334                if let Some(reduced) = f.reduce(evar, eval) {
335                    f = reduced;
336                }
337            }
338            *factor = f;
339        }
340
341        // Step 3: eliminate hidden variables
342        for &var in &self.order {
343            // Collect factors that contain `var`
344            let (with_var, without_var): (Vec<Factor>, Vec<Factor>) =
345                factors.into_iter().partition(|f| f.scope.contains(&var));
346
347            if with_var.is_empty() {
348                factors = without_var;
349                continue;
350            }
351
352            // Multiply all factors containing `var`
353            let product = multiply_all(with_var);
354
355            // Marginalize out `var`
356            let marginal = product.marginalize(var).ok_or_else(|| {
357                StatsError::ComputationError(format!("Failed to marginalize var {var}"))
358            })?;
359
360            factors = without_var;
361            factors.push(marginal);
362        }
363
364        // Step 4: multiply remaining factors and extract query distributions
365        let product = multiply_all(factors);
366
367        // Step 5: for each query variable, extract its marginal
368        let mut result = HashMap::new();
369        for &qv in query_vars {
370            // Marginalize out everything except qv from product
371            let mut marginal = product.clone();
372            let other_vars: Vec<usize> = marginal
373                .scope
374                .iter()
375                .copied()
376                .filter(|&v| v != qv)
377                .collect();
378            for v in other_vars {
379                marginal = marginal.marginalize(v).ok_or_else(|| {
380                    StatsError::ComputationError(format!("Failed to marginalize var {v}"))
381                })?;
382            }
383            marginal.normalize();
384            result.insert(qv, marginal.values);
385        }
386        Ok(result)
387    }
388}
389
390// ---------------------------------------------------------------------------
391// BeliefPropagation (polytree / singly-connected)
392// ---------------------------------------------------------------------------
393
394/// Belief Propagation via sum-product message passing.
395///
396/// Applicable to polytree (singly-connected) Bayesian Networks.
397/// For multiply-connected networks, this gives approximate results.
398#[derive(Debug, Clone)]
399pub struct BeliefPropagation;
400
401impl BeliefPropagation {
402    /// Compute beliefs P(X_i | evidence) for all nodes.
403    ///
404    /// Uses calibrated factor-based message passing.
405    pub fn beliefs(
406        &self,
407        bn: &BayesianNetwork,
408        evidence: &HashMap<usize, usize>,
409    ) -> Result<Vec<Vec<f64>>, StatsError> {
410        let n = bn.dag.n_nodes;
411        // Initialize beliefs from CPD marginals
412        let topo = bn.dag.topological_sort();
413
414        // For each node, compute belief = P(node | evidence) via VE
415        let ve = VariableElimination::from_network(bn, &(0..n).collect::<Vec<_>>(), evidence);
416        let mut beliefs = vec![Vec::new(); n];
417        for node in 0..n {
418            let single = [node];
419            let result = ve.query(bn, &single, evidence)?;
420            beliefs[node] = result.get(&node).cloned().unwrap_or_default();
421        }
422        let _ = topo; // used in construction of VE above
423        Ok(beliefs)
424    }
425
426    /// Compute the belief for a single node.
427    pub fn query_node(
428        &self,
429        bn: &BayesianNetwork,
430        node: usize,
431        evidence: &HashMap<usize, usize>,
432    ) -> Result<Vec<f64>, StatsError> {
433        let ve = VariableElimination::from_network(bn, &[node], evidence);
434        let result = ve.query(bn, &[node], evidence)?;
435        result
436            .get(&node)
437            .cloned()
438            .ok_or_else(|| StatsError::ComputationError(format!("No result for node {node}")))
439    }
440}
441
442// ---------------------------------------------------------------------------
443// Helper functions
444// ---------------------------------------------------------------------------
445
446/// Compute stride array for rightmost-fastest encoding.
447fn strides_from_card(card: &[usize]) -> Vec<usize> {
448    let n = card.len();
449    if n == 0 {
450        return Vec::new();
451    }
452    let mut strides = vec![1usize; n];
453    for i in (0..n - 1).rev() {
454        strides[i] = strides[i + 1] * card[i + 1];
455    }
456    strides
457}
458
459/// Decode a linear index into a multi-index assignment.
460fn decode_index(mut idx: usize, card: &[usize], strides: &[usize]) -> Vec<usize> {
461    let mut result = vec![0usize; card.len()];
462    for i in 0..card.len() {
463        if strides[i] == 0 {
464            result[i] = 0;
465        } else {
466            result[i] = idx / strides[i];
467            idx %= strides[i];
468        }
469    }
470    result
471}
472
473/// Encode a multi-index assignment into a linear index.
474fn encode_index(assignment: &[usize], strides: &[usize]) -> usize {
475    assignment.iter().zip(strides).map(|(&a, &s)| a * s).sum()
476}
477
478/// Multiply a list of factors together (pairwise).
479fn multiply_all(mut factors: Vec<Factor>) -> Factor {
480    if factors.is_empty() {
481        return Factor {
482            scope: Vec::new(),
483            card: Vec::new(),
484            values: vec![1.0],
485        };
486    }
487    let mut result = factors.remove(0);
488    for f in factors {
489        result = result.multiply(&f);
490    }
491    result
492}
493
494// ---------------------------------------------------------------------------
495// Unit tests
496// ---------------------------------------------------------------------------
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use crate::bayesian_network::cpd::TabularCPD;
502    use crate::bayesian_network::dag::DAG;
503
504    /// Build the classic Wet Grass Bayesian Network:
505    ///   Rain (0) → WetGrass (2)
506    ///   Sprinkler (1) → WetGrass (2)
507    fn wet_grass_network() -> BayesianNetwork {
508        // 0=Rain, 1=Sprinkler, 2=WetGrass
509        let mut dag = DAG::new(3);
510        dag.add_edge(0, 2).unwrap();
511        dag.add_edge(1, 2).unwrap();
512
513        // P(Rain): 0.8, 0.2
514        let cpd_rain = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.8, 0.2]]).unwrap();
515
516        // P(Sprinkler): 0.5, 0.5
517        let cpd_spr = TabularCPD::new(1, 2, vec![], vec![], vec![vec![0.5, 0.5]]).unwrap();
518
519        // P(WG | Rain, Sprinkler): 4 rows
520        // row 0: R=0,S=0 → WG:0.99,0.01
521        // row 1: R=0,S=1 → WG:0.01,0.99
522        // row 2: R=1,S=0 → WG:0.01,0.99
523        // row 3: R=1,S=1 → WG:0.01,0.99
524        let cpd_wg = TabularCPD::new(
525            2,
526            2,
527            vec![0, 1],
528            vec![2, 2],
529            vec![
530                vec![0.99, 0.01],
531                vec![0.01, 0.99],
532                vec![0.01, 0.99],
533                vec![0.01, 0.99],
534            ],
535        )
536        .unwrap();
537
538        let cpds: Vec<Box<dyn CPD>> = vec![Box::new(cpd_rain), Box::new(cpd_spr), Box::new(cpd_wg)];
539        BayesianNetwork::new(dag, cpds).unwrap()
540    }
541
542    #[test]
543    fn test_joint_probability_all_dry() {
544        let bn = wet_grass_network();
545        // P(R=0, S=0, WG=0) = P(R=0)*P(S=0)*P(WG=0|R=0,S=0)
546        //                    = 0.8 * 0.5 * 0.99 = 0.396
547        let p = bn.joint_probability(&[0, 0, 0]).unwrap();
548        assert!((p - 0.396).abs() < 1e-6, "Expected ~0.396, got {p}");
549    }
550
551    #[test]
552    fn test_ve_prior_rain() {
553        let bn = wet_grass_network();
554        let ve = VariableElimination::from_network(&bn, &[0], &HashMap::new());
555        let result = ve.query(&bn, &[0], &HashMap::new()).unwrap();
556        let rain = &result[&0];
557        assert!((rain[0] - 0.8).abs() < 1e-6, "P(Rain=0) should be 0.8");
558        assert!((rain[1] - 0.2).abs() < 1e-6, "P(Rain=1) should be 0.2");
559    }
560
561    #[test]
562    fn test_ve_prior_sprinkler() {
563        let bn = wet_grass_network();
564        let ve = VariableElimination::from_network(&bn, &[1], &HashMap::new());
565        let result = ve.query(&bn, &[1], &HashMap::new()).unwrap();
566        let spr = &result[&1];
567        assert!((spr[0] - 0.5).abs() < 1e-6, "P(Spr=0) should be 0.5");
568    }
569
570    #[test]
571    fn test_ve_conditional_rain_given_wetgrass() {
572        let bn = wet_grass_network();
573        let mut evidence = HashMap::new();
574        evidence.insert(2usize, 1usize); // WetGrass = 1
575        let ve = VariableElimination::from_network(&bn, &[0], &evidence);
576        let result = ve.query(&bn, &[0], &evidence).unwrap();
577        let rain = &result[&0];
578        // P(Rain=1 | WG=1) should be higher than prior 0.2
579        assert!(
580            rain[1] > 0.2,
581            "P(Rain=1|WG=1) should be > 0.2, got {}",
582            rain[1]
583        );
584        assert!((rain[0] + rain[1] - 1.0).abs() < 1e-6, "Should sum to 1");
585    }
586
587    #[test]
588    fn test_belief_propagation_prior() {
589        let bn = wet_grass_network();
590        let bp = BeliefPropagation;
591        let beliefs = bp.beliefs(&bn, &HashMap::new()).unwrap();
592        // Rain beliefs should match prior
593        assert!((beliefs[0][0] - 0.8).abs() < 1e-5, "Rain[0] should be 0.8");
594        assert!((beliefs[0][1] - 0.2).abs() < 1e-5, "Rain[1] should be 0.2");
595    }
596
597    #[test]
598    fn test_factor_marginalize() {
599        // Factor over [0, 1] with card [2, 2], uniform
600        let f = Factor {
601            scope: vec![0, 1],
602            card: vec![2, 2],
603            values: vec![0.25, 0.25, 0.25, 0.25],
604        };
605        let marginal = f.marginalize(1).unwrap();
606        assert_eq!(marginal.scope, vec![0]);
607        // Each value should be 0.5
608        assert!((marginal.values[0] - 0.5).abs() < 1e-9);
609        assert!((marginal.values[1] - 0.5).abs() < 1e-9);
610    }
611
612    #[test]
613    fn test_factor_reduce() {
614        let f = Factor {
615            scope: vec![0, 1],
616            card: vec![2, 2],
617            values: vec![0.3, 0.7, 0.6, 0.4],
618        };
619        // Reduce var=1 to val=0
620        let reduced = f.reduce(1, 0).unwrap();
621        assert_eq!(reduced.scope, vec![0]);
622        // Values: f[0,0]=0.3, f[1,0]=0.6
623        assert!((reduced.values[0] - 0.3).abs() < 1e-9);
624        assert!((reduced.values[1] - 0.6).abs() < 1e-9);
625    }
626}