tensorlogic_quantrs_hooks/
parallel_message_passing.rs1use crate::error::{PgmError, Result};
7use crate::factor::Factor;
8use crate::graph::FactorGraph;
9use crate::message_passing::ConvergenceStats;
10use rayon::prelude::*;
11use scirs2_core::ndarray::ArrayD;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15pub struct ParallelSumProduct {
20 pub max_iterations: usize,
22 pub tolerance: f64,
24 pub damping: f64,
26}
27
28impl Default for ParallelSumProduct {
29 fn default() -> Self {
30 Self {
31 max_iterations: 100,
32 tolerance: 1e-6,
33 damping: 0.0,
34 }
35 }
36}
37
38impl ParallelSumProduct {
39 pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
41 Self {
42 max_iterations,
43 tolerance,
44 damping,
45 }
46 }
47
48 pub fn run_parallel(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
50 let messages = Arc::new(Mutex::new(self.initialize_messages(graph)?));
52
53 for iteration in 0..self.max_iterations {
55 let old_messages = messages.lock().unwrap().clone();
56
57 let var_factor_updates: Vec<_> = graph
59 .variable_names()
60 .par_bridge()
61 .flat_map(|var_name| {
62 if let Some(factors) = graph.get_adjacent_factors(var_name) {
63 factors
64 .par_iter()
65 .filter_map(|factor_id| {
66 if let Some(factor) = graph.get_factor(factor_id) {
67 let key = (var_name.to_string(), factor.name.clone());
68 match self.compute_var_to_factor_message(
69 graph,
70 &old_messages,
71 var_name,
72 &factor.name,
73 ) {
74 Ok(msg) => Some((key, msg)),
75 Err(_) => None,
76 }
77 } else {
78 None
79 }
80 })
81 .collect::<Vec<_>>()
82 } else {
83 Vec::new()
84 }
85 })
86 .collect();
87
88 let factor_var_updates: Vec<_> = graph
90 .factor_ids()
91 .par_bridge()
92 .filter_map(|factor_id| graph.get_factor(factor_id))
93 .flat_map(|factor| {
94 factor
95 .variables
96 .par_iter()
97 .filter_map(|var_name| {
98 let key = (factor.name.clone(), var_name.clone());
99 match self.compute_factor_to_var_message(
100 graph,
101 &old_messages,
102 &factor.name,
103 var_name,
104 ) {
105 Ok(msg) => Some((key, msg)),
106 Err(_) => None,
107 }
108 })
109 .collect::<Vec<_>>()
110 })
111 .collect();
112
113 {
115 let mut messages_guard = messages.lock().unwrap();
116 for (key, new_msg) in var_factor_updates.into_iter().chain(factor_var_updates) {
117 if let Some(old_msg) = messages_guard.get(&key) {
118 if self.damping > 0.0 {
119 let damped = self.apply_damping(old_msg, &new_msg);
121 messages_guard.insert(key, damped);
122 } else {
123 messages_guard.insert(key, new_msg);
124 }
125 } else {
126 messages_guard.insert(key, new_msg);
127 }
128 }
129 }
130
131 let converged = self.check_convergence(&old_messages, &messages.lock().unwrap());
133 if converged {
134 break;
135 }
136
137 if iteration == self.max_iterations - 1 {
138 return Err(PgmError::ConvergenceFailure(format!(
139 "Parallel belief propagation did not converge after {} iterations",
140 self.max_iterations
141 )));
142 }
143 }
144
145 let marginals: HashMap<String, ArrayD<f64>> = graph
147 .variable_names()
148 .par_bridge()
149 .filter_map(|var_name| {
150 match self.compute_marginal(graph, &messages.lock().unwrap(), var_name) {
151 Ok(marginal) => Some((var_name.to_string(), marginal)),
152 Err(_) => None,
153 }
154 })
155 .collect();
156
157 Ok(marginals)
158 }
159
160 fn initialize_messages(
162 &self,
163 graph: &FactorGraph,
164 ) -> Result<HashMap<(String, String), Factor>> {
165 let mut messages = HashMap::new();
166
167 for var_name in graph.variable_names() {
169 if let Some(var_node) = graph.get_variable(var_name) {
170 let uniform_values = vec![1.0 / var_node.cardinality as f64; var_node.cardinality];
171 let uniform_array =
172 scirs2_core::ndarray::Array::from_vec(uniform_values).into_dyn();
173
174 if let Some(factors) = graph.get_adjacent_factors(var_name) {
175 for factor_id in factors {
176 if let Some(factor) = graph.get_factor(factor_id) {
177 let msg = Factor::new(
178 format!("msg_{}_{}", var_name, factor.name),
179 vec![var_name.to_string()],
180 uniform_array.clone(),
181 )?;
182 messages.insert((var_name.to_string(), factor.name.clone()), msg);
183 }
184 }
185 }
186 }
187 }
188
189 for factor_id in graph.factor_ids() {
191 if let Some(factor) = graph.get_factor(factor_id) {
192 for var_name in &factor.variables {
193 if let Some(var_node) = graph.get_variable(var_name) {
194 let uniform_values =
195 vec![1.0 / var_node.cardinality as f64; var_node.cardinality];
196 let uniform_array =
197 scirs2_core::ndarray::Array::from_vec(uniform_values).into_dyn();
198
199 let msg = Factor::new(
200 format!("msg_{}_{}", factor.name, var_name),
201 vec![var_name.to_string()],
202 uniform_array,
203 )?;
204 messages.insert((factor.name.clone(), var_name.to_string()), msg);
205 }
206 }
207 }
208 }
209
210 Ok(messages)
211 }
212
213 fn compute_var_to_factor_message(
215 &self,
216 graph: &FactorGraph,
217 messages: &HashMap<(String, String), Factor>,
218 var: &str,
219 target_factor: &str,
220 ) -> Result<Factor> {
221 let var_node = graph
222 .get_variable(var)
223 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
224
225 let mut message_values = vec![1.0; var_node.cardinality];
227
228 if let Some(factors) = graph.get_adjacent_factors(var) {
230 for factor_id in factors {
231 if let Some(factor) = graph.get_factor(factor_id) {
232 if factor.name != target_factor {
233 let key = (factor.name.clone(), var.to_string());
234 if let Some(incoming_msg) = messages.get(&key) {
235 for (i, message_value) in message_values
236 .iter_mut()
237 .enumerate()
238 .take(var_node.cardinality)
239 {
240 *message_value *= incoming_msg.values[[i]];
241 }
242 }
243 }
244 }
245 }
246 }
247
248 let array = scirs2_core::ndarray::Array::from_vec(message_values).into_dyn();
249 Factor::new(
250 format!("msg_{}_{}", var, target_factor),
251 vec![var.to_string()],
252 array,
253 )
254 }
255
256 fn compute_factor_to_var_message(
258 &self,
259 graph: &FactorGraph,
260 messages: &HashMap<(String, String), Factor>,
261 factor_name: &str,
262 target_var: &str,
263 ) -> Result<Factor> {
264 let factor = graph
265 .get_factor_by_name(factor_name)
266 .ok_or_else(|| PgmError::InvalidGraph(format!("Factor {} not found", factor_name)))?;
267
268 let mut product = factor.clone();
270
271 for var in &factor.variables {
273 if var != target_var {
274 let key = (var.clone(), factor_name.to_string());
275 if let Some(incoming_msg) = messages.get(&key) {
276 product = product.product(incoming_msg)?;
277 }
278 }
279 }
280
281 for var in &factor.variables {
283 if var != target_var {
284 product = product.marginalize_out(var)?;
285 }
286 }
287
288 Ok(product)
289 }
290
291 fn compute_marginal(
293 &self,
294 graph: &FactorGraph,
295 messages: &HashMap<(String, String), Factor>,
296 var: &str,
297 ) -> Result<ArrayD<f64>> {
298 let var_node = graph
299 .get_variable(var)
300 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
301
302 let mut marginal_values = vec![1.0; var_node.cardinality];
303
304 if let Some(factors) = graph.get_adjacent_factors(var) {
306 for factor_id in factors {
307 if let Some(factor) = graph.get_factor(factor_id) {
308 let key = (factor.name.clone(), var.to_string());
309 if let Some(msg) = messages.get(&key) {
310 for (i, marginal_value) in marginal_values
311 .iter_mut()
312 .enumerate()
313 .take(var_node.cardinality)
314 {
315 *marginal_value *= msg.values[[i]];
316 }
317 }
318 }
319 }
320 }
321
322 let sum: f64 = marginal_values.iter().sum();
324 if sum > 0.0 {
325 for val in &mut marginal_values {
326 *val /= sum;
327 }
328 }
329
330 Ok(scirs2_core::ndarray::Array::from_vec(marginal_values).into_dyn())
331 }
332
333 fn apply_damping(&self, old_msg: &Factor, new_msg: &Factor) -> Factor {
335 let mut damped_values = new_msg.values.clone();
336 for i in 0..damped_values.len() {
337 damped_values[[i]] =
338 (1.0 - self.damping) * damped_values[[i]] + self.damping * old_msg.values[[i]];
339 }
340
341 Factor::new(
342 new_msg.name.clone(),
343 new_msg.variables.clone(),
344 damped_values,
345 )
346 .unwrap_or_else(|_| new_msg.clone())
347 }
348
349 fn check_convergence(
351 &self,
352 old_messages: &HashMap<(String, String), Factor>,
353 new_messages: &HashMap<(String, String), Factor>,
354 ) -> bool {
355 for (key, new_msg) in new_messages {
356 if let Some(old_msg) = old_messages.get(key) {
357 let diff: f64 = new_msg
358 .values
359 .iter()
360 .zip(old_msg.values.iter())
361 .map(|(a, b)| (a - b).abs())
362 .sum();
363
364 if diff > self.tolerance {
365 return false;
366 }
367 }
368 }
369 true
370 }
371
372 pub fn get_stats(&self) -> ConvergenceStats {
374 ConvergenceStats {
375 iterations: 0,
376 converged: false,
377 max_delta: 0.0,
378 }
379 }
380}
381
382pub struct ParallelMaxProduct {
384 pub max_iterations: usize,
386 pub tolerance: f64,
388}
389
390impl Default for ParallelMaxProduct {
391 fn default() -> Self {
392 Self {
393 max_iterations: 100,
394 tolerance: 1e-6,
395 }
396 }
397}
398
399impl ParallelMaxProduct {
400 pub fn new(max_iterations: usize, tolerance: f64) -> Self {
402 Self {
403 max_iterations,
404 tolerance,
405 }
406 }
407
408 pub fn run_parallel(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
410 let parallel_sp = ParallelSumProduct::new(self.max_iterations, self.tolerance, 0.0);
414 parallel_sp.run_parallel(graph)
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use scirs2_core::ndarray::Array;
423
424 fn create_simple_chain() -> FactorGraph {
425 let mut graph = FactorGraph::new();
426
427 graph.add_variable_with_card("X".to_string(), "Domain".to_string(), 2);
428 graph.add_variable_with_card("Y".to_string(), "Domain".to_string(), 2);
429
430 let f_xy = Factor::new(
431 "f_xy".to_string(),
432 vec!["X".to_string(), "Y".to_string()],
433 Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
434 .unwrap()
435 .into_dyn(),
436 )
437 .unwrap();
438
439 graph.add_factor(f_xy).unwrap();
440
441 graph
442 }
443
444 #[test]
445 fn test_parallel_sum_product() {
446 let graph = create_simple_chain();
447 let parallel_bp = ParallelSumProduct::default();
448
449 let marginals = parallel_bp.run_parallel(&graph).unwrap();
450
451 assert_eq!(marginals.len(), 2);
452
453 for marginal in marginals.values() {
455 let sum: f64 = marginal.iter().sum();
456 assert!((sum - 1.0).abs() < 1e-6, "Marginal sum: {}", sum);
457 }
458 }
459
460 #[test]
461 fn test_parallel_with_damping() {
462 let graph = create_simple_chain();
463 let parallel_bp = ParallelSumProduct::new(100, 1e-6, 0.5);
464
465 let marginals = parallel_bp.run_parallel(&graph).unwrap();
466
467 assert_eq!(marginals.len(), 2);
468 }
469
470 #[test]
471 fn test_parallel_max_product() {
472 let graph = create_simple_chain();
473 let parallel_mp = ParallelMaxProduct::default();
474
475 let marginals = parallel_mp.run_parallel(&graph).unwrap();
476
477 assert_eq!(marginals.len(), 2);
478 }
479}