Skip to main content

tensorlogic_quantrs_hooks/vmp/
exponential_family.rs

1//! Exponential family trait for Variational Message Passing.
2//!
3//! In VMP (Winn & Bishop, 2005) every variable has a variational distribution that
4//! belongs to a known exponential family. This module defines the minimal contract
5//! every such distribution must satisfy so that the engine can perform coordinate
6//! ascent purely in natural-parameter space.
7//!
8//! The canonical form of an exponential family distribution is
9//!
10//! ```text
11//!   p(x | η) = h(x) · exp(ηᵀ u(x) − A(η))
12//! ```
13//!
14//! where
15//! - `η` is the vector of **natural parameters**,
16//! - `u(x)` is the vector of **sufficient statistics**,
17//! - `A(η)` is the **log partition function** (a.k.a. cumulant function), and
18//! - `h(x)` is the base measure (which cancels out of every VMP update and is
19//!   therefore not part of the trait).
20//!
21//! A key property we rely on is
22//!
23//! ```text
24//!   E_q[u(x)] = ∇_η A(η).
25//! ```
26//!
27//! That identity is what lets the engine send "expected sufficient statistic"
28//! messages between factors without ever touching raw probability tables.
29
30use crate::error::{PgmError, Result};
31
32/// Minimal contract every VMP-compatible distribution must implement.
33///
34/// Implementations represent a *variational distribution* over one variable, fully
35/// parameterised by its natural parameters (η). The trait is deliberately kept
36/// small — it only exposes what the coordinate-ascent update and ELBO computation
37/// need:
38///
39/// 1. Read / write the natural parameter vector (`to_natural` / `update_natural`).
40/// 2. Evaluate sufficient statistics at a point value (`sufficient_statistics`).
41/// 3. Evaluate the log partition (`log_partition`) and expected sufficient
42///    statistics (`expected_sufficient_statistics`) at the current η.
43///
44/// All vector shapes must match the fixed `natural_dim()` of the implementation.
45pub trait ExponentialFamily: Clone {
46    /// Name of the family (used for error messages, e.g. "Gaussian").
47    fn family_name(&self) -> &'static str;
48
49    /// Dimensionality of the natural-parameter vector.
50    fn natural_dim(&self) -> usize;
51
52    /// Return a copy of the current natural parameters.
53    fn natural_params(&self) -> Vec<f64>;
54
55    /// Read-only view of the natural parameter vector.
56    ///
57    /// Provided by default via `natural_params` but implementations may override
58    /// to avoid the allocation if they already store η contiguously.
59    fn to_natural(&self) -> Vec<f64> {
60        self.natural_params()
61    }
62
63    /// Replace the natural parameters with `new_eta`.
64    ///
65    /// Returns an error if the dimensions do not match `natural_dim()`.
66    fn set_natural(&mut self, new_eta: &[f64]) -> Result<()>;
67
68    /// Additively update the natural parameters: η ← η + δ.
69    ///
70    /// Returns an error if the dimensions do not match `natural_dim()`.
71    fn update_natural(&mut self, delta: &[f64]) -> Result<()> {
72        if delta.len() != self.natural_dim() {
73            return Err(PgmError::DimensionMismatch {
74                expected: vec![self.natural_dim()],
75                got: vec![delta.len()],
76            });
77        }
78        let mut eta = self.natural_params();
79        for (a, b) in eta.iter_mut().zip(delta.iter()) {
80            *a += *b;
81        }
82        self.set_natural(&eta)
83    }
84
85    /// Sufficient statistics `u(x)` evaluated at the scalar or categorical value `value`.
86    ///
87    /// For discrete families `value.floor() as usize` is the category index; for
88    /// continuous ones it is the raw real value. Returning a `Vec<f64>` keeps the
89    /// interface uniform at the cost of one small heap allocation per call — this
90    /// is only invoked in ELBO paths, not in the hot inner loop.
91    fn sufficient_statistics(&self, value: f64) -> Vec<f64>;
92
93    /// Log partition function `A(η)`.
94    fn log_partition(&self, natural_params: &[f64]) -> Result<f64>;
95
96    /// Expected sufficient statistics `E_q[u(x)] = ∇_η A(η)`.
97    ///
98    /// Computed from the *current* η stored inside `self`.
99    fn expected_sufficient_statistics(&self) -> Vec<f64>;
100
101    /// Differential entropy `H(q) = A(η) − ηᵀ E_q[u(x)] − E_q[log h(x)]`.
102    ///
103    /// The last term is zero for every family we ship (Gaussian with fixed
104    /// precision, Categorical, Dirichlet); if a future family needs a non-trivial
105    /// base measure it must override this default.
106    fn entropy(&self) -> Result<f64> {
107        let eta = self.natural_params();
108        let a = self.log_partition(&eta)?;
109        let ess = self.expected_sufficient_statistics();
110        debug_assert_eq!(eta.len(), ess.len());
111        let dot: f64 = eta.iter().zip(ess.iter()).map(|(e, s)| e * s).sum();
112        Ok(a - dot)
113    }
114}