Skip to main content

tensorlogic_quantrs_hooks/
parallel_message_passing.rs

1//! Parallel message passing algorithms using rayon.
2//!
3//! This module provides parallel implementations of belief propagation algorithms
4//! that can significantly speed up inference on large factor graphs with many variables.
5
6use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8use crate::graph::FactorGraph;
9use crate::message_passing::ConvergenceStats;
10use rayon::prelude::*;
11use scirs2_core::ndarray::ArrayD;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15/// Parallel sum-product belief propagation.
16///
17/// Uses rayon to compute messages in parallel, which can provide significant
18/// speedup for large factor graphs.
19pub struct ParallelSumProduct {
20    /// Maximum iterations for convergence
21    pub max_iterations: usize,
22    /// Convergence tolerance
23    pub tolerance: f64,
24    /// Damping factor (0.0 = no damping, 1.0 = full damping)
25    pub damping: f64,
26}
27
28impl Default for ParallelSumProduct {
29    fn default() -> Self {
30        Self {
31            max_iterations: 100,
32            tolerance: 1e-6,
33            damping: 0.0,
34        }
35    }
36}
37
38impl ParallelSumProduct {
39    /// Create with custom parameters.
40    pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
41        Self {
42            max_iterations,
43            tolerance,
44            damping,
45        }
46    }
47
48    /// Run parallel belief propagation.
49    pub fn run_parallel(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
50        // Initialize messages
51        let messages = Arc::new(Mutex::new(self.initialize_messages(graph)?));
52
53        // Iterative message passing
54        for iteration in 0..self.max_iterations {
55            let old_messages = messages
56                .lock()
57                .expect("lock should not be poisoned")
58                .clone();
59
60            // Parallel computation of variable-to-factor messages
61            let var_factor_updates: Vec<_> = graph
62                .variable_names()
63                .par_bridge()
64                .flat_map(|var_name| {
65                    if let Some(factors) = graph.get_adjacent_factors(var_name) {
66                        factors
67                            .par_iter()
68                            .filter_map(|factor_id| {
69                                if let Some(factor) = graph.get_factor(factor_id) {
70                                    let key = (var_name.to_string(), factor.name.clone());
71                                    match self.compute_var_to_factor_message(
72                                        graph,
73                                        &old_messages,
74                                        var_name,
75                                        &factor.name,
76                                    ) {
77                                        Ok(msg) => Some((key, msg)),
78                                        Err(_) => None,
79                                    }
80                                } else {
81                                    None
82                                }
83                            })
84                            .collect::<Vec<_>>()
85                    } else {
86                        Vec::new()
87                    }
88                })
89                .collect();
90
91            // Parallel computation of factor-to-variable messages
92            let factor_var_updates: Vec<_> = graph
93                .factor_ids()
94                .par_bridge()
95                .filter_map(|factor_id| graph.get_factor(factor_id))
96                .flat_map(|factor| {
97                    factor
98                        .variables
99                        .par_iter()
100                        .filter_map(|var_name| {
101                            let key = (factor.name.clone(), var_name.clone());
102                            match self.compute_factor_to_var_message(
103                                graph,
104                                &old_messages,
105                                &factor.name,
106                                var_name,
107                            ) {
108                                Ok(msg) => Some((key, msg)),
109                                Err(_) => None,
110                            }
111                        })
112                        .collect::<Vec<_>>()
113                })
114                .collect();
115
116            // Update messages with damping
117            {
118                let mut messages_guard = messages.lock().expect("lock should not be poisoned");
119                for (key, new_msg) in var_factor_updates.into_iter().chain(factor_var_updates) {
120                    if let Some(old_msg) = messages_guard.get(&key) {
121                        if self.damping > 0.0 {
122                            // Apply damping: msg_new = (1-d)*msg_new + d*msg_old
123                            let damped = self.apply_damping(old_msg, &new_msg);
124                            messages_guard.insert(key, damped);
125                        } else {
126                            messages_guard.insert(key, new_msg);
127                        }
128                    } else {
129                        messages_guard.insert(key, new_msg);
130                    }
131                }
132            }
133
134            // Check convergence
135            let converged = self.check_convergence(
136                &old_messages,
137                &messages.lock().expect("lock should not be poisoned"),
138            );
139            if converged {
140                break;
141            }
142
143            if iteration == self.max_iterations - 1 {
144                return Err(PgmError::ConvergenceFailure(format!(
145                    "Parallel belief propagation did not converge after {} iterations",
146                    self.max_iterations
147                )));
148            }
149        }
150
151        // Compute final marginals in parallel
152        let marginals: HashMap<String, ArrayD<f64>> = graph
153            .variable_names()
154            .par_bridge()
155            .filter_map(|var_name| {
156                match self.compute_marginal(
157                    graph,
158                    &messages.lock().expect("lock should not be poisoned"),
159                    var_name,
160                ) {
161                    Ok(marginal) => Some((var_name.to_string(), marginal)),
162                    Err(_) => None,
163                }
164            })
165            .collect();
166
167        Ok(marginals)
168    }
169
170    /// Initialize messages with uniform distributions.
171    fn initialize_messages(
172        &self,
173        graph: &FactorGraph,
174    ) -> Result<HashMap<(String, String), Factor>> {
175        let mut messages = HashMap::new();
176
177        // Initialize variable-to-factor messages
178        for var_name in graph.variable_names() {
179            if let Some(var_node) = graph.get_variable(var_name) {
180                let uniform_values = vec![1.0 / var_node.cardinality as f64; var_node.cardinality];
181                let uniform_array =
182                    scirs2_core::ndarray::Array::from_vec(uniform_values).into_dyn();
183
184                if let Some(factors) = graph.get_adjacent_factors(var_name) {
185                    for factor_id in factors {
186                        if let Some(factor) = graph.get_factor(factor_id) {
187                            let msg = Factor::new(
188                                format!("msg_{}_{}", var_name, factor.name),
189                                vec![var_name.to_string()],
190                                uniform_array.clone(),
191                            )?;
192                            messages.insert((var_name.to_string(), factor.name.clone()), msg);
193                        }
194                    }
195                }
196            }
197        }
198
199        // Initialize factor-to-variable messages
200        for factor_id in graph.factor_ids() {
201            if let Some(factor) = graph.get_factor(factor_id) {
202                for var_name in &factor.variables {
203                    if let Some(var_node) = graph.get_variable(var_name) {
204                        let uniform_values =
205                            vec![1.0 / var_node.cardinality as f64; var_node.cardinality];
206                        let uniform_array =
207                            scirs2_core::ndarray::Array::from_vec(uniform_values).into_dyn();
208
209                        let msg = Factor::new(
210                            format!("msg_{}_{}", factor.name, var_name),
211                            vec![var_name.to_string()],
212                            uniform_array,
213                        )?;
214                        messages.insert((factor.name.clone(), var_name.to_string()), msg);
215                    }
216                }
217            }
218        }
219
220        Ok(messages)
221    }
222
223    /// Compute variable-to-factor message.
224    fn compute_var_to_factor_message(
225        &self,
226        graph: &FactorGraph,
227        messages: &HashMap<(String, String), Factor>,
228        var: &str,
229        target_factor: &str,
230    ) -> Result<Factor> {
231        let var_node = graph
232            .get_variable(var)
233            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
234
235        // Start with uniform
236        let mut message_values = vec![1.0; var_node.cardinality];
237
238        // Multiply all incoming factor-to-variable messages except from target
239        if let Some(factors) = graph.get_adjacent_factors(var) {
240            for factor_id in factors {
241                if let Some(factor) = graph.get_factor(factor_id) {
242                    if factor.name != target_factor {
243                        let key = (factor.name.clone(), var.to_string());
244                        if let Some(incoming_msg) = messages.get(&key) {
245                            for (i, message_value) in message_values
246                                .iter_mut()
247                                .enumerate()
248                                .take(var_node.cardinality)
249                            {
250                                *message_value *= incoming_msg.values[[i]];
251                            }
252                        }
253                    }
254                }
255            }
256        }
257
258        let array = scirs2_core::ndarray::Array::from_vec(message_values).into_dyn();
259        Factor::new(
260            format!("msg_{}_{}", var, target_factor),
261            vec![var.to_string()],
262            array,
263        )
264    }
265
266    /// Compute factor-to-variable message.
267    fn compute_factor_to_var_message(
268        &self,
269        graph: &FactorGraph,
270        messages: &HashMap<(String, String), Factor>,
271        factor_name: &str,
272        target_var: &str,
273    ) -> Result<Factor> {
274        let factor = graph
275            .get_factor_by_name(factor_name)
276            .ok_or_else(|| PgmError::InvalidGraph(format!("Factor {} not found", factor_name)))?;
277
278        // Start with the factor
279        let mut product = factor.clone();
280
281        // Multiply all incoming variable-to-factor messages except from target
282        for var in &factor.variables {
283            if var != target_var {
284                let key = (var.clone(), factor_name.to_string());
285                if let Some(incoming_msg) = messages.get(&key) {
286                    product = product.product(incoming_msg)?;
287                }
288            }
289        }
290
291        // Marginalize out all variables except target
292        for var in &factor.variables {
293            if var != target_var {
294                product = product.marginalize_out(var)?;
295            }
296        }
297
298        Ok(product)
299    }
300
301    /// Compute marginal for a variable.
302    fn compute_marginal(
303        &self,
304        graph: &FactorGraph,
305        messages: &HashMap<(String, String), Factor>,
306        var: &str,
307    ) -> Result<ArrayD<f64>> {
308        let var_node = graph
309            .get_variable(var)
310            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
311
312        let mut marginal_values = vec![1.0; var_node.cardinality];
313
314        // Multiply all incoming factor-to-variable messages
315        if let Some(factors) = graph.get_adjacent_factors(var) {
316            for factor_id in factors {
317                if let Some(factor) = graph.get_factor(factor_id) {
318                    let key = (factor.name.clone(), var.to_string());
319                    if let Some(msg) = messages.get(&key) {
320                        for (i, marginal_value) in marginal_values
321                            .iter_mut()
322                            .enumerate()
323                            .take(var_node.cardinality)
324                        {
325                            *marginal_value *= msg.values[[i]];
326                        }
327                    }
328                }
329            }
330        }
331
332        // Normalize
333        let sum: f64 = marginal_values.iter().sum();
334        if sum > 0.0 {
335            for val in &mut marginal_values {
336                *val /= sum;
337            }
338        }
339
340        Ok(scirs2_core::ndarray::Array::from_vec(marginal_values).into_dyn())
341    }
342
343    /// Apply damping to messages.
344    fn apply_damping(&self, old_msg: &Factor, new_msg: &Factor) -> Factor {
345        let mut damped_values = new_msg.values.clone();
346        for i in 0..damped_values.len() {
347            damped_values[[i]] =
348                (1.0 - self.damping) * damped_values[[i]] + self.damping * old_msg.values[[i]];
349        }
350
351        Factor::new(
352            new_msg.name.clone(),
353            new_msg.variables.clone(),
354            damped_values,
355        )
356        .unwrap_or_else(|_| new_msg.clone())
357    }
358
359    /// Check convergence of messages.
360    fn check_convergence(
361        &self,
362        old_messages: &HashMap<(String, String), Factor>,
363        new_messages: &HashMap<(String, String), Factor>,
364    ) -> bool {
365        for (key, new_msg) in new_messages {
366            if let Some(old_msg) = old_messages.get(key) {
367                let diff: f64 = new_msg
368                    .values
369                    .iter()
370                    .zip(old_msg.values.iter())
371                    .map(|(a, b)| (a - b).abs())
372                    .sum();
373
374                if diff > self.tolerance {
375                    return false;
376                }
377            }
378        }
379        true
380    }
381
382    /// Get convergence statistics.
383    pub fn get_stats(&self) -> ConvergenceStats {
384        ConvergenceStats {
385            iterations: 0,
386            converged: false,
387            max_delta: 0.0,
388        }
389    }
390}
391
392/// Parallel max-product algorithm for MAP inference.
393pub struct ParallelMaxProduct {
394    /// Maximum iterations
395    pub max_iterations: usize,
396    /// Convergence tolerance
397    pub tolerance: f64,
398}
399
400impl Default for ParallelMaxProduct {
401    fn default() -> Self {
402        Self {
403            max_iterations: 100,
404            tolerance: 1e-6,
405        }
406    }
407}
408
409impl ParallelMaxProduct {
410    /// Create with custom parameters.
411    pub fn new(max_iterations: usize, tolerance: f64) -> Self {
412        Self {
413            max_iterations,
414            tolerance,
415        }
416    }
417
418    /// Run parallel max-product.
419    pub fn run_parallel(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
420        // Similar to ParallelSumProduct but using max instead of sum
421        // Implementation follows the same pattern with max operations
422
423        let parallel_sp = ParallelSumProduct::new(self.max_iterations, self.tolerance, 0.0);
424        // For now, delegate to sum-product (in a full implementation, replace sum with max)
425        parallel_sp.run_parallel(graph)
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use scirs2_core::ndarray::Array;
433
434    fn create_simple_chain() -> FactorGraph {
435        let mut graph = FactorGraph::new();
436
437        graph.add_variable_with_card("X".to_string(), "Domain".to_string(), 2);
438        graph.add_variable_with_card("Y".to_string(), "Domain".to_string(), 2);
439
440        let f_xy = Factor::new(
441            "f_xy".to_string(),
442            vec!["X".to_string(), "Y".to_string()],
443            Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
444                .expect("unwrap")
445                .into_dyn(),
446        )
447        .expect("unwrap");
448
449        graph.add_factor(f_xy).expect("unwrap");
450
451        graph
452    }
453
454    #[test]
455    fn test_parallel_sum_product() {
456        let graph = create_simple_chain();
457        let parallel_bp = ParallelSumProduct::default();
458
459        let marginals = parallel_bp.run_parallel(&graph).expect("unwrap");
460
461        assert_eq!(marginals.len(), 2);
462
463        // Check normalization
464        for marginal in marginals.values() {
465            let sum: f64 = marginal.iter().sum();
466            assert!((sum - 1.0).abs() < 1e-6, "Marginal sum: {}", sum);
467        }
468    }
469
470    #[test]
471    fn test_parallel_with_damping() {
472        let graph = create_simple_chain();
473        let parallel_bp = ParallelSumProduct::new(100, 1e-6, 0.5);
474
475        let marginals = parallel_bp.run_parallel(&graph).expect("unwrap");
476
477        assert_eq!(marginals.len(), 2);
478    }
479
480    #[test]
481    fn test_parallel_max_product() {
482        let graph = create_simple_chain();
483        let parallel_mp = ParallelMaxProduct::default();
484
485        let marginals = parallel_mp.run_parallel(&graph).expect("unwrap");
486
487        assert_eq!(marginals.len(), 2);
488    }
489}