1use std::collections::HashMap;
30
31use crate::error::{PgmError, Result};
32use crate::graph::FactorGraph;
33
34use super::distributions::{
35 categorical_kl, dirichlet_kl, gaussian_kl, CategoricalNP, DirichletNP, GaussianNP,
36};
37use super::exponential_family::ExponentialFamily;
38
39#[derive(Clone, Debug, PartialEq, Eq, Hash)]
45pub enum Family {
46 Gaussian,
48 Categorical,
50 Dirichlet,
52}
53
54#[derive(Clone, Debug)]
56pub enum VariationalState {
57 Gaussian { q: GaussianNP, prior: GaussianNP },
59 Categorical {
61 q: CategoricalNP,
62 prior: CategoricalNP,
63 },
64 Dirichlet { q: DirichletNP, prior: DirichletNP },
66}
67
68impl VariationalState {
69 pub fn family(&self) -> Family {
71 match self {
72 Self::Gaussian { .. } => Family::Gaussian,
73 Self::Categorical { .. } => Family::Categorical,
74 Self::Dirichlet { .. } => Family::Dirichlet,
75 }
76 }
77
78 pub fn natural_params(&self) -> Vec<f64> {
80 match self {
81 Self::Gaussian { q, .. } => q.natural_params(),
82 Self::Categorical { q, .. } => q.natural_params(),
83 Self::Dirichlet { q, .. } => q.natural_params(),
84 }
85 }
86
87 pub fn entropy(&self) -> Result<f64> {
89 match self {
90 Self::Gaussian { q, .. } => q.entropy(),
91 Self::Categorical { q, .. } => q.entropy(),
92 Self::Dirichlet { q, .. } => q.entropy(),
93 }
94 }
95
96 pub fn kl_to_prior(&self) -> Result<f64> {
98 match self {
99 Self::Gaussian { q, prior } => gaussian_kl(q, prior),
100 Self::Categorical { q, prior } => categorical_kl(q, prior),
101 Self::Dirichlet { q, prior } => dirichlet_kl(q, prior),
102 }
103 }
104}
105
106#[derive(Clone, Debug)]
112pub enum VmpFactor {
113 GaussianObservation {
117 target: String,
119 observation: f64,
121 precision: f64,
123 },
124 GaussianStep {
128 lhs: String,
130 rhs: String,
132 precision: f64,
134 },
135 DirichletCategorical {
139 dirichlet: String,
141 categorical: String,
143 },
144 CategoricalObservation {
147 dirichlet: String,
149 observation: usize,
151 num_categories: usize,
153 },
154}
155
156#[derive(Clone, Debug, Default)]
162pub struct VmpConfig {
163 pub states: HashMap<String, VariationalState>,
165 pub factors: Vec<VmpFactor>,
167 pub max_iterations: usize,
169 pub tolerance: f64,
171 pub divergence_tolerance: f64,
174}
175
176impl VmpConfig {
177 pub fn new() -> Self {
179 Self {
180 states: HashMap::new(),
181 factors: Vec::new(),
182 max_iterations: 100,
183 tolerance: 1e-6,
184 divergence_tolerance: 1e-4,
185 }
186 }
187
188 pub fn with_gaussian(mut self, name: &str, prior_mean: f64, precision: f64) -> Result<Self> {
190 let prior = GaussianNP::new(prior_mean, precision)?;
191 let q = prior.clone();
192 self.states
193 .insert(name.to_string(), VariationalState::Gaussian { q, prior });
194 Ok(self)
195 }
196
197 pub fn with_categorical(mut self, name: &str, num_categories: usize) -> Result<Self> {
199 if num_categories == 0 {
200 return Err(PgmError::InvalidDistribution(
201 "Categorical needs at least one category".to_string(),
202 ));
203 }
204 let probs = vec![1.0 / num_categories as f64; num_categories];
205 let prior = CategoricalNP::from_probs(&probs)?;
206 let q = prior.clone();
207 self.states
208 .insert(name.to_string(), VariationalState::Categorical { q, prior });
209 Ok(self)
210 }
211
212 pub fn with_dirichlet(mut self, name: &str, concentration: Vec<f64>) -> Result<Self> {
214 let prior = DirichletNP::new(concentration)?;
215 let q = prior.clone();
216 self.states
217 .insert(name.to_string(), VariationalState::Dirichlet { q, prior });
218 Ok(self)
219 }
220
221 pub fn with_factor(mut self, factor: VmpFactor) -> Self {
223 self.factors.push(factor);
224 self
225 }
226
227 pub fn with_limits(mut self, max_iterations: usize, tolerance: f64) -> Self {
229 self.max_iterations = max_iterations;
230 self.tolerance = tolerance;
231 self
232 }
233
234 fn validate(&self) -> Result<()> {
237 for f in &self.factors {
238 match f {
239 VmpFactor::GaussianObservation { target, .. } => {
240 let state = self
241 .states
242 .get(target)
243 .ok_or_else(|| PgmError::VariableNotFound(target.clone()))?;
244 if !matches!(state, VariationalState::Gaussian { .. }) {
245 return Err(PgmError::InvalidGraph(format!(
246 "GaussianObservation on non-Gaussian variable '{}'",
247 target
248 )));
249 }
250 }
251 VmpFactor::GaussianStep { lhs, rhs, .. } => {
252 for v in [lhs, rhs] {
253 let state = self
254 .states
255 .get(v)
256 .ok_or_else(|| PgmError::VariableNotFound(v.clone()))?;
257 if !matches!(state, VariationalState::Gaussian { .. }) {
258 return Err(PgmError::InvalidGraph(format!(
259 "GaussianStep on non-Gaussian variable '{}'",
260 v
261 )));
262 }
263 }
264 }
265 VmpFactor::DirichletCategorical {
266 dirichlet,
267 categorical,
268 } => {
269 let d = self
270 .states
271 .get(dirichlet)
272 .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?;
273 let c = self
274 .states
275 .get(categorical)
276 .ok_or_else(|| PgmError::VariableNotFound(categorical.clone()))?;
277 if !matches!(d, VariationalState::Dirichlet { .. }) {
278 return Err(PgmError::InvalidGraph(format!(
279 "DirichletCategorical: '{}' is not a Dirichlet variable",
280 dirichlet
281 )));
282 }
283 if !matches!(c, VariationalState::Categorical { .. }) {
284 return Err(PgmError::InvalidGraph(format!(
285 "DirichletCategorical: '{}' is not a Categorical variable",
286 categorical
287 )));
288 }
289 }
290 VmpFactor::CategoricalObservation {
291 dirichlet,
292 num_categories,
293 observation,
294 } => {
295 let d = self
296 .states
297 .get(dirichlet)
298 .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?;
299 match d {
300 VariationalState::Dirichlet { q, .. } => {
301 if q.concentration.len() != *num_categories {
302 return Err(PgmError::DimensionMismatch {
303 expected: vec![*num_categories],
304 got: vec![q.concentration.len()],
305 });
306 }
307 if observation >= num_categories {
308 return Err(PgmError::InvalidDistribution(format!(
309 "observation {} out of range for {} categories",
310 observation, num_categories
311 )));
312 }
313 }
314 _ => {
315 return Err(PgmError::InvalidGraph(format!(
316 "CategoricalObservation: '{}' is not a Dirichlet variable",
317 dirichlet
318 )));
319 }
320 }
321 }
322 }
323 }
324 Ok(())
325 }
326}
327
328#[derive(Clone, Debug)]
334pub struct VmpResult {
335 pub states: HashMap<String, VariationalState>,
337 pub elbo_history: Vec<f64>,
339 pub iterations: usize,
341 pub converged: bool,
343}
344
345pub struct VariationalMessagePassing {
354 config: VmpConfig,
355 update_order: Vec<String>,
356}
357
358impl VariationalMessagePassing {
359 pub fn new(config: VmpConfig) -> Result<Self> {
361 config.validate()?;
362 let mut keys: Vec<String> = config.states.keys().cloned().collect();
363 keys.sort(); Ok(Self {
365 config,
366 update_order: keys,
367 })
368 }
369
370 pub fn with_graph(graph: &FactorGraph, config: VmpConfig) -> Result<Self> {
376 for v in config.states.keys() {
377 if graph.get_variable(v).is_none() {
378 return Err(PgmError::VariableNotFound(format!(
379 "'{}' declared in VmpConfig but missing from FactorGraph",
380 v
381 )));
382 }
383 }
384 Self::new(config)
385 }
386
387 pub fn run(&mut self) -> Result<VmpResult> {
389 let elbo0 = self.compute_elbo()?;
390 let mut elbo_history = vec![elbo0];
391 let mut converged = false;
392 let mut iterations = 0;
393
394 for iter in 0..self.config.max_iterations {
395 let snapshot = self.snapshot_natural_params();
396 self.coordinate_sweep()?;
397 let elbo_new = self.compute_elbo()?;
398 let prev = *elbo_history
399 .last()
400 .ok_or_else(|| PgmError::ConvergenceFailure("elbo history is empty".into()))?;
401
402 if elbo_new < prev - self.config.divergence_tolerance {
406 return Err(PgmError::ConvergenceFailure(format!(
407 "VMP ELBO decreased from {} to {} at iteration {}",
408 prev, elbo_new, iter
409 )));
410 }
411
412 elbo_history.push(elbo_new);
413 iterations = iter + 1;
414
415 let linf = self.linf_from_snapshot(&snapshot);
416 let elbo_delta = (elbo_new - prev).abs();
417 if elbo_delta < self.config.tolerance || linf < self.config.tolerance {
418 converged = true;
419 break;
420 }
421 }
422
423 Ok(VmpResult {
424 states: self.config.states.clone(),
425 elbo_history,
426 iterations,
427 converged,
428 })
429 }
430
431 fn coordinate_sweep(&mut self) -> Result<()> {
433 let order = self.update_order.clone();
436 for var in order {
437 self.update_variable(&var)?;
438 }
439 Ok(())
440 }
441
442 fn update_variable(&mut self, var: &str) -> Result<()> {
445 let state = self
446 .config
447 .states
448 .get(var)
449 .cloned()
450 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
451 match state.family() {
452 Family::Gaussian => self.update_gaussian(var),
453 Family::Categorical => self.update_categorical(var),
454 Family::Dirichlet => self.update_dirichlet(var),
455 }
456 }
457
458 fn update_gaussian(&mut self, var: &str) -> Result<()> {
459 let (mut posterior_precision, mut posterior_natural_mean) = match self
461 .config
462 .states
463 .get(var)
464 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
465 {
466 VariationalState::Gaussian { prior, .. } => {
467 (prior.precision, prior.precision * prior.mean)
469 }
470 _ => unreachable!("non-Gaussian state reached update_gaussian"),
471 };
472
473 for factor in &self.config.factors {
474 match factor {
475 VmpFactor::GaussianObservation {
476 target,
477 observation,
478 precision,
479 } if target == var => {
480 posterior_precision += precision;
481 posterior_natural_mean += precision * observation;
482 }
483 VmpFactor::GaussianStep {
484 lhs,
485 rhs,
486 precision,
487 } => {
488 let (other, is_self) = if lhs == var {
491 (rhs, true)
492 } else if rhs == var {
493 (lhs, true)
494 } else {
495 (lhs, false)
496 };
497 if is_self {
498 let other_mean = match self
499 .config
500 .states
501 .get(other)
502 .ok_or_else(|| PgmError::VariableNotFound(other.clone()))?
503 {
504 VariationalState::Gaussian { q, .. } => q.mean,
505 _ => {
506 return Err(PgmError::InvalidGraph(format!(
507 "GaussianStep neighbour '{}' is not Gaussian",
508 other
509 )));
510 }
511 };
512 posterior_precision += precision;
513 posterior_natural_mean += precision * other_mean;
514 }
515 }
516 _ => {}
517 }
518 }
519
520 if posterior_precision <= 0.0 || !posterior_precision.is_finite() {
521 return Err(PgmError::InvalidDistribution(format!(
522 "Gaussian posterior precision must be positive (got {})",
523 posterior_precision
524 )));
525 }
526
527 let new_mean = posterior_natural_mean / posterior_precision;
528 let state = self
529 .config
530 .states
531 .get_mut(var)
532 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
533 if let VariationalState::Gaussian { q, .. } = state {
534 q.precision = posterior_precision;
539 q.mean = new_mean;
540 }
541 Ok(())
542 }
543
544 fn update_categorical(&mut self, var: &str) -> Result<()> {
545 let num_categories = match self
546 .config
547 .states
548 .get(var)
549 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
550 {
551 VariationalState::Categorical { q, .. } => q.num_categories(),
552 _ => unreachable!(),
553 };
554
555 let mut natural = match self
557 .config
558 .states
559 .get(var)
560 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
561 {
562 VariationalState::Categorical { prior, .. } => prior.natural_params(),
563 _ => unreachable!(),
564 };
565
566 for factor in &self.config.factors {
567 if let VmpFactor::DirichletCategorical {
568 dirichlet,
569 categorical,
570 } = factor
571 {
572 if categorical == var {
573 let dir_state = self
574 .config
575 .states
576 .get(dirichlet)
577 .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?;
578 if let VariationalState::Dirichlet { q, .. } = dir_state {
579 let e_log_pi = q.expected_sufficient_statistics();
580 if e_log_pi.len() != num_categories {
581 return Err(PgmError::DimensionMismatch {
582 expected: vec![num_categories],
583 got: vec![e_log_pi.len()],
584 });
585 }
586 for (a, b) in natural.iter_mut().zip(e_log_pi.iter()) {
587 *a += *b;
588 }
589 }
590 }
591 }
592 }
593
594 let state = self
595 .config
596 .states
597 .get_mut(var)
598 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
599 if let VariationalState::Categorical { q, .. } = state {
600 q.set_natural(&natural)?;
601 }
602 Ok(())
603 }
604
605 fn update_dirichlet(&mut self, var: &str) -> Result<()> {
606 let mut natural = match self
608 .config
609 .states
610 .get(var)
611 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
612 {
613 VariationalState::Dirichlet { prior, .. } => prior.natural_params(),
614 _ => unreachable!(),
615 };
616
617 let num_components = natural.len();
618 for factor in &self.config.factors {
619 match factor {
620 VmpFactor::DirichletCategorical {
621 dirichlet,
622 categorical,
623 } if dirichlet == var => {
624 let cat_state = self
625 .config
626 .states
627 .get(categorical)
628 .ok_or_else(|| PgmError::VariableNotFound(categorical.clone()))?;
629 if let VariationalState::Categorical { q, .. } = cat_state {
630 let expected_counts = q.expected_sufficient_statistics();
631 if expected_counts.len() != num_components {
632 return Err(PgmError::DimensionMismatch {
633 expected: vec![num_components],
634 got: vec![expected_counts.len()],
635 });
636 }
637 for (a, b) in natural.iter_mut().zip(expected_counts.iter()) {
638 *a += *b;
639 }
640 }
641 }
642 VmpFactor::CategoricalObservation {
643 dirichlet,
644 observation,
645 num_categories,
646 } if dirichlet == var => {
647 if *num_categories != num_components {
648 return Err(PgmError::DimensionMismatch {
649 expected: vec![num_components],
650 got: vec![*num_categories],
651 });
652 }
653 if let Some(slot) = natural.get_mut(*observation) {
654 *slot += 1.0;
655 } else {
656 return Err(PgmError::InvalidDistribution(format!(
657 "observation {} out of range for {} categories",
658 observation, num_categories
659 )));
660 }
661 }
662 _ => {}
663 }
664 }
665
666 let state = self
667 .config
668 .states
669 .get_mut(var)
670 .ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?;
671 if let VariationalState::Dirichlet { q, .. } = state {
672 q.set_natural(&natural)?;
673 }
674 Ok(())
675 }
676
677 pub fn compute_elbo(&self) -> Result<f64> {
687 let mut elbo = 0.0;
688 for factor in &self.config.factors {
690 elbo += self.factor_expected_log_joint(factor)?;
691 }
692 for state in self.config.states.values() {
694 elbo -= state.kl_to_prior()?;
695 }
696 Ok(elbo)
697 }
698
699 fn factor_expected_log_joint(&self, factor: &VmpFactor) -> Result<f64> {
700 match factor {
701 VmpFactor::GaussianObservation {
702 target,
703 observation,
704 precision,
705 } => {
706 let state = self
707 .config
708 .states
709 .get(target)
710 .ok_or_else(|| PgmError::VariableNotFound(target.clone()))?;
711 if let VariationalState::Gaussian { q, .. } = state {
712 let e_mu = q.mean;
715 let e_mu2 = q.mean * q.mean + 1.0 / q.precision;
716 let y = *observation;
717 let p = *precision;
718 let coef = 0.5 * p;
719 let log_norm = 0.5 * (p / (2.0 * std::f64::consts::PI)).ln();
720 Ok(log_norm - coef * (e_mu2 + y * y - 2.0 * y * e_mu))
721 } else {
722 Err(PgmError::InvalidGraph(format!(
723 "GaussianObservation target '{}' is not Gaussian",
724 target
725 )))
726 }
727 }
728 VmpFactor::GaussianStep {
729 lhs,
730 rhs,
731 precision,
732 } => {
733 let lq = match self
734 .config
735 .states
736 .get(lhs)
737 .ok_or_else(|| PgmError::VariableNotFound(lhs.clone()))?
738 {
739 VariationalState::Gaussian { q, .. } => q,
740 _ => {
741 return Err(PgmError::InvalidGraph(format!(
742 "GaussianStep endpoint '{}' is not Gaussian",
743 lhs
744 )));
745 }
746 };
747 let rq = match self
748 .config
749 .states
750 .get(rhs)
751 .ok_or_else(|| PgmError::VariableNotFound(rhs.clone()))?
752 {
753 VariationalState::Gaussian { q, .. } => q,
754 _ => {
755 return Err(PgmError::InvalidGraph(format!(
756 "GaussianStep endpoint '{}' is not Gaussian",
757 rhs
758 )));
759 }
760 };
761 let e_l = lq.mean;
763 let e_l2 = lq.mean * lq.mean + 1.0 / lq.precision;
764 let e_r = rq.mean;
765 let e_r2 = rq.mean * rq.mean + 1.0 / rq.precision;
766 let log_norm = 0.5 * (precision / (2.0 * std::f64::consts::PI)).ln();
767 let coef = 0.5 * precision;
768 Ok(log_norm - coef * (e_l2 - 2.0 * e_l * e_r + e_r2))
769 }
770 VmpFactor::DirichletCategorical {
771 dirichlet,
772 categorical,
773 } => {
774 let d = match self
775 .config
776 .states
777 .get(dirichlet)
778 .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?
779 {
780 VariationalState::Dirichlet { q, .. } => q,
781 _ => {
782 return Err(PgmError::InvalidGraph(format!(
783 "DirichletCategorical: '{}' not Dirichlet",
784 dirichlet
785 )));
786 }
787 };
788 let c = match self
789 .config
790 .states
791 .get(categorical)
792 .ok_or_else(|| PgmError::VariableNotFound(categorical.clone()))?
793 {
794 VariationalState::Categorical { q, .. } => q,
795 _ => {
796 return Err(PgmError::InvalidGraph(format!(
797 "DirichletCategorical: '{}' not Categorical",
798 categorical
799 )));
800 }
801 };
802 let e_log_pi = d.expected_sufficient_statistics();
803 let probs = c.probs();
804 if e_log_pi.len() != probs.len() {
805 return Err(PgmError::DimensionMismatch {
806 expected: vec![probs.len()],
807 got: vec![e_log_pi.len()],
808 });
809 }
810 Ok(e_log_pi.iter().zip(probs.iter()).map(|(l, p)| l * p).sum())
811 }
812 VmpFactor::CategoricalObservation {
813 dirichlet,
814 observation,
815 ..
816 } => {
817 let d = match self
818 .config
819 .states
820 .get(dirichlet)
821 .ok_or_else(|| PgmError::VariableNotFound(dirichlet.clone()))?
822 {
823 VariationalState::Dirichlet { q, .. } => q,
824 _ => {
825 return Err(PgmError::InvalidGraph(format!(
826 "CategoricalObservation: '{}' not Dirichlet",
827 dirichlet
828 )));
829 }
830 };
831 let e_log_pi = d.expected_sufficient_statistics();
832 e_log_pi.get(*observation).cloned().ok_or_else(|| {
833 PgmError::InvalidDistribution(format!(
834 "observation {} out of range for Dirichlet with {} components",
835 observation,
836 e_log_pi.len()
837 ))
838 })
839 }
840 }
841 }
842
843 fn snapshot_natural_params(&self) -> HashMap<String, Vec<f64>> {
848 self.config
849 .states
850 .iter()
851 .map(|(k, v)| (k.clone(), v.natural_params()))
852 .collect()
853 }
854
855 fn linf_from_snapshot(&self, snapshot: &HashMap<String, Vec<f64>>) -> f64 {
856 let mut max = 0.0f64;
857 for (k, v) in &self.config.states {
858 let before = match snapshot.get(k) {
859 Some(vec) => vec,
860 None => continue,
861 };
862 for (a, b) in v.natural_params().iter().zip(before.iter()) {
863 max = max.max((a - b).abs());
864 }
865 }
866 max
867 }
868
869 pub fn states(&self) -> &HashMap<String, VariationalState> {
871 &self.config.states
872 }
873}