tensorlogic_quantrs_hooks/vmp/exponential_family.rs
1//! Exponential family trait for Variational Message Passing.
2//!
3//! In VMP (Winn & Bishop, 2005) every variable has a variational distribution that
4//! belongs to a known exponential family. This module defines the minimal contract
5//! every such distribution must satisfy so that the engine can perform coordinate
6//! ascent purely in natural-parameter space.
7//!
8//! The canonical form of an exponential family distribution is
9//!
10//! ```text
11//! p(x | η) = h(x) · exp(ηᵀ u(x) − A(η))
12//! ```
13//!
14//! where
15//! - `η` is the vector of **natural parameters**,
16//! - `u(x)` is the vector of **sufficient statistics**,
17//! - `A(η)` is the **log partition function** (a.k.a. cumulant function), and
18//! - `h(x)` is the base measure (which cancels out of every VMP update and is
19//! therefore not part of the trait).
20//!
21//! A key property we rely on is
22//!
23//! ```text
24//! E_q[u(x)] = ∇_η A(η).
25//! ```
26//!
27//! That identity is what lets the engine send "expected sufficient statistic"
28//! messages between factors without ever touching raw probability tables.
29
30use crate::error::{PgmError, Result};
31
32/// Minimal contract every VMP-compatible distribution must implement.
33///
34/// Implementations represent a *variational distribution* over one variable, fully
35/// parameterised by its natural parameters (η). The trait is deliberately kept
36/// small — it only exposes what the coordinate-ascent update and ELBO computation
37/// need:
38///
39/// 1. Read / write the natural parameter vector (`to_natural` / `update_natural`).
40/// 2. Evaluate sufficient statistics at a point value (`sufficient_statistics`).
41/// 3. Evaluate the log partition (`log_partition`) and expected sufficient
42/// statistics (`expected_sufficient_statistics`) at the current η.
43///
44/// All vector shapes must match the fixed `natural_dim()` of the implementation.
45pub trait ExponentialFamily: Clone {
46 /// Name of the family (used for error messages, e.g. "Gaussian").
47 fn family_name(&self) -> &'static str;
48
49 /// Dimensionality of the natural-parameter vector.
50 fn natural_dim(&self) -> usize;
51
52 /// Return a copy of the current natural parameters.
53 fn natural_params(&self) -> Vec<f64>;
54
55 /// Read-only view of the natural parameter vector.
56 ///
57 /// Provided by default via `natural_params` but implementations may override
58 /// to avoid the allocation if they already store η contiguously.
59 fn to_natural(&self) -> Vec<f64> {
60 self.natural_params()
61 }
62
63 /// Replace the natural parameters with `new_eta`.
64 ///
65 /// Returns an error if the dimensions do not match `natural_dim()`.
66 fn set_natural(&mut self, new_eta: &[f64]) -> Result<()>;
67
68 /// Additively update the natural parameters: η ← η + δ.
69 ///
70 /// Returns an error if the dimensions do not match `natural_dim()`.
71 fn update_natural(&mut self, delta: &[f64]) -> Result<()> {
72 if delta.len() != self.natural_dim() {
73 return Err(PgmError::DimensionMismatch {
74 expected: vec![self.natural_dim()],
75 got: vec![delta.len()],
76 });
77 }
78 let mut eta = self.natural_params();
79 for (a, b) in eta.iter_mut().zip(delta.iter()) {
80 *a += *b;
81 }
82 self.set_natural(&eta)
83 }
84
85 /// Sufficient statistics `u(x)` evaluated at the scalar or categorical value `value`.
86 ///
87 /// For discrete families `value.floor() as usize` is the category index; for
88 /// continuous ones it is the raw real value. Returning a `Vec<f64>` keeps the
89 /// interface uniform at the cost of one small heap allocation per call — this
90 /// is only invoked in ELBO paths, not in the hot inner loop.
91 fn sufficient_statistics(&self, value: f64) -> Vec<f64>;
92
93 /// Log partition function `A(η)`.
94 fn log_partition(&self, natural_params: &[f64]) -> Result<f64>;
95
96 /// Expected sufficient statistics `E_q[u(x)] = ∇_η A(η)`.
97 ///
98 /// Computed from the *current* η stored inside `self`.
99 fn expected_sufficient_statistics(&self) -> Vec<f64>;
100
101 /// Differential entropy `H(q) = A(η) − ηᵀ E_q[u(x)] − E_q[log h(x)]`.
102 ///
103 /// The last term is zero for every family we ship (Gaussian with fixed
104 /// precision, Categorical, Dirichlet); if a future family needs a non-trivial
105 /// base measure it must override this default.
106 fn entropy(&self) -> Result<f64> {
107 let eta = self.natural_params();
108 let a = self.log_partition(&eta)?;
109 let ess = self.expected_sufficient_statistics();
110 debug_assert_eq!(eta.len(), ess.len());
111 let dot: f64 = eta.iter().zip(ess.iter()).map(|(e, s)| e * s).sum();
112 Ok(a - dot)
113 }
114}