Skip to main content

tensorlogic_quantrs_hooks/vmp/
gamma.rs

1//! Gamma natural parameters for Variational Message Passing.
2//!
3//! The Gamma distribution `Gamma(α, β)` with shape α > 0 and **rate** β > 0 is
4//! the conjugate prior for the Poisson rate and the precision parameter of a
5//! univariate Gaussian (the latter is out of scope for the v0.2.0 research
6//! preview). In exponential family form:
7//!
8//! ```text
9//!   p(x | α, β) = (β^α / Γ(α)) · x^{α-1} · exp(-β x)   (x > 0)
10//!                = h(x) · exp(ηᵀ u(x) − A(η))
11//! ```
12//!
13//! with base measure `h(x) = 1` on `x > 0`, natural parameters
14//! `η = (α − 1, −β)`, sufficient statistics `u(x) = (log x, x)`, and log
15//! partition `A(η) = ln Γ(η₁ + 1) − (η₁ + 1) ln(−η₂)`.
16//!
17//! The struct stores α and β directly for ergonomics; conversion to/from the
18//! natural-parameter vector is handled at the [`ExponentialFamily`] trait
19//! boundary.
20//!
21//! # Conjugacy cheat-sheet
22//!
23//! | Conjugate family | Observation likelihood              |
24//! |------------------|-------------------------------------|
25//! | Poisson          | `y ~ Poisson(λ)`, λ ~ Gamma          |
26//! | Exponential      | `y ~ Exp(λ)`, λ ~ Gamma              |
27//! | Gaussian (σ²)    | `y ~ N(μ, σ²)`, τ = 1/σ² ~ Gamma     |
28//!
29//! Only the Poisson pairing is wired into the VMP engine in v0.2.0; the
30//! remaining two can be added without touching [`GammaNP`] itself.
31
32use crate::error::{PgmError, Result};
33
34use super::exponential_family::ExponentialFamily;
35use super::special::{digamma, ln_gamma};
36
37/// Gamma distribution stored in (shape, rate) moment parameterisation.
38///
39/// Natural parameters are `η = (α − 1, −β)`. Both α and β must be strictly
40/// positive and finite for the distribution to be well-defined; the
41/// constructor and [`ExponentialFamily::set_natural`] reject values outside
42/// that open half-plane.
43#[derive(Clone, Debug)]
44pub struct GammaNP {
45    /// Shape parameter α > 0.
46    pub alpha: f64,
47    /// Rate parameter β > 0 (NOT the scale 1/β).
48    pub beta: f64,
49}
50
51impl GammaNP {
52    /// Construct from moment parameters (α, β). Both must be strictly positive
53    /// and finite.
54    pub fn new(alpha: f64, beta: f64) -> Result<Self> {
55        if !alpha.is_finite() || alpha <= 0.0 {
56            return Err(PgmError::InvalidDistribution(format!(
57                "Gamma shape α must be positive and finite (got {})",
58                alpha
59            )));
60        }
61        if !beta.is_finite() || beta <= 0.0 {
62            return Err(PgmError::InvalidDistribution(format!(
63                "Gamma rate β must be positive and finite (got {})",
64                beta
65            )));
66        }
67        Ok(Self { alpha, beta })
68    }
69
70    /// Reconstruct a Gamma from natural parameters `η = (α − 1, −β)`.
71    pub fn from_natural(natural: &[f64]) -> Result<Self> {
72        if natural.len() != 2 {
73            return Err(PgmError::DimensionMismatch {
74                expected: vec![2],
75                got: vec![natural.len()],
76            });
77        }
78        let alpha = natural[0] + 1.0;
79        let beta = -natural[1];
80        Self::new(alpha, beta)
81    }
82
83    /// Expected value `E[x] = α / β`.
84    pub fn expected_x(&self) -> f64 {
85        self.alpha / self.beta
86    }
87
88    /// Expected log value `E[log x] = ψ(α) − ln β`.
89    pub fn expected_log_x(&self) -> f64 {
90        digamma(self.alpha) - self.beta.ln()
91    }
92
93    /// Variance `Var[x] = α / β²`.
94    pub fn variance(&self) -> f64 {
95        self.alpha / (self.beta * self.beta)
96    }
97
98    /// Sum the natural parameters of `self` and `other`. Corresponds to the
99    /// pointwise product of densities: if both priors are Gamma on the same
100    /// variable, their product is another Gamma whose natural parameter is
101    /// the sum of the two input natural parameters.
102    ///
103    /// Concretely: `α_new = α₁ + α₂ − 1` and `β_new = β₁ + β₂`.
104    pub fn multiply_naturals(&self, other: &GammaNP) -> Result<GammaNP> {
105        let alpha = self.alpha + other.alpha - 1.0;
106        let beta = self.beta + other.beta;
107        GammaNP::new(alpha, beta)
108    }
109
110    /// Closed-form KL divergence `KL(Gamma(α_p, β_p) || Gamma(α_q, β_q))`.
111    ///
112    /// Standard result (Penny, 2001):
113    ///
114    /// ```text
115    ///   KL = (α_p − α_q) ψ(α_p) − ln Γ(α_p) + ln Γ(α_q)
116    ///        + α_q (ln β_p − ln β_q) + α_p (β_q − β_p) / β_p
117    /// ```
118    pub fn kl_to(&self, other: &GammaNP) -> f64 {
119        let ap = self.alpha;
120        let bp = self.beta;
121        let aq = other.alpha;
122        let bq = other.beta;
123        (ap - aq) * digamma(ap) - ln_gamma(ap)
124            + ln_gamma(aq)
125            + aq * (bp.ln() - bq.ln())
126            + ap * (bq - bp) / bp
127    }
128}
129
130impl ExponentialFamily for GammaNP {
131    fn family_name(&self) -> &'static str {
132        "Gamma"
133    }
134
135    fn natural_dim(&self) -> usize {
136        2
137    }
138
139    fn natural_params(&self) -> Vec<f64> {
140        vec![self.alpha - 1.0, -self.beta]
141    }
142
143    fn set_natural(&mut self, new_eta: &[f64]) -> Result<()> {
144        if new_eta.len() != 2 {
145            return Err(PgmError::DimensionMismatch {
146                expected: vec![2],
147                got: vec![new_eta.len()],
148            });
149        }
150        for &v in new_eta {
151            if !v.is_finite() {
152                return Err(PgmError::InvalidDistribution(
153                    "Gamma natural parameter must be finite".to_string(),
154                ));
155            }
156        }
157        let alpha = new_eta[0] + 1.0;
158        let beta = -new_eta[1];
159        if alpha <= 0.0 {
160            return Err(PgmError::InvalidDistribution(format!(
161                "Gamma shape must stay positive (η₁ + 1 = {} ≤ 0)",
162                alpha
163            )));
164        }
165        if beta <= 0.0 {
166            return Err(PgmError::InvalidDistribution(format!(
167                "Gamma rate must stay positive (−η₂ = {} ≤ 0)",
168                beta
169            )));
170        }
171        self.alpha = alpha;
172        self.beta = beta;
173        Ok(())
174    }
175
176    fn sufficient_statistics(&self, value: f64) -> Vec<f64> {
177        // u(x) = (log x, x). For `value <= 0` log x is undefined; we return a
178        // best-effort NEG_INFINITY so the caller can see that the stat is
179        // degenerate without panicking.
180        if value > 0.0 {
181            vec![value.ln(), value]
182        } else {
183            vec![f64::NEG_INFINITY, value]
184        }
185    }
186
187    fn log_partition(&self, natural_params: &[f64]) -> Result<f64> {
188        if natural_params.len() != 2 {
189            return Err(PgmError::DimensionMismatch {
190                expected: vec![2],
191                got: vec![natural_params.len()],
192            });
193        }
194        let alpha = natural_params[0] + 1.0;
195        let neg_beta = natural_params[1];
196        if alpha <= 0.0 || neg_beta >= 0.0 {
197            return Err(PgmError::InvalidDistribution(format!(
198                "Gamma log_partition: α = {} must be positive and −β = {} negative",
199                alpha, neg_beta
200            )));
201        }
202        // A(η) = ln Γ(α) − α ln β.
203        let beta = -neg_beta;
204        Ok(ln_gamma(alpha) - alpha * beta.ln())
205    }
206
207    fn expected_sufficient_statistics(&self) -> Vec<f64> {
208        // E[u(x)] = (E[log x], E[x]) = (ψ(α) − ln β, α / β).
209        vec![self.expected_log_x(), self.expected_x()]
210    }
211}
212
213/// Gamma-Poisson conjugate posterior update.
214///
215/// Given a `Gamma(α_prior, β_prior)` prior on the Poisson rate λ and a batch
216/// of `N` observed counts `y_i`, the exact posterior is
217/// `Gamma(α_prior + Σ y_i, β_prior + N)`.
218///
219/// This is exact because Poisson is conjugate to Gamma; the update adds the
220/// observation-dependent sufficient statistics (Σ y_i, N) into the natural
221/// parameters `(α − 1, −β)` of the prior.
222pub fn posterior_from_prior_and_observations(
223    prior: &GammaNP,
224    observations: &[u64],
225) -> Result<GammaNP> {
226    let n = observations.len() as f64;
227    let sum: u64 = observations.iter().sum();
228    let posterior_alpha = prior.alpha + sum as f64;
229    let posterior_beta = prior.beta + n;
230    GammaNP::new(posterior_alpha, posterior_beta)
231}
232
233/// `GammaPoissonObservation` captures a Poisson likelihood `y ~ Poisson(λ)`
234/// where the rate `λ` is a `GammaNP` variable. It contributes
235/// `(Σ y_i, N)` to the posterior natural parameters, i.e. adds `Σ y_i` to
236/// `(α − 1)` and `N` to `−(−β) = β`.
237///
238/// A factor holds a reference to its Gamma-distributed rate variable and a
239/// (possibly empty) batch of observations. Posterior inference combining
240/// prior + factor is exact in one VMP sweep because Poisson is conjugate to
241/// Gamma.
242#[derive(Clone, Debug)]
243pub struct GammaPoissonObservation {
244    /// Name of the `GammaNP` variable in the VMP graph.
245    pub rate_variable: String,
246    /// Observed Poisson counts.
247    pub observations: Vec<u64>,
248}
249
250impl GammaPoissonObservation {
251    /// Build a new Gamma-Poisson observation factor.
252    pub fn new(rate_variable: impl Into<String>, observations: Vec<u64>) -> Self {
253        Self {
254            rate_variable: rate_variable.into(),
255            observations,
256        }
257    }
258
259    /// Sum of observed counts Σ y_i. Used as the shape-parameter increment.
260    pub fn count_sum(&self) -> u64 {
261        self.observations.iter().sum()
262    }
263
264    /// Number of observations N. Used as the rate-parameter increment.
265    pub fn num_observations(&self) -> usize {
266        self.observations.len()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::vmp::special::{digamma, ln_gamma};
274
275    #[test]
276    fn gamma_expected_x_matches_alpha_over_beta() {
277        for &(alpha, beta) in &[(1.0_f64, 1.0_f64), (2.0, 0.5), (3.7, 4.2), (0.25, 10.0)] {
278            let g = GammaNP::new(alpha, beta).expect("ctor");
279            let ex = g.expected_x();
280            assert!(
281                (ex - alpha / beta).abs() < 1e-12,
282                "E[x] = {} but α/β = {}",
283                ex,
284                alpha / beta
285            );
286        }
287    }
288
289    #[test]
290    fn gamma_expected_log_x_matches_digamma_minus_lnbeta() {
291        for &(alpha, beta) in &[(1.0_f64, 1.0_f64), (2.5, 0.5), (4.0, 2.0)] {
292            let g = GammaNP::new(alpha, beta).expect("ctor");
293            let el = g.expected_log_x();
294            let expected = digamma(alpha) - beta.ln();
295            assert!(
296                (el - expected).abs() < 1e-12,
297                "E[log x] = {}, expected ψ(α)−ln β = {}",
298                el,
299                expected
300            );
301        }
302    }
303
304    #[test]
305    fn gamma_multiply_naturals_sums_natural_params() {
306        // Gamma(2, 1) has η = (1, -1); Gamma(3, 2) has η = (2, -2).
307        // Sum = (3, -3), i.e. Gamma(4, 3).
308        let a = GammaNP::new(2.0, 1.0).expect("ctor a");
309        let b = GammaNP::new(3.0, 2.0).expect("ctor b");
310        let p = a.multiply_naturals(&b).expect("product");
311        assert!((p.alpha - 4.0).abs() < 1e-12, "α = {}", p.alpha);
312        assert!((p.beta - 3.0).abs() < 1e-12, "β = {}", p.beta);
313        // And the round-trip through natural parameters matches.
314        let eta_a = a.natural_params();
315        let eta_b = b.natural_params();
316        let eta_sum: Vec<f64> = eta_a.iter().zip(eta_b.iter()).map(|(x, y)| x + y).collect();
317        let p2 = GammaNP::from_natural(&eta_sum).expect("from nat");
318        assert!((p2.alpha - p.alpha).abs() < 1e-12);
319        assert!((p2.beta - p.beta).abs() < 1e-12);
320    }
321
322    #[test]
323    fn gamma_kl_is_zero_for_self_positive_otherwise() {
324        let g = GammaNP::new(3.0, 2.0).expect("ctor");
325        let self_kl = g.kl_to(&g);
326        assert!(self_kl.abs() < 1e-10, "KL(g||g) = {}", self_kl);
327
328        let other = GammaNP::new(1.5, 4.0).expect("ctor other");
329        let kl = g.kl_to(&other);
330        assert!(kl > 0.0, "KL(g||other) should be positive, got {}", kl);
331
332        // Symmetric-ish sanity: cross KL also > 0.
333        let kl_rev = other.kl_to(&g);
334        assert!(
335            kl_rev > 0.0,
336            "KL(other||g) should be positive, got {}",
337            kl_rev
338        );
339    }
340
341    #[test]
342    fn gamma_poisson_posterior_adds_sum_and_count() {
343        let prior = GammaNP::new(1.0, 1.0).expect("prior");
344        let obs: [u64; 3] = [3, 5, 2];
345        let post = posterior_from_prior_and_observations(&prior, &obs).expect("posterior");
346        // Σ y_i = 10, N = 3, so posterior = Gamma(11, 4).
347        assert!((post.alpha - 11.0).abs() < 1e-12, "α = {}", post.alpha);
348        assert!((post.beta - 4.0).abs() < 1e-12, "β = {}", post.beta);
349    }
350
351    #[test]
352    fn gamma_log_partition_matches_closed_form() {
353        // A(η) = ln Γ(α) − α ln β.
354        let g = GammaNP::new(2.5, 3.0).expect("ctor");
355        let eta = g.natural_params();
356        let a = g.log_partition(&eta).expect("lp");
357        let expected = ln_gamma(2.5) - 2.5 * 3.0_f64.ln();
358        assert!(
359            (a - expected).abs() < 1e-12,
360            "A(η) = {}, expected {}",
361            a,
362            expected
363        );
364
365        // ∂A/∂η₁ = ψ(α) − ln β = E[log x].
366        // ∂A/∂η₂: Since β = −η₂, we have ∂A/∂η₂ = (∂A/∂β)(∂β/∂η₂)
367        //        = (−α/β)(−1) = α/β = E[x].
368        let h = 1e-6;
369        let a_plus_1 = g.log_partition(&[eta[0] + h, eta[1]]).expect("lp+1");
370        let a_minus_1 = g.log_partition(&[eta[0] - h, eta[1]]).expect("lp-1");
371        let d1 = (a_plus_1 - a_minus_1) / (2.0 * h);
372        let a_plus_2 = g.log_partition(&[eta[0], eta[1] + h]).expect("lp+2");
373        let a_minus_2 = g.log_partition(&[eta[0], eta[1] - h]).expect("lp-2");
374        let d2 = (a_plus_2 - a_minus_2) / (2.0 * h);
375        assert!(
376            (d1 - g.expected_log_x()).abs() < 1e-5,
377            "dA/dη1 = {}, expected {}",
378            d1,
379            g.expected_log_x()
380        );
381        assert!(
382            (d2 - g.expected_x()).abs() < 1e-5,
383            "dA/dη2 = {}, expected {}",
384            d2,
385            g.expected_x()
386        );
387    }
388
389    #[test]
390    fn gamma_natural_round_trip() {
391        let g = GammaNP::new(4.5, 2.25).expect("ctor");
392        let eta = g.natural_params();
393        let back = GammaNP::from_natural(&eta).expect("round trip");
394        assert!((back.alpha - 4.5).abs() < 1e-12);
395        assert!((back.beta - 2.25).abs() < 1e-12);
396    }
397
398    #[test]
399    fn gamma_set_natural_rejects_negative_alpha() {
400        let mut g = GammaNP::new(2.0, 1.0).expect("ctor");
401        let err = g.set_natural(&[-1.5, -1.0]); // α = -0.5
402        assert!(err.is_err());
403        let err = g.set_natural(&[0.5, 1.0]); // β = -1.0
404        assert!(err.is_err());
405        let err = g.set_natural(&[0.5, -1.0]); // fine
406        assert!(err.is_ok());
407    }
408}