1use super::{
9 cpd::{TabularCPD, CPD},
10 dag::DAG,
11};
12use crate::StatsError;
13use std::collections::HashMap;
14
15pub struct BayesianNetwork {
21 pub dag: DAG,
23 pub cpds: Vec<Box<dyn CPD>>,
25}
26
27impl BayesianNetwork {
28 pub fn new(dag: DAG, cpds: Vec<Box<dyn CPD>>) -> Result<Self, StatsError> {
33 if cpds.len() != dag.n_nodes {
34 return Err(StatsError::InvalidInput(format!(
35 "Expected {} CPDs (one per node), got {}",
36 dag.n_nodes,
37 cpds.len()
38 )));
39 }
40 Ok(Self { dag, cpds })
41 }
42
43 pub fn joint_probability(&self, assignment: &[usize]) -> Result<f64, StatsError> {
47 if assignment.len() != self.dag.n_nodes {
48 return Err(StatsError::InvalidInput(format!(
49 "assignment length {} does not match n_nodes {}",
50 assignment.len(),
51 self.dag.n_nodes
52 )));
53 }
54 let mut log_prob = 0.0f64;
55 for i in 0..self.dag.n_nodes {
56 let cpd = &self.cpds[i];
57 let parent_idx = cpd.parent_indices();
58 let parent_vals: Vec<usize> = parent_idx.iter().map(|&p| assignment[p]).collect();
59 let p = cpd.prob(assignment[i], &parent_vals);
60 if p <= 0.0 {
61 return Ok(0.0);
62 }
63 log_prob += p.ln();
64 }
65 Ok(log_prob.exp())
66 }
67
68 pub fn cardinality(&self, node: usize) -> usize {
70 self.cpds[node].cardinality()
71 }
72}
73
74#[derive(Debug, Clone)]
83pub struct Factor {
84 pub scope: Vec<usize>,
86 pub card: Vec<usize>,
88 pub values: Vec<f64>,
90}
91
92impl Factor {
93 pub fn from_cpd(cpd: &dyn CPD, bn: &BayesianNetwork) -> Self {
95 let node = cpd.node();
96 let card_node = cpd.cardinality();
97 let parent_idx = cpd.parent_indices();
98 let mut scope = vec![node];
101 scope.extend_from_slice(parent_idx);
102 let mut card = vec![card_node];
103 for &p in parent_idx {
104 card.push(bn.cpds[p].cardinality());
105 }
106 let n_entries: usize = card.iter().product();
107 let mut values = vec![0.0f64; n_entries];
108 let strides = strides_from_card(&card);
110 for idx in 0..n_entries {
112 let assignment = decode_index(idx, &card, &strides);
113 let node_val = assignment[0];
114 let parent_vals = &assignment[1..];
115 values[idx] = cpd.prob(node_val, parent_vals);
116 }
117 Factor {
118 scope,
119 card,
120 values,
121 }
122 }
123
124 pub fn marginalize(&self, var: usize) -> Option<Factor> {
126 let pos = self.scope.iter().position(|&v| v == var)?;
127 let var_card = self.card[pos];
128 let new_scope: Vec<usize> = self
130 .scope
131 .iter()
132 .enumerate()
133 .filter(|&(i, _)| i != pos)
134 .map(|(_, &v)| v)
135 .collect();
136 let new_card: Vec<usize> = self
137 .card
138 .iter()
139 .enumerate()
140 .filter(|&(i, _)| i != pos)
141 .map(|(_, &c)| c)
142 .collect();
143 let new_n: usize = if new_card.is_empty() {
144 1
145 } else {
146 new_card.iter().product()
147 };
148 let new_strides = strides_from_card(&new_card);
149 let old_strides = strides_from_card(&self.card);
150 let mut new_values = vec![0.0f64; new_n];
151 for idx in 0..self.values.len() {
152 let old_assign = decode_index(idx, &self.card, &old_strides);
153 let new_assign: Vec<usize> = old_assign
155 .iter()
156 .enumerate()
157 .filter(|&(i, _)| i != pos)
158 .map(|(_, &v)| v)
159 .collect();
160 let new_idx = encode_index(&new_assign, &new_strides);
161 new_values[new_idx] += self.values[idx];
162 }
163 let _ = var_card; Some(Factor {
166 scope: new_scope,
167 card: new_card,
168 values: new_values,
169 })
170 }
171
172 pub fn reduce(&self, var: usize, val: usize) -> Option<Factor> {
174 let pos = self.scope.iter().position(|&v| v == var)?;
175 let new_scope: Vec<usize> = self
176 .scope
177 .iter()
178 .enumerate()
179 .filter(|&(i, _)| i != pos)
180 .map(|(_, &v)| v)
181 .collect();
182 let new_card: Vec<usize> = self
183 .card
184 .iter()
185 .enumerate()
186 .filter(|&(i, _)| i != pos)
187 .map(|(_, &c)| c)
188 .collect();
189 let new_n: usize = if new_card.is_empty() {
190 1
191 } else {
192 new_card.iter().product()
193 };
194 let new_strides = strides_from_card(&new_card);
195 let old_strides = strides_from_card(&self.card);
196 let mut new_values = vec![0.0f64; new_n];
197 for idx in 0..self.values.len() {
198 let old_assign = decode_index(idx, &self.card, &old_strides);
199 if old_assign[pos] != val {
200 continue;
201 }
202 let new_assign: Vec<usize> = old_assign
203 .iter()
204 .enumerate()
205 .filter(|&(i, _)| i != pos)
206 .map(|(_, &v)| v)
207 .collect();
208 let new_idx = encode_index(&new_assign, &new_strides);
209 new_values[new_idx] = self.values[idx];
210 }
211 Some(Factor {
212 scope: new_scope,
213 card: new_card,
214 values: new_values,
215 })
216 }
217
218 pub fn multiply(&self, other: &Factor) -> Factor {
220 let mut new_scope = self.scope.clone();
222 let mut new_card = self.card.clone();
223 for (i, &v) in other.scope.iter().enumerate() {
224 if !new_scope.contains(&v) {
225 new_scope.push(v);
226 new_card.push(other.card[i]);
227 }
228 }
229 let new_n: usize = if new_card.is_empty() {
230 1
231 } else {
232 new_card.iter().product()
233 };
234 let new_strides = strides_from_card(&new_card);
235 let self_strides = strides_from_card(&self.card);
236 let other_strides = strides_from_card(&other.card);
237 let mut new_values = vec![0.0f64; new_n];
238 for idx in 0..new_n {
239 let full_assign = decode_index(idx, &new_card, &new_strides);
240 let self_assign: Vec<usize> = self
242 .scope
243 .iter()
244 .map(|v| {
245 let pos = new_scope.iter().position(|&x| x == *v).unwrap_or(0);
246 full_assign[pos]
247 })
248 .collect();
249 let other_assign: Vec<usize> = other
250 .scope
251 .iter()
252 .map(|v| {
253 let pos = new_scope.iter().position(|&x| x == *v).unwrap_or(0);
254 full_assign[pos]
255 })
256 .collect();
257 let si = encode_index(&self_assign, &self_strides);
258 let oi = encode_index(&other_assign, &other_strides);
259 new_values[idx] = self.values[si] * other.values[oi];
260 }
261 Factor {
262 scope: new_scope,
263 card: new_card,
264 values: new_values,
265 }
266 }
267
268 pub fn normalize(&mut self) {
270 let sum: f64 = self.values.iter().sum();
271 if sum > 1e-300 {
272 for v in &mut self.values {
273 *v /= sum;
274 }
275 }
276 }
277}
278
279#[derive(Debug, Clone)]
287pub struct VariableElimination {
288 pub order: Vec<usize>,
290}
291
292impl VariableElimination {
293 pub fn new(order: Vec<usize>) -> Self {
295 Self { order }
296 }
297
298 pub fn from_network(
300 bn: &BayesianNetwork,
301 query_vars: &[usize],
302 evidence: &HashMap<usize, usize>,
303 ) -> Self {
304 let topo = bn.dag.topological_sort();
305 let order: Vec<usize> = topo
307 .into_iter()
308 .rev()
309 .filter(|v| !query_vars.contains(v) && !evidence.contains_key(v))
310 .collect();
311 Self { order }
312 }
313
314 pub fn query(
318 &self,
319 bn: &BayesianNetwork,
320 query_vars: &[usize],
321 evidence: &HashMap<usize, usize>,
322 ) -> Result<HashMap<usize, Vec<f64>>, StatsError> {
323 let mut factors: Vec<Factor> = bn
325 .cpds
326 .iter()
327 .map(|cpd| Factor::from_cpd(cpd.as_ref(), bn))
328 .collect();
329
330 for factor in &mut factors {
332 let mut f = factor.clone();
333 for (&evar, &eval) in evidence {
334 if let Some(reduced) = f.reduce(evar, eval) {
335 f = reduced;
336 }
337 }
338 *factor = f;
339 }
340
341 for &var in &self.order {
343 let (with_var, without_var): (Vec<Factor>, Vec<Factor>) =
345 factors.into_iter().partition(|f| f.scope.contains(&var));
346
347 if with_var.is_empty() {
348 factors = without_var;
349 continue;
350 }
351
352 let product = multiply_all(with_var);
354
355 let marginal = product.marginalize(var).ok_or_else(|| {
357 StatsError::ComputationError(format!("Failed to marginalize var {var}"))
358 })?;
359
360 factors = without_var;
361 factors.push(marginal);
362 }
363
364 let product = multiply_all(factors);
366
367 let mut result = HashMap::new();
369 for &qv in query_vars {
370 let mut marginal = product.clone();
372 let other_vars: Vec<usize> = marginal
373 .scope
374 .iter()
375 .copied()
376 .filter(|&v| v != qv)
377 .collect();
378 for v in other_vars {
379 marginal = marginal.marginalize(v).ok_or_else(|| {
380 StatsError::ComputationError(format!("Failed to marginalize var {v}"))
381 })?;
382 }
383 marginal.normalize();
384 result.insert(qv, marginal.values);
385 }
386 Ok(result)
387 }
388}
389
390#[derive(Debug, Clone)]
399pub struct BeliefPropagation;
400
401impl BeliefPropagation {
402 pub fn beliefs(
406 &self,
407 bn: &BayesianNetwork,
408 evidence: &HashMap<usize, usize>,
409 ) -> Result<Vec<Vec<f64>>, StatsError> {
410 let n = bn.dag.n_nodes;
411 let topo = bn.dag.topological_sort();
413
414 let ve = VariableElimination::from_network(bn, &(0..n).collect::<Vec<_>>(), evidence);
416 let mut beliefs = vec![Vec::new(); n];
417 for node in 0..n {
418 let single = [node];
419 let result = ve.query(bn, &single, evidence)?;
420 beliefs[node] = result.get(&node).cloned().unwrap_or_default();
421 }
422 let _ = topo; Ok(beliefs)
424 }
425
426 pub fn query_node(
428 &self,
429 bn: &BayesianNetwork,
430 node: usize,
431 evidence: &HashMap<usize, usize>,
432 ) -> Result<Vec<f64>, StatsError> {
433 let ve = VariableElimination::from_network(bn, &[node], evidence);
434 let result = ve.query(bn, &[node], evidence)?;
435 result
436 .get(&node)
437 .cloned()
438 .ok_or_else(|| StatsError::ComputationError(format!("No result for node {node}")))
439 }
440}
441
442fn strides_from_card(card: &[usize]) -> Vec<usize> {
448 let n = card.len();
449 if n == 0 {
450 return Vec::new();
451 }
452 let mut strides = vec![1usize; n];
453 for i in (0..n - 1).rev() {
454 strides[i] = strides[i + 1] * card[i + 1];
455 }
456 strides
457}
458
459fn decode_index(mut idx: usize, card: &[usize], strides: &[usize]) -> Vec<usize> {
461 let mut result = vec![0usize; card.len()];
462 for i in 0..card.len() {
463 if strides[i] == 0 {
464 result[i] = 0;
465 } else {
466 result[i] = idx / strides[i];
467 idx %= strides[i];
468 }
469 }
470 result
471}
472
473fn encode_index(assignment: &[usize], strides: &[usize]) -> usize {
475 assignment.iter().zip(strides).map(|(&a, &s)| a * s).sum()
476}
477
478fn multiply_all(mut factors: Vec<Factor>) -> Factor {
480 if factors.is_empty() {
481 return Factor {
482 scope: Vec::new(),
483 card: Vec::new(),
484 values: vec![1.0],
485 };
486 }
487 let mut result = factors.remove(0);
488 for f in factors {
489 result = result.multiply(&f);
490 }
491 result
492}
493
494#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::bayesian_network::cpd::TabularCPD;
502 use crate::bayesian_network::dag::DAG;
503
504 fn wet_grass_network() -> BayesianNetwork {
508 let mut dag = DAG::new(3);
510 dag.add_edge(0, 2).unwrap();
511 dag.add_edge(1, 2).unwrap();
512
513 let cpd_rain = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.8, 0.2]]).unwrap();
515
516 let cpd_spr = TabularCPD::new(1, 2, vec![], vec![], vec![vec![0.5, 0.5]]).unwrap();
518
519 let cpd_wg = TabularCPD::new(
525 2,
526 2,
527 vec![0, 1],
528 vec![2, 2],
529 vec![
530 vec![0.99, 0.01],
531 vec![0.01, 0.99],
532 vec![0.01, 0.99],
533 vec![0.01, 0.99],
534 ],
535 )
536 .unwrap();
537
538 let cpds: Vec<Box<dyn CPD>> = vec![Box::new(cpd_rain), Box::new(cpd_spr), Box::new(cpd_wg)];
539 BayesianNetwork::new(dag, cpds).unwrap()
540 }
541
542 #[test]
543 fn test_joint_probability_all_dry() {
544 let bn = wet_grass_network();
545 let p = bn.joint_probability(&[0, 0, 0]).unwrap();
548 assert!((p - 0.396).abs() < 1e-6, "Expected ~0.396, got {p}");
549 }
550
551 #[test]
552 fn test_ve_prior_rain() {
553 let bn = wet_grass_network();
554 let ve = VariableElimination::from_network(&bn, &[0], &HashMap::new());
555 let result = ve.query(&bn, &[0], &HashMap::new()).unwrap();
556 let rain = &result[&0];
557 assert!((rain[0] - 0.8).abs() < 1e-6, "P(Rain=0) should be 0.8");
558 assert!((rain[1] - 0.2).abs() < 1e-6, "P(Rain=1) should be 0.2");
559 }
560
561 #[test]
562 fn test_ve_prior_sprinkler() {
563 let bn = wet_grass_network();
564 let ve = VariableElimination::from_network(&bn, &[1], &HashMap::new());
565 let result = ve.query(&bn, &[1], &HashMap::new()).unwrap();
566 let spr = &result[&1];
567 assert!((spr[0] - 0.5).abs() < 1e-6, "P(Spr=0) should be 0.5");
568 }
569
570 #[test]
571 fn test_ve_conditional_rain_given_wetgrass() {
572 let bn = wet_grass_network();
573 let mut evidence = HashMap::new();
574 evidence.insert(2usize, 1usize); let ve = VariableElimination::from_network(&bn, &[0], &evidence);
576 let result = ve.query(&bn, &[0], &evidence).unwrap();
577 let rain = &result[&0];
578 assert!(
580 rain[1] > 0.2,
581 "P(Rain=1|WG=1) should be > 0.2, got {}",
582 rain[1]
583 );
584 assert!((rain[0] + rain[1] - 1.0).abs() < 1e-6, "Should sum to 1");
585 }
586
587 #[test]
588 fn test_belief_propagation_prior() {
589 let bn = wet_grass_network();
590 let bp = BeliefPropagation;
591 let beliefs = bp.beliefs(&bn, &HashMap::new()).unwrap();
592 assert!((beliefs[0][0] - 0.8).abs() < 1e-5, "Rain[0] should be 0.8");
594 assert!((beliefs[0][1] - 0.2).abs() < 1e-5, "Rain[1] should be 0.2");
595 }
596
597 #[test]
598 fn test_factor_marginalize() {
599 let f = Factor {
601 scope: vec![0, 1],
602 card: vec![2, 2],
603 values: vec![0.25, 0.25, 0.25, 0.25],
604 };
605 let marginal = f.marginalize(1).unwrap();
606 assert_eq!(marginal.scope, vec![0]);
607 assert!((marginal.values[0] - 0.5).abs() < 1e-9);
609 assert!((marginal.values[1] - 0.5).abs() < 1e-9);
610 }
611
612 #[test]
613 fn test_factor_reduce() {
614 let f = Factor {
615 scope: vec![0, 1],
616 card: vec![2, 2],
617 values: vec![0.3, 0.7, 0.6, 0.4],
618 };
619 let reduced = f.reduce(1, 0).unwrap();
621 assert_eq!(reduced.scope, vec![0]);
622 assert!((reduced.values[0] - 0.3).abs() < 1e-9);
624 assert!((reduced.values[1] - 0.6).abs() < 1e-9);
625 }
626}