tensorlogic_quantrs_hooks/vmp/
messages.rs1use crate::error::{PgmError, Result};
16
17#[derive(Copy, Clone, Debug, PartialEq, Eq)]
19pub enum MessageDirection {
20 FactorToVariable,
22 VariableToFactor,
24}
25
26#[derive(Clone, Debug)]
33pub struct VmpMessage {
34 pub natural_params: Vec<f64>,
37 pub from: String,
39 pub to: String,
41 pub direction: MessageDirection,
43}
44
45impl VmpMessage {
46 pub fn zeros(from: String, to: String, direction: MessageDirection, dim: usize) -> Self {
48 Self {
49 natural_params: vec![0.0; dim],
50 from,
51 to,
52 direction,
53 }
54 }
55
56 pub fn dim(&self) -> usize {
58 self.natural_params.len()
59 }
60
61 pub fn product(a: &Self, b: &Self) -> Result<Self> {
67 if a.natural_params.len() != b.natural_params.len() {
68 return Err(PgmError::DimensionMismatch {
69 expected: vec![a.natural_params.len()],
70 got: vec![b.natural_params.len()],
71 });
72 }
73 let summed = a
74 .natural_params
75 .iter()
76 .zip(b.natural_params.iter())
77 .map(|(x, y)| x + y)
78 .collect();
79 Ok(Self {
80 natural_params: summed,
81 from: a.from.clone(),
82 to: a.to.clone(),
83 direction: a.direction,
84 })
85 }
86
87 pub fn accumulate(&mut self, rhs: &Self) -> Result<()> {
89 if self.natural_params.len() != rhs.natural_params.len() {
90 return Err(PgmError::DimensionMismatch {
91 expected: vec![self.natural_params.len()],
92 got: vec![rhs.natural_params.len()],
93 });
94 }
95 for (lhs, r) in self
96 .natural_params
97 .iter_mut()
98 .zip(rhs.natural_params.iter())
99 {
100 *lhs += *r;
101 }
102 Ok(())
103 }
104
105 pub fn linf_residual(a: &Self, b: &Self) -> Result<f64> {
108 if a.natural_params.len() != b.natural_params.len() {
109 return Err(PgmError::DimensionMismatch {
110 expected: vec![a.natural_params.len()],
111 got: vec![b.natural_params.len()],
112 });
113 }
114 let mut max = 0.0_f64;
115 for (x, y) in a.natural_params.iter().zip(b.natural_params.iter()) {
116 max = max.max((x - y).abs());
117 }
118 Ok(max)
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn zeros_has_requested_dim() {
128 let m = VmpMessage::zeros(
129 "f".to_string(),
130 "v".to_string(),
131 MessageDirection::FactorToVariable,
132 3,
133 );
134 assert_eq!(m.dim(), 3);
135 assert!(m.natural_params.iter().all(|&x| x == 0.0));
136 }
137
138 #[test]
139 fn product_is_element_wise_sum() {
140 let mut a = VmpMessage::zeros(
141 "f1".into(),
142 "v".into(),
143 MessageDirection::FactorToVariable,
144 3,
145 );
146 a.natural_params = vec![1.0, 2.0, 3.0];
147 let mut b = a.clone();
148 b.natural_params = vec![0.5, -1.0, 4.0];
149 let p = VmpMessage::product(&a, &b).expect("product");
150 assert_eq!(p.natural_params, vec![1.5, 1.0, 7.0]);
151 }
152
153 #[test]
154 fn accumulate_matches_product() {
155 let mut a = VmpMessage::zeros(
156 "f".into(),
157 "v".into(),
158 MessageDirection::FactorToVariable,
159 2,
160 );
161 a.natural_params = vec![1.0, -1.0];
162 let mut b = a.clone();
163 b.natural_params = vec![2.5, 0.5];
164 a.accumulate(&b).expect("accum");
165 assert_eq!(a.natural_params, vec![3.5, -0.5]);
166 }
167
168 #[test]
169 fn linf_residual_is_max_abs() {
170 let mut a = VmpMessage::zeros(
171 "f".into(),
172 "v".into(),
173 MessageDirection::FactorToVariable,
174 3,
175 );
176 a.natural_params = vec![1.0, 2.0, 3.0];
177 let mut b = a.clone();
178 b.natural_params = vec![1.1, 1.5, 5.0];
179 let r = VmpMessage::linf_residual(&a, &b).expect("residual");
180 assert!((r - 2.0).abs() < 1e-12);
181 }
182
183 #[test]
184 fn dimension_mismatch_is_error() {
185 let a = VmpMessage::zeros(
186 "f".into(),
187 "v".into(),
188 MessageDirection::FactorToVariable,
189 2,
190 );
191 let b = VmpMessage::zeros(
192 "f".into(),
193 "v".into(),
194 MessageDirection::FactorToVariable,
195 3,
196 );
197 assert!(VmpMessage::product(&a, &b).is_err());
198 assert!(VmpMessage::linf_residual(&a, &b).is_err());
199 }
200}