Skip to main content

tensorlogic_quantrs_hooks/
parameter_learning.rs

1//! Parameter learning algorithms for probabilistic graphical models.
2//!
3//! This module provides algorithms for learning model parameters from data:
4//! - **Maximum Likelihood Estimation (MLE)**: For fully observed data
5//! - **Expectation-Maximization (EM)**: For partially observed data
6//! - **Bayesian Estimation**: With Dirichlet priors
7//! - **Baum-Welch Algorithm**: Specialized EM for Hidden Markov Models
8//!
9//! # Overview
10//!
11//! Parameter learning is the process of estimating the parameters (probabilities)
12//! of a probabilistic model from observed data. This module supports both:
13//! - **Complete data**: All variables are observed (use MLE)
14//! - **Incomplete data**: Some variables are hidden (use EM)
15//!
16//! # Examples
17//!
18//! ```ignore
19//! // Learn HMM parameters from observed sequences
20//! let mut hmm = HiddenMarkovModel::new(2, 2);
21//! let learner = BaumWelchLearner::new(100, 1e-4);
22//! learner.learn(&mut hmm, &observation_sequences)?;
23//! ```
24
25use crate::error::{PgmError, Result};
26use crate::sampling::Assignment;
27use scirs2_core::ndarray::{Array1, Array2, ArrayD};
28use std::collections::HashMap;
29
30/// Simple HMM representation for parameter learning.
31///
32/// This is a standalone representation with explicit parameter matrices,
33/// designed for efficient parameter learning algorithms like Baum-Welch.
34#[derive(Debug, Clone)]
35pub struct SimpleHMM {
36    /// Number of hidden states
37    pub num_states: usize,
38    /// Number of observable symbols
39    pub num_observations: usize,
40    /// Initial state distribution π: `[num_states]`
41    pub initial_distribution: Array1<f64>,
42    /// Transition probabilities A: [from_state, to_state]
43    pub transition_probabilities: Array2<f64>,
44    /// Emission probabilities B: [state, observation]
45    pub emission_probabilities: Array2<f64>,
46}
47
48impl SimpleHMM {
49    /// Create a new SimpleHMM with uniform initialization.
50    pub fn new(num_states: usize, num_observations: usize) -> Self {
51        let initial_distribution = Array1::from_elem(num_states, 1.0 / num_states as f64);
52
53        let transition_probabilities =
54            Array2::from_elem((num_states, num_states), 1.0 / num_states as f64);
55
56        let emission_probabilities = Array2::from_elem(
57            (num_states, num_observations),
58            1.0 / num_observations as f64,
59        );
60
61        Self {
62            num_states,
63            num_observations,
64            initial_distribution,
65            transition_probabilities,
66            emission_probabilities,
67        }
68    }
69
70    /// Create an HMM with random initialization.
71    pub fn new_random(num_states: usize, num_observations: usize) -> Self {
72        use scirs2_core::random::thread_rng;
73
74        let mut rng = thread_rng();
75        let mut hmm = Self::new(num_states, num_observations);
76
77        // Randomize initial distribution
78        let mut init_sum = 0.0;
79        for i in 0..num_states {
80            hmm.initial_distribution[i] = rng.random::<f64>();
81            init_sum += hmm.initial_distribution[i];
82        }
83        hmm.initial_distribution /= init_sum;
84
85        // Randomize transition probabilities
86        for i in 0..num_states {
87            let mut trans_sum = 0.0;
88            for j in 0..num_states {
89                hmm.transition_probabilities[[i, j]] = rng.random::<f64>();
90                trans_sum += hmm.transition_probabilities[[i, j]];
91            }
92            for j in 0..num_states {
93                hmm.transition_probabilities[[i, j]] /= trans_sum;
94            }
95        }
96
97        // Randomize emission probabilities
98        for i in 0..num_states {
99            let mut emission_sum = 0.0;
100            for j in 0..num_observations {
101                hmm.emission_probabilities[[i, j]] = rng.random::<f64>();
102                emission_sum += hmm.emission_probabilities[[i, j]];
103            }
104            for j in 0..num_observations {
105                hmm.emission_probabilities[[i, j]] /= emission_sum;
106            }
107        }
108
109        hmm
110    }
111}
112
113/// Maximum Likelihood Estimator for discrete distributions.
114///
115/// Estimates parameters by counting frequencies in complete data.
116#[derive(Debug, Clone)]
117pub struct MaximumLikelihoodEstimator {
118    /// Use Laplace smoothing (add-one smoothing)
119    pub use_laplace: bool,
120    /// Pseudocount for Laplace smoothing
121    pub pseudocount: f64,
122}
123
124impl MaximumLikelihoodEstimator {
125    /// Create a new MLE estimator.
126    pub fn new() -> Self {
127        Self {
128            use_laplace: false,
129            pseudocount: 1.0,
130        }
131    }
132
133    /// Create an MLE estimator with Laplace smoothing.
134    pub fn with_laplace(pseudocount: f64) -> Self {
135        Self {
136            use_laplace: true,
137            pseudocount,
138        }
139    }
140
141    /// Estimate parameters for a single variable from data.
142    ///
143    /// # Arguments
144    ///
145    /// * `variable` - Variable name
146    /// * `cardinality` - Number of possible values
147    /// * `data` - Observed assignments
148    ///
149    /// # Returns
150    ///
151    /// Estimated probability distribution P(variable)
152    pub fn estimate_marginal(
153        &self,
154        variable: &str,
155        cardinality: usize,
156        data: &[Assignment],
157    ) -> Result<ArrayD<f64>> {
158        let pseudocount = if self.use_laplace {
159            self.pseudocount
160        } else {
161            0.0
162        };
163        let mut counts = vec![pseudocount; cardinality];
164
165        // Count occurrences
166        for assignment in data {
167            if let Some(&value) = assignment.get(variable) {
168                if value < cardinality {
169                    counts[value] += 1.0;
170                }
171            }
172        }
173
174        // Normalize to probabilities
175        let total: f64 = counts.iter().sum();
176        if total == 0.0 {
177            return Err(PgmError::InvalidDistribution(
178                "No data for variable".to_string(),
179            ));
180        }
181
182        let probs: Vec<f64> = counts.iter().map(|&c| c / total).collect();
183
184        ArrayD::from_shape_vec(vec![cardinality], probs)
185            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
186    }
187
188    /// Estimate conditional probability table P(child | parents) from data.
189    ///
190    /// # Arguments
191    ///
192    /// * `child` - Child variable name
193    /// * `parents` - Parent variable names
194    /// * `cardinalities` - Cardinalities for [child, parent1, parent2, ...]
195    /// * `data` - Observed assignments
196    pub fn estimate_conditional(
197        &self,
198        child: &str,
199        parents: &[String],
200        cardinalities: &[usize],
201        data: &[Assignment],
202    ) -> Result<ArrayD<f64>> {
203        if cardinalities.is_empty() {
204            return Err(PgmError::InvalidGraph(
205                "Cardinalities must not be empty".to_string(),
206            ));
207        }
208
209        let pseudocount = if self.use_laplace {
210            self.pseudocount
211        } else {
212            0.0
213        };
214
215        let child_card = cardinalities[0];
216        let parent_cards = &cardinalities[1..];
217
218        // Calculate total number of parent configurations
219        let num_parent_configs: usize = parent_cards.iter().product();
220
221        // Initialize counts: [parent_config][child_value]
222        let mut counts = vec![vec![pseudocount; child_card]; num_parent_configs];
223
224        // Count co-occurrences
225        for assignment in data {
226            if let Some(&child_val) = assignment.get(child) {
227                // Compute parent configuration index
228                let mut parent_config = 0;
229                let mut multiplier = 1;
230
231                for (i, parent) in parents.iter().enumerate() {
232                    if let Some(&parent_val) = assignment.get(parent) {
233                        parent_config += parent_val * multiplier;
234                        multiplier *= parent_cards[i];
235                    } else {
236                        continue; // Skip if parent value missing
237                    }
238                }
239
240                if parent_config < num_parent_configs && child_val < child_card {
241                    counts[parent_config][child_val] += 1.0;
242                }
243            }
244        }
245
246        // Normalize each parent configuration
247        let mut probs = Vec::new();
248        for config_counts in counts {
249            let total: f64 = config_counts.iter().sum();
250            if total > 0.0 {
251                for count in config_counts {
252                    probs.push(count / total);
253                }
254            } else {
255                // Uniform distribution if no data
256                for _ in 0..child_card {
257                    probs.push(1.0 / child_card as f64);
258                }
259            }
260        }
261
262        // Shape: [parent1_card, parent2_card, ..., child_card]
263        ArrayD::from_shape_vec(cardinalities.to_vec(), probs)
264            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
265    }
266}
267
268impl Default for MaximumLikelihoodEstimator {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274/// Bayesian parameter estimator with Dirichlet priors.
275///
276/// Uses conjugate Dirichlet priors for robust parameter estimation.
277#[derive(Debug, Clone)]
278pub struct BayesianEstimator {
279    /// Dirichlet prior hyperparameters (equivalent sample size)
280    pub prior_strength: f64,
281}
282
283impl BayesianEstimator {
284    /// Create a new Bayesian estimator.
285    ///
286    /// # Arguments
287    ///
288    /// * `prior_strength` - Strength of the prior (equivalent sample size)
289    pub fn new(prior_strength: f64) -> Self {
290        Self { prior_strength }
291    }
292
293    /// Estimate parameters with Dirichlet prior.
294    pub fn estimate_marginal(
295        &self,
296        variable: &str,
297        cardinality: usize,
298        data: &[Assignment],
299    ) -> Result<ArrayD<f64>> {
300        // Dirichlet(α, α, ..., α) prior
301        let alpha = self.prior_strength / cardinality as f64;
302        let mut counts = vec![alpha; cardinality];
303
304        // Add data counts
305        for assignment in data {
306            if let Some(&value) = assignment.get(variable) {
307                if value < cardinality {
308                    counts[value] += 1.0;
309                }
310            }
311        }
312
313        // Posterior mean of Dirichlet
314        let total: f64 = counts.iter().sum();
315        let probs: Vec<f64> = counts.iter().map(|&c| c / total).collect();
316
317        ArrayD::from_shape_vec(vec![cardinality], probs)
318            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
319    }
320}
321
322/// Baum-Welch algorithm for learning HMM parameters.
323///
324/// This is a specialized EM algorithm for Hidden Markov Models that learns:
325/// - Initial state distribution
326/// - Transition probabilities
327/// - Emission probabilities
328///
329/// from sequences of observations (even when hidden states are not observed).
330#[derive(Debug, Clone)]
331pub struct BaumWelchLearner {
332    /// Maximum number of EM iterations
333    pub max_iterations: usize,
334    /// Convergence tolerance (change in log-likelihood)
335    pub tolerance: f64,
336    /// Whether to print progress
337    pub verbose: bool,
338}
339
340impl BaumWelchLearner {
341    /// Create a new Baum-Welch learner.
342    pub fn new(max_iterations: usize, tolerance: f64) -> Self {
343        Self {
344            max_iterations,
345            tolerance,
346            verbose: false,
347        }
348    }
349
350    /// Create a verbose learner that prints progress.
351    pub fn with_verbose(max_iterations: usize, tolerance: f64) -> Self {
352        Self {
353            max_iterations,
354            tolerance,
355            verbose: true,
356        }
357    }
358
359    /// Learn HMM parameters from observation sequences.
360    ///
361    /// # Arguments
362    ///
363    /// * `hmm` - HMM model to update (will be modified in place)
364    /// * `observation_sequences` - Multiple observation sequences
365    ///
366    /// # Returns
367    ///
368    /// Final log-likelihood
369    #[allow(clippy::needless_range_loop)]
370    pub fn learn(&self, hmm: &mut SimpleHMM, observation_sequences: &[Vec<usize>]) -> Result<f64> {
371        let num_states = hmm.num_states;
372        let num_observations = hmm.num_observations;
373
374        let mut prev_log_likelihood = f64::NEG_INFINITY;
375
376        for iteration in 0..self.max_iterations {
377            // E-step: Compute expected counts
378            let mut initial_counts = vec![0.0; num_states];
379            let mut transition_counts = vec![vec![0.0; num_states]; num_states];
380            let mut emission_counts = vec![vec![0.0; num_observations]; num_states];
381
382            let mut total_log_likelihood = 0.0;
383
384            for sequence in observation_sequences {
385                let (alpha, beta, log_likelihood) = self.forward_backward(hmm, sequence)?;
386                total_log_likelihood += log_likelihood;
387
388                let seq_len = sequence.len();
389
390                // Expected counts for initial state
391                for (s, count) in initial_counts.iter_mut().enumerate().take(num_states) {
392                    let gamma_0 = self.compute_gamma(&alpha, &beta, 0, s, log_likelihood);
393                    *count += gamma_0;
394                }
395
396                // Expected counts for transitions and emissions
397                for t in 0..(seq_len - 1) {
398                    for s1 in 0..num_states {
399                        let gamma_t = self.compute_gamma(&alpha, &beta, t, s1, log_likelihood);
400
401                        // Emission count
402                        emission_counts[s1][sequence[t]] += gamma_t;
403
404                        // Transition counts
405                        for s2 in 0..num_states {
406                            let xi = self.compute_xi(
407                                hmm,
408                                &alpha,
409                                &beta,
410                                t,
411                                s1,
412                                s2,
413                                sequence[t + 1],
414                                log_likelihood,
415                            );
416                            transition_counts[s1][s2] += xi;
417                        }
418                    }
419                }
420
421                // Last time step emission
422                for (s, counts) in emission_counts.iter_mut().enumerate().take(num_states) {
423                    let gamma_last =
424                        self.compute_gamma(&alpha, &beta, seq_len - 1, s, log_likelihood);
425                    counts[sequence[seq_len - 1]] += gamma_last;
426                }
427            }
428
429            // M-step: Update parameters
430            self.update_parameters(hmm, &initial_counts, &transition_counts, &emission_counts)?;
431
432            // Check convergence
433            let avg_log_likelihood = total_log_likelihood / observation_sequences.len() as f64;
434
435            if self.verbose {
436                println!(
437                    "Iteration {}: log-likelihood = {:.4}",
438                    iteration, avg_log_likelihood
439                );
440            }
441
442            if (avg_log_likelihood - prev_log_likelihood).abs() < self.tolerance {
443                if self.verbose {
444                    println!("Converged after {} iterations", iteration + 1);
445                }
446                return Ok(avg_log_likelihood);
447            }
448
449            prev_log_likelihood = avg_log_likelihood;
450        }
451
452        if self.verbose {
453            println!("Maximum iterations reached");
454        }
455
456        Ok(prev_log_likelihood)
457    }
458
459    /// Forward-backward algorithm.
460    #[allow(clippy::type_complexity, clippy::needless_range_loop)]
461    fn forward_backward(
462        &self,
463        hmm: &SimpleHMM,
464        sequence: &[usize],
465    ) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>, f64)> {
466        let num_states = hmm.num_states;
467        let seq_len = sequence.len();
468
469        // Forward pass
470        let mut alpha = vec![vec![0.0; num_states]; seq_len];
471
472        // Initialize
473        for s in 0..num_states {
474            alpha[0][s] =
475                hmm.initial_distribution[[s]] * hmm.emission_probabilities[[s, sequence[0]]];
476        }
477
478        // Forward recursion
479        for t in 1..seq_len {
480            for s2 in 0..num_states {
481                let mut sum = 0.0;
482                for s1 in 0..num_states {
483                    sum += alpha[t - 1][s1] * hmm.transition_probabilities[[s1, s2]];
484                }
485                alpha[t][s2] = sum * hmm.emission_probabilities[[s2, sequence[t]]];
486            }
487        }
488
489        // Backward pass
490        let mut beta = vec![vec![0.0; num_states]; seq_len];
491
492        // Initialize
493        for s in 0..num_states {
494            beta[seq_len - 1][s] = 1.0;
495        }
496
497        // Backward recursion
498        for t in (0..(seq_len - 1)).rev() {
499            for s1 in 0..num_states {
500                let mut sum = 0.0;
501                for s2 in 0..num_states {
502                    sum += hmm.transition_probabilities[[s1, s2]]
503                        * hmm.emission_probabilities[[s2, sequence[t + 1]]]
504                        * beta[t + 1][s2];
505                }
506                beta[t][s1] = sum;
507            }
508        }
509
510        // Compute log-likelihood
511        let log_likelihood: f64 = alpha[seq_len - 1].iter().sum::<f64>().ln();
512
513        Ok((alpha, beta, log_likelihood))
514    }
515
516    /// Compute gamma (state occupation probability).
517    fn compute_gamma(
518        &self,
519        alpha: &[Vec<f64>],
520        beta: &[Vec<f64>],
521        t: usize,
522        s: usize,
523        log_likelihood: f64,
524    ) -> f64 {
525        (alpha[t][s] * beta[t][s]) / log_likelihood.exp()
526    }
527
528    /// Compute xi (state transition probability).
529    #[allow(clippy::too_many_arguments)]
530    fn compute_xi(
531        &self,
532        hmm: &SimpleHMM,
533        alpha: &[Vec<f64>],
534        beta: &[Vec<f64>],
535        t: usize,
536        s1: usize,
537        s2: usize,
538        next_obs: usize,
539        log_likelihood: f64,
540    ) -> f64 {
541        let numerator = alpha[t][s1]
542            * hmm.transition_probabilities[[s1, s2]]
543            * hmm.emission_probabilities[[s2, next_obs]]
544            * beta[t + 1][s2];
545
546        numerator / log_likelihood.exp()
547    }
548
549    /// Update HMM parameters (M-step).
550    fn update_parameters(
551        &self,
552        hmm: &mut SimpleHMM,
553        initial_counts: &[f64],
554        transition_counts: &[Vec<f64>],
555        emission_counts: &[Vec<f64>],
556    ) -> Result<()> {
557        let num_states = hmm.num_states;
558        let num_observations = hmm.num_observations;
559
560        // Update initial distribution
561        let initial_sum: f64 = initial_counts.iter().sum();
562        if initial_sum > 0.0 {
563            for (s, &count) in initial_counts.iter().enumerate().take(num_states) {
564                hmm.initial_distribution[[s]] = count / initial_sum;
565            }
566        }
567
568        // Update transition probabilities
569        for (s1, trans_counts) in transition_counts.iter().enumerate().take(num_states) {
570            let trans_sum: f64 = trans_counts.iter().sum();
571            if trans_sum > 0.0 {
572                for (s2, &count) in trans_counts.iter().enumerate().take(num_states) {
573                    hmm.transition_probabilities[[s1, s2]] = count / trans_sum;
574                }
575            }
576        }
577
578        // Update emission probabilities
579        for (s, emis_counts) in emission_counts.iter().enumerate().take(num_states) {
580            let emission_sum: f64 = emis_counts.iter().sum();
581            if emission_sum > 0.0 {
582                for (o, &count) in emis_counts.iter().enumerate().take(num_observations) {
583                    hmm.emission_probabilities[[s, o]] = count / emission_sum;
584                }
585            }
586        }
587
588        Ok(())
589    }
590}
591
592/// Utilities for parameter learning.
593pub mod utils {
594    use super::*;
595
596    /// Count variable occurrences in data.
597    pub fn count_occurrences(variable: &str, data: &[Assignment]) -> HashMap<usize, usize> {
598        let mut counts = HashMap::new();
599
600        for assignment in data {
601            if let Some(&value) = assignment.get(variable) {
602                *counts.entry(value).or_insert(0) += 1;
603            }
604        }
605
606        counts
607    }
608
609    /// Count co-occurrences of two variables.
610    pub fn count_joint_occurrences(
611        var1: &str,
612        var2: &str,
613        data: &[Assignment],
614    ) -> HashMap<(usize, usize), usize> {
615        let mut counts = HashMap::new();
616
617        for assignment in data {
618            if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
619                *counts.entry((v1, v2)).or_insert(0) += 1;
620            }
621        }
622
623        counts
624    }
625
626    /// Compute empirical distribution from counts.
627    pub fn counts_to_distribution(counts: &HashMap<usize, usize>, cardinality: usize) -> Vec<f64> {
628        let total: usize = counts.values().sum();
629        let mut probs = vec![0.0; cardinality];
630
631        for (&value, &count) in counts {
632            if value < cardinality && total > 0 {
633                probs[value] = count as f64 / total as f64;
634            }
635        }
636
637        probs
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644
645    #[test]
646    fn test_mle_marginal() {
647        let estimator = MaximumLikelihoodEstimator::new();
648
649        let mut data = Vec::new();
650        for _ in 0..7 {
651            let mut assignment = HashMap::new();
652            assignment.insert("X".to_string(), 0);
653            data.push(assignment);
654        }
655        for _ in 0..3 {
656            let mut assignment = HashMap::new();
657            assignment.insert("X".to_string(), 1);
658            data.push(assignment);
659        }
660
661        let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
662
663        assert!((probs[[0]] - 0.7).abs() < 1e-6);
664        assert!((probs[[1]] - 0.3).abs() < 1e-6);
665    }
666
667    #[test]
668    fn test_mle_with_laplace() {
669        let estimator = MaximumLikelihoodEstimator::with_laplace(1.0);
670
671        let mut data = Vec::new();
672        for _ in 0..8 {
673            let mut assignment = HashMap::new();
674            assignment.insert("X".to_string(), 0);
675            data.push(assignment);
676        }
677        // No observations of X=1
678
679        let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
680
681        // With Laplace: (8+1)/(8+1+0+1) = 9/10 = 0.9
682        assert!((probs[[0]] - 0.9).abs() < 1e-6);
683        assert!((probs[[1]] - 0.1).abs() < 1e-6);
684    }
685
686    #[test]
687    fn test_bayesian_estimator() {
688        let estimator = BayesianEstimator::new(2.0);
689
690        let mut data = Vec::new();
691        for _ in 0..8 {
692            let mut assignment = HashMap::new();
693            assignment.insert("X".to_string(), 0);
694            data.push(assignment);
695        }
696
697        let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
698
699        // Prior: Dirichlet(1, 1), Data: 8, 0
700        // Posterior: (8+1, 0+1) / (8+1+0+1) = (9, 1) / 10
701        assert!((probs[[0]] - 0.9).abs() < 1e-6);
702        assert!((probs[[1]] - 0.1).abs() < 1e-6);
703    }
704
705    #[test]
706    fn test_count_occurrences() {
707        let mut data = Vec::new();
708        for i in 0..10 {
709            let mut assignment = HashMap::new();
710            assignment.insert("X".to_string(), i % 3);
711            data.push(assignment);
712        }
713
714        let counts = utils::count_occurrences("X", &data);
715
716        assert_eq!(counts.get(&0), Some(&4)); // 0, 3, 6, 9
717        assert_eq!(counts.get(&1), Some(&3)); // 1, 4, 7
718        assert_eq!(counts.get(&2), Some(&3)); // 2, 5, 8
719    }
720
721    #[test]
722    fn test_counts_to_distribution() {
723        let mut counts = HashMap::new();
724        counts.insert(0, 7);
725        counts.insert(1, 3);
726
727        let probs = utils::counts_to_distribution(&counts, 2);
728
729        assert!((probs[0] - 0.7).abs() < 1e-6);
730        assert!((probs[1] - 0.3).abs() < 1e-6);
731    }
732}