Skip to main content

tensorlogic_quantrs_hooks/vmp/
engine.rs

1//! Variational Message Passing engine.
2//!
3//! Implements Winn & Bishop (2005) VMP for conjugate-exponential models built on
4//! three families: Gaussian (mean-unknown / precision-known), Categorical, and
5//! Dirichlet. The engine consumes a structural description via [`VmpConfig`] and
6//! runs coordinate-ascent natural-parameter updates until the ELBO / L∞ residual
7//! converges.
8//!
9//! The algorithm is independent of the discrete factor-potential tables carried
10//! by the `FactorGraph` type: VMP operates purely in continuous natural-parameter
11//! space. The user is therefore required to annotate each variable with its
12//! family and each factor with its conjugate role (see [`VmpFactor`]).
13//!
14//! # High-level flow
15//!
16//! 1. Initialise each variable's variational distribution `q(v)` from its prior.
17//! 2. For each iteration:
18//!    - For every variable `v` in a deterministic order:
19//!      - Accumulate contributions from every adjacent factor (natural-parameter
20//!        deltas).
21//!      - Replace `q(v)`'s natural parameters by `prior_nat + Σ Δ`.
22//!    - Compute the ELBO.
23//! 3. Stop when |ΔELBO| < ε or the maximum natural-parameter residual < ε.
24//!
25//! Divergence (ELBO decreasing by more than a small tolerance) is detected and
26//! surfaced as a `ConvergenceFailure` error so the caller does not silently
27//! consume a broken result.
28
29use std::collections::HashMap;
30
31use crate::error::{PgmError, Result};
32use crate::graph::FactorGraph;
33
34use super::distributions::{
35    categorical_kl, dirichlet_kl, gaussian_kl, CategoricalNP, DirichletNP, GaussianNP,
36};
37use super::exponential_family::ExponentialFamily;
38
39// ---------------------------------------------------------------------------
40// Family tagging
41// ---------------------------------------------------------------------------
42
43/// Variational family assigned to a variable.
44#[derive(Clone, Debug, PartialEq, Eq, Hash)]
45pub enum Family {
46    /// Univariate Gaussian with known precision (only mean is random).
47    Gaussian,
48    /// Categorical over `k` categories.
49    Categorical,
50    /// Dirichlet over `k` components (the conjugate prior for Categorical).
51    Dirichlet,
52}
53
54/// Variational state carried for a single variable.
55#[derive(Clone, Debug)]
56pub enum VariationalState {
57    /// Gaussian variable (mean unknown, precision fixed).
58    Gaussian { q: GaussianNP, prior: GaussianNP },
59    /// Categorical variable with a Dirichlet-prior parent.
60    Categorical {
61        q: CategoricalNP,
62        prior: CategoricalNP,
63    },
64    /// Dirichlet variable.
65    Dirichlet { q: DirichletNP, prior: DirichletNP },
66}
67
68impl VariationalState {
69    /// Family tag.
70    pub fn family(&self) -> Family {
71        match self {
72            Self::Gaussian { .. } => Family::Gaussian,
73            Self::Categorical { .. } => Family::Categorical,
74            Self::Dirichlet { .. } => Family::Dirichlet,
75        }
76    }
77
78    /// Current natural parameter vector.
79    pub fn natural_params(&self) -> Vec<f64> {
80        match self {
81            Self::Gaussian { q, .. } => q.natural_params(),
82            Self::Categorical { q, .. } => q.natural_params(),
83            Self::Dirichlet { q, .. } => q.natural_params(),
84        }
85    }
86
87    /// Current entropy `H(q)`.
88    pub fn entropy(&self) -> Result<f64> {
89        match self {
90            Self::Gaussian { q, .. } => q.entropy(),
91            Self::Categorical { q, .. } => q.entropy(),
92            Self::Dirichlet { q, .. } => q.entropy(),
93        }
94    }
95
96    /// KL from the current variational posterior to its prior.
97    pub fn kl_to_prior(&self) -> Result<f64> {
98        match self {
99            Self::Gaussian { q, prior } => gaussian_kl(q, prior),
100            Self::Categorical { q, prior } => categorical_kl(q, prior),
101            Self::Dirichlet { q, prior } => dirichlet_kl(q, prior),
102        }
103    }
104}
105
106// ---------------------------------------------------------------------------
107// Factor tagging
108// ---------------------------------------------------------------------------
109
110/// Conjugate relationship represented by a single factor.
111#[derive(Clone, Debug)]
112pub enum VmpFactor {
113    /// Gaussian observation with known precision, centred on another Gaussian
114    /// variable. Produces a natural-parameter delta of
115    /// `η_delta = [τ_obs · y]` with posterior precision contribution `τ_obs`.
116    GaussianObservation {
117        /// Variable whose mean is being inferred.
118        target: String,
119        /// Observed value y.
120        observation: f64,
121        /// Known observation precision τ_obs.
122        precision: f64,
123    },
124    /// Gaussian `x_child ~ N(x_parent, 1/τ)` — a "Gaussian step" between two
125    /// unknown means sharing a known precision. Both endpoints receive a
126    /// symmetric natural-parameter delta driven by the other's expected mean.
127    GaussianStep {
128        /// Endpoint 1 variable name.
129        lhs: String,
130        /// Endpoint 2 variable name.
131        rhs: String,
132        /// Known precision τ.
133        precision: f64,
134    },
135    /// `x ~ Categorical(π)` where π is itself a Dirichlet variable. Updates the
136    /// Dirichlet concentration by `E_q[u(x)] = softmax(η_x)` and contributes
137    /// `E_q[log π]` back to the Categorical natural parameters.
138    DirichletCategorical {
139        /// Dirichlet-distributed variable π.
140        dirichlet: String,
141        /// Categorical-distributed variable x.
142        categorical: String,
143    },
144    /// Observed categorical value (evidence). Only contributes counts to its
145    /// Dirichlet parent; the categorical itself is pinned.
146    CategoricalObservation {
147        /// Dirichlet-distributed variable π.
148        dirichlet: String,
149        /// Observed category index.
150        observation: usize,
151        /// Number of categories.
152        num_categories: usize,
153    },
154}
155
156// ---------------------------------------------------------------------------
157// Engine configuration
158// ---------------------------------------------------------------------------
159
160/// User-facing configuration describing a VMP problem.
161#[derive(Clone, Debug, Default)]
162pub struct VmpConfig {
163    /// Per-variable variational state.
164    pub states: HashMap<String, VariationalState>,
165    /// Factors (conjugate relationships).
166    pub factors: Vec<VmpFactor>,
167    /// Maximum iterations.
168    pub max_iterations: usize,
169    /// Convergence tolerance on both ELBO change and the max L∞ residual.
170    pub tolerance: f64,
171    /// Maximum allowed ELBO decrease before the engine bails out with a
172    /// `ConvergenceFailure` error (guards against numerical divergence).
173    pub divergence_tolerance: f64,
174}
175
176impl VmpConfig {
177    /// Build an empty configuration with sensible defaults.
178    pub fn new() -> Self {
179        Self {
180            states: HashMap::new(),
181            factors: Vec::new(),
182            max_iterations: 100,
183            tolerance: 1e-6,
184            divergence_tolerance: 1e-4,
185        }
186    }
187
188    /// Register a Gaussian variable with a prior `N(prior_mean, 1/precision)`.
189    pub fn with_gaussian(mut self, name: &str, prior_mean: f64, precision: f64) -> Result<Self> {
190        let prior = GaussianNP::new(prior_mean, precision)?;
191        let q = prior.clone();
192        self.states
193            .insert(name.to_string(), VariationalState::Gaussian { q, prior });
194        Ok(self)
195    }
196
197    /// Register a Categorical variable with a flat prior over `k` categories.
198    pub fn with_categorical(mut self, name: &str, num_categories: usize) -> Result<Self> {
199        if num_categories == 0 {
200            return Err(PgmError::InvalidDistribution(
201                "Categorical needs at least one category".to_string(),
202            ));
203        }
204        let probs = vec![1.0 / num_categories as f64; num_categories];
205        let prior = CategoricalNP::from_probs(&probs)?;
206        let q = prior.clone();
207        self.states
208            .insert(name.to_string(), VariationalState::Categorical { q, prior });
209        Ok(self)
210    }
211
212    /// Register a Dirichlet variable with prior concentration α.
213    pub fn with_dirichlet(mut self, name: &str, concentration: Vec<f64>) -> Result<Self> {
214        let prior = DirichletNP::new(concentration)?;
215        let q = prior.clone();
216        self.states
217            .insert(name.to_string(), VariationalState::Dirichlet { q, prior });
218        Ok(self)
219    }
220
221    /// Append a VMP factor.
222    pub fn with_factor(mut self, factor: VmpFactor) -> Self {
223        self.factors.push(factor);
224        self
225    }
226
227    /// Override the max iterations / tolerance pair.
228    pub fn with_limits(mut self, max_iterations: usize, tolerance: f64) -> Self {
229        self.max_iterations = max_iterations;
230        self.tolerance = tolerance;
231        self
232    }
233
234    /// Ensure every variable appearing in a factor is registered with a family
235    /// and that its family matches the factor's expectation.
236    fn validate(&self) -> Result<()> {
237        for f in &self.factors {
238            match f {
239                VmpFactor::GaussianObservation { target, .. } => {
240                    let state = self
241                        .states
242                        .get(target)
243                        .ok_or_else(|| PgmError::VariableNotFound(target.clone()))?;
244                    if !matches!(state, VariationalState::Gaussian { .. }) {
245                        return Err(PgmError::InvalidGraph(format!(
246                            "GaussianObservation on non-Gaussian variable '{}'",
247                            target
248                        )));
249                    }
250                }
251                VmpFactor::GaussianStep { lhs, rhs, .. } => {
252                    for v in [lhs, rhs] {
253                        let state = self
254                            .states
255                            .get(v)
256                            .ok_or_else(|| PgmError::VariableNotFound(v.clone()))?;
257                        if !matches!(state, VariationalState::Gaussian { .. }) {
258                            return Err(PgmError::InvalidGraph(format!(
259                                "GaussianStep on non-Gaussian variable '{}'",
260                                v
261                            )));
262                        }
263                    }
264                }
265                VmpFactor::DirichletCategorical {
266                    dirichlet,
267                    categorical,
268                } => {
269                    let d = self
270                        .states
271                        .get(dirichlet)
272                        .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?;
273                    let c = self
274                        .states
275                        .get(categorical)
276                        .ok_or_else(|| PgmError::VariableNotFound(categorical.clone()))?;
277                    if !matches!(d, VariationalState::Dirichlet { .. }) {
278                        return Err(PgmError::InvalidGraph(format!(
279                            "DirichletCategorical: '{}' is not a Dirichlet variable",
280                            dirichlet
281                        )));
282                    }
283                    if !matches!(c, VariationalState::Categorical { .. }) {
284                        return Err(PgmError::InvalidGraph(format!(
285                            "DirichletCategorical: '{}' is not a Categorical variable",
286                            categorical
287                        )));
288                    }
289                }
290                VmpFactor::CategoricalObservation {
291                    dirichlet,
292                    num_categories,
293                    observation,
294                } => {
295                    let d = self
296                        .states
297                        .get(dirichlet)
298                        .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?;
299                    match d {
300                        VariationalState::Dirichlet { q, .. } => {
301                            if q.concentration.len() != *num_categories {
302                                return Err(PgmError::DimensionMismatch {
303                                    expected: vec![*num_categories],
304                                    got: vec![q.concentration.len()],
305                                });
306                            }
307                            if observation >= num_categories {
308                                return Err(PgmError::InvalidDistribution(format!(
309                                    "observation {} out of range for {} categories",
310                                    observation, num_categories
311                                )));
312                            }
313                        }
314                        _ => {
315                            return Err(PgmError::InvalidGraph(format!(
316                                "CategoricalObservation: '{}' is not a Dirichlet variable",
317                                dirichlet
318                            )));
319                        }
320                    }
321                }
322            }
323        }
324        Ok(())
325    }
326}
327
328// ---------------------------------------------------------------------------
329// Result
330// ---------------------------------------------------------------------------
331
332/// Summary of a VMP run.
333#[derive(Clone, Debug)]
334pub struct VmpResult {
335    /// Final variational distributions, keyed by variable name.
336    pub states: HashMap<String, VariationalState>,
337    /// ELBO value at each iteration (length = iterations run + 1).
338    pub elbo_history: Vec<f64>,
339    /// Iterations actually run.
340    pub iterations: usize,
341    /// Whether the run met the tolerance criterion.
342    pub converged: bool,
343}
344
345// ---------------------------------------------------------------------------
346// Engine
347// ---------------------------------------------------------------------------
348
349/// The VMP coordinate-ascent engine.
350///
351/// The engine itself is a thin state machine that holds the [`VmpConfig`] and
352/// drives the per-variable update rules defined by each `VmpFactor` variant.
353pub struct VariationalMessagePassing {
354    config: VmpConfig,
355    update_order: Vec<String>,
356}
357
358impl VariationalMessagePassing {
359    /// Build an engine from a validated configuration.
360    pub fn new(config: VmpConfig) -> Result<Self> {
361        config.validate()?;
362        let mut keys: Vec<String> = config.states.keys().cloned().collect();
363        keys.sort(); // deterministic update order
364        Ok(Self {
365            config,
366            update_order: keys,
367        })
368    }
369
370    /// Construct an engine from an already-existing `FactorGraph` (structure
371    /// only) plus the VMP annotations. The graph is consulted to validate that
372    /// every factor references variables that the user *also* registered with
373    /// the factor-graph API, which is useful when VMP is layered on top of the
374    /// generic PGM pipeline.
375    pub fn with_graph(graph: &FactorGraph, config: VmpConfig) -> Result<Self> {
376        for v in config.states.keys() {
377            if graph.get_variable(v).is_none() {
378                return Err(PgmError::VariableNotFound(format!(
379                    "'{}' declared in VmpConfig but missing from FactorGraph",
380                    v
381                )));
382            }
383        }
384        Self::new(config)
385    }
386
387    /// Run the coordinate-ascent loop.
388    pub fn run(&mut self) -> Result<VmpResult> {
389        let elbo0 = self.compute_elbo()?;
390        let mut elbo_history = vec![elbo0];
391        let mut converged = false;
392        let mut iterations = 0;
393
394        for iter in 0..self.config.max_iterations {
395            let snapshot = self.snapshot_natural_params();
396            self.coordinate_sweep()?;
397            let elbo_new = self.compute_elbo()?;
398            let prev = *elbo_history
399                .last()
400                .ok_or_else(|| PgmError::ConvergenceFailure("elbo history is empty".into()))?;
401
402            // Divergence check: the ELBO is guaranteed to be non-decreasing for
403            // exact conjugate VMP, so a drop larger than the divergence
404            // tolerance is a red flag (numerical breakdown or ill-posed model).
405            if elbo_new < prev - self.config.divergence_tolerance {
406                return Err(PgmError::ConvergenceFailure(format!(
407                    "VMP ELBO decreased from {} to {} at iteration {}",
408                    prev, elbo_new, iter
409                )));
410            }
411
412            elbo_history.push(elbo_new);
413            iterations = iter + 1;
414
415            let linf = self.linf_from_snapshot(&snapshot);
416            let elbo_delta = (elbo_new - prev).abs();
417            if elbo_delta < self.config.tolerance || linf < self.config.tolerance {
418                converged = true;
419                break;
420            }
421        }
422
423        Ok(VmpResult {
424            states: self.config.states.clone(),
425            elbo_history,
426            iterations,
427            converged,
428        })
429    }
430
431    /// Single coordinate sweep across all variables in deterministic order.
432    fn coordinate_sweep(&mut self) -> Result<()> {
433        // We iterate by name over a cloned order so that the engine can mutate
434        // `self.config.states` freely inside the loop.
435        let order = self.update_order.clone();
436        for var in order {
437            self.update_variable(&var)?;
438        }
439        Ok(())
440    }
441
442    /// Compute the natural-parameter update for one variable by aggregating the
443    /// contributions of every adjacent factor.
444    fn update_variable(&mut self, var: &str) -> Result<()> {
445        let state = self
446            .config
447            .states
448            .get(var)
449            .cloned()
450            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
451        match state.family() {
452            Family::Gaussian => self.update_gaussian(var),
453            Family::Categorical => self.update_categorical(var),
454            Family::Dirichlet => self.update_dirichlet(var),
455        }
456    }
457
458    fn update_gaussian(&mut self, var: &str) -> Result<()> {
459        // Pull the prior natural parameters.
460        let (mut posterior_precision, mut posterior_natural_mean) = match self
461            .config
462            .states
463            .get(var)
464            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
465        {
466            VariationalState::Gaussian { prior, .. } => {
467                // prior contributes η_prior = τ_prior · μ_prior and precision τ_prior.
468                (prior.precision, prior.precision * prior.mean)
469            }
470            _ => unreachable!("non-Gaussian state reached update_gaussian"),
471        };
472
473        for factor in &self.config.factors {
474            match factor {
475                VmpFactor::GaussianObservation {
476                    target,
477                    observation,
478                    precision,
479                } if target == var => {
480                    posterior_precision += precision;
481                    posterior_natural_mean += precision * observation;
482                }
483                VmpFactor::GaussianStep {
484                    lhs,
485                    rhs,
486                    precision,
487                } => {
488                    // Symmetric Gaussian step: each endpoint observes E[q(other)]
489                    // with the given precision τ.
490                    let (other, is_self) = if lhs == var {
491                        (rhs, true)
492                    } else if rhs == var {
493                        (lhs, true)
494                    } else {
495                        (lhs, false)
496                    };
497                    if is_self {
498                        let other_mean = match self
499                            .config
500                            .states
501                            .get(other)
502                            .ok_or_else(|| PgmError::VariableNotFound(other.clone()))?
503                        {
504                            VariationalState::Gaussian { q, .. } => q.mean,
505                            _ => {
506                                return Err(PgmError::InvalidGraph(format!(
507                                    "GaussianStep neighbour '{}' is not Gaussian",
508                                    other
509                                )));
510                            }
511                        };
512                        posterior_precision += precision;
513                        posterior_natural_mean += precision * other_mean;
514                    }
515                }
516                _ => {}
517            }
518        }
519
520        if posterior_precision <= 0.0 || !posterior_precision.is_finite() {
521            return Err(PgmError::InvalidDistribution(format!(
522                "Gaussian posterior precision must be positive (got {})",
523                posterior_precision
524            )));
525        }
526
527        let new_mean = posterior_natural_mean / posterior_precision;
528        let state = self
529            .config
530            .states
531            .get_mut(var)
532            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
533        if let VariationalState::Gaussian { q, .. } = state {
534            // The stored q carries the *effective* posterior precision derived
535            // from the prior plus every adjacent observation / step factor; the
536            // prior's precision is preserved separately, which is what keeps
537            // the KL to the prior well defined (see `VariationalState::kl_to_prior`).
538            q.precision = posterior_precision;
539            q.mean = new_mean;
540        }
541        Ok(())
542    }
543
544    fn update_categorical(&mut self, var: &str) -> Result<()> {
545        let num_categories = match self
546            .config
547            .states
548            .get(var)
549            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
550        {
551            VariationalState::Categorical { q, .. } => q.num_categories(),
552            _ => unreachable!(),
553        };
554
555        // Start from the prior natural parameters.
556        let mut natural = match self
557            .config
558            .states
559            .get(var)
560            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
561        {
562            VariationalState::Categorical { prior, .. } => prior.natural_params(),
563            _ => unreachable!(),
564        };
565
566        for factor in &self.config.factors {
567            if let VmpFactor::DirichletCategorical {
568                dirichlet,
569                categorical,
570            } = factor
571            {
572                if categorical == var {
573                    let dir_state = self
574                        .config
575                        .states
576                        .get(dirichlet)
577                        .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?;
578                    if let VariationalState::Dirichlet { q, .. } = dir_state {
579                        let e_log_pi = q.expected_sufficient_statistics();
580                        if e_log_pi.len() != num_categories {
581                            return Err(PgmError::DimensionMismatch {
582                                expected: vec![num_categories],
583                                got: vec![e_log_pi.len()],
584                            });
585                        }
586                        for (a, b) in natural.iter_mut().zip(e_log_pi.iter()) {
587                            *a += *b;
588                        }
589                    }
590                }
591            }
592        }
593
594        let state = self
595            .config
596            .states
597            .get_mut(var)
598            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
599        if let VariationalState::Categorical { q, .. } = state {
600            q.set_natural(&natural)?;
601        }
602        Ok(())
603    }
604
605    fn update_dirichlet(&mut self, var: &str) -> Result<()> {
606        // Start from prior natural parameters.
607        let mut natural = match self
608            .config
609            .states
610            .get(var)
611            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
612        {
613            VariationalState::Dirichlet { prior, .. } => prior.natural_params(),
614            _ => unreachable!(),
615        };
616
617        let num_components = natural.len();
618        for factor in &self.config.factors {
619            match factor {
620                VmpFactor::DirichletCategorical {
621                    dirichlet,
622                    categorical,
623                } if dirichlet == var => {
624                    let cat_state = self
625                        .config
626                        .states
627                        .get(categorical)
628                        .ok_or_else(|| PgmError::VariableNotFound(categorical.clone()))?;
629                    if let VariationalState::Categorical { q, .. } = cat_state {
630                        let expected_counts = q.expected_sufficient_statistics();
631                        if expected_counts.len() != num_components {
632                            return Err(PgmError::DimensionMismatch {
633                                expected: vec![num_components],
634                                got: vec![expected_counts.len()],
635                            });
636                        }
637                        for (a, b) in natural.iter_mut().zip(expected_counts.iter()) {
638                            *a += *b;
639                        }
640                    }
641                }
642                VmpFactor::CategoricalObservation {
643                    dirichlet,
644                    observation,
645                    num_categories,
646                } if dirichlet == var => {
647                    if *num_categories != num_components {
648                        return Err(PgmError::DimensionMismatch {
649                            expected: vec![num_components],
650                            got: vec![*num_categories],
651                        });
652                    }
653                    if let Some(slot) = natural.get_mut(*observation) {
654                        *slot += 1.0;
655                    } else {
656                        return Err(PgmError::InvalidDistribution(format!(
657                            "observation {} out of range for {} categories",
658                            observation, num_categories
659                        )));
660                    }
661                }
662                _ => {}
663            }
664        }
665
666        let state = self
667            .config
668            .states
669            .get_mut(var)
670            .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
671        if let VariationalState::Dirichlet { q, .. } = state {
672            q.set_natural(&natural)?;
673        }
674        Ok(())
675    }
676
677    // ---------------------------------------------------------------------
678    // ELBO
679    // ---------------------------------------------------------------------
680
681    /// Evidence Lower Bound `L(q) = E_q[log p(x, z)] − E_q[log q(z)]`.
682    ///
683    /// For the three conjugate relationships shipped in v0.2.0 the ELBO
684    /// decomposes as `Σ E_q[log p(factor)] − Σ KL(q(v) || prior(v))` because
685    /// each prior cancels with the log p(z) term.
686    pub fn compute_elbo(&self) -> Result<f64> {
687        let mut elbo = 0.0;
688        // Likelihood contributions from each factor.
689        for factor in &self.config.factors {
690            elbo += self.factor_expected_log_joint(factor)?;
691        }
692        // − KL(q(v) || prior(v)) for every variable.
693        for state in self.config.states.values() {
694            elbo -= state.kl_to_prior()?;
695        }
696        Ok(elbo)
697    }
698
699    fn factor_expected_log_joint(&self, factor: &VmpFactor) -> Result<f64> {
700        match factor {
701            VmpFactor::GaussianObservation {
702                target,
703                observation,
704                precision,
705            } => {
706                let state = self
707                    .config
708                    .states
709                    .get(target)
710                    .ok_or_else(|| PgmError::VariableNotFound(target.clone()))?;
711                if let VariationalState::Gaussian { q, .. } = state {
712                    // E_q[log N(y | μ, 1/τ)] = ½ log(τ / 2π) − (τ/2) (E[μ²] + y² − 2 y E[μ]).
713                    // For q with precision τ_q (posterior effective), E[μ] = μ_q, Var[μ] = 1/τ_q.
714                    let e_mu = q.mean;
715                    let e_mu2 = q.mean * q.mean + 1.0 / q.precision;
716                    let y = *observation;
717                    let p = *precision;
718                    let coef = 0.5 * p;
719                    let log_norm = 0.5 * (p / (2.0 * std::f64::consts::PI)).ln();
720                    Ok(log_norm - coef * (e_mu2 + y * y - 2.0 * y * e_mu))
721                } else {
722                    Err(PgmError::InvalidGraph(format!(
723                        "GaussianObservation target '{}' is not Gaussian",
724                        target
725                    )))
726                }
727            }
728            VmpFactor::GaussianStep {
729                lhs,
730                rhs,
731                precision,
732            } => {
733                let lq = match self
734                    .config
735                    .states
736                    .get(lhs)
737                    .ok_or_else(|| PgmError::VariableNotFound(lhs.clone()))?
738                {
739                    VariationalState::Gaussian { q, .. } => q,
740                    _ => {
741                        return Err(PgmError::InvalidGraph(format!(
742                            "GaussianStep endpoint '{}' is not Gaussian",
743                            lhs
744                        )));
745                    }
746                };
747                let rq = match self
748                    .config
749                    .states
750                    .get(rhs)
751                    .ok_or_else(|| PgmError::VariableNotFound(rhs.clone()))?
752                {
753                    VariationalState::Gaussian { q, .. } => q,
754                    _ => {
755                        return Err(PgmError::InvalidGraph(format!(
756                            "GaussianStep endpoint '{}' is not Gaussian",
757                            rhs
758                        )));
759                    }
760                };
761                // E_q[log N(lhs | rhs, 1/τ)]
762                let e_l = lq.mean;
763                let e_l2 = lq.mean * lq.mean + 1.0 / lq.precision;
764                let e_r = rq.mean;
765                let e_r2 = rq.mean * rq.mean + 1.0 / rq.precision;
766                let log_norm = 0.5 * (precision / (2.0 * std::f64::consts::PI)).ln();
767                let coef = 0.5 * precision;
768                Ok(log_norm - coef * (e_l2 - 2.0 * e_l * e_r + e_r2))
769            }
770            VmpFactor::DirichletCategorical {
771                dirichlet,
772                categorical,
773            } => {
774                let d = match self
775                    .config
776                    .states
777                    .get(dirichlet)
778                    .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?
779                {
780                    VariationalState::Dirichlet { q, .. } => q,
781                    _ => {
782                        return Err(PgmError::InvalidGraph(format!(
783                            "DirichletCategorical: '{}' not Dirichlet",
784                            dirichlet
785                        )));
786                    }
787                };
788                let c = match self
789                    .config
790                    .states
791                    .get(categorical)
792                    .ok_or_else(|| PgmError::VariableNotFound(categorical.clone()))?
793                {
794                    VariationalState::Categorical { q, .. } => q,
795                    _ => {
796                        return Err(PgmError::InvalidGraph(format!(
797                            "DirichletCategorical: '{}' not Categorical",
798                            categorical
799                        )));
800                    }
801                };
802                let e_log_pi = d.expected_sufficient_statistics();
803                let probs = c.probs();
804                if e_log_pi.len() != probs.len() {
805                    return Err(PgmError::DimensionMismatch {
806                        expected: vec![probs.len()],
807                        got: vec![e_log_pi.len()],
808                    });
809                }
810                Ok(e_log_pi.iter().zip(probs.iter()).map(|(l, p)| l * p).sum())
811            }
812            VmpFactor::CategoricalObservation {
813                dirichlet,
814                observation,
815                ..
816            } => {
817                let d = match self
818                    .config
819                    .states
820                    .get(dirichlet)
821                    .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?
822                {
823                    VariationalState::Dirichlet { q, .. } => q,
824                    _ => {
825                        return Err(PgmError::InvalidGraph(format!(
826                            "CategoricalObservation: '{}' not Dirichlet",
827                            dirichlet
828                        )));
829                    }
830                };
831                let e_log_pi = d.expected_sufficient_statistics();
832                e_log_pi.get(*observation).cloned().ok_or_else(|| {
833                    PgmError::InvalidDistribution(format!(
834                        "observation {} out of range for Dirichlet with {} components",
835                        observation,
836                        e_log_pi.len()
837                    ))
838                })
839            }
840        }
841    }
842
843    // ---------------------------------------------------------------------
844    // Helpers
845    // ---------------------------------------------------------------------
846
847    fn snapshot_natural_params(&self) -> HashMap<String, Vec<f64>> {
848        self.config
849            .states
850            .iter()
851            .map(|(k, v)| (k.clone(), v.natural_params()))
852            .collect()
853    }
854
855    fn linf_from_snapshot(&self, snapshot: &HashMap<String, Vec<f64>>) -> f64 {
856        let mut max = 0.0f64;
857        for (k, v) in &self.config.states {
858            let before = match snapshot.get(k) {
859                Some(vec) => vec,
860                None => continue,
861            };
862            for (a, b) in v.natural_params().iter().zip(before.iter()) {
863                max = max.max((a - b).abs());
864            }
865        }
866        max
867    }
868
869    /// Read-only access to the current states.
870    pub fn states(&self) -> &HashMap<String, VariationalState> {
871        &self.config.states
872    }
873}