tensorlogic_quantrs_hooks/vmp/
gamma.rs1use crate::error::{PgmError, Result};
33
34use super::exponential_family::ExponentialFamily;
35use super::special::{digamma, ln_gamma};
36
37#[derive(Clone, Debug)]
44pub struct GammaNP {
45 pub alpha: f64,
47 pub beta: f64,
49}
50
51impl GammaNP {
52 pub fn new(alpha: f64, beta: f64) -> Result<Self> {
55 if !alpha.is_finite() || alpha <= 0.0 {
56 return Err(PgmError::InvalidDistribution(format!(
57 "Gamma shape α must be positive and finite (got {})",
58 alpha
59 )));
60 }
61 if !beta.is_finite() || beta <= 0.0 {
62 return Err(PgmError::InvalidDistribution(format!(
63 "Gamma rate β must be positive and finite (got {})",
64 beta
65 )));
66 }
67 Ok(Self { alpha, beta })
68 }
69
70 pub fn from_natural(natural: &[f64]) -> Result<Self> {
72 if natural.len() != 2 {
73 return Err(PgmError::DimensionMismatch {
74 expected: vec![2],
75 got: vec![natural.len()],
76 });
77 }
78 let alpha = natural[0] + 1.0;
79 let beta = -natural[1];
80 Self::new(alpha, beta)
81 }
82
83 pub fn expected_x(&self) -> f64 {
85 self.alpha / self.beta
86 }
87
88 pub fn expected_log_x(&self) -> f64 {
90 digamma(self.alpha) - self.beta.ln()
91 }
92
93 pub fn variance(&self) -> f64 {
95 self.alpha / (self.beta * self.beta)
96 }
97
98 pub fn multiply_naturals(&self, other: &GammaNP) -> Result<GammaNP> {
105 let alpha = self.alpha + other.alpha - 1.0;
106 let beta = self.beta + other.beta;
107 GammaNP::new(alpha, beta)
108 }
109
110 pub fn kl_to(&self, other: &GammaNP) -> f64 {
119 let ap = self.alpha;
120 let bp = self.beta;
121 let aq = other.alpha;
122 let bq = other.beta;
123 (ap - aq) * digamma(ap) - ln_gamma(ap)
124 + ln_gamma(aq)
125 + aq * (bp.ln() - bq.ln())
126 + ap * (bq - bp) / bp
127 }
128}
129
130impl ExponentialFamily for GammaNP {
131 fn family_name(&self) -> &'static str {
132 "Gamma"
133 }
134
135 fn natural_dim(&self) -> usize {
136 2
137 }
138
139 fn natural_params(&self) -> Vec<f64> {
140 vec![self.alpha - 1.0, -self.beta]
141 }
142
143 fn set_natural(&mut self, new_eta: &[f64]) -> Result<()> {
144 if new_eta.len() != 2 {
145 return Err(PgmError::DimensionMismatch {
146 expected: vec![2],
147 got: vec![new_eta.len()],
148 });
149 }
150 for &v in new_eta {
151 if !v.is_finite() {
152 return Err(PgmError::InvalidDistribution(
153 "Gamma natural parameter must be finite".to_string(),
154 ));
155 }
156 }
157 let alpha = new_eta[0] + 1.0;
158 let beta = -new_eta[1];
159 if alpha <= 0.0 {
160 return Err(PgmError::InvalidDistribution(format!(
161 "Gamma shape must stay positive (η₁ + 1 = {} ≤ 0)",
162 alpha
163 )));
164 }
165 if beta <= 0.0 {
166 return Err(PgmError::InvalidDistribution(format!(
167 "Gamma rate must stay positive (−η₂ = {} ≤ 0)",
168 beta
169 )));
170 }
171 self.alpha = alpha;
172 self.beta = beta;
173 Ok(())
174 }
175
176 fn sufficient_statistics(&self, value: f64) -> Vec<f64> {
177 if value > 0.0 {
181 vec![value.ln(), value]
182 } else {
183 vec![f64::NEG_INFINITY, value]
184 }
185 }
186
187 fn log_partition(&self, natural_params: &[f64]) -> Result<f64> {
188 if natural_params.len() != 2 {
189 return Err(PgmError::DimensionMismatch {
190 expected: vec![2],
191 got: vec![natural_params.len()],
192 });
193 }
194 let alpha = natural_params[0] + 1.0;
195 let neg_beta = natural_params[1];
196 if alpha <= 0.0 || neg_beta >= 0.0 {
197 return Err(PgmError::InvalidDistribution(format!(
198 "Gamma log_partition: α = {} must be positive and −β = {} negative",
199 alpha, neg_beta
200 )));
201 }
202 let beta = -neg_beta;
204 Ok(ln_gamma(alpha) - alpha * beta.ln())
205 }
206
207 fn expected_sufficient_statistics(&self) -> Vec<f64> {
208 vec![self.expected_log_x(), self.expected_x()]
210 }
211}
212
213pub fn posterior_from_prior_and_observations(
223 prior: &GammaNP,
224 observations: &[u64],
225) -> Result<GammaNP> {
226 let n = observations.len() as f64;
227 let sum: u64 = observations.iter().sum();
228 let posterior_alpha = prior.alpha + sum as f64;
229 let posterior_beta = prior.beta + n;
230 GammaNP::new(posterior_alpha, posterior_beta)
231}
232
233#[derive(Clone, Debug)]
243pub struct GammaPoissonObservation {
244 pub rate_variable: String,
246 pub observations: Vec<u64>,
248}
249
250impl GammaPoissonObservation {
251 pub fn new(rate_variable: impl Into<String>, observations: Vec<u64>) -> Self {
253 Self {
254 rate_variable: rate_variable.into(),
255 observations,
256 }
257 }
258
259 pub fn count_sum(&self) -> u64 {
261 self.observations.iter().sum()
262 }
263
264 pub fn num_observations(&self) -> usize {
266 self.observations.len()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::vmp::special::{digamma, ln_gamma};
274
275 #[test]
276 fn gamma_expected_x_matches_alpha_over_beta() {
277 for &(alpha, beta) in &[(1.0_f64, 1.0_f64), (2.0, 0.5), (3.7, 4.2), (0.25, 10.0)] {
278 let g = GammaNP::new(alpha, beta).expect("ctor");
279 let ex = g.expected_x();
280 assert!(
281 (ex - alpha / beta).abs() < 1e-12,
282 "E[x] = {} but α/β = {}",
283 ex,
284 alpha / beta
285 );
286 }
287 }
288
289 #[test]
290 fn gamma_expected_log_x_matches_digamma_minus_lnbeta() {
291 for &(alpha, beta) in &[(1.0_f64, 1.0_f64), (2.5, 0.5), (4.0, 2.0)] {
292 let g = GammaNP::new(alpha, beta).expect("ctor");
293 let el = g.expected_log_x();
294 let expected = digamma(alpha) - beta.ln();
295 assert!(
296 (el - expected).abs() < 1e-12,
297 "E[log x] = {}, expected ψ(α)−ln β = {}",
298 el,
299 expected
300 );
301 }
302 }
303
304 #[test]
305 fn gamma_multiply_naturals_sums_natural_params() {
306 let a = GammaNP::new(2.0, 1.0).expect("ctor a");
309 let b = GammaNP::new(3.0, 2.0).expect("ctor b");
310 let p = a.multiply_naturals(&b).expect("product");
311 assert!((p.alpha - 4.0).abs() < 1e-12, "α = {}", p.alpha);
312 assert!((p.beta - 3.0).abs() < 1e-12, "β = {}", p.beta);
313 let eta_a = a.natural_params();
315 let eta_b = b.natural_params();
316 let eta_sum: Vec<f64> = eta_a.iter().zip(eta_b.iter()).map(|(x, y)| x + y).collect();
317 let p2 = GammaNP::from_natural(&eta_sum).expect("from nat");
318 assert!((p2.alpha - p.alpha).abs() < 1e-12);
319 assert!((p2.beta - p.beta).abs() < 1e-12);
320 }
321
322 #[test]
323 fn gamma_kl_is_zero_for_self_positive_otherwise() {
324 let g = GammaNP::new(3.0, 2.0).expect("ctor");
325 let self_kl = g.kl_to(&g);
326 assert!(self_kl.abs() < 1e-10, "KL(g||g) = {}", self_kl);
327
328 let other = GammaNP::new(1.5, 4.0).expect("ctor other");
329 let kl = g.kl_to(&other);
330 assert!(kl > 0.0, "KL(g||other) should be positive, got {}", kl);
331
332 let kl_rev = other.kl_to(&g);
334 assert!(
335 kl_rev > 0.0,
336 "KL(other||g) should be positive, got {}",
337 kl_rev
338 );
339 }
340
341 #[test]
342 fn gamma_poisson_posterior_adds_sum_and_count() {
343 let prior = GammaNP::new(1.0, 1.0).expect("prior");
344 let obs: [u64; 3] = [3, 5, 2];
345 let post = posterior_from_prior_and_observations(&prior, &obs).expect("posterior");
346 assert!((post.alpha - 11.0).abs() < 1e-12, "α = {}", post.alpha);
348 assert!((post.beta - 4.0).abs() < 1e-12, "β = {}", post.beta);
349 }
350
351 #[test]
352 fn gamma_log_partition_matches_closed_form() {
353 let g = GammaNP::new(2.5, 3.0).expect("ctor");
355 let eta = g.natural_params();
356 let a = g.log_partition(&eta).expect("lp");
357 let expected = ln_gamma(2.5) - 2.5 * 3.0_f64.ln();
358 assert!(
359 (a - expected).abs() < 1e-12,
360 "A(η) = {}, expected {}",
361 a,
362 expected
363 );
364
365 let h = 1e-6;
369 let a_plus_1 = g.log_partition(&[eta[0] + h, eta[1]]).expect("lp+1");
370 let a_minus_1 = g.log_partition(&[eta[0] - h, eta[1]]).expect("lp-1");
371 let d1 = (a_plus_1 - a_minus_1) / (2.0 * h);
372 let a_plus_2 = g.log_partition(&[eta[0], eta[1] + h]).expect("lp+2");
373 let a_minus_2 = g.log_partition(&[eta[0], eta[1] - h]).expect("lp-2");
374 let d2 = (a_plus_2 - a_minus_2) / (2.0 * h);
375 assert!(
376 (d1 - g.expected_log_x()).abs() < 1e-5,
377 "dA/dη1 = {}, expected {}",
378 d1,
379 g.expected_log_x()
380 );
381 assert!(
382 (d2 - g.expected_x()).abs() < 1e-5,
383 "dA/dη2 = {}, expected {}",
384 d2,
385 g.expected_x()
386 );
387 }
388
389 #[test]
390 fn gamma_natural_round_trip() {
391 let g = GammaNP::new(4.5, 2.25).expect("ctor");
392 let eta = g.natural_params();
393 let back = GammaNP::from_natural(&eta).expect("round trip");
394 assert!((back.alpha - 4.5).abs() < 1e-12);
395 assert!((back.beta - 2.25).abs() < 1e-12);
396 }
397
398 #[test]
399 fn gamma_set_natural_rejects_negative_alpha() {
400 let mut g = GammaNP::new(2.0, 1.0).expect("ctor");
401 let err = g.set_natural(&[-1.5, -1.0]); assert!(err.is_err());
403 let err = g.set_natural(&[0.5, 1.0]); assert!(err.is_err());
405 let err = g.set_natural(&[0.5, -1.0]); assert!(err.is_ok());
407 }
408}