Skip to main content

tensorlogic_quantrs_hooks/vmp/
beta.rs

1//! Beta natural parameters for Variational Message Passing.
2//!
3//! The Beta distribution `Beta(α, β)` with α > 0 and β > 0 is the conjugate
4//! prior for the success probability of a Bernoulli / Binomial likelihood. In
5//! exponential family form:
6//!
7//! ```text
8//!   p(x | α, β) = (Γ(α + β) / (Γ(α) Γ(β))) · x^{α-1} (1 - x)^{β-1}   (0 < x < 1)
9//!                = h(x) · exp(ηᵀ u(x) − A(η))
10//! ```
11//!
12//! with base measure `h(x) = 1` on `(0, 1)`, natural parameters
13//! `η = (α − 1, β − 1)`, sufficient statistics `u(x) = (log x, log(1 − x))`, and
14//! log partition `A(η) = ln Γ(η₁ + 1) + ln Γ(η₂ + 1) − ln Γ(η₁ + η₂ + 2)`.
15//!
16//! The struct stores α and β directly for ergonomics; conversion to/from the
17//! natural-parameter vector is handled at the [`ExponentialFamily`] trait
18//! boundary.
19//!
20//! # Conjugacy cheat-sheet
21//!
22//! | Conjugate family | Observation likelihood              |
23//! |------------------|-------------------------------------|
24//! | Bernoulli        | `y ~ Bernoulli(p)`, p ~ Beta         |
25//! | Binomial         | `y ~ Binomial(N, p)`, p ~ Beta       |
26//!
27//! Only the Bernoulli pairing is wired into the VMP engine in v0.2.0; Binomial
28//! can be added with the same machinery (it contributes `(n_s, n_f)` to the
29//! natural parameters, just like a batch of Bernoulli draws).
30
31use crate::error::{PgmError, Result};
32
33use super::exponential_family::ExponentialFamily;
34use super::special::{digamma, ln_gamma};
35
36/// Beta distribution stored in (α, β) moment parameterisation.
37///
38/// Natural parameters are `η = (α − 1, β − 1)`. Both α and β must be strictly
39/// positive and finite for the distribution to be well-defined; the constructor
40/// and [`ExponentialFamily::set_natural`] reject values outside that open
41/// positive quadrant.
42#[derive(Clone, Debug)]
43pub struct BetaNP {
44    /// Shape parameter α > 0.
45    pub alpha: f64,
46    /// Shape parameter β > 0.
47    pub beta: f64,
48}
49
50impl BetaNP {
51    /// Construct from moment parameters (α, β). Both must be strictly positive
52    /// and finite.
53    pub fn new(alpha: f64, beta: f64) -> Result<Self> {
54        if !alpha.is_finite() || alpha <= 0.0 {
55            return Err(PgmError::InvalidDistribution(format!(
56                "Beta shape α must be positive and finite (got {})",
57                alpha
58            )));
59        }
60        if !beta.is_finite() || beta <= 0.0 {
61            return Err(PgmError::InvalidDistribution(format!(
62                "Beta shape β must be positive and finite (got {})",
63                beta
64            )));
65        }
66        Ok(Self { alpha, beta })
67    }
68
69    /// Reconstruct a Beta from natural parameters `η = (α − 1, β − 1)`.
70    pub fn from_natural(natural: &[f64]) -> Result<Self> {
71        if natural.len() != 2 {
72            return Err(PgmError::DimensionMismatch {
73                expected: vec![2],
74                got: vec![natural.len()],
75            });
76        }
77        let alpha = natural[0] + 1.0;
78        let beta = natural[1] + 1.0;
79        Self::new(alpha, beta)
80    }
81
82    /// Expected value `E[x] = α / (α + β)`.
83    pub fn expected_x(&self) -> f64 {
84        self.alpha / (self.alpha + self.beta)
85    }
86
87    /// Expected log value `E[log x] = ψ(α) − ψ(α + β)`.
88    pub fn expected_log_x(&self) -> f64 {
89        digamma(self.alpha) - digamma(self.alpha + self.beta)
90    }
91
92    /// Expected log of the complement `E[log(1 − x)] = ψ(β) − ψ(α + β)`.
93    pub fn expected_log_1mx(&self) -> f64 {
94        digamma(self.beta) - digamma(self.alpha + self.beta)
95    }
96
97    /// Variance `Var[x] = α β / ((α + β)² (α + β + 1))`.
98    pub fn variance(&self) -> f64 {
99        let ab = self.alpha + self.beta;
100        self.alpha * self.beta / (ab * ab * (ab + 1.0))
101    }
102
103    /// Sum the natural parameters of `self` and `other`. Corresponds to the
104    /// pointwise product of densities: if both priors are Beta on the same
105    /// variable, their product is another Beta whose natural parameter is the
106    /// sum of the two input natural parameters.
107    ///
108    /// Concretely: `α_new = α₁ + α₂ − 1` and `β_new = β₁ + β₂ − 1`.
109    pub fn multiply_naturals(&self, other: &BetaNP) -> Result<BetaNP> {
110        let alpha = self.alpha + other.alpha - 1.0;
111        let beta = self.beta + other.beta - 1.0;
112        BetaNP::new(alpha, beta)
113    }
114
115    /// Closed-form KL divergence `KL(Beta(α_p, β_p) || Beta(α_q, β_q))`.
116    ///
117    /// Standard result:
118    ///
119    /// ```text
120    ///   KL = ln B(α_q, β_q) − ln B(α_p, β_p)
121    ///        + (α_p − α_q) ψ(α_p)
122    ///        + (β_p − β_q) ψ(β_p)
123    ///        + (α_q − α_p + β_q − β_p) ψ(α_p + β_p)
124    /// ```
125    ///
126    /// where `ln B(a, b) = ln Γ(a) + ln Γ(b) − ln Γ(a + b)`.
127    pub fn kl_to(&self, other: &BetaNP) -> f64 {
128        let ap = self.alpha;
129        let bp = self.beta;
130        let aq = other.alpha;
131        let bq = other.beta;
132        let ln_beta_p = ln_gamma(ap) + ln_gamma(bp) - ln_gamma(ap + bp);
133        let ln_beta_q = ln_gamma(aq) + ln_gamma(bq) - ln_gamma(aq + bq);
134        let psi_ap = digamma(ap);
135        let psi_bp = digamma(bp);
136        let psi_abp = digamma(ap + bp);
137        ln_beta_q - ln_beta_p
138            + (ap - aq) * psi_ap
139            + (bp - bq) * psi_bp
140            + (aq - ap + bq - bp) * psi_abp
141    }
142}
143
144impl ExponentialFamily for BetaNP {
145    fn family_name(&self) -> &'static str {
146        "Beta"
147    }
148
149    fn natural_dim(&self) -> usize {
150        2
151    }
152
153    fn natural_params(&self) -> Vec<f64> {
154        vec![self.alpha - 1.0, self.beta - 1.0]
155    }
156
157    fn set_natural(&mut self, new_eta: &[f64]) -> Result<()> {
158        if new_eta.len() != 2 {
159            return Err(PgmError::DimensionMismatch {
160                expected: vec![2],
161                got: vec![new_eta.len()],
162            });
163        }
164        for &v in new_eta {
165            if !v.is_finite() {
166                return Err(PgmError::InvalidDistribution(
167                    "Beta natural parameter must be finite".to_string(),
168                ));
169            }
170        }
171        let alpha = new_eta[0] + 1.0;
172        let beta = new_eta[1] + 1.0;
173        if alpha <= 0.0 {
174            return Err(PgmError::InvalidDistribution(format!(
175                "Beta shape α must stay positive (η₁ + 1 = {} ≤ 0)",
176                alpha
177            )));
178        }
179        if beta <= 0.0 {
180            return Err(PgmError::InvalidDistribution(format!(
181                "Beta shape β must stay positive (η₂ + 1 = {} ≤ 0)",
182                beta
183            )));
184        }
185        self.alpha = alpha;
186        self.beta = beta;
187        Ok(())
188    }
189
190    fn sufficient_statistics(&self, value: f64) -> Vec<f64> {
191        // u(x) = (log x, log(1 - x)). For `value` outside (0, 1) the stat is
192        // degenerate; we return NEG_INFINITY rather than panicking so the
193        // caller can detect the bad input.
194        if value > 0.0 && value < 1.0 {
195            vec![value.ln(), (1.0 - value).ln()]
196        } else {
197            vec![f64::NEG_INFINITY, f64::NEG_INFINITY]
198        }
199    }
200
201    fn log_partition(&self, natural_params: &[f64]) -> Result<f64> {
202        if natural_params.len() != 2 {
203            return Err(PgmError::DimensionMismatch {
204                expected: vec![2],
205                got: vec![natural_params.len()],
206            });
207        }
208        let alpha = natural_params[0] + 1.0;
209        let beta = natural_params[1] + 1.0;
210        if alpha <= 0.0 || beta <= 0.0 {
211            return Err(PgmError::InvalidDistribution(format!(
212                "Beta log_partition: α = {} and β = {} must both be positive",
213                alpha, beta
214            )));
215        }
216        // A(η) = ln Γ(α) + ln Γ(β) − ln Γ(α + β).
217        Ok(ln_gamma(alpha) + ln_gamma(beta) - ln_gamma(alpha + beta))
218    }
219
220    fn expected_sufficient_statistics(&self) -> Vec<f64> {
221        // E[u(x)] = (E[log x], E[log(1-x)]) = (ψ(α) − ψ(α+β), ψ(β) − ψ(α+β)).
222        vec![self.expected_log_x(), self.expected_log_1mx()]
223    }
224}
225
226/// Beta-Bernoulli conjugate posterior update.
227///
228/// Given a `Beta(α_prior, β_prior)` prior on the Bernoulli success probability
229/// `p` and observed `successes` successes plus `failures` failures, the exact
230/// posterior is `Beta(α_prior + successes, β_prior + failures)`.
231///
232/// This is exact because Bernoulli is conjugate to Beta; the update adds the
233/// observation-dependent sufficient statistics `(n_s, n_f)` into the natural
234/// parameters `(α − 1, β − 1)` of the prior.
235pub fn posterior_from_prior_and_observations(
236    prior: &BetaNP,
237    successes: u64,
238    failures: u64,
239) -> Result<BetaNP> {
240    let posterior_alpha = prior.alpha + successes as f64;
241    let posterior_beta = prior.beta + failures as f64;
242    BetaNP::new(posterior_alpha, posterior_beta)
243}
244
245/// `BetaBernoulliObservation` captures a Bernoulli likelihood `y ~ Bernoulli(p)`
246/// where the success probability `p` is a `BetaNP` variable. It contributes
247/// `(n_s, n_f)` to the posterior natural parameters, i.e. adds `n_s` to
248/// `(α − 1)` and `n_f` to `(β − 1)`.
249///
250/// A factor holds a reference to its Beta-distributed probability variable and
251/// a (possibly empty) batch of binary outcomes. Posterior inference combining
252/// prior + factor is exact in one VMP sweep because Bernoulli is conjugate to
253/// Beta.
254#[derive(Clone, Debug)]
255pub struct BetaBernoulliObservation {
256    /// Name of the `BetaNP` variable in the VMP graph.
257    pub probability_variable: String,
258    /// Observed binary outcomes (true = success, false = failure).
259    pub observations: Vec<bool>,
260}
261
262impl BetaBernoulliObservation {
263    /// Build a new Beta-Bernoulli observation factor from a boolean batch.
264    pub fn new(probability_variable: impl Into<String>, observations: Vec<bool>) -> Self {
265        Self {
266            probability_variable: probability_variable.into(),
267            observations,
268        }
269    }
270
271    /// Convenience constructor from aggregate counts. Often you already have
272    /// the sufficient statistics as `(n_s, n_f)` without keeping the raw batch.
273    pub fn from_counts(
274        probability_variable: impl Into<String>,
275        successes: u64,
276        failures: u64,
277    ) -> Self {
278        let mut observations = Vec::with_capacity((successes + failures) as usize);
279        observations.extend(std::iter::repeat_n(true, successes as usize));
280        observations.extend(std::iter::repeat_n(false, failures as usize));
281        Self {
282            probability_variable: probability_variable.into(),
283            observations,
284        }
285    }
286
287    /// Number of successes n_s = Σ 1[y_i = 1]. Used as the α-parameter increment.
288    pub fn num_successes(&self) -> u64 {
289        self.observations.iter().filter(|b| **b).count() as u64
290    }
291
292    /// Number of failures n_f = Σ 1[y_i = 0]. Used as the β-parameter increment.
293    pub fn num_failures(&self) -> u64 {
294        self.observations.iter().filter(|b| !**b).count() as u64
295    }
296
297    /// Total number of observations N = n_s + n_f.
298    pub fn num_observations(&self) -> usize {
299        self.observations.len()
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::vmp::special::{digamma, ln_gamma};
307
308    #[test]
309    fn beta_expected_x_matches_alpha_over_total() {
310        for &(alpha, beta) in &[(1.0_f64, 1.0_f64), (2.0, 3.0), (4.5, 0.5), (0.25, 10.0)] {
311            let b = BetaNP::new(alpha, beta).expect("ctor");
312            let ex = b.expected_x();
313            let expected = alpha / (alpha + beta);
314            assert!(
315                (ex - expected).abs() < 1e-12,
316                "E[x] = {} but α/(α+β) = {}",
317                ex,
318                expected
319            );
320        }
321    }
322
323    #[test]
324    fn beta_expected_log_x_and_1mx_match_digamma() {
325        for &(alpha, beta) in &[(1.0_f64, 1.0_f64), (2.5, 1.5), (4.0, 2.0)] {
326            let b = BetaNP::new(alpha, beta).expect("ctor");
327            let el_x = b.expected_log_x();
328            let el_1mx = b.expected_log_1mx();
329            let expected_log_x = digamma(alpha) - digamma(alpha + beta);
330            let expected_log_1mx = digamma(beta) - digamma(alpha + beta);
331            assert!((el_x - expected_log_x).abs() < 1e-12);
332            assert!((el_1mx - expected_log_1mx).abs() < 1e-12);
333        }
334    }
335
336    #[test]
337    fn beta_multiply_naturals_sums_natural_params() {
338        // Beta(2, 3) has η = (1, 2); Beta(3, 5) has η = (2, 4).
339        // Sum = (3, 6), i.e. Beta(4, 7).
340        let a = BetaNP::new(2.0, 3.0).expect("ctor a");
341        let c = BetaNP::new(3.0, 5.0).expect("ctor b");
342        let p = a.multiply_naturals(&c).expect("product");
343        assert!((p.alpha - 4.0).abs() < 1e-12, "α = {}", p.alpha);
344        assert!((p.beta - 7.0).abs() < 1e-12, "β = {}", p.beta);
345        // Round-trip through natural parameters.
346        let eta_a = a.natural_params();
347        let eta_c = c.natural_params();
348        let eta_sum: Vec<f64> = eta_a.iter().zip(eta_c.iter()).map(|(x, y)| x + y).collect();
349        let p2 = BetaNP::from_natural(&eta_sum).expect("from nat");
350        assert!((p2.alpha - p.alpha).abs() < 1e-12);
351        assert!((p2.beta - p.beta).abs() < 1e-12);
352    }
353
354    #[test]
355    fn beta_kl_is_zero_for_self_positive_otherwise() {
356        let b = BetaNP::new(3.0, 2.0).expect("ctor");
357        let self_kl = b.kl_to(&b);
358        assert!(self_kl.abs() < 1e-10, "KL(b||b) = {}", self_kl);
359
360        let other = BetaNP::new(1.5, 4.0).expect("ctor other");
361        let kl = b.kl_to(&other);
362        assert!(kl > 0.0, "KL(b||other) should be positive, got {}", kl);
363
364        let kl_rev = other.kl_to(&b);
365        assert!(
366            kl_rev > 0.0,
367            "KL(other||b) should be positive, got {}",
368            kl_rev
369        );
370    }
371
372    #[test]
373    fn beta_bernoulli_posterior_adds_counts() {
374        // Beta(1, 1) + 7 successes, 3 failures → Beta(8, 4).
375        let prior = BetaNP::new(1.0, 1.0).expect("prior");
376        let post = posterior_from_prior_and_observations(&prior, 7, 3).expect("posterior");
377        assert!((post.alpha - 8.0).abs() < 1e-12, "α = {}", post.alpha);
378        assert!((post.beta - 4.0).abs() < 1e-12, "β = {}", post.beta);
379    }
380
381    #[test]
382    fn beta_log_partition_matches_closed_form() {
383        // A(η) = ln Γ(α) + ln Γ(β) − ln Γ(α + β).
384        let b = BetaNP::new(2.5, 3.0).expect("ctor");
385        let eta = b.natural_params();
386        let a = b.log_partition(&eta).expect("lp");
387        let expected = ln_gamma(2.5) + ln_gamma(3.0) - ln_gamma(5.5);
388        assert!(
389            (a - expected).abs() < 1e-12,
390            "A(η) = {}, expected {}",
391            a,
392            expected
393        );
394
395        // ∂A/∂η₁ = ψ(α) − ψ(α+β) = E[log x] and ∂A/∂η₂ = ψ(β) − ψ(α+β) = E[log(1-x)].
396        let h = 1e-6;
397        let a_plus_1 = b.log_partition(&[eta[0] + h, eta[1]]).expect("lp+1");
398        let a_minus_1 = b.log_partition(&[eta[0] - h, eta[1]]).expect("lp-1");
399        let d1 = (a_plus_1 - a_minus_1) / (2.0 * h);
400        let a_plus_2 = b.log_partition(&[eta[0], eta[1] + h]).expect("lp+2");
401        let a_minus_2 = b.log_partition(&[eta[0], eta[1] - h]).expect("lp-2");
402        let d2 = (a_plus_2 - a_minus_2) / (2.0 * h);
403        assert!(
404            (d1 - b.expected_log_x()).abs() < 1e-5,
405            "dA/dη1 = {}, expected {}",
406            d1,
407            b.expected_log_x()
408        );
409        assert!(
410            (d2 - b.expected_log_1mx()).abs() < 1e-5,
411            "dA/dη2 = {}, expected {}",
412            d2,
413            b.expected_log_1mx()
414        );
415    }
416
417    #[test]
418    fn beta_natural_round_trip() {
419        let b = BetaNP::new(4.5, 2.25).expect("ctor");
420        let eta = b.natural_params();
421        let back = BetaNP::from_natural(&eta).expect("round trip");
422        assert!((back.alpha - 4.5).abs() < 1e-12);
423        assert!((back.beta - 2.25).abs() < 1e-12);
424    }
425
426    #[test]
427    fn beta_set_natural_rejects_invalid_shapes() {
428        let mut b = BetaNP::new(2.0, 2.0).expect("ctor");
429        // α = 1 + (-1.5) = -0.5 < 0
430        let err = b.set_natural(&[-1.5, 0.0]);
431        assert!(err.is_err());
432        // β = 1 + (-2.0) = -1.0 < 0
433        let err = b.set_natural(&[0.0, -2.0]);
434        assert!(err.is_err());
435        // NaN
436        let err = b.set_natural(&[f64::NAN, 0.0]);
437        assert!(err.is_err());
438        // Wrong length
439        let err = b.set_natural(&[0.1]);
440        assert!(err.is_err());
441        // Valid.
442        let ok = b.set_natural(&[0.5, 1.5]);
443        assert!(ok.is_ok());
444        assert!((b.alpha - 1.5).abs() < 1e-12);
445        assert!((b.beta - 2.5).abs() < 1e-12);
446    }
447
448    #[test]
449    fn beta_bernoulli_observation_counts() {
450        let obs = BetaBernoulliObservation::new("p", vec![true, false, true, true, false, true]);
451        assert_eq!(obs.num_successes(), 4);
452        assert_eq!(obs.num_failures(), 2);
453        assert_eq!(obs.num_observations(), 6);
454
455        let from_counts = BetaBernoulliObservation::from_counts("p", 5, 3);
456        assert_eq!(from_counts.num_successes(), 5);
457        assert_eq!(from_counts.num_failures(), 3);
458        assert_eq!(from_counts.num_observations(), 8);
459    }
460}