1use scirs2_core::ndarray::ArrayD;
7use std::collections::HashMap;
8
9use crate::error::{PgmError, Result};
10use crate::factor::Factor;
11use crate::graph::FactorGraph;
12use crate::message_passing::MessagePassingAlgorithm;
13
14pub struct MeanFieldInference {
19 pub max_iterations: usize,
21 pub tolerance: f64,
23}
24
25impl Default for MeanFieldInference {
26 fn default() -> Self {
27 Self {
28 max_iterations: 100,
29 tolerance: 1e-6,
30 }
31 }
32}
33
34impl MeanFieldInference {
35 pub fn new(max_iterations: usize, tolerance: f64) -> Self {
37 Self {
38 max_iterations,
39 tolerance,
40 }
41 }
42
43 pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
47 let mut q_distributions: HashMap<String, ArrayD<f64>> = HashMap::new();
49
50 for var_name in graph.variable_names() {
51 if let Some(var_node) = graph.get_variable(var_name) {
52 let uniform = ArrayD::from_elem(
53 vec![var_node.cardinality],
54 1.0 / var_node.cardinality as f64,
55 );
56 q_distributions.insert(var_name.clone(), uniform);
57 }
58 }
59
60 for iteration in 0..self.max_iterations {
62 let old_q = q_distributions.clone();
63
64 for var_name in graph.variable_names() {
66 let updated_q = self.update_q_distribution(graph, var_name, &q_distributions)?;
67 q_distributions.insert(var_name.clone(), updated_q);
68 }
69
70 if self.check_convergence(&old_q, &q_distributions) {
72 return Ok(q_distributions);
73 }
74
75 if iteration == self.max_iterations - 1 {
76 return Err(PgmError::ConvergenceFailure(format!(
77 "Mean-field inference did not converge after {} iterations",
78 self.max_iterations
79 )));
80 }
81 }
82
83 Ok(q_distributions)
84 }
85
86 fn update_q_distribution(
90 &self,
91 graph: &FactorGraph,
92 var_name: &str,
93 q_distributions: &HashMap<String, ArrayD<f64>>,
94 ) -> Result<ArrayD<f64>> {
95 let var_node = graph
96 .get_variable(var_name)
97 .ok_or_else(|| PgmError::VariableNotFound(var_name.to_string()))?;
98
99 let mut log_potential = ArrayD::zeros(vec![var_node.cardinality]);
101
102 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
104 for factor_id in adjacent_factors {
105 if let Some(factor) = graph.get_factor(factor_id) {
106 let expected_log =
108 self.compute_expected_log_factor(factor, var_name, q_distributions)?;
109 log_potential = log_potential + expected_log;
110 }
111 }
112 }
113
114 let unnormalized = log_potential.mapv(|x: f64| x.exp());
116 let z: f64 = unnormalized.iter().sum();
117
118 if z > 0.0 {
119 Ok(&unnormalized / z)
120 } else {
121 Ok(ArrayD::from_elem(
123 vec![var_node.cardinality],
124 1.0 / var_node.cardinality as f64,
125 ))
126 }
127 }
128
129 fn compute_expected_log_factor(
131 &self,
132 factor: &Factor,
133 target_var: &str,
134 q_distributions: &HashMap<String, ArrayD<f64>>,
135 ) -> Result<ArrayD<f64>> {
136 let target_idx = factor
138 .variables
139 .iter()
140 .position(|v| v == target_var)
141 .ok_or_else(|| PgmError::VariableNotFound(target_var.to_string()))?;
142
143 let target_card = factor.values.shape()[target_idx];
144 let mut expected_log = ArrayD::zeros(vec![target_card]);
145
146 let total_size: usize = factor.values.shape().iter().product();
148 for linear_idx in 0..total_size {
149 let mut assignment = Vec::new();
151 let mut temp_idx = linear_idx;
152 for &dim in factor.values.shape().iter().rev() {
153 assignment.push(temp_idx % dim);
154 temp_idx /= dim;
155 }
156 assignment.reverse();
157
158 let factor_val = factor.values[assignment.as_slice()];
160 let log_factor_val = if factor_val > 1e-10 {
161 factor_val.ln()
162 } else {
163 -10.0 };
165
166 let mut q_prob = 1.0;
168 for (idx, var) in factor.variables.iter().enumerate() {
169 if var != target_var {
170 if let Some(q) = q_distributions.get(var) {
171 q_prob *= q[[assignment[idx]]];
172 }
173 }
174 }
175
176 let target_val = assignment[target_idx];
178 expected_log[[target_val]] += q_prob * log_factor_val;
179 }
180
181 Ok(expected_log)
182 }
183
184 fn check_convergence(
186 &self,
187 old_q: &HashMap<String, ArrayD<f64>>,
188 new_q: &HashMap<String, ArrayD<f64>>,
189 ) -> bool {
190 let mut max_delta = 0.0_f64;
191
192 for (var, new_dist) in new_q {
193 if let Some(old_dist) = old_q.get(var) {
194 let delta: f64 = (new_dist - old_dist)
195 .mapv(|x| x.abs())
196 .iter()
197 .fold(0.0_f64, |acc, &x| acc.max(x));
198 max_delta = max_delta.max(delta);
199 }
200 }
201
202 max_delta < self.tolerance
203 }
204
205 pub fn compute_elbo(
209 &self,
210 graph: &FactorGraph,
211 q_distributions: &HashMap<String, ArrayD<f64>>,
212 ) -> Result<f64> {
213 let mut elbo = 0.0;
214
215 for factor_id in graph.factor_ids() {
217 if let Some(factor) = graph.get_factor(factor_id) {
218 elbo += self.expected_log_joint_factor(factor, q_distributions)?;
219 }
220 }
221
222 for q_dist in q_distributions.values() {
224 let entropy: f64 = q_dist
225 .iter()
226 .map(|&p| if p > 1e-10 { -p * p.ln() } else { 0.0 })
227 .sum();
228 elbo += entropy;
229 }
230
231 Ok(elbo)
232 }
233
234 fn expected_log_joint_factor(
236 &self,
237 factor: &Factor,
238 q_distributions: &HashMap<String, ArrayD<f64>>,
239 ) -> Result<f64> {
240 let mut expected = 0.0;
241
242 let total_size: usize = factor.values.shape().iter().product();
243 for linear_idx in 0..total_size {
244 let mut assignment = Vec::new();
245 let mut temp_idx = linear_idx;
246 for &dim in factor.values.shape().iter().rev() {
247 assignment.push(temp_idx % dim);
248 temp_idx /= dim;
249 }
250 assignment.reverse();
251
252 let factor_val = factor.values[assignment.as_slice()];
254 let log_factor_val = if factor_val > 1e-10 {
255 factor_val.ln()
256 } else {
257 -10.0
258 };
259
260 let mut q_prob = 1.0;
262 for (idx, var) in factor.variables.iter().enumerate() {
263 if let Some(q) = q_distributions.get(var) {
264 q_prob *= q[[assignment[idx]]];
265 }
266 }
267
268 expected += q_prob * log_factor_val;
269 }
270
271 Ok(expected)
272 }
273}
274
275pub struct BetheApproximation {
289 pub max_iterations: usize,
291 pub tolerance: f64,
293 pub damping: f64,
295}
296
297impl Default for BetheApproximation {
298 fn default() -> Self {
299 Self {
300 max_iterations: 100,
301 tolerance: 1e-6,
302 damping: 0.0,
303 }
304 }
305}
306
307impl BetheApproximation {
308 pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
310 Self {
311 max_iterations,
312 tolerance,
313 damping: damping.clamp(0.0, 1.0),
314 }
315 }
316
317 pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
321 use crate::message_passing::SumProductAlgorithm;
324
325 let bp = SumProductAlgorithm::new(self.max_iterations, self.tolerance, self.damping);
326 bp.run(graph)
327 }
328
329 pub fn compute_free_energy(
333 &self,
334 graph: &FactorGraph,
335 variable_beliefs: &HashMap<String, ArrayD<f64>>,
336 factor_beliefs: &HashMap<String, ArrayD<f64>>,
337 ) -> Result<f64> {
338 let mut free_energy = 0.0;
339
340 for (factor_id, belief) in factor_beliefs {
342 if let Some(factor) = graph.get_factor(factor_id) {
343 let entropy_contrib: f64 = belief
345 .iter()
346 .map(|&p| if p > 1e-10 { -p * p.ln() } else { 0.0 })
347 .sum();
348
349 let mut energy_contrib = 0.0;
351 let total_size: usize = belief.shape().iter().product();
352 for linear_idx in 0..total_size {
353 let mut assignment = Vec::new();
354 let mut temp_idx = linear_idx;
355 for &dim in belief.shape().iter().rev() {
356 assignment.push(temp_idx % dim);
357 temp_idx /= dim;
358 }
359 assignment.reverse();
360
361 let b_val = belief[assignment.as_slice()];
362 let psi_val = factor.values[assignment.as_slice()];
363 if b_val > 1e-10 && psi_val > 1e-10 {
364 energy_contrib += b_val * psi_val.ln();
365 }
366 }
367
368 free_energy -= entropy_contrib;
369 free_energy -= energy_contrib;
370 }
371 }
372
373 for (var_name, belief) in variable_beliefs {
375 let degree = if let Some(adjacent) = graph.get_adjacent_factors(var_name) {
377 adjacent.len()
378 } else {
379 0
380 };
381
382 if degree > 0 {
383 let entropy: f64 = belief
384 .iter()
385 .map(|&p| if p > 1e-10 { -p * p.ln() } else { 0.0 })
386 .sum();
387
388 free_energy += (degree as f64 - 1.0) * entropy;
389 }
390 }
391
392 Ok(free_energy)
393 }
394
395 pub fn compute_factor_beliefs(
397 &self,
398 graph: &FactorGraph,
399 variable_beliefs: &HashMap<String, ArrayD<f64>>,
400 ) -> Result<HashMap<String, ArrayD<f64>>> {
401 let mut factor_beliefs = HashMap::new();
402
403 for factor_id in graph.factor_ids() {
404 if let Some(factor) = graph.get_factor(factor_id) {
405 let mut belief = factor.clone();
407
408 for var in &factor.variables {
410 if let Some(var_belief) = variable_beliefs.get(var) {
411 let var_factor = Factor {
413 name: format!("belief_{}", var),
414 variables: vec![var.clone()],
415 values: var_belief.clone(),
416 };
417 belief = belief.product(&var_factor)?;
418 }
419 }
420
421 belief.normalize();
422 factor_beliefs.insert(factor_id.clone(), belief.values);
423 }
424 }
425
426 Ok(factor_beliefs)
427 }
428}
429
430pub struct TreeReweightedBP {
437 pub max_iterations: usize,
439 pub tolerance: f64,
441 pub edge_weights: HashMap<(String, String), f64>,
443}
444
445impl Default for TreeReweightedBP {
446 fn default() -> Self {
447 Self {
448 max_iterations: 100,
449 tolerance: 1e-6,
450 edge_weights: HashMap::new(),
451 }
452 }
453}
454
455impl TreeReweightedBP {
456 pub fn new(max_iterations: usize, tolerance: f64) -> Self {
458 Self {
459 max_iterations,
460 tolerance,
461 edge_weights: HashMap::new(),
462 }
463 }
464
465 pub fn set_edge_weight(&mut self, var: String, factor: String, weight: f64) {
467 self.edge_weights
468 .insert((var, factor), weight.clamp(0.0, 1.0));
469 }
470
471 pub fn initialize_uniform_weights(&mut self, graph: &FactorGraph) {
473 for var_name in graph.variable_names() {
474 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
475 let weight = 1.0 / adjacent_factors.len() as f64;
476 for factor_id in adjacent_factors {
477 self.edge_weights
478 .insert((var_name.clone(), factor_id.clone()), weight);
479 }
480 }
481 }
482 }
483
484 fn get_edge_weight(&self, var: &str, factor: &str) -> f64 {
486 self.edge_weights
487 .get(&(var.to_string(), factor.to_string()))
488 .copied()
489 .unwrap_or(1.0)
490 }
491
492 pub fn run(&mut self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
496 if self.edge_weights.is_empty() {
498 self.initialize_uniform_weights(graph);
499 }
500
501 let mut messages: HashMap<(String, String), ArrayD<f64>> = HashMap::new();
503
504 for var_name in graph.variable_names() {
506 if let Some(var_node) = graph.get_variable(var_name) {
507 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
508 for factor_id in adjacent_factors {
509 let init_msg = ArrayD::from_elem(
510 vec![var_node.cardinality],
511 1.0 / var_node.cardinality as f64,
512 );
513 messages.insert((var_name.clone(), factor_id.clone()), init_msg);
514 }
515 }
516 }
517 }
518
519 for iteration in 0..self.max_iterations {
521 let old_messages = messages.clone();
522
523 for var_name in graph.variable_names() {
525 if let Some(var_node) = graph.get_variable(var_name) {
526 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
527 for target_factor in adjacent_factors {
528 let mut message = ArrayD::ones(vec![var_node.cardinality])
530 / var_node.cardinality as f64;
531
532 for other_factor in adjacent_factors {
534 if other_factor != target_factor {
535 if let Some(incoming) =
536 old_messages.get(&(var_name.clone(), other_factor.clone()))
537 {
538 let rho = self.get_edge_weight(var_name, other_factor);
539 let reweighted = incoming.mapv(|x| x.powf(rho));
541 message = &message * &reweighted;
542 }
543 }
544 }
545
546 let sum: f64 = message.iter().sum();
548 if sum > 1e-10 {
549 message /= sum;
550 }
551
552 messages.insert((var_name.clone(), target_factor.clone()), message);
553 }
554 }
555 }
556 }
557
558 let mut max_delta = 0.0_f64;
560 for ((var, factor), new_msg) in &messages {
561 if let Some(old_msg) = old_messages.get(&(var.clone(), factor.clone())) {
562 let delta: f64 = (new_msg - old_msg)
563 .mapv(|x| x.abs())
564 .iter()
565 .fold(0.0_f64, |acc, &x| acc.max(x));
566 max_delta = max_delta.max(delta);
567 }
568 }
569
570 if max_delta < self.tolerance {
571 break;
572 }
573
574 if iteration == self.max_iterations - 1 {
575 return Err(PgmError::ConvergenceFailure(format!(
576 "TRW-BP did not converge after {} iterations (max_delta={})",
577 self.max_iterations, max_delta
578 )));
579 }
580 }
581
582 let mut beliefs = HashMap::new();
584 for var_name in graph.variable_names() {
585 if let Some(var_node) = graph.get_variable(var_name) {
586 let mut belief =
587 ArrayD::ones(vec![var_node.cardinality]) / var_node.cardinality as f64;
588
589 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
590 for factor_id in adjacent_factors {
591 if let Some(message) = messages.get(&(var_name.clone(), factor_id.clone()))
592 {
593 let rho = self.get_edge_weight(var_name, factor_id);
594 let reweighted = message.mapv(|x| x.powf(rho));
595 belief = &belief * &reweighted;
596 }
597 }
598 }
599
600 let sum: f64 = belief.iter().sum();
602 if sum > 1e-10 {
603 belief /= sum;
604 }
605
606 beliefs.insert(var_name.clone(), belief);
607 }
608 }
609
610 Ok(beliefs)
611 }
612
613 pub fn compute_log_partition_upper_bound(
617 &self,
618 _graph: &FactorGraph,
619 _beliefs: &HashMap<String, ArrayD<f64>>,
620 ) -> Result<f64> {
621 Ok(0.0)
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630 use approx::assert_abs_diff_eq;
631
632 #[test]
633 fn test_mean_field_single_variable() {
634 let mut graph = FactorGraph::new();
635 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
636
637 let mf = MeanFieldInference::default();
638 let result = mf.run(&graph);
639 assert!(result.is_ok());
640
641 let marginals = result.unwrap();
642 assert!(marginals.contains_key("x"));
643
644 let dist = &marginals["x"];
646 assert_abs_diff_eq!(dist[[0]], 0.5, epsilon = 1e-6);
647 assert_abs_diff_eq!(dist[[1]], 0.5, epsilon = 1e-6);
648 }
649
650 #[test]
651 fn test_mean_field_convergence() {
652 let mut graph = FactorGraph::new();
653 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
654 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
655
656 let mf = MeanFieldInference::new(50, 1e-6);
657 let result = mf.run(&graph);
658 assert!(result.is_ok());
659 }
660
661 #[test]
662 fn test_elbo_computation() {
663 let mut graph = FactorGraph::new();
664 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
665
666 let mf = MeanFieldInference::default();
667 let marginals = mf.run(&graph).unwrap();
668
669 let elbo = mf.compute_elbo(&graph, &marginals);
670 assert!(elbo.is_ok());
671 }
672
673 #[test]
674 fn test_bethe_approximation_single_variable() {
675 let mut graph = FactorGraph::new();
676 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
677
678 let bethe = BetheApproximation::default();
679 let result = bethe.run(&graph);
680 assert!(result.is_ok());
681
682 let marginals = result.unwrap();
683 assert!(marginals.contains_key("x"));
684
685 let dist = &marginals["x"];
686 assert_abs_diff_eq!(dist[[0]], 0.5, epsilon = 1e-6);
687 assert_abs_diff_eq!(dist[[1]], 0.5, epsilon = 1e-6);
688 }
689
690 #[test]
691 fn test_bethe_approximation_two_variables() {
692 let mut graph = FactorGraph::new();
693 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
694 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
695
696 let bethe = BetheApproximation::new(50, 1e-6, 0.0);
697 let result = bethe.run(&graph);
698 assert!(result.is_ok());
699
700 let marginals = result.unwrap();
701 assert_eq!(marginals.len(), 2);
702 }
703
704 #[test]
705 fn test_bethe_free_energy() {
706 let mut graph = FactorGraph::new();
707 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
708
709 let bethe = BetheApproximation::default();
710 let marginals = bethe.run(&graph).unwrap();
711 let factor_beliefs = bethe.compute_factor_beliefs(&graph, &marginals).unwrap();
712
713 let free_energy = bethe.compute_free_energy(&graph, &marginals, &factor_beliefs);
714 assert!(free_energy.is_ok());
715 }
716
717 #[test]
718 fn test_bethe_with_damping() {
719 let mut graph = FactorGraph::new();
720 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
721 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
722
723 let bethe = BetheApproximation::new(50, 1e-6, 0.5);
724 let result = bethe.run(&graph);
725 assert!(result.is_ok());
726 }
727
728 #[test]
729 fn test_trw_bp_single_variable() {
730 let mut graph = FactorGraph::new();
731 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
732
733 let mut trw = TreeReweightedBP::default();
734 let result = trw.run(&graph);
735 assert!(result.is_ok());
736
737 let beliefs = result.unwrap();
738 assert!(beliefs.contains_key("x"));
739
740 let dist = &beliefs["x"];
741 assert_abs_diff_eq!(dist[[0]], 0.5, epsilon = 1e-6);
742 assert_abs_diff_eq!(dist[[1]], 0.5, epsilon = 1e-6);
743 }
744
745 #[test]
746 fn test_trw_bp_two_variables() {
747 let mut graph = FactorGraph::new();
748 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
749 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
750
751 let mut trw = TreeReweightedBP::new(50, 1e-6);
752 let result = trw.run(&graph);
753 assert!(result.is_ok());
754
755 let beliefs = result.unwrap();
756 assert_eq!(beliefs.len(), 2);
757 }
758
759 #[test]
760 fn test_trw_bp_custom_weights() {
761 let mut graph = FactorGraph::new();
762 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
763 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
764
765 let mut trw = TreeReweightedBP::default();
766 trw.set_edge_weight("x".to_string(), "f1".to_string(), 0.5);
767
768 let result = trw.run(&graph);
770 assert!(result.is_ok());
771 }
772
773 #[test]
774 fn test_trw_bp_uniform_initialization() {
775 let mut graph = FactorGraph::new();
776 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
777
778 let mut trw = TreeReweightedBP::default();
779 trw.initialize_uniform_weights(&graph);
780
781 assert!(!trw.edge_weights.is_empty() || graph.factor_ids().count() == 0);
782 }
783
784 #[test]
785 fn test_trw_bp_partition_bound() {
786 let mut graph = FactorGraph::new();
787 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
788
789 let mut trw = TreeReweightedBP::default();
790 let beliefs = trw.run(&graph).unwrap();
791
792 let bound = trw.compute_log_partition_upper_bound(&graph, &beliefs);
793 assert!(bound.is_ok());
794 }
795}