Skip to main content

tensorlogic_quantrs_hooks/
parallel_message_passing.rs

1//! Parallel message passing algorithms using rayon.
2//!
3//! This module provides parallel implementations of belief propagation algorithms
4//! that can significantly speed up inference on large factor graphs with many variables.
5
6use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8use crate::graph::FactorGraph;
9use crate::message_passing::ConvergenceStats;
10use rayon::prelude::*;
11use scirs2_core::ndarray::ArrayD;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15/// Parallel sum-product belief propagation.
16///
17/// Uses rayon to compute messages in parallel, which can provide significant
18/// speedup for large factor graphs.
19pub struct ParallelSumProduct {
20    /// Maximum iterations for convergence
21    pub max_iterations: usize,
22    /// Convergence tolerance
23    pub tolerance: f64,
24    /// Damping factor (0.0 = no damping, 1.0 = full damping)
25    pub damping: f64,
26}
27
28impl Default for ParallelSumProduct {
29    fn default() -> Self {
30        Self {
31            max_iterations: 100,
32            tolerance: 1e-6,
33            damping: 0.0,
34        }
35    }
36}
37
38impl ParallelSumProduct {
39    /// Create with custom parameters.
40    pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
41        Self {
42            max_iterations,
43            tolerance,
44            damping,
45        }
46    }
47
48    /// Run parallel belief propagation.
49    pub fn run_parallel(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
50        // Initialize messages
51        let messages = Arc::new(Mutex::new(self.initialize_messages(graph)?));
52
53        // Iterative message passing
54        for iteration in 0..self.max_iterations {
55            let old_messages = messages.lock().unwrap().clone();
56
57            // Parallel computation of variable-to-factor messages
58            let var_factor_updates: Vec<_> = graph
59                .variable_names()
60                .par_bridge()
61                .flat_map(|var_name| {
62                    if let Some(factors) = graph.get_adjacent_factors(var_name) {
63                        factors
64                            .par_iter()
65                            .filter_map(|factor_id| {
66                                if let Some(factor) = graph.get_factor(factor_id) {
67                                    let key = (var_name.to_string(), factor.name.clone());
68                                    match self.compute_var_to_factor_message(
69                                        graph,
70                                        &old_messages,
71                                        var_name,
72                                        &factor.name,
73                                    ) {
74                                        Ok(msg) => Some((key, msg)),
75                                        Err(_) => None,
76                                    }
77                                } else {
78                                    None
79                                }
80                            })
81                            .collect::<Vec<_>>()
82                    } else {
83                        Vec::new()
84                    }
85                })
86                .collect();
87
88            // Parallel computation of factor-to-variable messages
89            let factor_var_updates: Vec<_> = graph
90                .factor_ids()
91                .par_bridge()
92                .filter_map(|factor_id| graph.get_factor(factor_id))
93                .flat_map(|factor| {
94                    factor
95                        .variables
96                        .par_iter()
97                        .filter_map(|var_name| {
98                            let key = (factor.name.clone(), var_name.clone());
99                            match self.compute_factor_to_var_message(
100                                graph,
101                                &old_messages,
102                                &factor.name,
103                                var_name,
104                            ) {
105                                Ok(msg) => Some((key, msg)),
106                                Err(_) => None,
107                            }
108                        })
109                        .collect::<Vec<_>>()
110                })
111                .collect();
112
113            // Update messages with damping
114            {
115                let mut messages_guard = messages.lock().unwrap();
116                for (key, new_msg) in var_factor_updates.into_iter().chain(factor_var_updates) {
117                    if let Some(old_msg) = messages_guard.get(&key) {
118                        if self.damping > 0.0 {
119                            // Apply damping: msg_new = (1-d)*msg_new + d*msg_old
120                            let damped = self.apply_damping(old_msg, &new_msg);
121                            messages_guard.insert(key, damped);
122                        } else {
123                            messages_guard.insert(key, new_msg);
124                        }
125                    } else {
126                        messages_guard.insert(key, new_msg);
127                    }
128                }
129            }
130
131            // Check convergence
132            let converged = self.check_convergence(&old_messages, &messages.lock().unwrap());
133            if converged {
134                break;
135            }
136
137            if iteration == self.max_iterations - 1 {
138                return Err(PgmError::ConvergenceFailure(format!(
139                    "Parallel belief propagation did not converge after {} iterations",
140                    self.max_iterations
141                )));
142            }
143        }
144
145        // Compute final marginals in parallel
146        let marginals: HashMap<String, ArrayD<f64>> = graph
147            .variable_names()
148            .par_bridge()
149            .filter_map(|var_name| {
150                match self.compute_marginal(graph, &messages.lock().unwrap(), var_name) {
151                    Ok(marginal) => Some((var_name.to_string(), marginal)),
152                    Err(_) => None,
153                }
154            })
155            .collect();
156
157        Ok(marginals)
158    }
159
160    /// Initialize messages with uniform distributions.
161    fn initialize_messages(
162        &self,
163        graph: &FactorGraph,
164    ) -> Result<HashMap<(String, String), Factor>> {
165        let mut messages = HashMap::new();
166
167        // Initialize variable-to-factor messages
168        for var_name in graph.variable_names() {
169            if let Some(var_node) = graph.get_variable(var_name) {
170                let uniform_values = vec![1.0 / var_node.cardinality as f64; var_node.cardinality];
171                let uniform_array =
172                    scirs2_core::ndarray::Array::from_vec(uniform_values).into_dyn();
173
174                if let Some(factors) = graph.get_adjacent_factors(var_name) {
175                    for factor_id in factors {
176                        if let Some(factor) = graph.get_factor(factor_id) {
177                            let msg = Factor::new(
178                                format!("msg_{}_{}", var_name, factor.name),
179                                vec![var_name.to_string()],
180                                uniform_array.clone(),
181                            )?;
182                            messages.insert((var_name.to_string(), factor.name.clone()), msg);
183                        }
184                    }
185                }
186            }
187        }
188
189        // Initialize factor-to-variable messages
190        for factor_id in graph.factor_ids() {
191            if let Some(factor) = graph.get_factor(factor_id) {
192                for var_name in &factor.variables {
193                    if let Some(var_node) = graph.get_variable(var_name) {
194                        let uniform_values =
195                            vec![1.0 / var_node.cardinality as f64; var_node.cardinality];
196                        let uniform_array =
197                            scirs2_core::ndarray::Array::from_vec(uniform_values).into_dyn();
198
199                        let msg = Factor::new(
200                            format!("msg_{}_{}", factor.name, var_name),
201                            vec![var_name.to_string()],
202                            uniform_array,
203                        )?;
204                        messages.insert((factor.name.clone(), var_name.to_string()), msg);
205                    }
206                }
207            }
208        }
209
210        Ok(messages)
211    }
212
213    /// Compute variable-to-factor message.
214    fn compute_var_to_factor_message(
215        &self,
216        graph: &FactorGraph,
217        messages: &HashMap<(String, String), Factor>,
218        var: &str,
219        target_factor: &str,
220    ) -> Result<Factor> {
221        let var_node = graph
222            .get_variable(var)
223            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
224
225        // Start with uniform
226        let mut message_values = vec![1.0; var_node.cardinality];
227
228        // Multiply all incoming factor-to-variable messages except from target
229        if let Some(factors) = graph.get_adjacent_factors(var) {
230            for factor_id in factors {
231                if let Some(factor) = graph.get_factor(factor_id) {
232                    if factor.name != target_factor {
233                        let key = (factor.name.clone(), var.to_string());
234                        if let Some(incoming_msg) = messages.get(&key) {
235                            for (i, message_value) in message_values
236                                .iter_mut()
237                                .enumerate()
238                                .take(var_node.cardinality)
239                            {
240                                *message_value *= incoming_msg.values[[i]];
241                            }
242                        }
243                    }
244                }
245            }
246        }
247
248        let array = scirs2_core::ndarray::Array::from_vec(message_values).into_dyn();
249        Factor::new(
250            format!("msg_{}_{}", var, target_factor),
251            vec![var.to_string()],
252            array,
253        )
254    }
255
256    /// Compute factor-to-variable message.
257    fn compute_factor_to_var_message(
258        &self,
259        graph: &FactorGraph,
260        messages: &HashMap<(String, String), Factor>,
261        factor_name: &str,
262        target_var: &str,
263    ) -> Result<Factor> {
264        let factor = graph
265            .get_factor_by_name(factor_name)
266            .ok_or_else(|| PgmError::InvalidGraph(format!("Factor {} not found", factor_name)))?;
267
268        // Start with the factor
269        let mut product = factor.clone();
270
271        // Multiply all incoming variable-to-factor messages except from target
272        for var in &factor.variables {
273            if var != target_var {
274                let key = (var.clone(), factor_name.to_string());
275                if let Some(incoming_msg) = messages.get(&key) {
276                    product = product.product(incoming_msg)?;
277                }
278            }
279        }
280
281        // Marginalize out all variables except target
282        for var in &factor.variables {
283            if var != target_var {
284                product = product.marginalize_out(var)?;
285            }
286        }
287
288        Ok(product)
289    }
290
291    /// Compute marginal for a variable.
292    fn compute_marginal(
293        &self,
294        graph: &FactorGraph,
295        messages: &HashMap<(String, String), Factor>,
296        var: &str,
297    ) -> Result<ArrayD<f64>> {
298        let var_node = graph
299            .get_variable(var)
300            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
301
302        let mut marginal_values = vec![1.0; var_node.cardinality];
303
304        // Multiply all incoming factor-to-variable messages
305        if let Some(factors) = graph.get_adjacent_factors(var) {
306            for factor_id in factors {
307                if let Some(factor) = graph.get_factor(factor_id) {
308                    let key = (factor.name.clone(), var.to_string());
309                    if let Some(msg) = messages.get(&key) {
310                        for (i, marginal_value) in marginal_values
311                            .iter_mut()
312                            .enumerate()
313                            .take(var_node.cardinality)
314                        {
315                            *marginal_value *= msg.values[[i]];
316                        }
317                    }
318                }
319            }
320        }
321
322        // Normalize
323        let sum: f64 = marginal_values.iter().sum();
324        if sum > 0.0 {
325            for val in &mut marginal_values {
326                *val /= sum;
327            }
328        }
329
330        Ok(scirs2_core::ndarray::Array::from_vec(marginal_values).into_dyn())
331    }
332
333    /// Apply damping to messages.
334    fn apply_damping(&self, old_msg: &Factor, new_msg: &Factor) -> Factor {
335        let mut damped_values = new_msg.values.clone();
336        for i in 0..damped_values.len() {
337            damped_values[[i]] =
338                (1.0 - self.damping) * damped_values[[i]] + self.damping * old_msg.values[[i]];
339        }
340
341        Factor::new(
342            new_msg.name.clone(),
343            new_msg.variables.clone(),
344            damped_values,
345        )
346        .unwrap_or_else(|_| new_msg.clone())
347    }
348
349    /// Check convergence of messages.
350    fn check_convergence(
351        &self,
352        old_messages: &HashMap<(String, String), Factor>,
353        new_messages: &HashMap<(String, String), Factor>,
354    ) -> bool {
355        for (key, new_msg) in new_messages {
356            if let Some(old_msg) = old_messages.get(key) {
357                let diff: f64 = new_msg
358                    .values
359                    .iter()
360                    .zip(old_msg.values.iter())
361                    .map(|(a, b)| (a - b).abs())
362                    .sum();
363
364                if diff > self.tolerance {
365                    return false;
366                }
367            }
368        }
369        true
370    }
371
372    /// Get convergence statistics.
373    pub fn get_stats(&self) -> ConvergenceStats {
374        ConvergenceStats {
375            iterations: 0,
376            converged: false,
377            max_delta: 0.0,
378        }
379    }
380}
381
382/// Parallel max-product algorithm for MAP inference.
383pub struct ParallelMaxProduct {
384    /// Maximum iterations
385    pub max_iterations: usize,
386    /// Convergence tolerance
387    pub tolerance: f64,
388}
389
390impl Default for ParallelMaxProduct {
391    fn default() -> Self {
392        Self {
393            max_iterations: 100,
394            tolerance: 1e-6,
395        }
396    }
397}
398
399impl ParallelMaxProduct {
400    /// Create with custom parameters.
401    pub fn new(max_iterations: usize, tolerance: f64) -> Self {
402        Self {
403            max_iterations,
404            tolerance,
405        }
406    }
407
408    /// Run parallel max-product.
409    pub fn run_parallel(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
410        // Similar to ParallelSumProduct but using max instead of sum
411        // Implementation follows the same pattern with max operations
412
413        let parallel_sp = ParallelSumProduct::new(self.max_iterations, self.tolerance, 0.0);
414        // For now, delegate to sum-product (in a full implementation, replace sum with max)
415        parallel_sp.run_parallel(graph)
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use scirs2_core::ndarray::Array;
423
424    fn create_simple_chain() -> FactorGraph {
425        let mut graph = FactorGraph::new();
426
427        graph.add_variable_with_card("X".to_string(), "Domain".to_string(), 2);
428        graph.add_variable_with_card("Y".to_string(), "Domain".to_string(), 2);
429
430        let f_xy = Factor::new(
431            "f_xy".to_string(),
432            vec!["X".to_string(), "Y".to_string()],
433            Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
434                .unwrap()
435                .into_dyn(),
436        )
437        .unwrap();
438
439        graph.add_factor(f_xy).unwrap();
440
441        graph
442    }
443
444    #[test]
445    fn test_parallel_sum_product() {
446        let graph = create_simple_chain();
447        let parallel_bp = ParallelSumProduct::default();
448
449        let marginals = parallel_bp.run_parallel(&graph).unwrap();
450
451        assert_eq!(marginals.len(), 2);
452
453        // Check normalization
454        for marginal in marginals.values() {
455            let sum: f64 = marginal.iter().sum();
456            assert!((sum - 1.0).abs() < 1e-6, "Marginal sum: {}", sum);
457        }
458    }
459
460    #[test]
461    fn test_parallel_with_damping() {
462        let graph = create_simple_chain();
463        let parallel_bp = ParallelSumProduct::new(100, 1e-6, 0.5);
464
465        let marginals = parallel_bp.run_parallel(&graph).unwrap();
466
467        assert_eq!(marginals.len(), 2);
468    }
469
470    #[test]
471    fn test_parallel_max_product() {
472        let graph = create_simple_chain();
473        let parallel_mp = ParallelMaxProduct::default();
474
475        let marginals = parallel_mp.run_parallel(&graph).unwrap();
476
477        assert_eq!(marginals.len(), 2);
478    }
479}