tensorlogic_quantrs_hooks/
expectation_propagation.rs

1//! Expectation Propagation (EP) for approximate inference.
2//!
3//! EP is an iterative algorithm that approximates complex posterior distributions
4//! using products of simpler "site" approximations via moment matching.
5//!
6//! # Algorithm
7//!
8//! 1. Initialize site approximations (factors)
9//! 2. For each factor:
10//!    a. Compute cavity distribution (remove current site)
11//!    b. Compute tilted distribution (include true factor)
12//!    c. Moment match to update site approximation
13//! 3. Repeat until convergence
14//!
15//! # References
16//!
17//! - Minka, "Expectation Propagation for approximate Bayesian inference" (2001)
18//! - Bishop, "Pattern Recognition and Machine Learning" (2006), Section 10.7
19
20use crate::{Factor, FactorGraph, PgmError, Result};
21use scirs2_core::ndarray::ArrayD;
22use std::collections::HashMap;
23
24/// Site approximation for a single factor.
25///
26/// In EP, each true factor f_i(x) is approximated by a simpler site s_i(x).
27/// For discrete distributions, we store the site as a factor.
28#[derive(Debug, Clone)]
29pub struct Site {
30    /// The site approximation (as a factor)
31    pub factor: Factor,
32    /// Variables this site depends on
33    pub variables: Vec<String>,
34}
35
36impl Site {
37    /// Create a new site initialized to uniform distribution.
38    pub fn new_uniform(
39        name: String,
40        variables: Vec<String>,
41        cardinalities: &[usize],
42    ) -> Result<Self> {
43        let total_size: usize = cardinalities.iter().product();
44        let uniform_value = 1.0 / total_size as f64;
45        let values = ArrayD::from_elem(cardinalities.to_vec(), uniform_value);
46
47        let factor = Factor::new(name, variables.clone(), values)?;
48        Ok(Self { factor, variables })
49    }
50
51    /// Create a new site from a factor.
52    pub fn from_factor(factor: Factor) -> Self {
53        let variables = factor.variables.clone();
54        Self { factor, variables }
55    }
56}
57
58/// Expectation Propagation algorithm for approximate inference.
59///
60/// EP approximates the posterior distribution by iteratively refining
61/// local approximations using moment matching.
62pub struct ExpectationPropagation {
63    /// Maximum number of iterations
64    max_iterations: usize,
65    /// Convergence tolerance
66    tolerance: f64,
67    /// Damping factor (0.0 = no damping, 1.0 = full damping)
68    damping: f64,
69    /// Minimum value for numerical stability
70    min_value: f64,
71}
72
73impl Default for ExpectationPropagation {
74    fn default() -> Self {
75        Self::new(100, 1e-6, 0.0)
76    }
77}
78
79impl ExpectationPropagation {
80    /// Create a new EP algorithm with custom parameters.
81    pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
82        Self {
83            max_iterations,
84            tolerance,
85            damping,
86            min_value: 1e-10,
87        }
88    }
89
90    /// Run EP inference on a factor graph.
91    ///
92    /// Returns the approximate marginal distributions for each variable.
93    pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
94        // Initialize sites
95        let mut sites = self.initialize_sites(graph)?;
96
97        // Compute initial approximation
98        let mut approx = self.compute_global_approximation(graph, &sites)?;
99
100        // EP iterations
101        for iteration in 0..self.max_iterations {
102            let mut max_change: f64 = 0.0;
103
104            // Update each site
105            for (factor_idx, factor) in graph.factors().enumerate() {
106                // Compute cavity distribution (remove current site)
107                let cavity = self.compute_cavity(&approx, &sites[factor_idx])?;
108
109                // Compute tilted distribution (include true factor)
110                let tilted = self.compute_tilted(&cavity, factor)?;
111
112                // Moment match to update site
113                let new_site = self.moment_match(&cavity, &tilted, &sites[factor_idx])?;
114
115                // Apply damping
116                let damped_site = self.apply_damping(&sites[factor_idx], &new_site)?;
117
118                // Compute change
119                let change = self.compute_site_change(&sites[factor_idx], &damped_site)?;
120                max_change = max_change.max(change);
121
122                // Update site
123                sites[factor_idx] = damped_site;
124            }
125
126            // Recompute global approximation
127            approx = self.compute_global_approximation(graph, &sites)?;
128
129            // Check convergence
130            if max_change < self.tolerance {
131                eprintln!(
132                    "EP converged in {} iterations (max change: {:.6})",
133                    iteration + 1,
134                    max_change
135                );
136                break;
137            }
138
139            if iteration == self.max_iterations - 1 {
140                eprintln!(
141                    "EP reached maximum iterations ({}) with max change: {:.6}",
142                    self.max_iterations, max_change
143                );
144            }
145        }
146
147        // Extract marginals from global approximation
148        self.extract_marginals(graph, &approx, &sites)
149    }
150
151    /// Initialize sites uniformly.
152    fn initialize_sites(&self, graph: &FactorGraph) -> Result<Vec<Site>> {
153        let mut sites = Vec::new();
154
155        for (idx, factor) in graph.factors().enumerate() {
156            let cardinalities: Vec<usize> = factor
157                .variables
158                .iter()
159                .map(|var| graph.get_variable(var).map(|v| v.cardinality).unwrap_or(2))
160                .collect();
161
162            let site = Site::new_uniform(
163                format!("site_{}", idx),
164                factor.variables.clone(),
165                &cardinalities,
166            )?;
167
168            sites.push(site);
169        }
170
171        Ok(sites)
172    }
173
174    /// Compute global approximation as product of all sites.
175    fn compute_global_approximation(&self, _graph: &FactorGraph, sites: &[Site]) -> Result<Factor> {
176        if sites.is_empty() {
177            return Err(PgmError::InvalidGraph(
178                "No sites to compute approximation".to_string(),
179            ));
180        }
181
182        let mut result = sites[0].factor.clone();
183
184        for site in sites.iter().skip(1) {
185            result = result.product(&site.factor)?;
186        }
187
188        // Normalize
189        result.normalize();
190
191        Ok(result)
192    }
193
194    /// Compute cavity distribution by removing a site.
195    fn compute_cavity(&self, approx: &Factor, site: &Site) -> Result<Factor> {
196        // First, marginalize the approximation to the variables in the site
197        // to ensure both have the same scope before division
198        let approx_marginal = if approx.variables == site.variables {
199            approx.clone()
200        } else {
201            approx.marginalize_out_all_except(&site.variables)?
202        };
203
204        // Cavity = approximation / site
205        let cavity = approx_marginal.divide(&site.factor)?;
206        Ok(cavity)
207    }
208
209    /// Compute tilted distribution by including the true factor.
210    fn compute_tilted(&self, cavity: &Factor, true_factor: &Factor) -> Result<Factor> {
211        // Tilted = cavity × true_factor
212        let tilted = cavity.product(true_factor)?;
213        Ok(tilted)
214    }
215
216    /// Moment match: find site that makes cavity × site ≈ tilted.
217    fn moment_match(&self, cavity: &Factor, tilted: &Factor, _old_site: &Site) -> Result<Site> {
218        // For discrete distributions, we can compute the new site as:
219        // new_site = tilted / cavity
220
221        let new_factor = tilted.divide(cavity)?;
222
223        // Ensure numerical stability
224        let mut stabilized = new_factor.clone();
225        stabilized.values.mapv_inplace(|v| v.max(self.min_value));
226
227        Ok(Site::from_factor(stabilized))
228    }
229
230    /// Apply damping to site update.
231    fn apply_damping(&self, old_site: &Site, new_site: &Site) -> Result<Site> {
232        if self.damping == 0.0 {
233            return Ok(new_site.clone());
234        }
235
236        // Damped = (1 - damping) × new + damping × old
237        let old_values = &old_site.factor.values;
238        let new_values = &new_site.factor.values;
239
240        let damped_values = (1.0 - self.damping) * new_values + self.damping * old_values;
241
242        let damped_factor = Factor::new(
243            new_site.factor.name.clone(),
244            new_site.factor.variables.clone(),
245            damped_values,
246        )?;
247
248        Ok(Site::from_factor(damped_factor))
249    }
250
251    /// Compute change between two sites (for convergence check).
252    fn compute_site_change(&self, old_site: &Site, new_site: &Site) -> Result<f64> {
253        // Compute L1 distance between site parameters
254        let diff = &new_site.factor.values - &old_site.factor.values;
255        let change = diff.mapv(|v| v.abs()).sum();
256        Ok(change)
257    }
258
259    /// Extract marginals from the global approximation.
260    fn extract_marginals(
261        &self,
262        graph: &FactorGraph,
263        approx: &Factor,
264        _sites: &[Site],
265    ) -> Result<HashMap<String, ArrayD<f64>>> {
266        let mut marginals = HashMap::new();
267
268        for (var, _) in graph.variables() {
269            let marginal = approx.marginalize_out_all_except(std::slice::from_ref(var))?;
270            let mut normalized = marginal.clone();
271            normalized.normalize();
272            marginals.insert(var.clone(), normalized.values);
273        }
274
275        Ok(marginals)
276    }
277}
278
279/// Gaussian site approximation for continuous variables.
280///
281/// In the Gaussian case, sites are parameterized by natural parameters (precision, precision-weighted mean).
282#[derive(Debug, Clone)]
283pub struct GaussianSite {
284    /// Variable name
285    pub variable: String,
286    /// Precision (inverse variance)
287    pub precision: f64,
288    /// Precision-weighted mean (precision × mean)
289    pub precision_mean: f64,
290}
291
292impl GaussianSite {
293    /// Create a new Gaussian site with given parameters.
294    pub fn new(variable: String, precision: f64, precision_mean: f64) -> Self {
295        Self {
296            variable,
297            precision,
298            precision_mean,
299        }
300    }
301
302    /// Create a uniform (uninformative) Gaussian site.
303    pub fn uniform(variable: String) -> Self {
304        Self {
305            variable,
306            precision: 0.0,
307            precision_mean: 0.0,
308        }
309    }
310
311    /// Compute mean from natural parameters.
312    pub fn mean(&self) -> f64 {
313        if self.precision > 1e-10 {
314            self.precision_mean / self.precision
315        } else {
316            0.0
317        }
318    }
319
320    /// Compute variance from precision.
321    pub fn variance(&self) -> f64 {
322        if self.precision > 1e-10 {
323            1.0 / self.precision
324        } else {
325            f64::INFINITY
326        }
327    }
328
329    /// Product of two Gaussian sites (in natural parameterization).
330    pub fn product(&self, other: &GaussianSite) -> Self {
331        Self {
332            variable: self.variable.clone(),
333            precision: self.precision + other.precision,
334            precision_mean: self.precision_mean + other.precision_mean,
335        }
336    }
337
338    /// Division of two Gaussian sites (in natural parameterization).
339    pub fn divide(&self, other: &GaussianSite) -> Self {
340        Self {
341            variable: self.variable.clone(),
342            precision: self.precision - other.precision,
343            precision_mean: self.precision_mean - other.precision_mean,
344        }
345    }
346}
347
348/// Gaussian EP for continuous variables with moment matching.
349#[allow(dead_code)]
350pub struct GaussianEP {
351    /// Maximum number of iterations
352    max_iterations: usize,
353    /// Convergence tolerance
354    tolerance: f64,
355    /// Damping factor
356    damping: f64,
357}
358
359impl Default for GaussianEP {
360    fn default() -> Self {
361        Self::new(100, 1e-6, 0.0)
362    }
363}
364
365impl GaussianEP {
366    /// Create a new Gaussian EP instance.
367    pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
368        Self {
369            max_iterations,
370            tolerance,
371            damping,
372        }
373    }
374
375    /// Compute Gaussian moments (mean, variance) from a tilted distribution.
376    ///
377    /// This is a placeholder for moment computation. In practice, you would:
378    /// 1. Compute cavity distribution
379    /// 2. Multiply by true factor
380    /// 3. Compute mean and variance of the result
381    pub fn compute_moments(
382        &self,
383        cavity: &GaussianSite,
384        _true_factor_callback: impl Fn(f64) -> f64,
385    ) -> (f64, f64) {
386        // This is simplified - in practice, you'd integrate to get moments
387        let mean = cavity.mean();
388        let variance = cavity.variance();
389        (mean, variance)
390    }
391
392    /// Match moments to update site.
393    pub fn match_moments(
394        &self,
395        cavity: &GaussianSite,
396        tilted_mean: f64,
397        tilted_var: f64,
398    ) -> GaussianSite {
399        // Compute new site such that cavity × site has given moments
400        let new_precision = 1.0 / tilted_var - cavity.precision;
401        let new_precision_mean = tilted_mean / tilted_var - cavity.precision_mean;
402
403        GaussianSite::new(
404            cavity.variable.clone(),
405            new_precision.max(0.0), // Ensure non-negative
406            new_precision_mean,
407        )
408    }
409
410    /// Apply damping to site update.
411    pub fn damp_site(&self, old_site: &GaussianSite, new_site: &GaussianSite) -> GaussianSite {
412        if self.damping == 0.0 {
413            return new_site.clone();
414        }
415
416        let damped_precision =
417            (1.0 - self.damping) * new_site.precision + self.damping * old_site.precision;
418        let damped_precision_mean =
419            (1.0 - self.damping) * new_site.precision_mean + self.damping * old_site.precision_mean;
420
421        GaussianSite::new(
422            new_site.variable.clone(),
423            damped_precision,
424            damped_precision_mean,
425        )
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use approx::assert_abs_diff_eq;
433    use scirs2_core::ndarray::Array;
434
435    #[test]
436    fn test_site_creation() {
437        let site = Site::new_uniform("test_site".to_string(), vec!["X".to_string()], &[2]).unwrap();
438
439        assert_eq!(site.variables.len(), 1);
440        assert_eq!(site.factor.variables[0], "X");
441
442        // Should be uniform
443        let sum: f64 = site.factor.values.sum();
444        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
445    }
446
447    #[test]
448    fn test_gaussian_site_moments() {
449        let site = GaussianSite::new("X".to_string(), 2.0, 4.0);
450
451        // mean = precision_mean / precision = 4.0 / 2.0 = 2.0
452        assert_abs_diff_eq!(site.mean(), 2.0, epsilon = 1e-10);
453
454        // variance = 1 / precision = 1 / 2.0 = 0.5
455        assert_abs_diff_eq!(site.variance(), 0.5, epsilon = 1e-10);
456    }
457
458    #[test]
459    fn test_gaussian_site_product() {
460        let site1 = GaussianSite::new("X".to_string(), 2.0, 4.0);
461        let site2 = GaussianSite::new("X".to_string(), 3.0, 6.0);
462
463        let product = site1.product(&site2);
464
465        // Precision adds: 2.0 + 3.0 = 5.0
466        assert_abs_diff_eq!(product.precision, 5.0, epsilon = 1e-10);
467
468        // Precision-weighted means add: 4.0 + 6.0 = 10.0
469        assert_abs_diff_eq!(product.precision_mean, 10.0, epsilon = 1e-10);
470    }
471
472    #[test]
473    fn test_gaussian_site_divide() {
474        let site1 = GaussianSite::new("X".to_string(), 5.0, 10.0);
475        let site2 = GaussianSite::new("X".to_string(), 2.0, 4.0);
476
477        let quotient = site1.divide(&site2);
478
479        // Precision subtracts: 5.0 - 2.0 = 3.0
480        assert_abs_diff_eq!(quotient.precision, 3.0, epsilon = 1e-10);
481
482        // Precision-weighted means subtract: 10.0 - 4.0 = 6.0
483        assert_abs_diff_eq!(quotient.precision_mean, 6.0, epsilon = 1e-10);
484    }
485
486    #[test]
487    fn test_ep_initialization() {
488        let ep = ExpectationPropagation::new(50, 1e-5, 0.5);
489        assert_eq!(ep.max_iterations, 50);
490        assert_abs_diff_eq!(ep.tolerance, 1e-5, epsilon = 1e-10);
491        assert_abs_diff_eq!(ep.damping, 0.5, epsilon = 1e-10);
492    }
493
494    #[test]
495    fn test_ep_simple_graph() {
496        use crate::FactorGraph;
497
498        // Create a simple factor graph with one binary variable
499        let mut graph = FactorGraph::new();
500        graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
501
502        // Add a simple factor P(X) = [0.7, 0.3]
503        let values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
504            .unwrap()
505            .into_dyn();
506        let factor = Factor::new("P(X)".to_string(), vec!["X".to_string()], values).unwrap();
507        graph.add_factor(factor).unwrap();
508
509        // Run EP
510        let ep = ExpectationPropagation::default();
511        let marginals = ep.run(&graph).unwrap();
512
513        // Check that we got a marginal for X
514        assert!(marginals.contains_key("X"));
515
516        let marginal = &marginals["X"];
517        assert_eq!(marginal.ndim(), 1);
518        assert_eq!(marginal.len(), 2);
519
520        // Should be normalized
521        let sum: f64 = marginal.sum();
522        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
523    }
524
525    #[test]
526    fn test_gaussian_ep_moment_matching() {
527        let gep = GaussianEP::default();
528
529        // Cavity distribution: N(mean=0, var=1) => precision=1, precision_mean=0
530        let cavity = GaussianSite::new("X".to_string(), 1.0, 0.0);
531
532        // Tilted distribution has mean=2, var=0.5
533        let tilted_mean = 2.0;
534        let tilted_var = 0.5;
535
536        // Match moments
537        let new_site = gep.match_moments(&cavity, tilted_mean, tilted_var);
538
539        // Verify product has correct moments
540        let product = cavity.product(&new_site);
541
542        assert_abs_diff_eq!(product.mean(), tilted_mean, epsilon = 1e-6);
543        assert_abs_diff_eq!(product.variance(), tilted_var, epsilon = 1e-6);
544    }
545
546    #[test]
547    fn test_ep_two_factor_graph() {
548        use crate::FactorGraph;
549
550        // Create a factor graph with two variables and two factors
551        let mut graph = FactorGraph::new();
552        graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
553        graph.add_variable_with_card("Y".to_string(), "Binary".to_string(), 2);
554
555        // Factor P(X) = [0.6, 0.4]
556        let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
557            .unwrap()
558            .into_dyn();
559        let px = Factor::new("P(X)".to_string(), vec!["X".to_string()], px_values).unwrap();
560        graph.add_factor(px).unwrap();
561
562        // Factor P(Y|X)
563        let pyx_values = Array::from_shape_vec(
564            vec![2, 2],
565            vec![0.8, 0.2, 0.3, 0.7], // P(Y|X=0), P(Y|X=1)
566        )
567        .unwrap()
568        .into_dyn();
569        let pyx = Factor::new(
570            "P(Y|X)".to_string(),
571            vec!["X".to_string(), "Y".to_string()],
572            pyx_values,
573        )
574        .unwrap();
575        graph.add_factor(pyx).unwrap();
576
577        // Run EP
578        let ep = ExpectationPropagation::new(100, 1e-6, 0.0);
579        let marginals = ep.run(&graph).unwrap();
580
581        // Check marginals
582        assert!(marginals.contains_key("X"));
583        assert!(marginals.contains_key("Y"));
584
585        // Both should be normalized
586        let sum_x: f64 = marginals["X"].sum();
587        let sum_y: f64 = marginals["Y"].sum();
588        assert_abs_diff_eq!(sum_x, 1.0, epsilon = 1e-6);
589        assert_abs_diff_eq!(sum_y, 1.0, epsilon = 1e-6);
590    }
591}