tensorlogic_quantrs_hooks/loopy_bp/
engine.rs1use scirs2_core::ndarray::{Array1, ArrayD};
5use std::collections::HashMap;
6
7use crate::error::{PgmError, Result};
8use crate::graph::FactorGraph;
9use crate::message_passing::MessagePassingAlgorithm;
10
11use super::config::{LoopyBpConfig, LoopyBpResult};
12use super::cycle::CycleDetector;
13use super::energy::bethe_free_energy;
14use super::types::{LbpConvergenceMonitor, LbpIterStats, LogMessage, UpdateSchedule};
15
16#[derive(Clone, Debug, Default)]
22pub(super) struct LogMessageStore {
23 pub(super) var_to_factor: HashMap<(String, String), LogMessage>,
25 pub(super) factor_to_var: HashMap<(String, String), LogMessage>,
27}
28
29impl LogMessageStore {
30 pub(super) fn new() -> Self {
31 Self::default()
32 }
33
34 pub(super) fn get_vtf(&self, var: &str, fac: &str) -> Option<&LogMessage> {
35 self.var_to_factor.get(&(var.to_string(), fac.to_string()))
36 }
37
38 pub(super) fn set_vtf(&mut self, var: String, fac: String, msg: LogMessage) {
39 self.var_to_factor.insert((var, fac), msg);
40 }
41
42 pub(super) fn get_ftv(&self, fac: &str, var: &str) -> Option<&LogMessage> {
43 self.factor_to_var.get(&(fac.to_string(), var.to_string()))
44 }
45
46 pub(super) fn set_ftv(&mut self, fac: String, var: String, msg: LogMessage) {
47 self.factor_to_var.insert((fac, var), msg);
48 }
49}
50
51pub struct LoopyBeliefPropagation {
75 pub config: LoopyBpConfig,
77}
78
79impl LoopyBeliefPropagation {
80 pub fn new(config: LoopyBpConfig) -> Self {
82 Self { config }
83 }
84
85 pub fn run_full(&self, graph: &FactorGraph) -> Result<LoopyBpResult> {
87 let cycle_analysis = CycleDetector::new(graph).analyse();
89
90 let mut messages = self.initialise_messages(graph);
92
93 let mut monitor = LbpConvergenceMonitor::new();
94
95 match self.config.schedule {
96 UpdateSchedule::Synchronous => {
97 self.run_synchronous(graph, &mut messages, &mut monitor)?;
98 }
99 UpdateSchedule::Sequential => {
100 self.run_sequential(graph, &mut messages, &mut monitor)?;
101 }
102 UpdateSchedule::Residual => {
103 self.run_residual(graph, &mut messages, &mut monitor)?;
104 }
105 }
106
107 let beliefs = self.compute_variable_beliefs(graph, &messages)?;
109 let factor_beliefs = self.compute_factor_beliefs(graph, &messages)?;
110
111 let bethe = if self.config.compute_bethe {
113 Some(bethe_free_energy(graph, &beliefs, &factor_beliefs))
114 } else {
115 None
116 };
117
118 Ok(LoopyBpResult {
119 beliefs,
120 factor_beliefs,
121 convergence: monitor,
122 bethe,
123 cycle_analysis,
124 })
125 }
126
127 pub(super) fn initialise_messages(&self, graph: &FactorGraph) -> LogMessageStore {
130 let mut store = LogMessageStore::new();
131
132 for var_name in graph.variable_names() {
133 let card = graph
134 .get_variable(var_name)
135 .map(|v| v.cardinality)
136 .unwrap_or(2);
137
138 if let Some(fac_ids) = graph.get_adjacent_factors(var_name) {
139 for fac_id in fac_ids {
140 store.set_vtf(
142 var_name.clone(),
143 fac_id.clone(),
144 LogMessage::uniform(var_name, card),
145 );
146
147 let ftv_msg = if let Some(factor) = graph.get_factor(fac_id) {
149 let marginal = self.marginalise_factor_to_var(factor, var_name);
151 marginal.unwrap_or_else(|_| LogMessage::uniform(var_name, card))
152 } else {
153 LogMessage::uniform(var_name, card)
154 };
155 store.set_ftv(fac_id.clone(), var_name.clone(), ftv_msg);
156 }
157 }
158 }
159
160 store
161 }
162
163 pub(super) fn compute_vtf_message(
168 &self,
169 graph: &FactorGraph,
170 messages: &LogMessageStore,
171 var: &str,
172 target_fac: &str,
173 ) -> Result<LogMessage> {
174 let card = graph
175 .get_variable(var)
176 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
177 .cardinality;
178
179 let mut log_msg = Array1::<f64>::zeros(card);
180
181 if let Some(fac_ids) = graph.get_adjacent_factors(var) {
182 for fac_id in fac_ids {
183 if fac_id == target_fac {
184 continue;
185 }
186 if let Some(ftv) = messages.get_ftv(fac_id, var) {
187 log_msg += &ftv.log_values;
189 }
190 }
191 }
192
193 let mut msg = LogMessage {
194 variable: var.to_string(),
195 log_values: log_msg,
196 };
197 msg.log_normalise();
198 Ok(msg)
199 }
200
201 pub(super) fn compute_ftv_message(
204 &self,
205 graph: &FactorGraph,
206 messages: &LogMessageStore,
207 fac_id: &str,
208 target_var: &str,
209 ) -> Result<LogMessage> {
210 let factor = graph
211 .get_factor(fac_id)
212 .ok_or_else(|| PgmError::FactorNotFound(fac_id.to_string()))?;
213
214 let target_idx = factor
215 .variables
216 .iter()
217 .position(|v| v == target_var)
218 .ok_or_else(|| {
219 PgmError::VariableNotFound(format!(
220 "Variable '{}' not in factor '{}'",
221 target_var, fac_id
222 ))
223 })?;
224
225 let target_card = factor.values.shape()[target_idx];
226
227 let total_size: usize = factor.values.shape().iter().product();
230 let mut log_joint = Vec::with_capacity(total_size);
231
232 for lin_idx in 0..total_size {
233 let assignment = linear_to_assignment(lin_idx, factor.values.shape());
234 let mut log_val = {
235 let phi = factor.values[assignment.as_slice()];
236 if phi > 1e-300 {
237 phi.ln()
238 } else {
239 -700.0
240 }
241 };
242 for (dim, var_name) in factor.variables.iter().enumerate() {
244 if var_name == target_var {
245 continue;
246 }
247 if let Some(vtf) = messages.get_vtf(var_name, fac_id) {
248 let val_idx = assignment[dim];
249 let lv = vtf.log_values.get(val_idx).copied().unwrap_or(-700.0);
250 log_val += lv;
251 }
252 }
253 log_joint.push((assignment[target_idx], log_val));
254 }
255
256 let mut result = vec![f64::NEG_INFINITY; target_card];
258 for (t_val, lv) in log_joint {
259 let cur = result[t_val];
261 if cur == f64::NEG_INFINITY {
262 result[t_val] = lv;
263 } else {
264 let m = cur.max(lv);
265 result[t_val] = m + ((cur - m).exp() + (lv - m).exp()).ln();
266 }
267 }
268
269 let mut msg = LogMessage {
270 variable: target_var.to_string(),
271 log_values: Array1::from(result),
272 };
273 msg.log_normalise();
274 Ok(msg)
275 }
276
277 pub(super) fn apply_updates_and_track(
279 &self,
280 messages: &mut LogMessageStore,
281 new_vtf: HashMap<(String, String), LogMessage>,
282 new_ftv: HashMap<(String, String), LogMessage>,
283 iteration: usize,
284 ) -> LbpIterStats {
285 let mut max_residual = 0.0_f64;
286 let mut sum_residual = 0.0_f64;
287 let mut count = 0usize;
288 let mut active = 0usize;
289
290 for ((var, fac), new_msg) in new_vtf {
291 let old = messages.get_vtf(&var, &fac).cloned();
292 let residual = old
293 .as_ref()
294 .map(|o| new_msg.residual_linf(o))
295 .unwrap_or(1.0);
296 let lambda = self.config.damping.effective_lambda(residual);
297 let final_msg = if let Some(o) = &old {
298 new_msg.damp(o, lambda)
299 } else {
300 new_msg
301 };
302 max_residual = max_residual.max(residual);
303 sum_residual += residual;
304 count += 1;
305 if residual >= self.config.tolerance {
306 active += 1;
307 }
308 messages.set_vtf(var, fac, final_msg);
309 }
310
311 for ((fac, var), new_msg) in new_ftv {
312 let old = messages.get_ftv(&fac, &var).cloned();
313 let residual = old
314 .as_ref()
315 .map(|o| new_msg.residual_linf(o))
316 .unwrap_or(1.0);
317 let lambda = self.config.damping.effective_lambda(residual);
318 let final_msg = if let Some(o) = &old {
319 new_msg.damp(o, lambda)
320 } else {
321 new_msg
322 };
323 max_residual = max_residual.max(residual);
324 sum_residual += residual;
325 count += 1;
326 if residual >= self.config.tolerance {
327 active += 1;
328 }
329 messages.set_ftv(fac, var, final_msg);
330 }
331
332 let mean_residual = if count > 0 {
333 sum_residual / count as f64
334 } else {
335 0.0
336 };
337
338 LbpIterStats {
339 iteration,
340 max_residual,
341 mean_residual,
342 active_messages: active,
343 }
344 }
345}
346
347impl MessagePassingAlgorithm for LoopyBeliefPropagation {
352 fn run(
353 &self,
354 graph: &FactorGraph,
355 ) -> std::result::Result<HashMap<String, ArrayD<f64>>, crate::error::PgmError> {
356 let result = self.run_full(graph)?;
357 let beliefs_dyn: HashMap<String, ArrayD<f64>> = result
359 .beliefs
360 .into_iter()
361 .map(|(k, v)| (k, v.into_dyn()))
362 .collect();
363 Ok(beliefs_dyn)
364 }
365
366 fn name(&self) -> &str {
367 "LoopyBeliefPropagation"
368 }
369}
370
371pub(super) fn linear_to_assignment(mut lin: usize, shape: &[usize]) -> Vec<usize> {
377 let mut assignment = vec![0usize; shape.len()];
378 for (i, &dim) in shape.iter().enumerate().rev() {
379 assignment[i] = lin % dim;
380 lin /= dim;
381 }
382 assignment
383}