tensorlogic_quantrs_hooks/
parameter_learning.rs

1//! Parameter learning algorithms for probabilistic graphical models.
2//!
3//! This module provides algorithms for learning model parameters from data:
4//! - **Maximum Likelihood Estimation (MLE)**: For fully observed data
5//! - **Expectation-Maximization (EM)**: For partially observed data
6//! - **Bayesian Estimation**: With Dirichlet priors
7//! - **Baum-Welch Algorithm**: Specialized EM for Hidden Markov Models
8//!
9//! # Overview
10//!
11//! Parameter learning is the process of estimating the parameters (probabilities)
12//! of a probabilistic model from observed data. This module supports both:
13//! - **Complete data**: All variables are observed (use MLE)
14//! - **Incomplete data**: Some variables are hidden (use EM)
15//!
16//! # Examples
17//!
18//! ```ignore
19//! // Learn HMM parameters from observed sequences
20//! let mut hmm = HiddenMarkovModel::new(2, 2);
21//! let learner = BaumWelchLearner::new(100, 1e-4);
22//! learner.learn(&mut hmm, &observation_sequences)?;
23//! ```
24
25use crate::error::{PgmError, Result};
26use crate::sampling::Assignment;
27use scirs2_core::ndarray::{Array1, Array2, ArrayD};
28use std::collections::HashMap;
29
30/// Simple HMM representation for parameter learning.
31///
32/// This is a standalone representation with explicit parameter matrices,
33/// designed for efficient parameter learning algorithms like Baum-Welch.
34#[derive(Debug, Clone)]
35pub struct SimpleHMM {
36    /// Number of hidden states
37    pub num_states: usize,
38    /// Number of observable symbols
39    pub num_observations: usize,
40    /// Initial state distribution π: [num_states]
41    pub initial_distribution: Array1<f64>,
42    /// Transition probabilities A: [from_state, to_state]
43    pub transition_probabilities: Array2<f64>,
44    /// Emission probabilities B: [state, observation]
45    pub emission_probabilities: Array2<f64>,
46}
47
48impl SimpleHMM {
49    /// Create a new SimpleHMM with uniform initialization.
50    pub fn new(num_states: usize, num_observations: usize) -> Self {
51        let initial_distribution = Array1::from_elem(num_states, 1.0 / num_states as f64);
52
53        let transition_probabilities =
54            Array2::from_elem((num_states, num_states), 1.0 / num_states as f64);
55
56        let emission_probabilities = Array2::from_elem(
57            (num_states, num_observations),
58            1.0 / num_observations as f64,
59        );
60
61        Self {
62            num_states,
63            num_observations,
64            initial_distribution,
65            transition_probabilities,
66            emission_probabilities,
67        }
68    }
69
70    /// Create an HMM with random initialization.
71    pub fn new_random(num_states: usize, num_observations: usize) -> Self {
72        use scirs2_core::random::{thread_rng, Rng};
73
74        let mut rng = thread_rng();
75        let mut hmm = Self::new(num_states, num_observations);
76
77        // Randomize initial distribution
78        let mut init_sum = 0.0;
79        for i in 0..num_states {
80            hmm.initial_distribution[i] = rng.random::<f64>();
81            init_sum += hmm.initial_distribution[i];
82        }
83        hmm.initial_distribution /= init_sum;
84
85        // Randomize transition probabilities
86        for i in 0..num_states {
87            let mut trans_sum = 0.0;
88            for j in 0..num_states {
89                hmm.transition_probabilities[[i, j]] = rng.random::<f64>();
90                trans_sum += hmm.transition_probabilities[[i, j]];
91            }
92            for j in 0..num_states {
93                hmm.transition_probabilities[[i, j]] /= trans_sum;
94            }
95        }
96
97        // Randomize emission probabilities
98        for i in 0..num_states {
99            let mut emission_sum = 0.0;
100            for j in 0..num_observations {
101                hmm.emission_probabilities[[i, j]] = rng.random::<f64>();
102                emission_sum += hmm.emission_probabilities[[i, j]];
103            }
104            for j in 0..num_observations {
105                hmm.emission_probabilities[[i, j]] /= emission_sum;
106            }
107        }
108
109        hmm
110    }
111}
112
113/// Maximum Likelihood Estimator for discrete distributions.
114///
115/// Estimates parameters by counting frequencies in complete data.
116#[derive(Debug, Clone)]
117pub struct MaximumLikelihoodEstimator {
118    /// Use Laplace smoothing (add-one smoothing)
119    pub use_laplace: bool,
120    /// Pseudocount for Laplace smoothing
121    pub pseudocount: f64,
122}
123
124impl MaximumLikelihoodEstimator {
125    /// Create a new MLE estimator.
126    pub fn new() -> Self {
127        Self {
128            use_laplace: false,
129            pseudocount: 1.0,
130        }
131    }
132
133    /// Create an MLE estimator with Laplace smoothing.
134    pub fn with_laplace(pseudocount: f64) -> Self {
135        Self {
136            use_laplace: true,
137            pseudocount,
138        }
139    }
140
141    /// Estimate parameters for a single variable from data.
142    ///
143    /// # Arguments
144    ///
145    /// * `variable` - Variable name
146    /// * `cardinality` - Number of possible values
147    /// * `data` - Observed assignments
148    ///
149    /// # Returns
150    ///
151    /// Estimated probability distribution P(variable)
152    pub fn estimate_marginal(
153        &self,
154        variable: &str,
155        cardinality: usize,
156        data: &[Assignment],
157    ) -> Result<ArrayD<f64>> {
158        let pseudocount = if self.use_laplace {
159            self.pseudocount
160        } else {
161            0.0
162        };
163        let mut counts = vec![pseudocount; cardinality];
164
165        // Count occurrences
166        for assignment in data {
167            if let Some(&value) = assignment.get(variable) {
168                if value < cardinality {
169                    counts[value] += 1.0;
170                }
171            }
172        }
173
174        // Normalize to probabilities
175        let total: f64 = counts.iter().sum();
176        if total == 0.0 {
177            return Err(PgmError::InvalidDistribution(
178                "No data for variable".to_string(),
179            ));
180        }
181
182        let probs: Vec<f64> = counts.iter().map(|&c| c / total).collect();
183
184        ArrayD::from_shape_vec(vec![cardinality], probs)
185            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
186    }
187
188    /// Estimate conditional probability table P(child | parents) from data.
189    ///
190    /// # Arguments
191    ///
192    /// * `child` - Child variable name
193    /// * `parents` - Parent variable names
194    /// * `cardinalities` - Cardinalities for [child, parent1, parent2, ...]
195    /// * `data` - Observed assignments
196    pub fn estimate_conditional(
197        &self,
198        child: &str,
199        parents: &[String],
200        cardinalities: &[usize],
201        data: &[Assignment],
202    ) -> Result<ArrayD<f64>> {
203        if cardinalities.is_empty() {
204            return Err(PgmError::InvalidGraph(
205                "Cardinalities must not be empty".to_string(),
206            ));
207        }
208
209        let pseudocount = if self.use_laplace {
210            self.pseudocount
211        } else {
212            0.0
213        };
214
215        let child_card = cardinalities[0];
216        let parent_cards = &cardinalities[1..];
217
218        // Calculate total number of parent configurations
219        let num_parent_configs: usize = parent_cards.iter().product();
220
221        // Initialize counts: [parent_config][child_value]
222        let mut counts = vec![vec![pseudocount; child_card]; num_parent_configs];
223
224        // Count co-occurrences
225        for assignment in data {
226            if let Some(&child_val) = assignment.get(child) {
227                // Compute parent configuration index
228                let mut parent_config = 0;
229                let mut multiplier = 1;
230
231                for (i, parent) in parents.iter().enumerate() {
232                    if let Some(&parent_val) = assignment.get(parent) {
233                        parent_config += parent_val * multiplier;
234                        multiplier *= parent_cards[i];
235                    } else {
236                        continue; // Skip if parent value missing
237                    }
238                }
239
240                if parent_config < num_parent_configs && child_val < child_card {
241                    counts[parent_config][child_val] += 1.0;
242                }
243            }
244        }
245
246        // Normalize each parent configuration
247        let mut probs = Vec::new();
248        for config_counts in counts {
249            let total: f64 = config_counts.iter().sum();
250            if total > 0.0 {
251                for count in config_counts {
252                    probs.push(count / total);
253                }
254            } else {
255                // Uniform distribution if no data
256                for _ in 0..child_card {
257                    probs.push(1.0 / child_card as f64);
258                }
259            }
260        }
261
262        // Shape: [parent1_card, parent2_card, ..., child_card]
263        ArrayD::from_shape_vec(cardinalities.to_vec(), probs)
264            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
265    }
266}
267
268impl Default for MaximumLikelihoodEstimator {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274/// Bayesian parameter estimator with Dirichlet priors.
275///
276/// Uses conjugate Dirichlet priors for robust parameter estimation.
277#[derive(Debug, Clone)]
278pub struct BayesianEstimator {
279    /// Dirichlet prior hyperparameters (equivalent sample size)
280    pub prior_strength: f64,
281}
282
283impl BayesianEstimator {
284    /// Create a new Bayesian estimator.
285    ///
286    /// # Arguments
287    ///
288    /// * `prior_strength` - Strength of the prior (equivalent sample size)
289    pub fn new(prior_strength: f64) -> Self {
290        Self { prior_strength }
291    }
292
293    /// Estimate parameters with Dirichlet prior.
294    pub fn estimate_marginal(
295        &self,
296        variable: &str,
297        cardinality: usize,
298        data: &[Assignment],
299    ) -> Result<ArrayD<f64>> {
300        // Dirichlet(α, α, ..., α) prior
301        let alpha = self.prior_strength / cardinality as f64;
302        let mut counts = vec![alpha; cardinality];
303
304        // Add data counts
305        for assignment in data {
306            if let Some(&value) = assignment.get(variable) {
307                if value < cardinality {
308                    counts[value] += 1.0;
309                }
310            }
311        }
312
313        // Posterior mean of Dirichlet
314        let total: f64 = counts.iter().sum();
315        let probs: Vec<f64> = counts.iter().map(|&c| c / total).collect();
316
317        ArrayD::from_shape_vec(vec![cardinality], probs)
318            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
319    }
320}
321
322/// Baum-Welch algorithm for learning HMM parameters.
323///
324/// This is a specialized EM algorithm for Hidden Markov Models that learns:
325/// - Initial state distribution
326/// - Transition probabilities
327/// - Emission probabilities
328///
329/// from sequences of observations (even when hidden states are not observed).
330#[derive(Debug, Clone)]
331pub struct BaumWelchLearner {
332    /// Maximum number of EM iterations
333    pub max_iterations: usize,
334    /// Convergence tolerance (change in log-likelihood)
335    pub tolerance: f64,
336    /// Whether to print progress
337    pub verbose: bool,
338}
339
340impl BaumWelchLearner {
341    /// Create a new Baum-Welch learner.
342    pub fn new(max_iterations: usize, tolerance: f64) -> Self {
343        Self {
344            max_iterations,
345            tolerance,
346            verbose: false,
347        }
348    }
349
350    /// Create a verbose learner that prints progress.
351    pub fn with_verbose(max_iterations: usize, tolerance: f64) -> Self {
352        Self {
353            max_iterations,
354            tolerance,
355            verbose: true,
356        }
357    }
358
359    /// Learn HMM parameters from observation sequences.
360    ///
361    /// # Arguments
362    ///
363    /// * `hmm` - HMM model to update (will be modified in place)
364    /// * `observation_sequences` - Multiple observation sequences
365    ///
366    /// # Returns
367    ///
368    /// Final log-likelihood
369    pub fn learn(&self, hmm: &mut SimpleHMM, observation_sequences: &[Vec<usize>]) -> Result<f64> {
370        let num_states = hmm.num_states;
371        let num_observations = hmm.num_observations;
372
373        let mut prev_log_likelihood = f64::NEG_INFINITY;
374
375        for iteration in 0..self.max_iterations {
376            // E-step: Compute expected counts
377            let mut initial_counts = vec![0.0; num_states];
378            let mut transition_counts = vec![vec![0.0; num_states]; num_states];
379            let mut emission_counts = vec![vec![0.0; num_observations]; num_states];
380
381            let mut total_log_likelihood = 0.0;
382
383            for sequence in observation_sequences {
384                let (alpha, beta, log_likelihood) = self.forward_backward(hmm, sequence)?;
385                total_log_likelihood += log_likelihood;
386
387                let seq_len = sequence.len();
388
389                // Expected counts for initial state
390                for (s, count) in initial_counts.iter_mut().enumerate().take(num_states) {
391                    let gamma_0 = self.compute_gamma(&alpha, &beta, 0, s, log_likelihood);
392                    *count += gamma_0;
393                }
394
395                // Expected counts for transitions and emissions
396                for t in 0..(seq_len - 1) {
397                    for s1 in 0..num_states {
398                        let gamma_t = self.compute_gamma(&alpha, &beta, t, s1, log_likelihood);
399
400                        // Emission count
401                        emission_counts[s1][sequence[t]] += gamma_t;
402
403                        // Transition counts
404                        for s2 in 0..num_states {
405                            let xi = self.compute_xi(
406                                hmm,
407                                &alpha,
408                                &beta,
409                                t,
410                                s1,
411                                s2,
412                                sequence[t + 1],
413                                log_likelihood,
414                            );
415                            transition_counts[s1][s2] += xi;
416                        }
417                    }
418                }
419
420                // Last time step emission
421                for (s, counts) in emission_counts.iter_mut().enumerate().take(num_states) {
422                    let gamma_last =
423                        self.compute_gamma(&alpha, &beta, seq_len - 1, s, log_likelihood);
424                    counts[sequence[seq_len - 1]] += gamma_last;
425                }
426            }
427
428            // M-step: Update parameters
429            self.update_parameters(hmm, &initial_counts, &transition_counts, &emission_counts)?;
430
431            // Check convergence
432            let avg_log_likelihood = total_log_likelihood / observation_sequences.len() as f64;
433
434            if self.verbose {
435                println!(
436                    "Iteration {}: log-likelihood = {:.4}",
437                    iteration, avg_log_likelihood
438                );
439            }
440
441            if (avg_log_likelihood - prev_log_likelihood).abs() < self.tolerance {
442                if self.verbose {
443                    println!("Converged after {} iterations", iteration + 1);
444                }
445                return Ok(avg_log_likelihood);
446            }
447
448            prev_log_likelihood = avg_log_likelihood;
449        }
450
451        if self.verbose {
452            println!("Maximum iterations reached");
453        }
454
455        Ok(prev_log_likelihood)
456    }
457
458    /// Forward-backward algorithm.
459    #[allow(clippy::type_complexity)]
460    fn forward_backward(
461        &self,
462        hmm: &SimpleHMM,
463        sequence: &[usize],
464    ) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>, f64)> {
465        let num_states = hmm.num_states;
466        let seq_len = sequence.len();
467
468        // Forward pass
469        let mut alpha = vec![vec![0.0; num_states]; seq_len];
470
471        // Initialize
472        for s in 0..num_states {
473            alpha[0][s] =
474                hmm.initial_distribution[[s]] * hmm.emission_probabilities[[s, sequence[0]]];
475        }
476
477        // Forward recursion
478        for t in 1..seq_len {
479            for s2 in 0..num_states {
480                let mut sum = 0.0;
481                for s1 in 0..num_states {
482                    sum += alpha[t - 1][s1] * hmm.transition_probabilities[[s1, s2]];
483                }
484                alpha[t][s2] = sum * hmm.emission_probabilities[[s2, sequence[t]]];
485            }
486        }
487
488        // Backward pass
489        let mut beta = vec![vec![0.0; num_states]; seq_len];
490
491        // Initialize
492        for s in 0..num_states {
493            beta[seq_len - 1][s] = 1.0;
494        }
495
496        // Backward recursion
497        for t in (0..(seq_len - 1)).rev() {
498            for s1 in 0..num_states {
499                let mut sum = 0.0;
500                for s2 in 0..num_states {
501                    sum += hmm.transition_probabilities[[s1, s2]]
502                        * hmm.emission_probabilities[[s2, sequence[t + 1]]]
503                        * beta[t + 1][s2];
504                }
505                beta[t][s1] = sum;
506            }
507        }
508
509        // Compute log-likelihood
510        let log_likelihood: f64 = alpha[seq_len - 1].iter().sum::<f64>().ln();
511
512        Ok((alpha, beta, log_likelihood))
513    }
514
515    /// Compute gamma (state occupation probability).
516    fn compute_gamma(
517        &self,
518        alpha: &[Vec<f64>],
519        beta: &[Vec<f64>],
520        t: usize,
521        s: usize,
522        log_likelihood: f64,
523    ) -> f64 {
524        (alpha[t][s] * beta[t][s]) / log_likelihood.exp()
525    }
526
527    /// Compute xi (state transition probability).
528    #[allow(clippy::too_many_arguments)]
529    fn compute_xi(
530        &self,
531        hmm: &SimpleHMM,
532        alpha: &[Vec<f64>],
533        beta: &[Vec<f64>],
534        t: usize,
535        s1: usize,
536        s2: usize,
537        next_obs: usize,
538        log_likelihood: f64,
539    ) -> f64 {
540        let numerator = alpha[t][s1]
541            * hmm.transition_probabilities[[s1, s2]]
542            * hmm.emission_probabilities[[s2, next_obs]]
543            * beta[t + 1][s2];
544
545        numerator / log_likelihood.exp()
546    }
547
548    /// Update HMM parameters (M-step).
549    fn update_parameters(
550        &self,
551        hmm: &mut SimpleHMM,
552        initial_counts: &[f64],
553        transition_counts: &[Vec<f64>],
554        emission_counts: &[Vec<f64>],
555    ) -> Result<()> {
556        let num_states = hmm.num_states;
557        let num_observations = hmm.num_observations;
558
559        // Update initial distribution
560        let initial_sum: f64 = initial_counts.iter().sum();
561        if initial_sum > 0.0 {
562            for (s, &count) in initial_counts.iter().enumerate().take(num_states) {
563                hmm.initial_distribution[[s]] = count / initial_sum;
564            }
565        }
566
567        // Update transition probabilities
568        for (s1, trans_counts) in transition_counts.iter().enumerate().take(num_states) {
569            let trans_sum: f64 = trans_counts.iter().sum();
570            if trans_sum > 0.0 {
571                for (s2, &count) in trans_counts.iter().enumerate().take(num_states) {
572                    hmm.transition_probabilities[[s1, s2]] = count / trans_sum;
573                }
574            }
575        }
576
577        // Update emission probabilities
578        for (s, emis_counts) in emission_counts.iter().enumerate().take(num_states) {
579            let emission_sum: f64 = emis_counts.iter().sum();
580            if emission_sum > 0.0 {
581                for (o, &count) in emis_counts.iter().enumerate().take(num_observations) {
582                    hmm.emission_probabilities[[s, o]] = count / emission_sum;
583                }
584            }
585        }
586
587        Ok(())
588    }
589}
590
591/// Utilities for parameter learning.
592pub mod utils {
593    use super::*;
594
595    /// Count variable occurrences in data.
596    pub fn count_occurrences(variable: &str, data: &[Assignment]) -> HashMap<usize, usize> {
597        let mut counts = HashMap::new();
598
599        for assignment in data {
600            if let Some(&value) = assignment.get(variable) {
601                *counts.entry(value).or_insert(0) += 1;
602            }
603        }
604
605        counts
606    }
607
608    /// Count co-occurrences of two variables.
609    pub fn count_joint_occurrences(
610        var1: &str,
611        var2: &str,
612        data: &[Assignment],
613    ) -> HashMap<(usize, usize), usize> {
614        let mut counts = HashMap::new();
615
616        for assignment in data {
617            if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
618                *counts.entry((v1, v2)).or_insert(0) += 1;
619            }
620        }
621
622        counts
623    }
624
625    /// Compute empirical distribution from counts.
626    pub fn counts_to_distribution(counts: &HashMap<usize, usize>, cardinality: usize) -> Vec<f64> {
627        let total: usize = counts.values().sum();
628        let mut probs = vec![0.0; cardinality];
629
630        for (&value, &count) in counts {
631            if value < cardinality && total > 0 {
632                probs[value] = count as f64 / total as f64;
633            }
634        }
635
636        probs
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643
644    #[test]
645    fn test_mle_marginal() {
646        let estimator = MaximumLikelihoodEstimator::new();
647
648        let mut data = Vec::new();
649        for _ in 0..7 {
650            let mut assignment = HashMap::new();
651            assignment.insert("X".to_string(), 0);
652            data.push(assignment);
653        }
654        for _ in 0..3 {
655            let mut assignment = HashMap::new();
656            assignment.insert("X".to_string(), 1);
657            data.push(assignment);
658        }
659
660        let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
661
662        assert!((probs[[0]] - 0.7).abs() < 1e-6);
663        assert!((probs[[1]] - 0.3).abs() < 1e-6);
664    }
665
666    #[test]
667    fn test_mle_with_laplace() {
668        let estimator = MaximumLikelihoodEstimator::with_laplace(1.0);
669
670        let mut data = Vec::new();
671        for _ in 0..8 {
672            let mut assignment = HashMap::new();
673            assignment.insert("X".to_string(), 0);
674            data.push(assignment);
675        }
676        // No observations of X=1
677
678        let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
679
680        // With Laplace: (8+1)/(8+1+0+1) = 9/10 = 0.9
681        assert!((probs[[0]] - 0.9).abs() < 1e-6);
682        assert!((probs[[1]] - 0.1).abs() < 1e-6);
683    }
684
685    #[test]
686    fn test_bayesian_estimator() {
687        let estimator = BayesianEstimator::new(2.0);
688
689        let mut data = Vec::new();
690        for _ in 0..8 {
691            let mut assignment = HashMap::new();
692            assignment.insert("X".to_string(), 0);
693            data.push(assignment);
694        }
695
696        let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
697
698        // Prior: Dirichlet(1, 1), Data: 8, 0
699        // Posterior: (8+1, 0+1) / (8+1+0+1) = (9, 1) / 10
700        assert!((probs[[0]] - 0.9).abs() < 1e-6);
701        assert!((probs[[1]] - 0.1).abs() < 1e-6);
702    }
703
704    #[test]
705    fn test_count_occurrences() {
706        let mut data = Vec::new();
707        for i in 0..10 {
708            let mut assignment = HashMap::new();
709            assignment.insert("X".to_string(), i % 3);
710            data.push(assignment);
711        }
712
713        let counts = utils::count_occurrences("X", &data);
714
715        assert_eq!(counts.get(&0), Some(&4)); // 0, 3, 6, 9
716        assert_eq!(counts.get(&1), Some(&3)); // 1, 4, 7
717        assert_eq!(counts.get(&2), Some(&3)); // 2, 5, 8
718    }
719
720    #[test]
721    fn test_counts_to_distribution() {
722        let mut counts = HashMap::new();
723        counts.insert(0, 7);
724        counts.insert(1, 3);
725
726        let probs = utils::counts_to_distribution(&counts, 2);
727
728        assert!((probs[0] - 0.7).abs() < 1e-6);
729        assert!((probs[1] - 0.3).abs() < 1e-6);
730    }
731}