1use crate::{Factor, FactorGraph, PgmError, Result};
21use scirs2_core::ndarray::ArrayD;
22use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
29pub struct Site {
30 pub factor: Factor,
32 pub variables: Vec<String>,
34}
35
36impl Site {
37 pub fn new_uniform(
39 name: String,
40 variables: Vec<String>,
41 cardinalities: &[usize],
42 ) -> Result<Self> {
43 let total_size: usize = cardinalities.iter().product();
44 let uniform_value = 1.0 / total_size as f64;
45 let values = ArrayD::from_elem(cardinalities.to_vec(), uniform_value);
46
47 let factor = Factor::new(name, variables.clone(), values)?;
48 Ok(Self { factor, variables })
49 }
50
51 pub fn from_factor(factor: Factor) -> Self {
53 let variables = factor.variables.clone();
54 Self { factor, variables }
55 }
56}
57
58pub struct ExpectationPropagation {
63 max_iterations: usize,
65 tolerance: f64,
67 damping: f64,
69 min_value: f64,
71}
72
73impl Default for ExpectationPropagation {
74 fn default() -> Self {
75 Self::new(100, 1e-6, 0.0)
76 }
77}
78
79impl ExpectationPropagation {
80 pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
82 Self {
83 max_iterations,
84 tolerance,
85 damping,
86 min_value: 1e-10,
87 }
88 }
89
90 pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
94 let mut sites = self.initialize_sites(graph)?;
96
97 let mut approx = self.compute_global_approximation(graph, &sites)?;
99
100 for iteration in 0..self.max_iterations {
102 let mut max_change: f64 = 0.0;
103
104 for (factor_idx, factor) in graph.factors().enumerate() {
106 let cavity = self.compute_cavity(&approx, &sites[factor_idx])?;
108
109 let tilted = self.compute_tilted(&cavity, factor)?;
111
112 let new_site = self.moment_match(&cavity, &tilted, &sites[factor_idx])?;
114
115 let damped_site = self.apply_damping(&sites[factor_idx], &new_site)?;
117
118 let change = self.compute_site_change(&sites[factor_idx], &damped_site)?;
120 max_change = max_change.max(change);
121
122 sites[factor_idx] = damped_site;
124 }
125
126 approx = self.compute_global_approximation(graph, &sites)?;
128
129 if max_change < self.tolerance {
131 eprintln!(
132 "EP converged in {} iterations (max change: {:.6})",
133 iteration + 1,
134 max_change
135 );
136 break;
137 }
138
139 if iteration == self.max_iterations - 1 {
140 eprintln!(
141 "EP reached maximum iterations ({}) with max change: {:.6}",
142 self.max_iterations, max_change
143 );
144 }
145 }
146
147 self.extract_marginals(graph, &approx, &sites)
149 }
150
151 fn initialize_sites(&self, graph: &FactorGraph) -> Result<Vec<Site>> {
153 let mut sites = Vec::new();
154
155 for (idx, factor) in graph.factors().enumerate() {
156 let cardinalities: Vec<usize> = factor
157 .variables
158 .iter()
159 .map(|var| graph.get_variable(var).map(|v| v.cardinality).unwrap_or(2))
160 .collect();
161
162 let site = Site::new_uniform(
163 format!("site_{}", idx),
164 factor.variables.clone(),
165 &cardinalities,
166 )?;
167
168 sites.push(site);
169 }
170
171 Ok(sites)
172 }
173
174 fn compute_global_approximation(&self, _graph: &FactorGraph, sites: &[Site]) -> Result<Factor> {
176 if sites.is_empty() {
177 return Err(PgmError::InvalidGraph(
178 "No sites to compute approximation".to_string(),
179 ));
180 }
181
182 let mut result = sites[0].factor.clone();
183
184 for site in sites.iter().skip(1) {
185 result = result.product(&site.factor)?;
186 }
187
188 result.normalize();
190
191 Ok(result)
192 }
193
194 fn compute_cavity(&self, approx: &Factor, site: &Site) -> Result<Factor> {
196 let approx_marginal = if approx.variables == site.variables {
199 approx.clone()
200 } else {
201 approx.marginalize_out_all_except(&site.variables)?
202 };
203
204 let cavity = approx_marginal.divide(&site.factor)?;
206 Ok(cavity)
207 }
208
209 fn compute_tilted(&self, cavity: &Factor, true_factor: &Factor) -> Result<Factor> {
211 let tilted = cavity.product(true_factor)?;
213 Ok(tilted)
214 }
215
216 fn moment_match(&self, cavity: &Factor, tilted: &Factor, _old_site: &Site) -> Result<Site> {
218 let new_factor = tilted.divide(cavity)?;
222
223 let mut stabilized = new_factor.clone();
225 stabilized.values.mapv_inplace(|v| v.max(self.min_value));
226
227 Ok(Site::from_factor(stabilized))
228 }
229
230 fn apply_damping(&self, old_site: &Site, new_site: &Site) -> Result<Site> {
232 if self.damping == 0.0 {
233 return Ok(new_site.clone());
234 }
235
236 let old_values = &old_site.factor.values;
238 let new_values = &new_site.factor.values;
239
240 let damped_values = (1.0 - self.damping) * new_values + self.damping * old_values;
241
242 let damped_factor = Factor::new(
243 new_site.factor.name.clone(),
244 new_site.factor.variables.clone(),
245 damped_values,
246 )?;
247
248 Ok(Site::from_factor(damped_factor))
249 }
250
251 fn compute_site_change(&self, old_site: &Site, new_site: &Site) -> Result<f64> {
253 let diff = &new_site.factor.values - &old_site.factor.values;
255 let change = diff.mapv(|v| v.abs()).sum();
256 Ok(change)
257 }
258
259 fn extract_marginals(
261 &self,
262 graph: &FactorGraph,
263 approx: &Factor,
264 _sites: &[Site],
265 ) -> Result<HashMap<String, ArrayD<f64>>> {
266 let mut marginals = HashMap::new();
267
268 for (var, _) in graph.variables() {
269 let marginal = approx.marginalize_out_all_except(std::slice::from_ref(var))?;
270 let mut normalized = marginal.clone();
271 normalized.normalize();
272 marginals.insert(var.clone(), normalized.values);
273 }
274
275 Ok(marginals)
276 }
277}
278
279#[derive(Debug, Clone)]
283pub struct GaussianSite {
284 pub variable: String,
286 pub precision: f64,
288 pub precision_mean: f64,
290}
291
292impl GaussianSite {
293 pub fn new(variable: String, precision: f64, precision_mean: f64) -> Self {
295 Self {
296 variable,
297 precision,
298 precision_mean,
299 }
300 }
301
302 pub fn uniform(variable: String) -> Self {
304 Self {
305 variable,
306 precision: 0.0,
307 precision_mean: 0.0,
308 }
309 }
310
311 pub fn mean(&self) -> f64 {
313 if self.precision > 1e-10 {
314 self.precision_mean / self.precision
315 } else {
316 0.0
317 }
318 }
319
320 pub fn variance(&self) -> f64 {
322 if self.precision > 1e-10 {
323 1.0 / self.precision
324 } else {
325 f64::INFINITY
326 }
327 }
328
329 pub fn product(&self, other: &GaussianSite) -> Self {
331 Self {
332 variable: self.variable.clone(),
333 precision: self.precision + other.precision,
334 precision_mean: self.precision_mean + other.precision_mean,
335 }
336 }
337
338 pub fn divide(&self, other: &GaussianSite) -> Self {
340 Self {
341 variable: self.variable.clone(),
342 precision: self.precision - other.precision,
343 precision_mean: self.precision_mean - other.precision_mean,
344 }
345 }
346}
347
348#[allow(dead_code)]
350pub struct GaussianEP {
351 max_iterations: usize,
353 tolerance: f64,
355 damping: f64,
357}
358
359impl Default for GaussianEP {
360 fn default() -> Self {
361 Self::new(100, 1e-6, 0.0)
362 }
363}
364
365impl GaussianEP {
366 pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
368 Self {
369 max_iterations,
370 tolerance,
371 damping,
372 }
373 }
374
375 pub fn compute_moments(
382 &self,
383 cavity: &GaussianSite,
384 _true_factor_callback: impl Fn(f64) -> f64,
385 ) -> (f64, f64) {
386 let mean = cavity.mean();
388 let variance = cavity.variance();
389 (mean, variance)
390 }
391
392 pub fn match_moments(
394 &self,
395 cavity: &GaussianSite,
396 tilted_mean: f64,
397 tilted_var: f64,
398 ) -> GaussianSite {
399 let new_precision = 1.0 / tilted_var - cavity.precision;
401 let new_precision_mean = tilted_mean / tilted_var - cavity.precision_mean;
402
403 GaussianSite::new(
404 cavity.variable.clone(),
405 new_precision.max(0.0), new_precision_mean,
407 )
408 }
409
410 pub fn damp_site(&self, old_site: &GaussianSite, new_site: &GaussianSite) -> GaussianSite {
412 if self.damping == 0.0 {
413 return new_site.clone();
414 }
415
416 let damped_precision =
417 (1.0 - self.damping) * new_site.precision + self.damping * old_site.precision;
418 let damped_precision_mean =
419 (1.0 - self.damping) * new_site.precision_mean + self.damping * old_site.precision_mean;
420
421 GaussianSite::new(
422 new_site.variable.clone(),
423 damped_precision,
424 damped_precision_mean,
425 )
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use approx::assert_abs_diff_eq;
433 use scirs2_core::ndarray::Array;
434
435 #[test]
436 fn test_site_creation() {
437 let site = Site::new_uniform("test_site".to_string(), vec!["X".to_string()], &[2])
438 .expect("unwrap");
439
440 assert_eq!(site.variables.len(), 1);
441 assert_eq!(site.factor.variables[0], "X");
442
443 let sum: f64 = site.factor.values.sum();
445 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
446 }
447
448 #[test]
449 fn test_gaussian_site_moments() {
450 let site = GaussianSite::new("X".to_string(), 2.0, 4.0);
451
452 assert_abs_diff_eq!(site.mean(), 2.0, epsilon = 1e-10);
454
455 assert_abs_diff_eq!(site.variance(), 0.5, epsilon = 1e-10);
457 }
458
459 #[test]
460 fn test_gaussian_site_product() {
461 let site1 = GaussianSite::new("X".to_string(), 2.0, 4.0);
462 let site2 = GaussianSite::new("X".to_string(), 3.0, 6.0);
463
464 let product = site1.product(&site2);
465
466 assert_abs_diff_eq!(product.precision, 5.0, epsilon = 1e-10);
468
469 assert_abs_diff_eq!(product.precision_mean, 10.0, epsilon = 1e-10);
471 }
472
473 #[test]
474 fn test_gaussian_site_divide() {
475 let site1 = GaussianSite::new("X".to_string(), 5.0, 10.0);
476 let site2 = GaussianSite::new("X".to_string(), 2.0, 4.0);
477
478 let quotient = site1.divide(&site2);
479
480 assert_abs_diff_eq!(quotient.precision, 3.0, epsilon = 1e-10);
482
483 assert_abs_diff_eq!(quotient.precision_mean, 6.0, epsilon = 1e-10);
485 }
486
487 #[test]
488 fn test_ep_initialization() {
489 let ep = ExpectationPropagation::new(50, 1e-5, 0.5);
490 assert_eq!(ep.max_iterations, 50);
491 assert_abs_diff_eq!(ep.tolerance, 1e-5, epsilon = 1e-10);
492 assert_abs_diff_eq!(ep.damping, 0.5, epsilon = 1e-10);
493 }
494
495 #[test]
496 fn test_ep_simple_graph() {
497 use crate::FactorGraph;
498
499 let mut graph = FactorGraph::new();
501 graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
502
503 let values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
505 .expect("unwrap")
506 .into_dyn();
507 let factor =
508 Factor::new("P(X)".to_string(), vec!["X".to_string()], values).expect("unwrap");
509 graph.add_factor(factor).expect("unwrap");
510
511 let ep = ExpectationPropagation::default();
513 let marginals = ep.run(&graph).expect("unwrap");
514
515 assert!(marginals.contains_key("X"));
517
518 let marginal = &marginals["X"];
519 assert_eq!(marginal.ndim(), 1);
520 assert_eq!(marginal.len(), 2);
521
522 let sum: f64 = marginal.sum();
524 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
525 }
526
527 #[test]
528 fn test_gaussian_ep_moment_matching() {
529 let gep = GaussianEP::default();
530
531 let cavity = GaussianSite::new("X".to_string(), 1.0, 0.0);
533
534 let tilted_mean = 2.0;
536 let tilted_var = 0.5;
537
538 let new_site = gep.match_moments(&cavity, tilted_mean, tilted_var);
540
541 let product = cavity.product(&new_site);
543
544 assert_abs_diff_eq!(product.mean(), tilted_mean, epsilon = 1e-6);
545 assert_abs_diff_eq!(product.variance(), tilted_var, epsilon = 1e-6);
546 }
547
548 #[test]
549 fn test_ep_two_factor_graph() {
550 use crate::FactorGraph;
551
552 let mut graph = FactorGraph::new();
554 graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
555 graph.add_variable_with_card("Y".to_string(), "Binary".to_string(), 2);
556
557 let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
559 .expect("unwrap")
560 .into_dyn();
561 let px = Factor::new("P(X)".to_string(), vec!["X".to_string()], px_values).expect("unwrap");
562 graph.add_factor(px).expect("unwrap");
563
564 let pyx_values = Array::from_shape_vec(
566 vec![2, 2],
567 vec![0.8, 0.2, 0.3, 0.7], )
569 .expect("unwrap")
570 .into_dyn();
571 let pyx = Factor::new(
572 "P(Y|X)".to_string(),
573 vec!["X".to_string(), "Y".to_string()],
574 pyx_values,
575 )
576 .expect("unwrap");
577 graph.add_factor(pyx).expect("unwrap");
578
579 let ep = ExpectationPropagation::new(100, 1e-6, 0.0);
581 let marginals = ep.run(&graph).expect("unwrap");
582
583 assert!(marginals.contains_key("X"));
585 assert!(marginals.contains_key("Y"));
586
587 let sum_x: f64 = marginals["X"].sum();
589 let sum_y: f64 = marginals["Y"].sum();
590 assert_abs_diff_eq!(sum_x, 1.0, epsilon = 1e-6);
591 assert_abs_diff_eq!(sum_y, 1.0, epsilon = 1e-6);
592 }
593}