1use scirs2_core::ndarray::ArrayD;
4use std::collections::HashMap;
5
6use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8use crate::graph::FactorGraph;
9
10pub trait MessagePassingAlgorithm: Send + Sync {
12 fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>>;
14
15 fn name(&self) -> &str;
17}
18
19#[derive(Clone, Debug)]
21struct MessageStore {
22 var_to_factor: HashMap<(String, String), Factor>,
24 factor_to_var: HashMap<(String, String), Factor>,
26}
27
28impl MessageStore {
29 fn new() -> Self {
30 Self {
31 var_to_factor: HashMap::new(),
32 factor_to_var: HashMap::new(),
33 }
34 }
35
36 fn get_var_to_factor(&self, var: &str, factor: &str) -> Option<&Factor> {
37 self.var_to_factor
38 .get(&(var.to_string(), factor.to_string()))
39 }
40
41 fn set_var_to_factor(&mut self, var: String, factor: String, message: Factor) {
42 self.var_to_factor.insert((var, factor), message);
43 }
44
45 fn get_factor_to_var(&self, factor: &str, var: &str) -> Option<&Factor> {
46 self.factor_to_var
47 .get(&(factor.to_string(), var.to_string()))
48 }
49
50 fn set_factor_to_var(&mut self, factor: String, var: String, message: Factor) {
51 self.factor_to_var.insert((factor, var), message);
52 }
53}
54
55#[derive(Clone, Debug)]
57pub struct ConvergenceStats {
58 pub iterations: usize,
60 pub max_delta: f64,
62 pub converged: bool,
64}
65
66pub struct SumProductAlgorithm {
71 pub max_iterations: usize,
73 pub tolerance: f64,
75 pub damping: f64,
77}
78
79impl Default for SumProductAlgorithm {
80 fn default() -> Self {
81 Self {
82 max_iterations: 100,
83 tolerance: 1e-6,
84 damping: 0.0,
85 }
86 }
87}
88
89impl SumProductAlgorithm {
90 pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
92 Self {
93 max_iterations,
94 tolerance,
95 damping: damping.clamp(0.0, 1.0),
96 }
97 }
98
99 fn compute_var_to_factor_message(
103 &self,
104 graph: &FactorGraph,
105 messages: &MessageStore,
106 var: &str,
107 target_factor: &str,
108 ) -> Result<Factor> {
109 let var_node = graph
110 .get_variable(var)
111 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
112
113 let adjacent_factors = graph
115 .get_adjacent_factors(var)
116 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
117
118 let other_factors: Vec<&String> = adjacent_factors
119 .iter()
120 .filter(|&f| f != target_factor)
121 .collect();
122
123 let mut message = Factor::uniform(
125 format!("msg_{}_{}", var, target_factor),
126 vec![var.to_string()],
127 var_node.cardinality,
128 );
129
130 for &factor_id in &other_factors {
132 if let Some(incoming) = messages.get_factor_to_var(factor_id, var) {
133 message = message.product(incoming)?;
134 }
135 }
136
137 message.normalize();
139
140 Ok(message)
141 }
142
143 fn compute_factor_to_var_message(
147 &self,
148 graph: &FactorGraph,
149 messages: &MessageStore,
150 factor_id: &str,
151 target_var: &str,
152 ) -> Result<Factor> {
153 let factor = graph
154 .get_factor(factor_id)
155 .ok_or_else(|| PgmError::FactorNotFound(factor_id.to_string()))?;
156
157 let mut message = factor.clone();
159
160 let other_vars: Vec<&String> = factor
162 .variables
163 .iter()
164 .filter(|&v| v != target_var)
165 .collect();
166
167 for &var in &other_vars {
169 if let Some(incoming) = messages.get_var_to_factor(var, factor_id) {
170 message = message.product(incoming)?;
171 }
172 }
173
174 for &var in &other_vars {
176 message = message.marginalize_out(var)?;
177 }
178
179 message.normalize();
181
182 Ok(message)
183 }
184
185 fn compute_beliefs(
187 &self,
188 graph: &FactorGraph,
189 messages: &MessageStore,
190 ) -> Result<HashMap<String, ArrayD<f64>>> {
191 let mut beliefs = HashMap::new();
192
193 for var_name in graph.variable_names() {
195 if let Some(var_node) = graph.get_variable(var_name) {
196 let mut belief = Factor::uniform(
197 format!("belief_{}", var_name),
198 vec![var_name.clone()],
199 var_node.cardinality,
200 );
201
202 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
204 for factor_id in adjacent_factors {
205 if let Some(message) = messages.get_factor_to_var(factor_id, var_name) {
206 belief = belief.product(message)?;
207 }
208 }
209 }
210
211 belief.normalize();
212 beliefs.insert(var_name.clone(), belief.values);
213 }
214 }
215
216 Ok(beliefs)
217 }
218
219 fn check_convergence(
221 &self,
222 old_messages: &MessageStore,
223 new_messages: &MessageStore,
224 ) -> (bool, f64) {
225 let mut max_delta: f64 = 0.0;
226
227 for ((factor, var), new_msg) in &new_messages.factor_to_var {
229 if let Some(old_msg) = old_messages.get_factor_to_var(factor, var) {
230 let delta: f64 = (&new_msg.values - &old_msg.values)
231 .mapv(|x| x.abs())
232 .iter()
233 .fold(0.0_f64, |acc, &x| acc.max(x));
234 max_delta = max_delta.max(delta);
235 }
236 }
237
238 (max_delta < self.tolerance, max_delta)
239 }
240
241 fn apply_damping(&self, old_msg: &Factor, new_msg: &Factor) -> Result<Factor> {
243 if self.damping == 0.0 {
244 return Ok(new_msg.clone());
245 }
246
247 let damped_values = &new_msg.values * (1.0 - self.damping) + &old_msg.values * self.damping;
249
250 Ok(Factor {
251 name: new_msg.name.clone(),
252 variables: new_msg.variables.clone(),
253 values: damped_values,
254 })
255 }
256}
257
258impl MessagePassingAlgorithm for SumProductAlgorithm {
259 fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
260 let mut messages = MessageStore::new();
261
262 for var_name in graph.variable_names() {
264 if let Some(var_node) = graph.get_variable(var_name) {
265 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
266 for factor_id in adjacent_factors {
267 let init_msg = Factor::uniform(
268 format!("init_{}_{}", var_name, factor_id),
269 vec![var_name.clone()],
270 var_node.cardinality,
271 );
272 messages.set_var_to_factor(var_name.clone(), factor_id.clone(), init_msg);
273 }
274 }
275 }
276 }
277
278 for iteration in 0..self.max_iterations {
280 let old_messages = messages.clone();
281
282 for var_name in graph.variable_names() {
284 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
285 for factor_id in adjacent_factors {
286 let new_msg = self
287 .compute_var_to_factor_message(graph, &messages, var_name, factor_id)?;
288 messages.set_var_to_factor(var_name.clone(), factor_id.clone(), new_msg);
289 }
290 }
291 }
292
293 for factor_id in graph.factor_ids() {
295 if let Some(adjacent_vars) = graph.get_adjacent_variables(factor_id) {
296 for var in adjacent_vars {
297 let new_msg =
298 self.compute_factor_to_var_message(graph, &messages, factor_id, var)?;
299
300 let damped_msg =
302 if let Some(old_msg) = old_messages.get_factor_to_var(factor_id, var) {
303 self.apply_damping(old_msg, &new_msg)?
304 } else {
305 new_msg
306 };
307
308 messages.set_factor_to_var(factor_id.clone(), var.clone(), damped_msg);
309 }
310 }
311 }
312
313 let (converged, max_delta) = self.check_convergence(&old_messages, &messages);
315
316 if converged {
317 return self.compute_beliefs(graph, &messages);
319 }
320
321 if iteration == self.max_iterations - 1 {
323 return Err(PgmError::ConvergenceFailure(format!(
324 "Failed to converge after {} iterations (max_delta={})",
325 self.max_iterations, max_delta
326 )));
327 }
328 }
329
330 self.compute_beliefs(graph, &messages)
332 }
333
334 fn name(&self) -> &str {
335 "SumProduct"
336 }
337}
338
339pub struct MaxProductAlgorithm {
343 pub max_iterations: usize,
345 pub tolerance: f64,
347}
348
349impl Default for MaxProductAlgorithm {
350 fn default() -> Self {
351 Self {
352 max_iterations: 100,
353 tolerance: 1e-6,
354 }
355 }
356}
357
358impl MaxProductAlgorithm {
359 pub fn new(max_iterations: usize, tolerance: f64) -> Self {
361 Self {
362 max_iterations,
363 tolerance,
364 }
365 }
366}
367
368impl MaxProductAlgorithm {
369 fn compute_var_to_factor_message(
371 &self,
372 graph: &FactorGraph,
373 messages: &MessageStore,
374 var: &str,
375 target_factor: &str,
376 ) -> Result<Factor> {
377 let var_node = graph
378 .get_variable(var)
379 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
380
381 let adjacent_factors = graph
382 .get_adjacent_factors(var)
383 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
384
385 let other_factors: Vec<&String> = adjacent_factors
386 .iter()
387 .filter(|&f| f != target_factor)
388 .collect();
389
390 let mut message = Factor::uniform(
391 format!("msg_{}_{}", var, target_factor),
392 vec![var.to_string()],
393 var_node.cardinality,
394 );
395
396 for &factor_id in &other_factors {
397 if let Some(incoming) = messages.get_factor_to_var(factor_id, var) {
398 message = message.product(incoming)?;
399 }
400 }
401
402 message.normalize();
403 Ok(message)
404 }
405
406 fn compute_factor_to_var_message(
408 &self,
409 graph: &FactorGraph,
410 messages: &MessageStore,
411 factor_id: &str,
412 target_var: &str,
413 ) -> Result<Factor> {
414 let factor = graph
415 .get_factor(factor_id)
416 .ok_or_else(|| PgmError::FactorNotFound(factor_id.to_string()))?;
417
418 let mut message = factor.clone();
419
420 let other_vars: Vec<&String> = factor
421 .variables
422 .iter()
423 .filter(|&v| v != target_var)
424 .collect();
425
426 for &var in &other_vars {
427 if let Some(incoming) = messages.get_var_to_factor(var, factor_id) {
428 message = message.product(incoming)?;
429 }
430 }
431
432 for &var in &other_vars {
434 message = message.maximize_out(var)?;
435 }
436
437 message.normalize();
438 Ok(message)
439 }
440
441 fn compute_beliefs(
443 &self,
444 graph: &FactorGraph,
445 messages: &MessageStore,
446 ) -> Result<HashMap<String, ArrayD<f64>>> {
447 let mut beliefs = HashMap::new();
448
449 for var_name in graph.variable_names() {
450 if let Some(var_node) = graph.get_variable(var_name) {
451 let mut belief = Factor::uniform(
452 format!("belief_{}", var_name),
453 vec![var_name.clone()],
454 var_node.cardinality,
455 );
456
457 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
458 for factor_id in adjacent_factors {
459 if let Some(message) = messages.get_factor_to_var(factor_id, var_name) {
460 belief = belief.product(message)?;
461 }
462 }
463 }
464
465 belief.normalize();
466 beliefs.insert(var_name.clone(), belief.values);
467 }
468 }
469
470 Ok(beliefs)
471 }
472}
473
474impl MessagePassingAlgorithm for MaxProductAlgorithm {
475 fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
476 let mut messages = MessageStore::new();
477
478 for var_name in graph.variable_names() {
480 if let Some(var_node) = graph.get_variable(var_name) {
481 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
482 for factor_id in adjacent_factors {
483 let init_msg = Factor::uniform(
484 format!("init_{}_{}", var_name, factor_id),
485 vec![var_name.clone()],
486 var_node.cardinality,
487 );
488 messages.set_var_to_factor(var_name.clone(), factor_id.clone(), init_msg);
489 }
490 }
491 }
492 }
493
494 for _iteration in 0..self.max_iterations {
496 let _old_messages = messages.clone();
497
498 for var_name in graph.variable_names() {
500 if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
501 for factor_id in adjacent_factors {
502 let new_msg = self
503 .compute_var_to_factor_message(graph, &messages, var_name, factor_id)?;
504 messages.set_var_to_factor(var_name.clone(), factor_id.clone(), new_msg);
505 }
506 }
507 }
508
509 for factor_id in graph.factor_ids() {
511 if let Some(adjacent_vars) = graph.get_adjacent_variables(factor_id) {
512 for var in adjacent_vars {
513 let new_msg =
514 self.compute_factor_to_var_message(graph, &messages, factor_id, var)?;
515 messages.set_factor_to_var(factor_id.clone(), var.clone(), new_msg);
516 }
517 }
518 }
519 }
520
521 self.compute_beliefs(graph, &messages)
522 }
523
524 fn name(&self) -> &str {
525 "MaxProduct"
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532 use crate::graph::FactorGraph;
533 use approx::assert_abs_diff_eq;
534
535 #[test]
536 fn test_sum_product_algorithm() {
537 let algorithm = SumProductAlgorithm::default();
538 assert_eq!(algorithm.name(), "SumProduct");
539
540 let mut graph = FactorGraph::new();
541 graph.add_variable("var_0".to_string(), "D1".to_string());
542
543 let result = algorithm.run(&graph);
544 assert!(result.is_ok());
545 }
546
547 #[test]
548 fn test_max_product_algorithm() {
549 let algorithm = MaxProductAlgorithm::default();
550 assert_eq!(algorithm.name(), "MaxProduct");
551
552 let mut graph = FactorGraph::new();
553 graph.add_variable("var_0".to_string(), "D1".to_string());
554
555 let result = algorithm.run(&graph);
556 assert!(result.is_ok());
557 }
558
559 #[test]
560 fn test_message_store() {
561 let mut store = MessageStore::new();
562 let msg = Factor::uniform("test".to_string(), vec!["x".to_string()], 2);
563
564 store.set_var_to_factor("x".to_string(), "f1".to_string(), msg.clone());
565 assert!(store.get_var_to_factor("x", "f1").is_some());
566
567 store.set_factor_to_var("f1".to_string(), "x".to_string(), msg.clone());
568 assert!(store.get_factor_to_var("f1", "x").is_some());
569 }
570
571 #[test]
572 fn test_sum_product_with_damping() {
573 let algorithm = SumProductAlgorithm::new(50, 1e-5, 0.5);
574 assert_eq!(algorithm.damping, 0.5);
575
576 let mut graph = FactorGraph::new();
577 graph.add_variable("var_0".to_string(), "D1".to_string());
578
579 let result = algorithm.run(&graph);
580 assert!(result.is_ok());
581 }
582
583 #[test]
584 fn test_belief_normalization() {
585 let mut graph = FactorGraph::new();
586 graph.add_variable("var_0".to_string(), "D1".to_string());
587
588 let algorithm = SumProductAlgorithm::default();
589 let beliefs = algorithm.run(&graph).unwrap();
590
591 if let Some(belief) = beliefs.get("var_0") {
592 let sum: f64 = belief.iter().sum();
593 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
594 }
595 }
596}