tensorlogic_quantrs_hooks/
parallel_message_passing.rs1use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8use crate::graph::FactorGraph;
9use crate::message_passing::ConvergenceStats;
10use rayon::prelude::*;
11use scirs2_core::ndarray::ArrayD;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15pub struct ParallelSumProduct {
20 pub max_iterations: usize,
22 pub tolerance: f64,
24 pub damping: f64,
26}
27
28impl Default for ParallelSumProduct {
29 fn default() -> Self {
30 Self {
31 max_iterations: 100,
32 tolerance: 1e-6,
33 damping: 0.0,
34 }
35 }
36}
37
38impl ParallelSumProduct {
39 pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
41 Self {
42 max_iterations,
43 tolerance,
44 damping,
45 }
46 }
47
48 pub fn run_parallel(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
50 let messages = Arc::new(Mutex::new(self.initialize_messages(graph)?));
52
53 for iteration in 0..self.max_iterations {
55 let old_messages = messages
56 .lock()
57 .expect("lock should not be poisoned")
58 .clone();
59
60 let var_factor_updates: Vec<_> = graph
62 .variable_names()
63 .par_bridge()
64 .flat_map(|var_name| {
65 if let Some(factors) = graph.get_adjacent_factors(var_name) {
66 factors
67 .par_iter()
68 .filter_map(|factor_id| {
69 if let Some(factor) = graph.get_factor(factor_id) {
70 let key = (var_name.to_string(), factor.name.clone());
71 match self.compute_var_to_factor_message(
72 graph,
73 &old_messages,
74 var_name,
75 &factor.name,
76 ) {
77 Ok(msg) => Some((key, msg)),
78 Err(_) => None,
79 }
80 } else {
81 None
82 }
83 })
84 .collect::<Vec<_>>()
85 } else {
86 Vec::new()
87 }
88 })
89 .collect();
90
91 let factor_var_updates: Vec<_> = graph
93 .factor_ids()
94 .par_bridge()
95 .filter_map(|factor_id| graph.get_factor(factor_id))
96 .flat_map(|factor| {
97 factor
98 .variables
99 .par_iter()
100 .filter_map(|var_name| {
101 let key = (factor.name.clone(), var_name.clone());
102 match self.compute_factor_to_var_message(
103 graph,
104 &old_messages,
105 &factor.name,
106 var_name,
107 ) {
108 Ok(msg) => Some((key, msg)),
109 Err(_) => None,
110 }
111 })
112 .collect::<Vec<_>>()
113 })
114 .collect();
115
116 {
118 let mut messages_guard = messages.lock().expect("lock should not be poisoned");
119 for (key, new_msg) in var_factor_updates.into_iter().chain(factor_var_updates) {
120 if let Some(old_msg) = messages_guard.get(&key) {
121 if self.damping > 0.0 {
122 let damped = self.apply_damping(old_msg, &new_msg);
124 messages_guard.insert(key, damped);
125 } else {
126 messages_guard.insert(key, new_msg);
127 }
128 } else {
129 messages_guard.insert(key, new_msg);
130 }
131 }
132 }
133
134 let converged = self.check_convergence(
136 &old_messages,
137 &messages.lock().expect("lock should not be poisoned"),
138 );
139 if converged {
140 break;
141 }
142
143 if iteration == self.max_iterations - 1 {
144 return Err(PgmError::ConvergenceFailure(format!(
145 "Parallel belief propagation did not converge after {} iterations",
146 self.max_iterations
147 )));
148 }
149 }
150
151 let marginals: HashMap<String, ArrayD<f64>> = graph
153 .variable_names()
154 .par_bridge()
155 .filter_map(|var_name| {
156 match self.compute_marginal(
157 graph,
158 &messages.lock().expect("lock should not be poisoned"),
159 var_name,
160 ) {
161 Ok(marginal) => Some((var_name.to_string(), marginal)),
162 Err(_) => None,
163 }
164 })
165 .collect();
166
167 Ok(marginals)
168 }
169
170 fn initialize_messages(
172 &self,
173 graph: &FactorGraph,
174 ) -> Result<HashMap<(String, String), Factor>> {
175 let mut messages = HashMap::new();
176
177 for var_name in graph.variable_names() {
179 if let Some(var_node) = graph.get_variable(var_name) {
180 let uniform_values = vec![1.0 / var_node.cardinality as f64; var_node.cardinality];
181 let uniform_array =
182 scirs2_core::ndarray::Array::from_vec(uniform_values).into_dyn();
183
184 if let Some(factors) = graph.get_adjacent_factors(var_name) {
185 for factor_id in factors {
186 if let Some(factor) = graph.get_factor(factor_id) {
187 let msg = Factor::new(
188 format!("msg_{}_{}", var_name, factor.name),
189 vec![var_name.to_string()],
190 uniform_array.clone(),
191 )?;
192 messages.insert((var_name.to_string(), factor.name.clone()), msg);
193 }
194 }
195 }
196 }
197 }
198
199 for factor_id in graph.factor_ids() {
201 if let Some(factor) = graph.get_factor(factor_id) {
202 for var_name in &factor.variables {
203 if let Some(var_node) = graph.get_variable(var_name) {
204 let uniform_values =
205 vec![1.0 / var_node.cardinality as f64; var_node.cardinality];
206 let uniform_array =
207 scirs2_core::ndarray::Array::from_vec(uniform_values).into_dyn();
208
209 let msg = Factor::new(
210 format!("msg_{}_{}", factor.name, var_name),
211 vec![var_name.to_string()],
212 uniform_array,
213 )?;
214 messages.insert((factor.name.clone(), var_name.to_string()), msg);
215 }
216 }
217 }
218 }
219
220 Ok(messages)
221 }
222
223 fn compute_var_to_factor_message(
225 &self,
226 graph: &FactorGraph,
227 messages: &HashMap<(String, String), Factor>,
228 var: &str,
229 target_factor: &str,
230 ) -> Result<Factor> {
231 let var_node = graph
232 .get_variable(var)
233 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
234
235 let mut message_values = vec![1.0; var_node.cardinality];
237
238 if let Some(factors) = graph.get_adjacent_factors(var) {
240 for factor_id in factors {
241 if let Some(factor) = graph.get_factor(factor_id) {
242 if factor.name != target_factor {
243 let key = (factor.name.clone(), var.to_string());
244 if let Some(incoming_msg) = messages.get(&key) {
245 for (i, message_value) in message_values
246 .iter_mut()
247 .enumerate()
248 .take(var_node.cardinality)
249 {
250 *message_value *= incoming_msg.values[[i]];
251 }
252 }
253 }
254 }
255 }
256 }
257
258 let array = scirs2_core::ndarray::Array::from_vec(message_values).into_dyn();
259 Factor::new(
260 format!("msg_{}_{}", var, target_factor),
261 vec![var.to_string()],
262 array,
263 )
264 }
265
266 fn compute_factor_to_var_message(
268 &self,
269 graph: &FactorGraph,
270 messages: &HashMap<(String, String), Factor>,
271 factor_name: &str,
272 target_var: &str,
273 ) -> Result<Factor> {
274 let factor = graph
275 .get_factor_by_name(factor_name)
276 .ok_or_else(|| PgmError::InvalidGraph(format!("Factor {} not found", factor_name)))?;
277
278 let mut product = factor.clone();
280
281 for var in &factor.variables {
283 if var != target_var {
284 let key = (var.clone(), factor_name.to_string());
285 if let Some(incoming_msg) = messages.get(&key) {
286 product = product.product(incoming_msg)?;
287 }
288 }
289 }
290
291 for var in &factor.variables {
293 if var != target_var {
294 product = product.marginalize_out(var)?;
295 }
296 }
297
298 Ok(product)
299 }
300
301 fn compute_marginal(
303 &self,
304 graph: &FactorGraph,
305 messages: &HashMap<(String, String), Factor>,
306 var: &str,
307 ) -> Result<ArrayD<f64>> {
308 let var_node = graph
309 .get_variable(var)
310 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
311
312 let mut marginal_values = vec![1.0; var_node.cardinality];
313
314 if let Some(factors) = graph.get_adjacent_factors(var) {
316 for factor_id in factors {
317 if let Some(factor) = graph.get_factor(factor_id) {
318 let key = (factor.name.clone(), var.to_string());
319 if let Some(msg) = messages.get(&key) {
320 for (i, marginal_value) in marginal_values
321 .iter_mut()
322 .enumerate()
323 .take(var_node.cardinality)
324 {
325 *marginal_value *= msg.values[[i]];
326 }
327 }
328 }
329 }
330 }
331
332 let sum: f64 = marginal_values.iter().sum();
334 if sum > 0.0 {
335 for val in &mut marginal_values {
336 *val /= sum;
337 }
338 }
339
340 Ok(scirs2_core::ndarray::Array::from_vec(marginal_values).into_dyn())
341 }
342
343 fn apply_damping(&self, old_msg: &Factor, new_msg: &Factor) -> Factor {
345 let mut damped_values = new_msg.values.clone();
346 for i in 0..damped_values.len() {
347 damped_values[[i]] =
348 (1.0 - self.damping) * damped_values[[i]] + self.damping * old_msg.values[[i]];
349 }
350
351 Factor::new(
352 new_msg.name.clone(),
353 new_msg.variables.clone(),
354 damped_values,
355 )
356 .unwrap_or_else(|_| new_msg.clone())
357 }
358
359 fn check_convergence(
361 &self,
362 old_messages: &HashMap<(String, String), Factor>,
363 new_messages: &HashMap<(String, String), Factor>,
364 ) -> bool {
365 for (key, new_msg) in new_messages {
366 if let Some(old_msg) = old_messages.get(key) {
367 let diff: f64 = new_msg
368 .values
369 .iter()
370 .zip(old_msg.values.iter())
371 .map(|(a, b)| (a - b).abs())
372 .sum();
373
374 if diff > self.tolerance {
375 return false;
376 }
377 }
378 }
379 true
380 }
381
382 pub fn get_stats(&self) -> ConvergenceStats {
384 ConvergenceStats {
385 iterations: 0,
386 converged: false,
387 max_delta: 0.0,
388 }
389 }
390}
391
392pub struct ParallelMaxProduct {
394 pub max_iterations: usize,
396 pub tolerance: f64,
398}
399
400impl Default for ParallelMaxProduct {
401 fn default() -> Self {
402 Self {
403 max_iterations: 100,
404 tolerance: 1e-6,
405 }
406 }
407}
408
409impl ParallelMaxProduct {
410 pub fn new(max_iterations: usize, tolerance: f64) -> Self {
412 Self {
413 max_iterations,
414 tolerance,
415 }
416 }
417
418 pub fn run_parallel(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
420 let parallel_sp = ParallelSumProduct::new(self.max_iterations, self.tolerance, 0.0);
424 parallel_sp.run_parallel(graph)
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use scirs2_core::ndarray::Array;
433
434 fn create_simple_chain() -> FactorGraph {
435 let mut graph = FactorGraph::new();
436
437 graph.add_variable_with_card("X".to_string(), "Domain".to_string(), 2);
438 graph.add_variable_with_card("Y".to_string(), "Domain".to_string(), 2);
439
440 let f_xy = Factor::new(
441 "f_xy".to_string(),
442 vec!["X".to_string(), "Y".to_string()],
443 Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
444 .expect("unwrap")
445 .into_dyn(),
446 )
447 .expect("unwrap");
448
449 graph.add_factor(f_xy).expect("unwrap");
450
451 graph
452 }
453
454 #[test]
455 fn test_parallel_sum_product() {
456 let graph = create_simple_chain();
457 let parallel_bp = ParallelSumProduct::default();
458
459 let marginals = parallel_bp.run_parallel(&graph).expect("unwrap");
460
461 assert_eq!(marginals.len(), 2);
462
463 for marginal in marginals.values() {
465 let sum: f64 = marginal.iter().sum();
466 assert!((sum - 1.0).abs() < 1e-6, "Marginal sum: {}", sum);
467 }
468 }
469
470 #[test]
471 fn test_parallel_with_damping() {
472 let graph = create_simple_chain();
473 let parallel_bp = ParallelSumProduct::new(100, 1e-6, 0.5);
474
475 let marginals = parallel_bp.run_parallel(&graph).expect("unwrap");
476
477 assert_eq!(marginals.len(), 2);
478 }
479
480 #[test]
481 fn test_parallel_max_product() {
482 let graph = create_simple_chain();
483 let parallel_mp = ParallelMaxProduct::default();
484
485 let marginals = parallel_mp.run_parallel(&graph).expect("unwrap");
486
487 assert_eq!(marginals.len(), 2);
488 }
489}