Skip to main content

tensorlogic_quantrs_hooks/vmp/
messages.rs

1//! Natural-parameter messages for Variational Message Passing.
2//!
3//! In VMP every message carries a vector of *natural parameters* that can be
4//! summed element-wise: the product of two exponential-family densities that
5//! share the same sufficient statistics is another density in the same family
6//! with natural parameters equal to the sum of the two input vectors. That is
7//! the basic arithmetic this module exposes — it deliberately does not carry
8//! probabilities, because everything in VMP happens in log / natural space.
9//!
10//! A message is tagged with a direction so that the engine can distinguish
11//! factor→variable messages (information flowing into a variable's posterior)
12//! from variable→factor messages (sufficient statistics needed by the factor
13//! to emit its own factor→variable updates).
14
15use crate::error::{PgmError, Result};
16
17/// Direction of a VMP message.
18#[derive(Copy, Clone, Debug, PartialEq, Eq)]
19pub enum MessageDirection {
20    /// A factor sending an update to one of its variable neighbours.
21    FactorToVariable,
22    /// A variable sending its expected sufficient statistics to an adjacent factor.
23    VariableToFactor,
24}
25
26/// A VMP message.
27///
28/// The `natural_params` vector has the dimensionality of the *receiving
29/// variable's* exponential family in the factor→variable direction and the
30/// dimensionality of the *sender's* sufficient statistics in the opposite
31/// direction. The engine is responsible for maintaining this invariant.
32#[derive(Clone, Debug)]
33pub struct VmpMessage {
34    /// Natural-parameter vector η (or its analogue for variable→factor messages,
35    /// which carry expected sufficient statistics).
36    pub natural_params: Vec<f64>,
37    /// Sender identifier (factor id or variable name).
38    pub from: String,
39    /// Receiver identifier.
40    pub to: String,
41    /// Direction (factor→variable or variable→factor).
42    pub direction: MessageDirection,
43}
44
45impl VmpMessage {
46    /// Zero-message in the given direction and dimensionality.
47    pub fn zeros(from: String, to: String, direction: MessageDirection, dim: usize) -> Self {
48        Self {
49            natural_params: vec![0.0; dim],
50            from,
51            to,
52            direction,
53        }
54    }
55
56    /// Dimensionality of the message.
57    pub fn dim(&self) -> usize {
58        self.natural_params.len()
59    }
60
61    /// Sum two messages element-wise, producing a third (natural parameters
62    /// add under the product-of-densities rule).
63    ///
64    /// Requires identical direction / endpoints / dimensionality — otherwise
65    /// returns a dimension-mismatch error.
66    pub fn product(a: &Self, b: &Self) -> Result<Self> {
67        if a.natural_params.len() != b.natural_params.len() {
68            return Err(PgmError::DimensionMismatch {
69                expected: vec![a.natural_params.len()],
70                got: vec![b.natural_params.len()],
71            });
72        }
73        let summed = a
74            .natural_params
75            .iter()
76            .zip(b.natural_params.iter())
77            .map(|(x, y)| x + y)
78            .collect();
79        Ok(Self {
80            natural_params: summed,
81            from: a.from.clone(),
82            to: a.to.clone(),
83            direction: a.direction,
84        })
85    }
86
87    /// Add `rhs` into `self` in place, producing the natural-parameter sum.
88    pub fn accumulate(&mut self, rhs: &Self) -> Result<()> {
89        if self.natural_params.len() != rhs.natural_params.len() {
90            return Err(PgmError::DimensionMismatch {
91                expected: vec![self.natural_params.len()],
92                got: vec![rhs.natural_params.len()],
93            });
94        }
95        for (lhs, r) in self
96            .natural_params
97            .iter_mut()
98            .zip(rhs.natural_params.iter())
99        {
100            *lhs += *r;
101        }
102        Ok(())
103    }
104
105    /// L∞ residual between two messages of identical shape. Useful for
106    /// convergence monitoring in the engine.
107    pub fn linf_residual(a: &Self, b: &Self) -> Result<f64> {
108        if a.natural_params.len() != b.natural_params.len() {
109            return Err(PgmError::DimensionMismatch {
110                expected: vec![a.natural_params.len()],
111                got: vec![b.natural_params.len()],
112            });
113        }
114        let mut max = 0.0_f64;
115        for (x, y) in a.natural_params.iter().zip(b.natural_params.iter()) {
116            max = max.max((x - y).abs());
117        }
118        Ok(max)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn zeros_has_requested_dim() {
128        let m = VmpMessage::zeros(
129            "f".to_string(),
130            "v".to_string(),
131            MessageDirection::FactorToVariable,
132            3,
133        );
134        assert_eq!(m.dim(), 3);
135        assert!(m.natural_params.iter().all(|&x| x == 0.0));
136    }
137
138    #[test]
139    fn product_is_element_wise_sum() {
140        let mut a = VmpMessage::zeros(
141            "f1".into(),
142            "v".into(),
143            MessageDirection::FactorToVariable,
144            3,
145        );
146        a.natural_params = vec![1.0, 2.0, 3.0];
147        let mut b = a.clone();
148        b.natural_params = vec![0.5, -1.0, 4.0];
149        let p = VmpMessage::product(&a, &b).expect("product");
150        assert_eq!(p.natural_params, vec![1.5, 1.0, 7.0]);
151    }
152
153    #[test]
154    fn accumulate_matches_product() {
155        let mut a = VmpMessage::zeros(
156            "f".into(),
157            "v".into(),
158            MessageDirection::FactorToVariable,
159            2,
160        );
161        a.natural_params = vec![1.0, -1.0];
162        let mut b = a.clone();
163        b.natural_params = vec![2.5, 0.5];
164        a.accumulate(&b).expect("accum");
165        assert_eq!(a.natural_params, vec![3.5, -0.5]);
166    }
167
168    #[test]
169    fn linf_residual_is_max_abs() {
170        let mut a = VmpMessage::zeros(
171            "f".into(),
172            "v".into(),
173            MessageDirection::FactorToVariable,
174            3,
175        );
176        a.natural_params = vec![1.0, 2.0, 3.0];
177        let mut b = a.clone();
178        b.natural_params = vec![1.1, 1.5, 5.0];
179        let r = VmpMessage::linf_residual(&a, &b).expect("residual");
180        assert!((r - 2.0).abs() < 1e-12);
181    }
182
183    #[test]
184    fn dimension_mismatch_is_error() {
185        let a = VmpMessage::zeros(
186            "f".into(),
187            "v".into(),
188            MessageDirection::FactorToVariable,
189            2,
190        );
191        let b = VmpMessage::zeros(
192            "f".into(),
193            "v".into(),
194            MessageDirection::FactorToVariable,
195            3,
196        );
197        assert!(VmpMessage::product(&a, &b).is_err());
198        assert!(VmpMessage::linf_residual(&a, &b).is_err());
199    }
200}