Skip to main content

tensorlogic_quantrs_hooks/loopy_bp/
engine.rs

1//! Core Loopy BP engine: message store, belief computation, and the
2//! [`MessagePassingAlgorithm`] trait implementation.
3
4use scirs2_core::ndarray::{Array1, ArrayD};
5use std::collections::HashMap;
6
7use crate::error::{PgmError, Result};
8use crate::graph::FactorGraph;
9use crate::message_passing::MessagePassingAlgorithm;
10
11use super::config::{LoopyBpConfig, LoopyBpResult};
12use super::cycle::CycleDetector;
13use super::energy::bethe_free_energy;
14use super::types::{LbpConvergenceMonitor, LbpIterStats, LogMessage, UpdateSchedule};
15
16// ──────────────────────────────────────────────────────────────────────────────
17// Message store (log-domain)
18// ──────────────────────────────────────────────────────────────────────────────
19
20/// Log-domain message store for LBP.
21#[derive(Clone, Debug, Default)]
22pub(super) struct LogMessageStore {
23    /// (variable, factor) → log-message
24    pub(super) var_to_factor: HashMap<(String, String), LogMessage>,
25    /// (factor, variable) → log-message
26    pub(super) factor_to_var: HashMap<(String, String), LogMessage>,
27}
28
29impl LogMessageStore {
30    pub(super) fn new() -> Self {
31        Self::default()
32    }
33
34    pub(super) fn get_vtf(&self, var: &str, fac: &str) -> Option<&LogMessage> {
35        self.var_to_factor.get(&(var.to_string(), fac.to_string()))
36    }
37
38    pub(super) fn set_vtf(&mut self, var: String, fac: String, msg: LogMessage) {
39        self.var_to_factor.insert((var, fac), msg);
40    }
41
42    pub(super) fn get_ftv(&self, fac: &str, var: &str) -> Option<&LogMessage> {
43        self.factor_to_var.get(&(fac.to_string(), var.to_string()))
44    }
45
46    pub(super) fn set_ftv(&mut self, fac: String, var: String, msg: LogMessage) {
47        self.factor_to_var.insert((fac, var), msg);
48    }
49}
50
51// ──────────────────────────────────────────────────────────────────────────────
52// Core Loopy BP engine
53// ──────────────────────────────────────────────────────────────────────────────
54
55/// Loopy Belief Propagation on general (cyclic) factor graphs.
56///
57/// # Example
58/// ```rust,ignore
59/// use tensorlogic_quantrs_hooks::loopy_bp::{LoopyBpConfig, LoopyBeliefPropagation, LbpDampingPolicy, UpdateSchedule};
60/// use tensorlogic_quantrs_hooks::graph::FactorGraph;
61///
62/// let mut graph = FactorGraph::new();
63/// graph.add_variable("x".to_string(), "Binary".to_string());
64/// graph.add_variable("y".to_string(), "Binary".to_string());
65///
66/// let config = LoopyBpConfig::default()
67///     .with_damping(LbpDampingPolicy::Uniform(0.5))
68///     .with_schedule(UpdateSchedule::Synchronous);
69///
70/// let lbp = LoopyBeliefPropagation::new(config);
71/// let result = lbp.run_full(&graph).expect("LBP failed");
72/// println!("Converged: {}", result.convergence.is_converged());
73/// ```
74pub struct LoopyBeliefPropagation {
75    /// Configuration.
76    pub config: LoopyBpConfig,
77}
78
79impl LoopyBeliefPropagation {
80    /// Create a new LBP engine with the given configuration.
81    pub fn new(config: LoopyBpConfig) -> Self {
82        Self { config }
83    }
84
85    /// Run Loopy BP and return the full [`LoopyBpResult`].
86    pub fn run_full(&self, graph: &FactorGraph) -> Result<LoopyBpResult> {
87        // Analyse cycles first.
88        let cycle_analysis = CycleDetector::new(graph).analyse();
89
90        // Initialise message store with uniform messages.
91        let mut messages = self.initialise_messages(graph);
92
93        let mut monitor = LbpConvergenceMonitor::new();
94
95        match self.config.schedule {
96            UpdateSchedule::Synchronous => {
97                self.run_synchronous(graph, &mut messages, &mut monitor)?;
98            }
99            UpdateSchedule::Sequential => {
100                self.run_sequential(graph, &mut messages, &mut monitor)?;
101            }
102            UpdateSchedule::Residual => {
103                self.run_residual(graph, &mut messages, &mut monitor)?;
104            }
105        }
106
107        // Compute variable beliefs.
108        let beliefs = self.compute_variable_beliefs(graph, &messages)?;
109        let factor_beliefs = self.compute_factor_beliefs(graph, &messages)?;
110
111        // Optionally compute Bethe free energy.
112        let bethe = if self.config.compute_bethe {
113            Some(bethe_free_energy(graph, &beliefs, &factor_beliefs))
114        } else {
115            None
116        };
117
118        Ok(LoopyBpResult {
119            beliefs,
120            factor_beliefs,
121            convergence: monitor,
122            bethe,
123            cycle_analysis,
124        })
125    }
126
127    // ── Initialisation ──────────────────────────────────────────────────────
128
129    pub(super) fn initialise_messages(&self, graph: &FactorGraph) -> LogMessageStore {
130        let mut store = LogMessageStore::new();
131
132        for var_name in graph.variable_names() {
133            let card = graph
134                .get_variable(var_name)
135                .map(|v| v.cardinality)
136                .unwrap_or(2);
137
138            if let Some(fac_ids) = graph.get_adjacent_factors(var_name) {
139                for fac_id in fac_ids {
140                    // var → factor  (uniform)
141                    store.set_vtf(
142                        var_name.clone(),
143                        fac_id.clone(),
144                        LogMessage::uniform(var_name, card),
145                    );
146
147                    // factor → var  (uniform, but seeded from factor values if available)
148                    let ftv_msg = if let Some(factor) = graph.get_factor(fac_id) {
149                        // Marginalise factor over all variables except this one.
150                        let marginal = self.marginalise_factor_to_var(factor, var_name);
151                        marginal.unwrap_or_else(|_| LogMessage::uniform(var_name, card))
152                    } else {
153                        LogMessage::uniform(var_name, card)
154                    };
155                    store.set_ftv(fac_id.clone(), var_name.clone(), ftv_msg);
156                }
157            }
158        }
159
160        store
161    }
162
163    // ── Message computation ───────────────────────────────────────────────────
164
165    /// Compute a variable→factor log-message:
166    /// `log μ(x→f, xᵢ) = ∑_{g ∈ N(x) \ f} log μ(g→x, xᵢ)`
167    pub(super) fn compute_vtf_message(
168        &self,
169        graph: &FactorGraph,
170        messages: &LogMessageStore,
171        var: &str,
172        target_fac: &str,
173    ) -> Result<LogMessage> {
174        let card = graph
175            .get_variable(var)
176            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
177            .cardinality;
178
179        let mut log_msg = Array1::<f64>::zeros(card);
180
181        if let Some(fac_ids) = graph.get_adjacent_factors(var) {
182            for fac_id in fac_ids {
183                if fac_id == target_fac {
184                    continue;
185                }
186                if let Some(ftv) = messages.get_ftv(fac_id, var) {
187                    // Sum log-messages (= product in probability space).
188                    log_msg += &ftv.log_values;
189                }
190            }
191        }
192
193        let mut msg = LogMessage {
194            variable: var.to_string(),
195            log_values: log_msg,
196        };
197        msg.log_normalise();
198        Ok(msg)
199    }
200
201    /// Compute a factor→variable log-message (log-domain sum-product):
202    /// `log μ(f→x, xᵢ) = log ∑_{~xᵢ} [φ_f(x) ∏_{y∈N(f)\xᵢ} μ(y→f, y_j)]`
203    pub(super) fn compute_ftv_message(
204        &self,
205        graph: &FactorGraph,
206        messages: &LogMessageStore,
207        fac_id: &str,
208        target_var: &str,
209    ) -> Result<LogMessage> {
210        let factor = graph
211            .get_factor(fac_id)
212            .ok_or_else(|| PgmError::FactorNotFound(fac_id.to_string()))?;
213
214        let target_idx = factor
215            .variables
216            .iter()
217            .position(|v| v == target_var)
218            .ok_or_else(|| {
219                PgmError::VariableNotFound(format!(
220                    "Variable '{}' not in factor '{}'",
221                    target_var, fac_id
222                ))
223            })?;
224
225        let target_card = factor.values.shape()[target_idx];
226
227        // Compute log(φ_f(x) * ∏_{y≠target} μ(y→f, y_j)) for every joint assignment,
228        // then marginalise (log-sum-exp) over all dimensions except target.
229        let total_size: usize = factor.values.shape().iter().product();
230        let mut log_joint = Vec::with_capacity(total_size);
231
232        for lin_idx in 0..total_size {
233            let assignment = linear_to_assignment(lin_idx, factor.values.shape());
234            let mut log_val = {
235                let phi = factor.values[assignment.as_slice()];
236                if phi > 1e-300 {
237                    phi.ln()
238                } else {
239                    -700.0
240                }
241            };
242            // Multiply in incoming var→factor messages.
243            for (dim, var_name) in factor.variables.iter().enumerate() {
244                if var_name == target_var {
245                    continue;
246                }
247                if let Some(vtf) = messages.get_vtf(var_name, fac_id) {
248                    let val_idx = assignment[dim];
249                    let lv = vtf.log_values.get(val_idx).copied().unwrap_or(-700.0);
250                    log_val += lv;
251                }
252            }
253            log_joint.push((assignment[target_idx], log_val));
254        }
255
256        // Log-sum-exp over all assignments sharing the same target value.
257        let mut result = vec![f64::NEG_INFINITY; target_card];
258        for (t_val, lv) in log_joint {
259            // log-sum-exp accumulate: log(exp(a) + exp(b)) = max + log(1 + exp(min-max))
260            let cur = result[t_val];
261            if cur == f64::NEG_INFINITY {
262                result[t_val] = lv;
263            } else {
264                let m = cur.max(lv);
265                result[t_val] = m + ((cur - m).exp() + (lv - m).exp()).ln();
266            }
267        }
268
269        let mut msg = LogMessage {
270            variable: target_var.to_string(),
271            log_values: Array1::from(result),
272        };
273        msg.log_normalise();
274        Ok(msg)
275    }
276
277    /// Apply all updated messages with damping and track residuals.
278    pub(super) fn apply_updates_and_track(
279        &self,
280        messages: &mut LogMessageStore,
281        new_vtf: HashMap<(String, String), LogMessage>,
282        new_ftv: HashMap<(String, String), LogMessage>,
283        iteration: usize,
284    ) -> LbpIterStats {
285        let mut max_residual = 0.0_f64;
286        let mut sum_residual = 0.0_f64;
287        let mut count = 0usize;
288        let mut active = 0usize;
289
290        for ((var, fac), new_msg) in new_vtf {
291            let old = messages.get_vtf(&var, &fac).cloned();
292            let residual = old
293                .as_ref()
294                .map(|o| new_msg.residual_linf(o))
295                .unwrap_or(1.0);
296            let lambda = self.config.damping.effective_lambda(residual);
297            let final_msg = if let Some(o) = &old {
298                new_msg.damp(o, lambda)
299            } else {
300                new_msg
301            };
302            max_residual = max_residual.max(residual);
303            sum_residual += residual;
304            count += 1;
305            if residual >= self.config.tolerance {
306                active += 1;
307            }
308            messages.set_vtf(var, fac, final_msg);
309        }
310
311        for ((fac, var), new_msg) in new_ftv {
312            let old = messages.get_ftv(&fac, &var).cloned();
313            let residual = old
314                .as_ref()
315                .map(|o| new_msg.residual_linf(o))
316                .unwrap_or(1.0);
317            let lambda = self.config.damping.effective_lambda(residual);
318            let final_msg = if let Some(o) = &old {
319                new_msg.damp(o, lambda)
320            } else {
321                new_msg
322            };
323            max_residual = max_residual.max(residual);
324            sum_residual += residual;
325            count += 1;
326            if residual >= self.config.tolerance {
327                active += 1;
328            }
329            messages.set_ftv(fac, var, final_msg);
330        }
331
332        let mean_residual = if count > 0 {
333            sum_residual / count as f64
334        } else {
335            0.0
336        };
337
338        LbpIterStats {
339            iteration,
340            max_residual,
341            mean_residual,
342            active_messages: active,
343        }
344    }
345}
346
347// ──────────────────────────────────────────────────────────────────────────────
348// MessagePassingAlgorithm impl (integrates with the existing trait)
349// ──────────────────────────────────────────────────────────────────────────────
350
351impl MessagePassingAlgorithm for LoopyBeliefPropagation {
352    fn run(
353        &self,
354        graph: &FactorGraph,
355    ) -> std::result::Result<HashMap<String, ArrayD<f64>>, crate::error::PgmError> {
356        let result = self.run_full(graph)?;
357        // Convert Array1 → ArrayD for compatibility with the trait.
358        let beliefs_dyn: HashMap<String, ArrayD<f64>> = result
359            .beliefs
360            .into_iter()
361            .map(|(k, v)| (k, v.into_dyn()))
362            .collect();
363        Ok(beliefs_dyn)
364    }
365
366    fn name(&self) -> &str {
367        "LoopyBeliefPropagation"
368    }
369}
370
371// ──────────────────────────────────────────────────────────────────────────────
372// Utility
373// ──────────────────────────────────────────────────────────────────────────────
374
375/// Convert a linear index to a multi-dimensional index for a given shape.
376pub(super) fn linear_to_assignment(mut lin: usize, shape: &[usize]) -> Vec<usize> {
377    let mut assignment = vec![0usize; shape.len()];
378    for (i, &dim) in shape.iter().enumerate().rev() {
379        assignment[i] = lin % dim;
380        lin /= dim;
381    }
382    assignment
383}