tensorlogic_quantrs_hooks/vmp/
beta.rs1use crate::error::{PgmError, Result};
32
33use super::exponential_family::ExponentialFamily;
34use super::special::{digamma, ln_gamma};
35
36#[derive(Clone, Debug)]
43pub struct BetaNP {
44 pub alpha: f64,
46 pub beta: f64,
48}
49
50impl BetaNP {
51 pub fn new(alpha: f64, beta: f64) -> Result<Self> {
54 if !alpha.is_finite() || alpha <= 0.0 {
55 return Err(PgmError::InvalidDistribution(format!(
56 "Beta shape α must be positive and finite (got {})",
57 alpha
58 )));
59 }
60 if !beta.is_finite() || beta <= 0.0 {
61 return Err(PgmError::InvalidDistribution(format!(
62 "Beta shape β must be positive and finite (got {})",
63 beta
64 )));
65 }
66 Ok(Self { alpha, beta })
67 }
68
69 pub fn from_natural(natural: &[f64]) -> Result<Self> {
71 if natural.len() != 2 {
72 return Err(PgmError::DimensionMismatch {
73 expected: vec![2],
74 got: vec![natural.len()],
75 });
76 }
77 let alpha = natural[0] + 1.0;
78 let beta = natural[1] + 1.0;
79 Self::new(alpha, beta)
80 }
81
82 pub fn expected_x(&self) -> f64 {
84 self.alpha / (self.alpha + self.beta)
85 }
86
87 pub fn expected_log_x(&self) -> f64 {
89 digamma(self.alpha) - digamma(self.alpha + self.beta)
90 }
91
92 pub fn expected_log_1mx(&self) -> f64 {
94 digamma(self.beta) - digamma(self.alpha + self.beta)
95 }
96
97 pub fn variance(&self) -> f64 {
99 let ab = self.alpha + self.beta;
100 self.alpha * self.beta / (ab * ab * (ab + 1.0))
101 }
102
103 pub fn multiply_naturals(&self, other: &BetaNP) -> Result<BetaNP> {
110 let alpha = self.alpha + other.alpha - 1.0;
111 let beta = self.beta + other.beta - 1.0;
112 BetaNP::new(alpha, beta)
113 }
114
115 pub fn kl_to(&self, other: &BetaNP) -> f64 {
128 let ap = self.alpha;
129 let bp = self.beta;
130 let aq = other.alpha;
131 let bq = other.beta;
132 let ln_beta_p = ln_gamma(ap) + ln_gamma(bp) - ln_gamma(ap + bp);
133 let ln_beta_q = ln_gamma(aq) + ln_gamma(bq) - ln_gamma(aq + bq);
134 let psi_ap = digamma(ap);
135 let psi_bp = digamma(bp);
136 let psi_abp = digamma(ap + bp);
137 ln_beta_q - ln_beta_p
138 + (ap - aq) * psi_ap
139 + (bp - bq) * psi_bp
140 + (aq - ap + bq - bp) * psi_abp
141 }
142}
143
144impl ExponentialFamily for BetaNP {
145 fn family_name(&self) -> &'static str {
146 "Beta"
147 }
148
149 fn natural_dim(&self) -> usize {
150 2
151 }
152
153 fn natural_params(&self) -> Vec<f64> {
154 vec![self.alpha - 1.0, self.beta - 1.0]
155 }
156
157 fn set_natural(&mut self, new_eta: &[f64]) -> Result<()> {
158 if new_eta.len() != 2 {
159 return Err(PgmError::DimensionMismatch {
160 expected: vec![2],
161 got: vec![new_eta.len()],
162 });
163 }
164 for &v in new_eta {
165 if !v.is_finite() {
166 return Err(PgmError::InvalidDistribution(
167 "Beta natural parameter must be finite".to_string(),
168 ));
169 }
170 }
171 let alpha = new_eta[0] + 1.0;
172 let beta = new_eta[1] + 1.0;
173 if alpha <= 0.0 {
174 return Err(PgmError::InvalidDistribution(format!(
175 "Beta shape α must stay positive (η₁ + 1 = {} ≤ 0)",
176 alpha
177 )));
178 }
179 if beta <= 0.0 {
180 return Err(PgmError::InvalidDistribution(format!(
181 "Beta shape β must stay positive (η₂ + 1 = {} ≤ 0)",
182 beta
183 )));
184 }
185 self.alpha = alpha;
186 self.beta = beta;
187 Ok(())
188 }
189
190 fn sufficient_statistics(&self, value: f64) -> Vec<f64> {
191 if value > 0.0 && value < 1.0 {
195 vec![value.ln(), (1.0 - value).ln()]
196 } else {
197 vec![f64::NEG_INFINITY, f64::NEG_INFINITY]
198 }
199 }
200
201 fn log_partition(&self, natural_params: &[f64]) -> Result<f64> {
202 if natural_params.len() != 2 {
203 return Err(PgmError::DimensionMismatch {
204 expected: vec![2],
205 got: vec![natural_params.len()],
206 });
207 }
208 let alpha = natural_params[0] + 1.0;
209 let beta = natural_params[1] + 1.0;
210 if alpha <= 0.0 || beta <= 0.0 {
211 return Err(PgmError::InvalidDistribution(format!(
212 "Beta log_partition: α = {} and β = {} must both be positive",
213 alpha, beta
214 )));
215 }
216 Ok(ln_gamma(alpha) + ln_gamma(beta) - ln_gamma(alpha + beta))
218 }
219
220 fn expected_sufficient_statistics(&self) -> Vec<f64> {
221 vec![self.expected_log_x(), self.expected_log_1mx()]
223 }
224}
225
226pub fn posterior_from_prior_and_observations(
236 prior: &BetaNP,
237 successes: u64,
238 failures: u64,
239) -> Result<BetaNP> {
240 let posterior_alpha = prior.alpha + successes as f64;
241 let posterior_beta = prior.beta + failures as f64;
242 BetaNP::new(posterior_alpha, posterior_beta)
243}
244
245#[derive(Clone, Debug)]
255pub struct BetaBernoulliObservation {
256 pub probability_variable: String,
258 pub observations: Vec<bool>,
260}
261
262impl BetaBernoulliObservation {
263 pub fn new(probability_variable: impl Into<String>, observations: Vec<bool>) -> Self {
265 Self {
266 probability_variable: probability_variable.into(),
267 observations,
268 }
269 }
270
271 pub fn from_counts(
274 probability_variable: impl Into<String>,
275 successes: u64,
276 failures: u64,
277 ) -> Self {
278 let mut observations = Vec::with_capacity((successes + failures) as usize);
279 observations.extend(std::iter::repeat_n(true, successes as usize));
280 observations.extend(std::iter::repeat_n(false, failures as usize));
281 Self {
282 probability_variable: probability_variable.into(),
283 observations,
284 }
285 }
286
287 pub fn num_successes(&self) -> u64 {
289 self.observations.iter().filter(|b| **b).count() as u64
290 }
291
292 pub fn num_failures(&self) -> u64 {
294 self.observations.iter().filter(|b| !**b).count() as u64
295 }
296
297 pub fn num_observations(&self) -> usize {
299 self.observations.len()
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use crate::vmp::special::{digamma, ln_gamma};
307
308 #[test]
309 fn beta_expected_x_matches_alpha_over_total() {
310 for &(alpha, beta) in &[(1.0_f64, 1.0_f64), (2.0, 3.0), (4.5, 0.5), (0.25, 10.0)] {
311 let b = BetaNP::new(alpha, beta).expect("ctor");
312 let ex = b.expected_x();
313 let expected = alpha / (alpha + beta);
314 assert!(
315 (ex - expected).abs() < 1e-12,
316 "E[x] = {} but α/(α+β) = {}",
317 ex,
318 expected
319 );
320 }
321 }
322
323 #[test]
324 fn beta_expected_log_x_and_1mx_match_digamma() {
325 for &(alpha, beta) in &[(1.0_f64, 1.0_f64), (2.5, 1.5), (4.0, 2.0)] {
326 let b = BetaNP::new(alpha, beta).expect("ctor");
327 let el_x = b.expected_log_x();
328 let el_1mx = b.expected_log_1mx();
329 let expected_log_x = digamma(alpha) - digamma(alpha + beta);
330 let expected_log_1mx = digamma(beta) - digamma(alpha + beta);
331 assert!((el_x - expected_log_x).abs() < 1e-12);
332 assert!((el_1mx - expected_log_1mx).abs() < 1e-12);
333 }
334 }
335
336 #[test]
337 fn beta_multiply_naturals_sums_natural_params() {
338 let a = BetaNP::new(2.0, 3.0).expect("ctor a");
341 let c = BetaNP::new(3.0, 5.0).expect("ctor b");
342 let p = a.multiply_naturals(&c).expect("product");
343 assert!((p.alpha - 4.0).abs() < 1e-12, "α = {}", p.alpha);
344 assert!((p.beta - 7.0).abs() < 1e-12, "β = {}", p.beta);
345 let eta_a = a.natural_params();
347 let eta_c = c.natural_params();
348 let eta_sum: Vec<f64> = eta_a.iter().zip(eta_c.iter()).map(|(x, y)| x + y).collect();
349 let p2 = BetaNP::from_natural(&eta_sum).expect("from nat");
350 assert!((p2.alpha - p.alpha).abs() < 1e-12);
351 assert!((p2.beta - p.beta).abs() < 1e-12);
352 }
353
354 #[test]
355 fn beta_kl_is_zero_for_self_positive_otherwise() {
356 let b = BetaNP::new(3.0, 2.0).expect("ctor");
357 let self_kl = b.kl_to(&b);
358 assert!(self_kl.abs() < 1e-10, "KL(b||b) = {}", self_kl);
359
360 let other = BetaNP::new(1.5, 4.0).expect("ctor other");
361 let kl = b.kl_to(&other);
362 assert!(kl > 0.0, "KL(b||other) should be positive, got {}", kl);
363
364 let kl_rev = other.kl_to(&b);
365 assert!(
366 kl_rev > 0.0,
367 "KL(other||b) should be positive, got {}",
368 kl_rev
369 );
370 }
371
372 #[test]
373 fn beta_bernoulli_posterior_adds_counts() {
374 let prior = BetaNP::new(1.0, 1.0).expect("prior");
376 let post = posterior_from_prior_and_observations(&prior, 7, 3).expect("posterior");
377 assert!((post.alpha - 8.0).abs() < 1e-12, "α = {}", post.alpha);
378 assert!((post.beta - 4.0).abs() < 1e-12, "β = {}", post.beta);
379 }
380
381 #[test]
382 fn beta_log_partition_matches_closed_form() {
383 let b = BetaNP::new(2.5, 3.0).expect("ctor");
385 let eta = b.natural_params();
386 let a = b.log_partition(&eta).expect("lp");
387 let expected = ln_gamma(2.5) + ln_gamma(3.0) - ln_gamma(5.5);
388 assert!(
389 (a - expected).abs() < 1e-12,
390 "A(η) = {}, expected {}",
391 a,
392 expected
393 );
394
395 let h = 1e-6;
397 let a_plus_1 = b.log_partition(&[eta[0] + h, eta[1]]).expect("lp+1");
398 let a_minus_1 = b.log_partition(&[eta[0] - h, eta[1]]).expect("lp-1");
399 let d1 = (a_plus_1 - a_minus_1) / (2.0 * h);
400 let a_plus_2 = b.log_partition(&[eta[0], eta[1] + h]).expect("lp+2");
401 let a_minus_2 = b.log_partition(&[eta[0], eta[1] - h]).expect("lp-2");
402 let d2 = (a_plus_2 - a_minus_2) / (2.0 * h);
403 assert!(
404 (d1 - b.expected_log_x()).abs() < 1e-5,
405 "dA/dη1 = {}, expected {}",
406 d1,
407 b.expected_log_x()
408 );
409 assert!(
410 (d2 - b.expected_log_1mx()).abs() < 1e-5,
411 "dA/dη2 = {}, expected {}",
412 d2,
413 b.expected_log_1mx()
414 );
415 }
416
417 #[test]
418 fn beta_natural_round_trip() {
419 let b = BetaNP::new(4.5, 2.25).expect("ctor");
420 let eta = b.natural_params();
421 let back = BetaNP::from_natural(&eta).expect("round trip");
422 assert!((back.alpha - 4.5).abs() < 1e-12);
423 assert!((back.beta - 2.25).abs() < 1e-12);
424 }
425
426 #[test]
427 fn beta_set_natural_rejects_invalid_shapes() {
428 let mut b = BetaNP::new(2.0, 2.0).expect("ctor");
429 let err = b.set_natural(&[-1.5, 0.0]);
431 assert!(err.is_err());
432 let err = b.set_natural(&[0.0, -2.0]);
434 assert!(err.is_err());
435 let err = b.set_natural(&[f64::NAN, 0.0]);
437 assert!(err.is_err());
438 let err = b.set_natural(&[0.1]);
440 assert!(err.is_err());
441 let ok = b.set_natural(&[0.5, 1.5]);
443 assert!(ok.is_ok());
444 assert!((b.alpha - 1.5).abs() < 1e-12);
445 assert!((b.beta - 2.5).abs() < 1e-12);
446 }
447
448 #[test]
449 fn beta_bernoulli_observation_counts() {
450 let obs = BetaBernoulliObservation::new("p", vec![true, false, true, true, false, true]);
451 assert_eq!(obs.num_successes(), 4);
452 assert_eq!(obs.num_failures(), 2);
453 assert_eq!(obs.num_observations(), 6);
454
455 let from_counts = BetaBernoulliObservation::from_counts("p", 5, 3);
456 assert_eq!(from_counts.num_successes(), 5);
457 assert_eq!(from_counts.num_failures(), 3);
458 assert_eq!(from_counts.num_observations(), 8);
459 }
460}