Skip to main content

tensorlogic_quantrs_hooks/
variational.rs

1//! Variational inference methods for approximate PGM inference.
2//!
3//! This module provides variational inference algorithms that approximate
4//! intractable posterior distributions with simpler distributions.
5
6use scirs2_core::ndarray::ArrayD;
7use std::collections::HashMap;
8
9use crate::error::{PgmError, Result};
10use crate::factor::Factor;
11use crate::graph::FactorGraph;
12use crate::message_passing::MessagePassingAlgorithm;
13
14/// Mean-field variational inference.
15///
16/// Approximates the joint distribution with a product of independent marginals:
17/// Q(X₁, ..., Xₙ) = ∏ᵢ Qᵢ(Xᵢ)
18pub struct MeanFieldInference {
19    /// Maximum iterations
20    pub max_iterations: usize,
21    /// Convergence tolerance
22    pub tolerance: f64,
23}
24
25impl Default for MeanFieldInference {
26    fn default() -> Self {
27        Self {
28            max_iterations: 100,
29            tolerance: 1e-6,
30        }
31    }
32}
33
34impl MeanFieldInference {
35    /// Create with custom parameters.
36    pub fn new(max_iterations: usize, tolerance: f64) -> Self {
37        Self {
38            max_iterations,
39            tolerance,
40        }
41    }
42
43    /// Run mean-field variational inference.
44    ///
45    /// Returns approximate marginals for each variable.
46    pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
47        // Initialize Q distributions uniformly
48        let mut q_distributions: HashMap<String, ArrayD<f64>> = HashMap::new();
49
50        for var_name in graph.variable_names() {
51            if let Some(var_node) = graph.get_variable(var_name) {
52                let uniform = ArrayD::from_elem(
53                    vec![var_node.cardinality],
54                    1.0 / var_node.cardinality as f64,
55                );
56                q_distributions.insert(var_name.clone(), uniform);
57            }
58        }
59
60        // Iterative updates
61        for iteration in 0..self.max_iterations {
62            let old_q = q_distributions.clone();
63
64            // Update each Q distribution
65            for var_name in graph.variable_names() {
66                let updated_q = self.update_q_distribution(graph, var_name, &q_distributions)?;
67                q_distributions.insert(var_name.clone(), updated_q);
68            }
69
70            // Check convergence
71            if self.check_convergence(&old_q, &q_distributions) {
72                return Ok(q_distributions);
73            }
74
75            if iteration == self.max_iterations - 1 {
76                return Err(PgmError::ConvergenceFailure(format!(
77                    "Mean-field inference did not converge after {} iterations",
78                    self.max_iterations
79                )));
80            }
81        }
82
83        Ok(q_distributions)
84    }
85
86    /// Update Q distribution for a single variable.
87    ///
88    /// Q*(Xᵢ) ∝ exp(E[log p(X)] over Q\{Xᵢ})
89    fn update_q_distribution(
90        &self,
91        graph: &FactorGraph,
92        var_name: &str,
93        q_distributions: &HashMap<String, ArrayD<f64>>,
94    ) -> Result<ArrayD<f64>> {
95        let var_node = graph
96            .get_variable(var_name)
97            .ok_or_else(|| PgmError::VariableNotFound(var_name.to_string()))?;
98
99        // Initialize log potential
100        let mut log_potential = ArrayD::zeros(vec![var_node.cardinality]);
101
102        // Get factors containing this variable
103        if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
104            for factor_id in adjacent_factors {
105                if let Some(factor) = graph.get_factor(factor_id) {
106                    // Compute expected log factor
107                    let expected_log =
108                        self.compute_expected_log_factor(factor, var_name, q_distributions)?;
109                    log_potential = log_potential + expected_log;
110                }
111            }
112        }
113
114        // Normalize: Q*(X) = exp(log_potential) / Z
115        let unnormalized = log_potential.mapv(|x: f64| x.exp());
116        let z: f64 = unnormalized.iter().sum();
117
118        if z > 0.0 {
119            Ok(&unnormalized / z)
120        } else {
121            // Fallback to uniform if normalization fails
122            Ok(ArrayD::from_elem(
123                vec![var_node.cardinality],
124                1.0 / var_node.cardinality as f64,
125            ))
126        }
127    }
128
129    /// Compute expected log factor: E[log φ(X)] over Q\{Xᵢ}
130    fn compute_expected_log_factor(
131        &self,
132        factor: &Factor,
133        target_var: &str,
134        q_distributions: &HashMap<String, ArrayD<f64>>,
135    ) -> Result<ArrayD<f64>> {
136        // Find target variable index
137        let target_idx = factor
138            .variables
139            .iter()
140            .position(|v| v == target_var)
141            .ok_or_else(|| PgmError::VariableNotFound(target_var.to_string()))?;
142
143        let target_card = factor.values.shape()[target_idx];
144        let mut expected_log = ArrayD::zeros(vec![target_card]);
145
146        // Compute expectation over all assignments
147        let total_size: usize = factor.values.shape().iter().product();
148        for linear_idx in 0..total_size {
149            // Convert to multi-dimensional index
150            let mut assignment = Vec::new();
151            let mut temp_idx = linear_idx;
152            for &dim in factor.values.shape().iter().rev() {
153                assignment.push(temp_idx % dim);
154                temp_idx /= dim;
155            }
156            assignment.reverse();
157
158            // Get factor value
159            let factor_val = factor.values[assignment.as_slice()];
160            let log_factor_val = if factor_val > 1e-10 {
161                factor_val.ln()
162            } else {
163                -10.0 // log of very small number
164            };
165
166            // Compute probability of this assignment under Q
167            let mut q_prob = 1.0;
168            for (idx, var) in factor.variables.iter().enumerate() {
169                if var != target_var {
170                    if let Some(q) = q_distributions.get(var) {
171                        q_prob *= q[[assignment[idx]]];
172                    }
173                }
174            }
175
176            // Accumulate expected log
177            let target_val = assignment[target_idx];
178            expected_log[[target_val]] += q_prob * log_factor_val;
179        }
180
181        Ok(expected_log)
182    }
183
184    /// Check convergence by comparing Q distributions.
185    fn check_convergence(
186        &self,
187        old_q: &HashMap<String, ArrayD<f64>>,
188        new_q: &HashMap<String, ArrayD<f64>>,
189    ) -> bool {
190        let mut max_delta = 0.0_f64;
191
192        for (var, new_dist) in new_q {
193            if let Some(old_dist) = old_q.get(var) {
194                let delta: f64 = (new_dist - old_dist)
195                    .mapv(|x| x.abs())
196                    .iter()
197                    .fold(0.0_f64, |acc, &x| acc.max(x));
198                max_delta = max_delta.max(delta);
199            }
200        }
201
202        max_delta < self.tolerance
203    }
204
205    /// Compute ELBO (Evidence Lower BOund).
206    ///
207    /// ELBO = E[log p(X, Z)] - E[log q(Z)]
208    pub fn compute_elbo(
209        &self,
210        graph: &FactorGraph,
211        q_distributions: &HashMap<String, ArrayD<f64>>,
212    ) -> Result<f64> {
213        let mut elbo = 0.0;
214
215        // E[log p(X, Z)] - sum over all factors
216        for factor_id in graph.factor_ids() {
217            if let Some(factor) = graph.get_factor(factor_id) {
218                elbo += self.expected_log_joint_factor(factor, q_distributions)?;
219            }
220        }
221
222        // -E[log q(Z)] - entropy of Q
223        for q_dist in q_distributions.values() {
224            let entropy: f64 = q_dist
225                .iter()
226                .map(|&p| if p > 1e-10 { -p * p.ln() } else { 0.0 })
227                .sum();
228            elbo += entropy;
229        }
230
231        Ok(elbo)
232    }
233
234    /// Compute E[log φ(X)] for a factor.
235    fn expected_log_joint_factor(
236        &self,
237        factor: &Factor,
238        q_distributions: &HashMap<String, ArrayD<f64>>,
239    ) -> Result<f64> {
240        let mut expected = 0.0;
241
242        let total_size: usize = factor.values.shape().iter().product();
243        for linear_idx in 0..total_size {
244            let mut assignment = Vec::new();
245            let mut temp_idx = linear_idx;
246            for &dim in factor.values.shape().iter().rev() {
247                assignment.push(temp_idx % dim);
248                temp_idx /= dim;
249            }
250            assignment.reverse();
251
252            // Factor value
253            let factor_val = factor.values[assignment.as_slice()];
254            let log_factor_val = if factor_val > 1e-10 {
255                factor_val.ln()
256            } else {
257                -10.0
258            };
259
260            // Probability under Q
261            let mut q_prob = 1.0;
262            for (idx, var) in factor.variables.iter().enumerate() {
263                if let Some(q) = q_distributions.get(var) {
264                    q_prob *= q[[assignment[idx]]];
265                }
266            }
267
268            expected += q_prob * log_factor_val;
269        }
270
271        Ok(expected)
272    }
273}
274
275/// Bethe approximation for structured variational inference.
276///
277/// Uses the factor graph structure to define a structured approximation.
278/// More accurate than mean-field but still tractable.
279///
280/// The Bethe free energy is:
281/// F_Bethe = Σ_α H(b_α) - Σ_i (d_i - 1) H(b_i) - Σ_α <log ψ_α>_b_α
282///
283/// where:
284/// - b_α are factor beliefs (cluster marginals)
285/// - b_i are variable beliefs (node marginals)
286/// - d_i is the degree of variable i
287/// - H is entropy
288pub struct BetheApproximation {
289    /// Maximum iterations
290    pub max_iterations: usize,
291    /// Convergence tolerance
292    pub tolerance: f64,
293    /// Damping factor for message updates
294    pub damping: f64,
295}
296
297impl Default for BetheApproximation {
298    fn default() -> Self {
299        Self {
300            max_iterations: 100,
301            tolerance: 1e-6,
302            damping: 0.0,
303        }
304    }
305}
306
307impl BetheApproximation {
308    /// Create with custom parameters.
309    pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
310        Self {
311            max_iterations,
312            tolerance,
313            damping: damping.clamp(0.0, 1.0),
314        }
315    }
316
317    /// Run Bethe approximation using belief propagation.
318    ///
319    /// Returns variable beliefs (marginals).
320    pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
321        // Use sum-product belief propagation as the message passing algorithm
322        // The fixed points of BP correspond to stationary points of Bethe free energy
323        use crate::message_passing::SumProductAlgorithm;
324
325        let bp = SumProductAlgorithm::new(self.max_iterations, self.tolerance, self.damping);
326        bp.run(graph)
327    }
328
329    /// Compute Bethe free energy.
330    ///
331    /// F_Bethe = Σ_α E_bα[log bα - log ψα] - Σ_i (d_i - 1) H(b_i)
332    pub fn compute_free_energy(
333        &self,
334        graph: &FactorGraph,
335        variable_beliefs: &HashMap<String, ArrayD<f64>>,
336        factor_beliefs: &HashMap<String, ArrayD<f64>>,
337    ) -> Result<f64> {
338        let mut free_energy = 0.0;
339
340        // Factor contribution: Σ_α E_bα[log bα - log ψα]
341        for (factor_id, belief) in factor_beliefs {
342            if let Some(factor) = graph.get_factor(factor_id) {
343                // E[log bα]
344                let entropy_contrib: f64 = belief
345                    .iter()
346                    .map(|&p| if p > 1e-10 { -p * p.ln() } else { 0.0 })
347                    .sum();
348
349                // E[log ψα]
350                let mut energy_contrib = 0.0;
351                let total_size: usize = belief.shape().iter().product();
352                for linear_idx in 0..total_size {
353                    let mut assignment = Vec::new();
354                    let mut temp_idx = linear_idx;
355                    for &dim in belief.shape().iter().rev() {
356                        assignment.push(temp_idx % dim);
357                        temp_idx /= dim;
358                    }
359                    assignment.reverse();
360
361                    let b_val = belief[assignment.as_slice()];
362                    let psi_val = factor.values[assignment.as_slice()];
363                    if b_val > 1e-10 && psi_val > 1e-10 {
364                        energy_contrib += b_val * psi_val.ln();
365                    }
366                }
367
368                free_energy -= entropy_contrib;
369                free_energy -= energy_contrib;
370            }
371        }
372
373        // Variable contribution: Σ_i (d_i - 1) H(b_i)
374        for (var_name, belief) in variable_beliefs {
375            // Get degree of variable (number of adjacent factors)
376            let degree = if let Some(adjacent) = graph.get_adjacent_factors(var_name) {
377                adjacent.len()
378            } else {
379                0
380            };
381
382            if degree > 0 {
383                let entropy: f64 = belief
384                    .iter()
385                    .map(|&p| if p > 1e-10 { -p * p.ln() } else { 0.0 })
386                    .sum();
387
388                free_energy += (degree as f64 - 1.0) * entropy;
389            }
390        }
391
392        Ok(free_energy)
393    }
394
395    /// Compute factor beliefs from variable beliefs and factor potentials.
396    pub fn compute_factor_beliefs(
397        &self,
398        graph: &FactorGraph,
399        variable_beliefs: &HashMap<String, ArrayD<f64>>,
400    ) -> Result<HashMap<String, ArrayD<f64>>> {
401        let mut factor_beliefs = HashMap::new();
402
403        for factor_id in graph.factor_ids() {
404            if let Some(factor) = graph.get_factor(factor_id) {
405                // Start with factor potential
406                let mut belief = factor.clone();
407
408                // Multiply by variable beliefs (approximately - using product)
409                for var in &factor.variables {
410                    if let Some(var_belief) = variable_beliefs.get(var) {
411                        // Create a factor from variable belief
412                        let var_factor = Factor {
413                            name: format!("belief_{}", var),
414                            variables: vec![var.clone()],
415                            values: var_belief.clone(),
416                        };
417                        belief = belief.product(&var_factor)?;
418                    }
419                }
420
421                belief.normalize();
422                factor_beliefs.insert(factor_id.clone(), belief.values);
423            }
424        }
425
426        Ok(factor_beliefs)
427    }
428}
429
430/// Tree-reweighted belief propagation (TRW-BP).
431///
432/// Uses a convex combination of spanning trees to provide an upper bound
433/// on the log partition function. More robust than standard BP for loopy graphs.
434///
435/// Messages are reweighted by edge appearance probabilities ρ_e ∈ `[0,1]`.
436pub struct TreeReweightedBP {
437    /// Maximum iterations
438    pub max_iterations: usize,
439    /// Convergence tolerance
440    pub tolerance: f64,
441    /// Edge appearance probabilities (default: uniform)
442    pub edge_weights: HashMap<(String, String), f64>,
443}
444
445impl Default for TreeReweightedBP {
446    fn default() -> Self {
447        Self {
448            max_iterations: 100,
449            tolerance: 1e-6,
450            edge_weights: HashMap::new(),
451        }
452    }
453}
454
455impl TreeReweightedBP {
456    /// Create with custom parameters.
457    pub fn new(max_iterations: usize, tolerance: f64) -> Self {
458        Self {
459            max_iterations,
460            tolerance,
461            edge_weights: HashMap::new(),
462        }
463    }
464
465    /// Set edge appearance probability for a variable-factor edge.
466    pub fn set_edge_weight(&mut self, var: String, factor: String, weight: f64) {
467        self.edge_weights
468            .insert((var, factor), weight.clamp(0.0, 1.0));
469    }
470
471    /// Initialize uniform edge weights for all edges in graph.
472    pub fn initialize_uniform_weights(&mut self, graph: &FactorGraph) {
473        for var_name in graph.variable_names() {
474            if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
475                let weight = 1.0 / adjacent_factors.len() as f64;
476                for factor_id in adjacent_factors {
477                    self.edge_weights
478                        .insert((var_name.clone(), factor_id.clone()), weight);
479                }
480            }
481        }
482    }
483
484    /// Get edge weight (default to 1.0 if not set).
485    fn get_edge_weight(&self, var: &str, factor: &str) -> f64 {
486        self.edge_weights
487            .get(&(var.to_string(), factor.to_string()))
488            .copied()
489            .unwrap_or(1.0)
490    }
491
492    /// Run tree-reweighted belief propagation.
493    ///
494    /// Returns variable beliefs (marginals) and an upper bound on log Z.
495    pub fn run(&mut self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
496        // Initialize uniform weights if not set
497        if self.edge_weights.is_empty() {
498            self.initialize_uniform_weights(graph);
499        }
500
501        // Message storage: (var, factor) -> message
502        let mut messages: HashMap<(String, String), ArrayD<f64>> = HashMap::new();
503
504        // Initialize messages uniformly
505        for var_name in graph.variable_names() {
506            if let Some(var_node) = graph.get_variable(var_name) {
507                if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
508                    for factor_id in adjacent_factors {
509                        let init_msg = ArrayD::from_elem(
510                            vec![var_node.cardinality],
511                            1.0 / var_node.cardinality as f64,
512                        );
513                        messages.insert((var_name.clone(), factor_id.clone()), init_msg);
514                    }
515                }
516            }
517        }
518
519        // Iterative message passing
520        for iteration in 0..self.max_iterations {
521            let old_messages = messages.clone();
522
523            // Update all messages
524            for var_name in graph.variable_names() {
525                if let Some(var_node) = graph.get_variable(var_name) {
526                    if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
527                        for target_factor in adjacent_factors {
528                            // Compute reweighted message
529                            let mut message = ArrayD::ones(vec![var_node.cardinality])
530                                / var_node.cardinality as f64;
531
532                            // Multiply messages from other factors (with reweighting)
533                            for other_factor in adjacent_factors {
534                                if other_factor != target_factor {
535                                    if let Some(incoming) =
536                                        old_messages.get(&(var_name.clone(), other_factor.clone()))
537                                    {
538                                        let rho = self.get_edge_weight(var_name, other_factor);
539                                        // Reweighted message: m^ρ
540                                        let reweighted = incoming.mapv(|x| x.powf(rho));
541                                        message = &message * &reweighted;
542                                    }
543                                }
544                            }
545
546                            // Normalize
547                            let sum: f64 = message.iter().sum();
548                            if sum > 1e-10 {
549                                message /= sum;
550                            }
551
552                            messages.insert((var_name.clone(), target_factor.clone()), message);
553                        }
554                    }
555                }
556            }
557
558            // Check convergence
559            let mut max_delta = 0.0_f64;
560            for ((var, factor), new_msg) in &messages {
561                if let Some(old_msg) = old_messages.get(&(var.clone(), factor.clone())) {
562                    let delta: f64 = (new_msg - old_msg)
563                        .mapv(|x| x.abs())
564                        .iter()
565                        .fold(0.0_f64, |acc, &x| acc.max(x));
566                    max_delta = max_delta.max(delta);
567                }
568            }
569
570            if max_delta < self.tolerance {
571                break;
572            }
573
574            if iteration == self.max_iterations - 1 {
575                return Err(PgmError::ConvergenceFailure(format!(
576                    "TRW-BP did not converge after {} iterations (max_delta={})",
577                    self.max_iterations, max_delta
578                )));
579            }
580        }
581
582        // Compute beliefs from messages
583        let mut beliefs = HashMap::new();
584        for var_name in graph.variable_names() {
585            if let Some(var_node) = graph.get_variable(var_name) {
586                let mut belief =
587                    ArrayD::ones(vec![var_node.cardinality]) / var_node.cardinality as f64;
588
589                if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
590                    for factor_id in adjacent_factors {
591                        if let Some(message) = messages.get(&(var_name.clone(), factor_id.clone()))
592                        {
593                            let rho = self.get_edge_weight(var_name, factor_id);
594                            let reweighted = message.mapv(|x| x.powf(rho));
595                            belief = &belief * &reweighted;
596                        }
597                    }
598                }
599
600                // Normalize
601                let sum: f64 = belief.iter().sum();
602                if sum > 1e-10 {
603                    belief /= sum;
604                }
605
606                beliefs.insert(var_name.clone(), belief);
607            }
608        }
609
610        Ok(beliefs)
611    }
612
613    /// Compute upper bound on log partition function.
614    ///
615    /// log Z ≤ log Z_TRW = Σ_i ρ_i log Z_i
616    pub fn compute_log_partition_upper_bound(
617        &self,
618        _graph: &FactorGraph,
619        _beliefs: &HashMap<String, ArrayD<f64>>,
620    ) -> Result<f64> {
621        // Simplified implementation - full version requires factor beliefs
622        // and region-based computation
623        Ok(0.0)
624    }
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630    use approx::assert_abs_diff_eq;
631
632    #[test]
633    fn test_mean_field_single_variable() {
634        let mut graph = FactorGraph::new();
635        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
636
637        let mf = MeanFieldInference::default();
638        let result = mf.run(&graph);
639        assert!(result.is_ok());
640
641        let marginals = result.unwrap();
642        assert!(marginals.contains_key("x"));
643
644        // Should be uniform for single variable with no factors
645        let dist = &marginals["x"];
646        assert_abs_diff_eq!(dist[[0]], 0.5, epsilon = 1e-6);
647        assert_abs_diff_eq!(dist[[1]], 0.5, epsilon = 1e-6);
648    }
649
650    #[test]
651    fn test_mean_field_convergence() {
652        let mut graph = FactorGraph::new();
653        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
654        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
655
656        let mf = MeanFieldInference::new(50, 1e-6);
657        let result = mf.run(&graph);
658        assert!(result.is_ok());
659    }
660
661    #[test]
662    fn test_elbo_computation() {
663        let mut graph = FactorGraph::new();
664        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
665
666        let mf = MeanFieldInference::default();
667        let marginals = mf.run(&graph).unwrap();
668
669        let elbo = mf.compute_elbo(&graph, &marginals);
670        assert!(elbo.is_ok());
671    }
672
673    #[test]
674    fn test_bethe_approximation_single_variable() {
675        let mut graph = FactorGraph::new();
676        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
677
678        let bethe = BetheApproximation::default();
679        let result = bethe.run(&graph);
680        assert!(result.is_ok());
681
682        let marginals = result.unwrap();
683        assert!(marginals.contains_key("x"));
684
685        let dist = &marginals["x"];
686        assert_abs_diff_eq!(dist[[0]], 0.5, epsilon = 1e-6);
687        assert_abs_diff_eq!(dist[[1]], 0.5, epsilon = 1e-6);
688    }
689
690    #[test]
691    fn test_bethe_approximation_two_variables() {
692        let mut graph = FactorGraph::new();
693        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
694        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
695
696        let bethe = BetheApproximation::new(50, 1e-6, 0.0);
697        let result = bethe.run(&graph);
698        assert!(result.is_ok());
699
700        let marginals = result.unwrap();
701        assert_eq!(marginals.len(), 2);
702    }
703
704    #[test]
705    fn test_bethe_free_energy() {
706        let mut graph = FactorGraph::new();
707        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
708
709        let bethe = BetheApproximation::default();
710        let marginals = bethe.run(&graph).unwrap();
711        let factor_beliefs = bethe.compute_factor_beliefs(&graph, &marginals).unwrap();
712
713        let free_energy = bethe.compute_free_energy(&graph, &marginals, &factor_beliefs);
714        assert!(free_energy.is_ok());
715    }
716
717    #[test]
718    fn test_bethe_with_damping() {
719        let mut graph = FactorGraph::new();
720        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
721        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
722
723        let bethe = BetheApproximation::new(50, 1e-6, 0.5);
724        let result = bethe.run(&graph);
725        assert!(result.is_ok());
726    }
727
728    #[test]
729    fn test_trw_bp_single_variable() {
730        let mut graph = FactorGraph::new();
731        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
732
733        let mut trw = TreeReweightedBP::default();
734        let result = trw.run(&graph);
735        assert!(result.is_ok());
736
737        let beliefs = result.unwrap();
738        assert!(beliefs.contains_key("x"));
739
740        let dist = &beliefs["x"];
741        assert_abs_diff_eq!(dist[[0]], 0.5, epsilon = 1e-6);
742        assert_abs_diff_eq!(dist[[1]], 0.5, epsilon = 1e-6);
743    }
744
745    #[test]
746    fn test_trw_bp_two_variables() {
747        let mut graph = FactorGraph::new();
748        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
749        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
750
751        let mut trw = TreeReweightedBP::new(50, 1e-6);
752        let result = trw.run(&graph);
753        assert!(result.is_ok());
754
755        let beliefs = result.unwrap();
756        assert_eq!(beliefs.len(), 2);
757    }
758
759    #[test]
760    fn test_trw_bp_custom_weights() {
761        let mut graph = FactorGraph::new();
762        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
763        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
764
765        let mut trw = TreeReweightedBP::default();
766        trw.set_edge_weight("x".to_string(), "f1".to_string(), 0.5);
767
768        // Should handle missing edges gracefully
769        let result = trw.run(&graph);
770        assert!(result.is_ok());
771    }
772
773    #[test]
774    fn test_trw_bp_uniform_initialization() {
775        let mut graph = FactorGraph::new();
776        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
777
778        let mut trw = TreeReweightedBP::default();
779        trw.initialize_uniform_weights(&graph);
780
781        assert!(!trw.edge_weights.is_empty() || graph.factor_ids().count() == 0);
782    }
783
784    #[test]
785    fn test_trw_bp_partition_bound() {
786        let mut graph = FactorGraph::new();
787        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
788
789        let mut trw = TreeReweightedBP::default();
790        let beliefs = trw.run(&graph).unwrap();
791
792        let bound = trw.compute_log_partition_upper_bound(&graph, &beliefs);
793        assert!(bound.is_ok());
794    }
795}