Skip to main content

tensorlogic_quantrs_hooks/
message_passing.rs

1//! Message passing algorithms for PGM inference.
2
3use scirs2_core::ndarray::ArrayD;
4use std::collections::HashMap;
5
6use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8use crate::graph::FactorGraph;
9
10/// Trait for message passing algorithms.
11pub trait MessagePassingAlgorithm: Send + Sync {
12    /// Run the algorithm on a factor graph.
13    fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>>;
14
15    /// Get algorithm name.
16    fn name(&self) -> &str;
17}
18
19/// Message storage for belief propagation.
20#[derive(Clone, Debug)]
21struct MessageStore {
22    /// Messages from variables to factors: (var, factor) -> message
23    var_to_factor: HashMap<(String, String), Factor>,
24    /// Messages from factors to variables: (factor, var) -> message
25    factor_to_var: HashMap<(String, String), Factor>,
26}
27
28impl MessageStore {
29    fn new() -> Self {
30        Self {
31            var_to_factor: HashMap::new(),
32            factor_to_var: HashMap::new(),
33        }
34    }
35
36    fn get_var_to_factor(&self, var: &str, factor: &str) -> Option<&Factor> {
37        self.var_to_factor
38            .get(&(var.to_string(), factor.to_string()))
39    }
40
41    fn set_var_to_factor(&mut self, var: String, factor: String, message: Factor) {
42        self.var_to_factor.insert((var, factor), message);
43    }
44
45    fn get_factor_to_var(&self, factor: &str, var: &str) -> Option<&Factor> {
46        self.factor_to_var
47            .get(&(factor.to_string(), var.to_string()))
48    }
49
50    fn set_factor_to_var(&mut self, factor: String, var: String, message: Factor) {
51        self.factor_to_var.insert((factor, var), message);
52    }
53}
54
55/// Convergence statistics for belief propagation.
56#[derive(Clone, Debug)]
57pub struct ConvergenceStats {
58    /// Number of iterations performed
59    pub iterations: usize,
60    /// Maximum message difference in last iteration
61    pub max_delta: f64,
62    /// Whether convergence was achieved
63    pub converged: bool,
64}
65
66/// Sum-product algorithm (belief propagation).
67///
68/// Computes exact marginal probabilities for tree-structured graphs.
69/// For loopy graphs, runs loopy belief propagation with optional damping.
70pub struct SumProductAlgorithm {
71    /// Maximum iterations for loopy graphs
72    pub max_iterations: usize,
73    /// Convergence threshold
74    pub tolerance: f64,
75    /// Damping factor (0.0 = no damping, 1.0 = full damping)
76    pub damping: f64,
77}
78
79impl Default for SumProductAlgorithm {
80    fn default() -> Self {
81        Self {
82            max_iterations: 100,
83            tolerance: 1e-6,
84            damping: 0.0,
85        }
86    }
87}
88
89impl SumProductAlgorithm {
90    /// Create with custom parameters.
91    pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
92        Self {
93            max_iterations,
94            tolerance,
95            damping: damping.clamp(0.0, 1.0),
96        }
97    }
98
99    /// Compute variable-to-factor message.
100    ///
101    /// μ(x→f) = ∏_{g∈N(x)\f} μ(g→x)
102    fn compute_var_to_factor_message(
103        &self,
104        graph: &FactorGraph,
105        messages: &MessageStore,
106        var: &str,
107        target_factor: &str,
108    ) -> Result<Factor> {
109        let var_node = graph
110            .get_variable(var)
111            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
112
113        // Get all factors connected to this variable except target
114        let adjacent_factors = graph
115            .get_adjacent_factors(var)
116            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
117
118        let other_factors: Vec<&String> = adjacent_factors
119            .iter()
120            .filter(|&f| f != target_factor)
121            .collect();
122
123        // Start with uniform message
124        let mut message = Factor::uniform(
125            format!("msg_{}_{}", var, target_factor),
126            vec![var.to_string()],
127            var_node.cardinality,
128        );
129
130        // Multiply incoming messages from other factors
131        for &factor_id in &other_factors {
132            if let Some(incoming) = messages.get_factor_to_var(factor_id, var) {
133                message = message.product(incoming)?;
134            }
135        }
136
137        // Normalize
138        message.normalize();
139
140        Ok(message)
141    }
142
143    /// Compute factor-to-variable message.
144    ///
145    /// μ(f→x) = ∑_{~x} [φ(x) ∏_{y∈N(f)\x} μ(y→f)]
146    fn compute_factor_to_var_message(
147        &self,
148        graph: &FactorGraph,
149        messages: &MessageStore,
150        factor_id: &str,
151        target_var: &str,
152    ) -> Result<Factor> {
153        let factor = graph
154            .get_factor(factor_id)
155            .ok_or_else(|| PgmError::FactorNotFound(factor_id.to_string()))?;
156
157        // Start with factor itself
158        let mut message = factor.clone();
159
160        // Get variables in factor except target
161        let other_vars: Vec<&String> = factor
162            .variables
163            .iter()
164            .filter(|&v| v != target_var)
165            .collect();
166
167        // Multiply incoming messages from other variables
168        for &var in &other_vars {
169            if let Some(incoming) = messages.get_var_to_factor(var, factor_id) {
170                message = message.product(incoming)?;
171            }
172        }
173
174        // Marginalize out all variables except target
175        for &var in &other_vars {
176            message = message.marginalize_out(var)?;
177        }
178
179        // Normalize
180        message.normalize();
181
182        Ok(message)
183    }
184
185    /// Compute beliefs (marginals) from messages.
186    fn compute_beliefs(
187        &self,
188        graph: &FactorGraph,
189        messages: &MessageStore,
190    ) -> Result<HashMap<String, ArrayD<f64>>> {
191        let mut beliefs = HashMap::new();
192
193        // For each variable, multiply all incoming factor messages
194        for var_name in graph.variable_names() {
195            if let Some(var_node) = graph.get_variable(var_name) {
196                let mut belief = Factor::uniform(
197                    format!("belief_{}", var_name),
198                    vec![var_name.clone()],
199                    var_node.cardinality,
200                );
201
202                // Get all adjacent factors
203                if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
204                    for factor_id in adjacent_factors {
205                        if let Some(message) = messages.get_factor_to_var(factor_id, var_name) {
206                            belief = belief.product(message)?;
207                        }
208                    }
209                }
210
211                belief.normalize();
212                beliefs.insert(var_name.clone(), belief.values);
213            }
214        }
215
216        Ok(beliefs)
217    }
218
219    /// Check convergence by comparing message differences.
220    fn check_convergence(
221        &self,
222        old_messages: &MessageStore,
223        new_messages: &MessageStore,
224    ) -> (bool, f64) {
225        let mut max_delta: f64 = 0.0;
226
227        // Check factor-to-var messages
228        for ((factor, var), new_msg) in &new_messages.factor_to_var {
229            if let Some(old_msg) = old_messages.get_factor_to_var(factor, var) {
230                let delta: f64 = (&new_msg.values - &old_msg.values)
231                    .mapv(|x| x.abs())
232                    .iter()
233                    .fold(0.0_f64, |acc, &x| acc.max(x));
234                max_delta = max_delta.max(delta);
235            }
236        }
237
238        (max_delta < self.tolerance, max_delta)
239    }
240
241    /// Apply damping to messages.
242    fn apply_damping(&self, old_msg: &Factor, new_msg: &Factor) -> Result<Factor> {
243        if self.damping == 0.0 {
244            return Ok(new_msg.clone());
245        }
246
247        // Damped message = (1 - λ) * new + λ * old
248        let damped_values = &new_msg.values * (1.0 - self.damping) + &old_msg.values * self.damping;
249
250        Ok(Factor {
251            name: new_msg.name.clone(),
252            variables: new_msg.variables.clone(),
253            values: damped_values,
254        })
255    }
256}
257
258impl MessagePassingAlgorithm for SumProductAlgorithm {
259    fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
260        let mut messages = MessageStore::new();
261
262        // Initialize all messages to uniform
263        for var_name in graph.variable_names() {
264            if let Some(var_node) = graph.get_variable(var_name) {
265                if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
266                    for factor_id in adjacent_factors {
267                        let init_msg = Factor::uniform(
268                            format!("init_{}_{}", var_name, factor_id),
269                            vec![var_name.clone()],
270                            var_node.cardinality,
271                        );
272                        messages.set_var_to_factor(var_name.clone(), factor_id.clone(), init_msg);
273                    }
274                }
275            }
276        }
277
278        // Iterative message passing
279        for iteration in 0..self.max_iterations {
280            let old_messages = messages.clone();
281
282            // Update all variable-to-factor messages
283            for var_name in graph.variable_names() {
284                if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
285                    for factor_id in adjacent_factors {
286                        let new_msg = self
287                            .compute_var_to_factor_message(graph, &messages, var_name, factor_id)?;
288                        messages.set_var_to_factor(var_name.clone(), factor_id.clone(), new_msg);
289                    }
290                }
291            }
292
293            // Update all factor-to-variable messages
294            for factor_id in graph.factor_ids() {
295                if let Some(adjacent_vars) = graph.get_adjacent_variables(factor_id) {
296                    for var in adjacent_vars {
297                        let new_msg =
298                            self.compute_factor_to_var_message(graph, &messages, factor_id, var)?;
299
300                        // Apply damping if enabled
301                        let damped_msg =
302                            if let Some(old_msg) = old_messages.get_factor_to_var(factor_id, var) {
303                                self.apply_damping(old_msg, &new_msg)?
304                            } else {
305                                new_msg
306                            };
307
308                        messages.set_factor_to_var(factor_id.clone(), var.clone(), damped_msg);
309                    }
310                }
311            }
312
313            // Check convergence
314            let (converged, max_delta) = self.check_convergence(&old_messages, &messages);
315
316            if converged {
317                // Compute and return beliefs
318                return self.compute_beliefs(graph, &messages);
319            }
320
321            // Prevent infinite loop
322            if iteration == self.max_iterations - 1 {
323                return Err(PgmError::ConvergenceFailure(format!(
324                    "Failed to converge after {} iterations (max_delta={})",
325                    self.max_iterations, max_delta
326                )));
327            }
328        }
329
330        // Compute beliefs even if not converged
331        self.compute_beliefs(graph, &messages)
332    }
333
334    fn name(&self) -> &str {
335        "SumProduct"
336    }
337}
338
339/// Max-product algorithm (MAP inference).
340///
341/// Computes the most likely assignment to all variables.
342pub struct MaxProductAlgorithm {
343    /// Maximum iterations
344    pub max_iterations: usize,
345    /// Convergence threshold
346    pub tolerance: f64,
347}
348
349impl Default for MaxProductAlgorithm {
350    fn default() -> Self {
351        Self {
352            max_iterations: 100,
353            tolerance: 1e-6,
354        }
355    }
356}
357
358impl MaxProductAlgorithm {
359    /// Create with custom parameters.
360    pub fn new(max_iterations: usize, tolerance: f64) -> Self {
361        Self {
362            max_iterations,
363            tolerance,
364        }
365    }
366}
367
368impl MaxProductAlgorithm {
369    /// Compute variable-to-factor message (same as sum-product).
370    fn compute_var_to_factor_message(
371        &self,
372        graph: &FactorGraph,
373        messages: &MessageStore,
374        var: &str,
375        target_factor: &str,
376    ) -> Result<Factor> {
377        let var_node = graph
378            .get_variable(var)
379            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
380
381        let adjacent_factors = graph
382            .get_adjacent_factors(var)
383            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
384
385        let other_factors: Vec<&String> = adjacent_factors
386            .iter()
387            .filter(|&f| f != target_factor)
388            .collect();
389
390        let mut message = Factor::uniform(
391            format!("msg_{}_{}", var, target_factor),
392            vec![var.to_string()],
393            var_node.cardinality,
394        );
395
396        for &factor_id in &other_factors {
397            if let Some(incoming) = messages.get_factor_to_var(factor_id, var) {
398                message = message.product(incoming)?;
399            }
400        }
401
402        message.normalize();
403        Ok(message)
404    }
405
406    /// Compute factor-to-variable message using MAX instead of SUM.
407    fn compute_factor_to_var_message(
408        &self,
409        graph: &FactorGraph,
410        messages: &MessageStore,
411        factor_id: &str,
412        target_var: &str,
413    ) -> Result<Factor> {
414        let factor = graph
415            .get_factor(factor_id)
416            .ok_or_else(|| PgmError::FactorNotFound(factor_id.to_string()))?;
417
418        let mut message = factor.clone();
419
420        let other_vars: Vec<&String> = factor
421            .variables
422            .iter()
423            .filter(|&v| v != target_var)
424            .collect();
425
426        for &var in &other_vars {
427            if let Some(incoming) = messages.get_var_to_factor(var, factor_id) {
428                message = message.product(incoming)?;
429            }
430        }
431
432        // Use MAX instead of SUM for marginalization
433        for &var in &other_vars {
434            message = message.maximize_out(var)?;
435        }
436
437        message.normalize();
438        Ok(message)
439    }
440
441    /// Compute beliefs using max-product messages.
442    fn compute_beliefs(
443        &self,
444        graph: &FactorGraph,
445        messages: &MessageStore,
446    ) -> Result<HashMap<String, ArrayD<f64>>> {
447        let mut beliefs = HashMap::new();
448
449        for var_name in graph.variable_names() {
450            if let Some(var_node) = graph.get_variable(var_name) {
451                let mut belief = Factor::uniform(
452                    format!("belief_{}", var_name),
453                    vec![var_name.clone()],
454                    var_node.cardinality,
455                );
456
457                if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
458                    for factor_id in adjacent_factors {
459                        if let Some(message) = messages.get_factor_to_var(factor_id, var_name) {
460                            belief = belief.product(message)?;
461                        }
462                    }
463                }
464
465                belief.normalize();
466                beliefs.insert(var_name.clone(), belief.values);
467            }
468        }
469
470        Ok(beliefs)
471    }
472}
473
474impl MessagePassingAlgorithm for MaxProductAlgorithm {
475    fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
476        let mut messages = MessageStore::new();
477
478        // Initialize messages
479        for var_name in graph.variable_names() {
480            if let Some(var_node) = graph.get_variable(var_name) {
481                if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
482                    for factor_id in adjacent_factors {
483                        let init_msg = Factor::uniform(
484                            format!("init_{}_{}", var_name, factor_id),
485                            vec![var_name.clone()],
486                            var_node.cardinality,
487                        );
488                        messages.set_var_to_factor(var_name.clone(), factor_id.clone(), init_msg);
489                    }
490                }
491            }
492        }
493
494        // Iterative message passing
495        for _iteration in 0..self.max_iterations {
496            let _old_messages = messages.clone();
497
498            // Update variable-to-factor messages
499            for var_name in graph.variable_names() {
500                if let Some(adjacent_factors) = graph.get_adjacent_factors(var_name) {
501                    for factor_id in adjacent_factors {
502                        let new_msg = self
503                            .compute_var_to_factor_message(graph, &messages, var_name, factor_id)?;
504                        messages.set_var_to_factor(var_name.clone(), factor_id.clone(), new_msg);
505                    }
506                }
507            }
508
509            // Update factor-to-variable messages
510            for factor_id in graph.factor_ids() {
511                if let Some(adjacent_vars) = graph.get_adjacent_variables(factor_id) {
512                    for var in adjacent_vars {
513                        let new_msg =
514                            self.compute_factor_to_var_message(graph, &messages, factor_id, var)?;
515                        messages.set_factor_to_var(factor_id.clone(), var.clone(), new_msg);
516                    }
517                }
518            }
519        }
520
521        self.compute_beliefs(graph, &messages)
522    }
523
524    fn name(&self) -> &str {
525        "MaxProduct"
526    }
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use crate::graph::FactorGraph;
533    use approx::assert_abs_diff_eq;
534
535    #[test]
536    fn test_sum_product_algorithm() {
537        let algorithm = SumProductAlgorithm::default();
538        assert_eq!(algorithm.name(), "SumProduct");
539
540        let mut graph = FactorGraph::new();
541        graph.add_variable("var_0".to_string(), "D1".to_string());
542
543        let result = algorithm.run(&graph);
544        assert!(result.is_ok());
545    }
546
547    #[test]
548    fn test_max_product_algorithm() {
549        let algorithm = MaxProductAlgorithm::default();
550        assert_eq!(algorithm.name(), "MaxProduct");
551
552        let mut graph = FactorGraph::new();
553        graph.add_variable("var_0".to_string(), "D1".to_string());
554
555        let result = algorithm.run(&graph);
556        assert!(result.is_ok());
557    }
558
559    #[test]
560    fn test_message_store() {
561        let mut store = MessageStore::new();
562        let msg = Factor::uniform("test".to_string(), vec!["x".to_string()], 2);
563
564        store.set_var_to_factor("x".to_string(), "f1".to_string(), msg.clone());
565        assert!(store.get_var_to_factor("x", "f1").is_some());
566
567        store.set_factor_to_var("f1".to_string(), "x".to_string(), msg.clone());
568        assert!(store.get_factor_to_var("f1", "x").is_some());
569    }
570
571    #[test]
572    fn test_sum_product_with_damping() {
573        let algorithm = SumProductAlgorithm::new(50, 1e-5, 0.5);
574        assert_eq!(algorithm.damping, 0.5);
575
576        let mut graph = FactorGraph::new();
577        graph.add_variable("var_0".to_string(), "D1".to_string());
578
579        let result = algorithm.run(&graph);
580        assert!(result.is_ok());
581    }
582
583    #[test]
584    fn test_belief_normalization() {
585        let mut graph = FactorGraph::new();
586        graph.add_variable("var_0".to_string(), "D1".to_string());
587
588        let algorithm = SumProductAlgorithm::default();
589        let beliefs = algorithm.run(&graph).unwrap();
590
591        if let Some(belief) = beliefs.get("var_0") {
592            let sum: f64 = belief.iter().sum();
593            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
594        }
595    }
596}