1use crate::{Factor, FactorGraph, PgmError, Result};
21use scirs2_core::ndarray::ArrayD;
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
29pub struct Site {
30 pub factor: Factor,
32 pub variables: Vec<String>,
34}
35
36impl Site {
37 pub fn new_uniform(
39 name: String,
40 variables: Vec<String>,
41 cardinalities: &[usize],
42 ) -> Result<Self> {
43 let total_size: usize = cardinalities.iter().product();
44 let uniform_value = 1.0 / total_size as f64;
45 let values = ArrayD::from_elem(cardinalities.to_vec(), uniform_value);
46
47 let factor = Factor::new(name, variables.clone(), values)?;
48 Ok(Self { factor, variables })
49 }
50
51 pub fn from_factor(factor: Factor) -> Self {
53 let variables = factor.variables.clone();
54 Self { factor, variables }
55 }
56}
57
58pub struct ExpectationPropagation {
63 max_iterations: usize,
65 tolerance: f64,
67 damping: f64,
69 min_value: f64,
71}
72
73impl Default for ExpectationPropagation {
74 fn default() -> Self {
75 Self::new(100, 1e-6, 0.0)
76 }
77}
78
79impl ExpectationPropagation {
80 pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
82 Self {
83 max_iterations,
84 tolerance,
85 damping,
86 min_value: 1e-10,
87 }
88 }
89
90 pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
94 let mut sites = self.initialize_sites(graph)?;
96
97 let mut approx = self.compute_global_approximation(graph, &sites)?;
99
100 for iteration in 0..self.max_iterations {
102 let mut max_change: f64 = 0.0;
103
104 for (factor_idx, factor) in graph.factors().enumerate() {
106 let cavity = self.compute_cavity(&approx, &sites[factor_idx])?;
108
109 let tilted = self.compute_tilted(&cavity, factor)?;
111
112 let new_site = self.moment_match(&cavity, &tilted, &sites[factor_idx])?;
114
115 let damped_site = self.apply_damping(&sites[factor_idx], &new_site)?;
117
118 let change = self.compute_site_change(&sites[factor_idx], &damped_site)?;
120 max_change = max_change.max(change);
121
122 sites[factor_idx] = damped_site;
124 }
125
126 approx = self.compute_global_approximation(graph, &sites)?;
128
129 if max_change < self.tolerance {
131 eprintln!(
132 "EP converged in {} iterations (max change: {:.6})",
133 iteration + 1,
134 max_change
135 );
136 break;
137 }
138
139 if iteration == self.max_iterations - 1 {
140 eprintln!(
141 "EP reached maximum iterations ({}) with max change: {:.6}",
142 self.max_iterations, max_change
143 );
144 }
145 }
146
147 self.extract_marginals(graph, &approx, &sites)
149 }
150
151 fn initialize_sites(&self, graph: &FactorGraph) -> Result<Vec<Site>> {
153 let mut sites = Vec::new();
154
155 for (idx, factor) in graph.factors().enumerate() {
156 let cardinalities: Vec<usize> = factor
157 .variables
158 .iter()
159 .map(|var| graph.get_variable(var).map(|v| v.cardinality).unwrap_or(2))
160 .collect();
161
162 let site = Site::new_uniform(
163 format!("site_{}", idx),
164 factor.variables.clone(),
165 &cardinalities,
166 )?;
167
168 sites.push(site);
169 }
170
171 Ok(sites)
172 }
173
174 fn compute_global_approximation(&self, _graph: &FactorGraph, sites: &[Site]) -> Result<Factor> {
176 if sites.is_empty() {
177 return Err(PgmError::InvalidGraph(
178 "No sites to compute approximation".to_string(),
179 ));
180 }
181
182 let mut result = sites[0].factor.clone();
183
184 for site in sites.iter().skip(1) {
185 result = result.product(&site.factor)?;
186 }
187
188 result.normalize();
190
191 Ok(result)
192 }
193
194 fn compute_cavity(&self, approx: &Factor, site: &Site) -> Result<Factor> {
196 let approx_marginal = if approx.variables == site.variables {
199 approx.clone()
200 } else {
201 approx.marginalize_out_all_except(&site.variables)?
202 };
203
204 let cavity = approx_marginal.divide(&site.factor)?;
206 Ok(cavity)
207 }
208
209 fn compute_tilted(&self, cavity: &Factor, true_factor: &Factor) -> Result<Factor> {
211 let tilted = cavity.product(true_factor)?;
213 Ok(tilted)
214 }
215
216 fn moment_match(&self, cavity: &Factor, tilted: &Factor, _old_site: &Site) -> Result<Site> {
218 let new_factor = tilted.divide(cavity)?;
222
223 let mut stabilized = new_factor.clone();
225 stabilized.values.mapv_inplace(|v| v.max(self.min_value));
226
227 Ok(Site::from_factor(stabilized))
228 }
229
230 fn apply_damping(&self, old_site: &Site, new_site: &Site) -> Result<Site> {
232 if self.damping == 0.0 {
233 return Ok(new_site.clone());
234 }
235
236 let old_values = &old_site.factor.values;
238 let new_values = &new_site.factor.values;
239
240 let damped_values = (1.0 - self.damping) * new_values + self.damping * old_values;
241
242 let damped_factor = Factor::new(
243 new_site.factor.name.clone(),
244 new_site.factor.variables.clone(),
245 damped_values,
246 )?;
247
248 Ok(Site::from_factor(damped_factor))
249 }
250
251 fn compute_site_change(&self, old_site: &Site, new_site: &Site) -> Result<f64> {
253 let diff = &new_site.factor.values - &old_site.factor.values;
255 let change = diff.mapv(|v| v.abs()).sum();
256 Ok(change)
257 }
258
259 fn extract_marginals(
261 &self,
262 graph: &FactorGraph,
263 approx: &Factor,
264 _sites: &[Site],
265 ) -> Result<HashMap<String, ArrayD<f64>>> {
266 let mut marginals = HashMap::new();
267
268 for (var, _) in graph.variables() {
269 let marginal = approx.marginalize_out_all_except(std::slice::from_ref(var))?;
270 let mut normalized = marginal.clone();
271 normalized.normalize();
272 marginals.insert(var.clone(), normalized.values);
273 }
274
275 Ok(marginals)
276 }
277}
278
279#[derive(Debug, Clone)]
283pub struct GaussianSite {
284 pub variable: String,
286 pub precision: f64,
288 pub precision_mean: f64,
290}
291
292impl GaussianSite {
293 pub fn new(variable: String, precision: f64, precision_mean: f64) -> Self {
295 Self {
296 variable,
297 precision,
298 precision_mean,
299 }
300 }
301
302 pub fn uniform(variable: String) -> Self {
304 Self {
305 variable,
306 precision: 0.0,
307 precision_mean: 0.0,
308 }
309 }
310
311 pub fn mean(&self) -> f64 {
313 if self.precision > 1e-10 {
314 self.precision_mean / self.precision
315 } else {
316 0.0
317 }
318 }
319
320 pub fn variance(&self) -> f64 {
322 if self.precision > 1e-10 {
323 1.0 / self.precision
324 } else {
325 f64::INFINITY
326 }
327 }
328
329 pub fn product(&self, other: &GaussianSite) -> Self {
331 Self {
332 variable: self.variable.clone(),
333 precision: self.precision + other.precision,
334 precision_mean: self.precision_mean + other.precision_mean,
335 }
336 }
337
338 pub fn divide(&self, other: &GaussianSite) -> Self {
340 Self {
341 variable: self.variable.clone(),
342 precision: self.precision - other.precision,
343 precision_mean: self.precision_mean - other.precision_mean,
344 }
345 }
346}
347
348#[allow(dead_code)]
350pub struct GaussianEP {
351 max_iterations: usize,
353 tolerance: f64,
355 damping: f64,
357}
358
359impl Default for GaussianEP {
360 fn default() -> Self {
361 Self::new(100, 1e-6, 0.0)
362 }
363}
364
365impl GaussianEP {
366 pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
368 Self {
369 max_iterations,
370 tolerance,
371 damping,
372 }
373 }
374
375 pub fn compute_moments(
382 &self,
383 cavity: &GaussianSite,
384 _true_factor_callback: impl Fn(f64) -> f64,
385 ) -> (f64, f64) {
386 let mean = cavity.mean();
388 let variance = cavity.variance();
389 (mean, variance)
390 }
391
392 pub fn match_moments(
394 &self,
395 cavity: &GaussianSite,
396 tilted_mean: f64,
397 tilted_var: f64,
398 ) -> GaussianSite {
399 let new_precision = 1.0 / tilted_var - cavity.precision;
401 let new_precision_mean = tilted_mean / tilted_var - cavity.precision_mean;
402
403 GaussianSite::new(
404 cavity.variable.clone(),
405 new_precision.max(0.0), new_precision_mean,
407 )
408 }
409
410 pub fn damp_site(&self, old_site: &GaussianSite, new_site: &GaussianSite) -> GaussianSite {
412 if self.damping == 0.0 {
413 return new_site.clone();
414 }
415
416 let damped_precision =
417 (1.0 - self.damping) * new_site.precision + self.damping * old_site.precision;
418 let damped_precision_mean =
419 (1.0 - self.damping) * new_site.precision_mean + self.damping * old_site.precision_mean;
420
421 GaussianSite::new(
422 new_site.variable.clone(),
423 damped_precision,
424 damped_precision_mean,
425 )
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use approx::assert_abs_diff_eq;
433 use scirs2_core::ndarray::Array;
434
435 #[test]
436 fn test_site_creation() {
437 let site = Site::new_uniform("test_site".to_string(), vec!["X".to_string()], &[2]).unwrap();
438
439 assert_eq!(site.variables.len(), 1);
440 assert_eq!(site.factor.variables[0], "X");
441
442 let sum: f64 = site.factor.values.sum();
444 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
445 }
446
447 #[test]
448 fn test_gaussian_site_moments() {
449 let site = GaussianSite::new("X".to_string(), 2.0, 4.0);
450
451 assert_abs_diff_eq!(site.mean(), 2.0, epsilon = 1e-10);
453
454 assert_abs_diff_eq!(site.variance(), 0.5, epsilon = 1e-10);
456 }
457
458 #[test]
459 fn test_gaussian_site_product() {
460 let site1 = GaussianSite::new("X".to_string(), 2.0, 4.0);
461 let site2 = GaussianSite::new("X".to_string(), 3.0, 6.0);
462
463 let product = site1.product(&site2);
464
465 assert_abs_diff_eq!(product.precision, 5.0, epsilon = 1e-10);
467
468 assert_abs_diff_eq!(product.precision_mean, 10.0, epsilon = 1e-10);
470 }
471
472 #[test]
473 fn test_gaussian_site_divide() {
474 let site1 = GaussianSite::new("X".to_string(), 5.0, 10.0);
475 let site2 = GaussianSite::new("X".to_string(), 2.0, 4.0);
476
477 let quotient = site1.divide(&site2);
478
479 assert_abs_diff_eq!(quotient.precision, 3.0, epsilon = 1e-10);
481
482 assert_abs_diff_eq!(quotient.precision_mean, 6.0, epsilon = 1e-10);
484 }
485
486 #[test]
487 fn test_ep_initialization() {
488 let ep = ExpectationPropagation::new(50, 1e-5, 0.5);
489 assert_eq!(ep.max_iterations, 50);
490 assert_abs_diff_eq!(ep.tolerance, 1e-5, epsilon = 1e-10);
491 assert_abs_diff_eq!(ep.damping, 0.5, epsilon = 1e-10);
492 }
493
494 #[test]
495 fn test_ep_simple_graph() {
496 use crate::FactorGraph;
497
498 let mut graph = FactorGraph::new();
500 graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
501
502 let values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
504 .unwrap()
505 .into_dyn();
506 let factor = Factor::new("P(X)".to_string(), vec!["X".to_string()], values).unwrap();
507 graph.add_factor(factor).unwrap();
508
509 let ep = ExpectationPropagation::default();
511 let marginals = ep.run(&graph).unwrap();
512
513 assert!(marginals.contains_key("X"));
515
516 let marginal = &marginals["X"];
517 assert_eq!(marginal.ndim(), 1);
518 assert_eq!(marginal.len(), 2);
519
520 let sum: f64 = marginal.sum();
522 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
523 }
524
525 #[test]
526 fn test_gaussian_ep_moment_matching() {
527 let gep = GaussianEP::default();
528
529 let cavity = GaussianSite::new("X".to_string(), 1.0, 0.0);
531
532 let tilted_mean = 2.0;
534 let tilted_var = 0.5;
535
536 let new_site = gep.match_moments(&cavity, tilted_mean, tilted_var);
538
539 let product = cavity.product(&new_site);
541
542 assert_abs_diff_eq!(product.mean(), tilted_mean, epsilon = 1e-6);
543 assert_abs_diff_eq!(product.variance(), tilted_var, epsilon = 1e-6);
544 }
545
546 #[test]
547 fn test_ep_two_factor_graph() {
548 use crate::FactorGraph;
549
550 let mut graph = FactorGraph::new();
552 graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
553 graph.add_variable_with_card("Y".to_string(), "Binary".to_string(), 2);
554
555 let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
557 .unwrap()
558 .into_dyn();
559 let px = Factor::new("P(X)".to_string(), vec!["X".to_string()], px_values).unwrap();
560 graph.add_factor(px).unwrap();
561
562 let pyx_values = Array::from_shape_vec(
564 vec![2, 2],
565 vec![0.8, 0.2, 0.3, 0.7], )
567 .unwrap()
568 .into_dyn();
569 let pyx = Factor::new(
570 "P(Y|X)".to_string(),
571 vec!["X".to_string(), "Y".to_string()],
572 pyx_values,
573 )
574 .unwrap();
575 graph.add_factor(pyx).unwrap();
576
577 let ep = ExpectationPropagation::new(100, 1e-6, 0.0);
579 let marginals = ep.run(&graph).unwrap();
580
581 assert!(marginals.contains_key("X"));
583 assert!(marginals.contains_key("Y"));
584
585 let sum_x: f64 = marginals["X"].sum();
587 let sum_y: f64 = marginals["Y"].sum();
588 assert_abs_diff_eq!(sum_x, 1.0, epsilon = 1e-6);
589 assert_abs_diff_eq!(sum_y, 1.0, epsilon = 1e-6);
590 }
591}