1use crate::error::{PgmError, Result};
26use crate::sampling::Assignment;
27use scirs2_core::ndarray::{Array1, Array2, ArrayD};
28use std::collections::HashMap;
29
30#[derive(Debug, Clone)]
35pub struct SimpleHMM {
36 pub num_states: usize,
38 pub num_observations: usize,
40 pub initial_distribution: Array1<f64>,
42 pub transition_probabilities: Array2<f64>,
44 pub emission_probabilities: Array2<f64>,
46}
47
48impl SimpleHMM {
49 pub fn new(num_states: usize, num_observations: usize) -> Self {
51 let initial_distribution = Array1::from_elem(num_states, 1.0 / num_states as f64);
52
53 let transition_probabilities =
54 Array2::from_elem((num_states, num_states), 1.0 / num_states as f64);
55
56 let emission_probabilities = Array2::from_elem(
57 (num_states, num_observations),
58 1.0 / num_observations as f64,
59 );
60
61 Self {
62 num_states,
63 num_observations,
64 initial_distribution,
65 transition_probabilities,
66 emission_probabilities,
67 }
68 }
69
70 pub fn new_random(num_states: usize, num_observations: usize) -> Self {
72 use scirs2_core::random::thread_rng;
73
74 let mut rng = thread_rng();
75 let mut hmm = Self::new(num_states, num_observations);
76
77 let mut init_sum = 0.0;
79 for i in 0..num_states {
80 hmm.initial_distribution[i] = rng.random::<f64>();
81 init_sum += hmm.initial_distribution[i];
82 }
83 hmm.initial_distribution /= init_sum;
84
85 for i in 0..num_states {
87 let mut trans_sum = 0.0;
88 for j in 0..num_states {
89 hmm.transition_probabilities[[i, j]] = rng.random::<f64>();
90 trans_sum += hmm.transition_probabilities[[i, j]];
91 }
92 for j in 0..num_states {
93 hmm.transition_probabilities[[i, j]] /= trans_sum;
94 }
95 }
96
97 for i in 0..num_states {
99 let mut emission_sum = 0.0;
100 for j in 0..num_observations {
101 hmm.emission_probabilities[[i, j]] = rng.random::<f64>();
102 emission_sum += hmm.emission_probabilities[[i, j]];
103 }
104 for j in 0..num_observations {
105 hmm.emission_probabilities[[i, j]] /= emission_sum;
106 }
107 }
108
109 hmm
110 }
111}
112
113#[derive(Debug, Clone)]
117pub struct MaximumLikelihoodEstimator {
118 pub use_laplace: bool,
120 pub pseudocount: f64,
122}
123
124impl MaximumLikelihoodEstimator {
125 pub fn new() -> Self {
127 Self {
128 use_laplace: false,
129 pseudocount: 1.0,
130 }
131 }
132
133 pub fn with_laplace(pseudocount: f64) -> Self {
135 Self {
136 use_laplace: true,
137 pseudocount,
138 }
139 }
140
141 pub fn estimate_marginal(
153 &self,
154 variable: &str,
155 cardinality: usize,
156 data: &[Assignment],
157 ) -> Result<ArrayD<f64>> {
158 let pseudocount = if self.use_laplace {
159 self.pseudocount
160 } else {
161 0.0
162 };
163 let mut counts = vec![pseudocount; cardinality];
164
165 for assignment in data {
167 if let Some(&value) = assignment.get(variable) {
168 if value < cardinality {
169 counts[value] += 1.0;
170 }
171 }
172 }
173
174 let total: f64 = counts.iter().sum();
176 if total == 0.0 {
177 return Err(PgmError::InvalidDistribution(
178 "No data for variable".to_string(),
179 ));
180 }
181
182 let probs: Vec<f64> = counts.iter().map(|&c| c / total).collect();
183
184 ArrayD::from_shape_vec(vec![cardinality], probs)
185 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
186 }
187
188 pub fn estimate_conditional(
197 &self,
198 child: &str,
199 parents: &[String],
200 cardinalities: &[usize],
201 data: &[Assignment],
202 ) -> Result<ArrayD<f64>> {
203 if cardinalities.is_empty() {
204 return Err(PgmError::InvalidGraph(
205 "Cardinalities must not be empty".to_string(),
206 ));
207 }
208
209 let pseudocount = if self.use_laplace {
210 self.pseudocount
211 } else {
212 0.0
213 };
214
215 let child_card = cardinalities[0];
216 let parent_cards = &cardinalities[1..];
217
218 let num_parent_configs: usize = parent_cards.iter().product();
220
221 let mut counts = vec![vec![pseudocount; child_card]; num_parent_configs];
223
224 for assignment in data {
226 if let Some(&child_val) = assignment.get(child) {
227 let mut parent_config = 0;
229 let mut multiplier = 1;
230
231 for (i, parent) in parents.iter().enumerate() {
232 if let Some(&parent_val) = assignment.get(parent) {
233 parent_config += parent_val * multiplier;
234 multiplier *= parent_cards[i];
235 } else {
236 continue; }
238 }
239
240 if parent_config < num_parent_configs && child_val < child_card {
241 counts[parent_config][child_val] += 1.0;
242 }
243 }
244 }
245
246 let mut probs = Vec::new();
248 for config_counts in counts {
249 let total: f64 = config_counts.iter().sum();
250 if total > 0.0 {
251 for count in config_counts {
252 probs.push(count / total);
253 }
254 } else {
255 for _ in 0..child_card {
257 probs.push(1.0 / child_card as f64);
258 }
259 }
260 }
261
262 ArrayD::from_shape_vec(cardinalities.to_vec(), probs)
264 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
265 }
266}
267
268impl Default for MaximumLikelihoodEstimator {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[derive(Debug, Clone)]
278pub struct BayesianEstimator {
279 pub prior_strength: f64,
281}
282
283impl BayesianEstimator {
284 pub fn new(prior_strength: f64) -> Self {
290 Self { prior_strength }
291 }
292
293 pub fn estimate_marginal(
295 &self,
296 variable: &str,
297 cardinality: usize,
298 data: &[Assignment],
299 ) -> Result<ArrayD<f64>> {
300 let alpha = self.prior_strength / cardinality as f64;
302 let mut counts = vec![alpha; cardinality];
303
304 for assignment in data {
306 if let Some(&value) = assignment.get(variable) {
307 if value < cardinality {
308 counts[value] += 1.0;
309 }
310 }
311 }
312
313 let total: f64 = counts.iter().sum();
315 let probs: Vec<f64> = counts.iter().map(|&c| c / total).collect();
316
317 ArrayD::from_shape_vec(vec![cardinality], probs)
318 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
319 }
320}
321
322#[derive(Debug, Clone)]
331pub struct BaumWelchLearner {
332 pub max_iterations: usize,
334 pub tolerance: f64,
336 pub verbose: bool,
338}
339
340impl BaumWelchLearner {
341 pub fn new(max_iterations: usize, tolerance: f64) -> Self {
343 Self {
344 max_iterations,
345 tolerance,
346 verbose: false,
347 }
348 }
349
350 pub fn with_verbose(max_iterations: usize, tolerance: f64) -> Self {
352 Self {
353 max_iterations,
354 tolerance,
355 verbose: true,
356 }
357 }
358
359 #[allow(clippy::needless_range_loop)]
370 pub fn learn(&self, hmm: &mut SimpleHMM, observation_sequences: &[Vec<usize>]) -> Result<f64> {
371 let num_states = hmm.num_states;
372 let num_observations = hmm.num_observations;
373
374 let mut prev_log_likelihood = f64::NEG_INFINITY;
375
376 for iteration in 0..self.max_iterations {
377 let mut initial_counts = vec![0.0; num_states];
379 let mut transition_counts = vec![vec![0.0; num_states]; num_states];
380 let mut emission_counts = vec![vec![0.0; num_observations]; num_states];
381
382 let mut total_log_likelihood = 0.0;
383
384 for sequence in observation_sequences {
385 let (alpha, beta, log_likelihood) = self.forward_backward(hmm, sequence)?;
386 total_log_likelihood += log_likelihood;
387
388 let seq_len = sequence.len();
389
390 for (s, count) in initial_counts.iter_mut().enumerate().take(num_states) {
392 let gamma_0 = self.compute_gamma(&alpha, &beta, 0, s, log_likelihood);
393 *count += gamma_0;
394 }
395
396 for t in 0..(seq_len - 1) {
398 for s1 in 0..num_states {
399 let gamma_t = self.compute_gamma(&alpha, &beta, t, s1, log_likelihood);
400
401 emission_counts[s1][sequence[t]] += gamma_t;
403
404 for s2 in 0..num_states {
406 let xi = self.compute_xi(
407 hmm,
408 &alpha,
409 &beta,
410 t,
411 s1,
412 s2,
413 sequence[t + 1],
414 log_likelihood,
415 );
416 transition_counts[s1][s2] += xi;
417 }
418 }
419 }
420
421 for (s, counts) in emission_counts.iter_mut().enumerate().take(num_states) {
423 let gamma_last =
424 self.compute_gamma(&alpha, &beta, seq_len - 1, s, log_likelihood);
425 counts[sequence[seq_len - 1]] += gamma_last;
426 }
427 }
428
429 self.update_parameters(hmm, &initial_counts, &transition_counts, &emission_counts)?;
431
432 let avg_log_likelihood = total_log_likelihood / observation_sequences.len() as f64;
434
435 if self.verbose {
436 println!(
437 "Iteration {}: log-likelihood = {:.4}",
438 iteration, avg_log_likelihood
439 );
440 }
441
442 if (avg_log_likelihood - prev_log_likelihood).abs() < self.tolerance {
443 if self.verbose {
444 println!("Converged after {} iterations", iteration + 1);
445 }
446 return Ok(avg_log_likelihood);
447 }
448
449 prev_log_likelihood = avg_log_likelihood;
450 }
451
452 if self.verbose {
453 println!("Maximum iterations reached");
454 }
455
456 Ok(prev_log_likelihood)
457 }
458
459 #[allow(clippy::type_complexity, clippy::needless_range_loop)]
461 fn forward_backward(
462 &self,
463 hmm: &SimpleHMM,
464 sequence: &[usize],
465 ) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>, f64)> {
466 let num_states = hmm.num_states;
467 let seq_len = sequence.len();
468
469 let mut alpha = vec![vec![0.0; num_states]; seq_len];
471
472 for s in 0..num_states {
474 alpha[0][s] =
475 hmm.initial_distribution[[s]] * hmm.emission_probabilities[[s, sequence[0]]];
476 }
477
478 for t in 1..seq_len {
480 for s2 in 0..num_states {
481 let mut sum = 0.0;
482 for s1 in 0..num_states {
483 sum += alpha[t - 1][s1] * hmm.transition_probabilities[[s1, s2]];
484 }
485 alpha[t][s2] = sum * hmm.emission_probabilities[[s2, sequence[t]]];
486 }
487 }
488
489 let mut beta = vec![vec![0.0; num_states]; seq_len];
491
492 for s in 0..num_states {
494 beta[seq_len - 1][s] = 1.0;
495 }
496
497 for t in (0..(seq_len - 1)).rev() {
499 for s1 in 0..num_states {
500 let mut sum = 0.0;
501 for s2 in 0..num_states {
502 sum += hmm.transition_probabilities[[s1, s2]]
503 * hmm.emission_probabilities[[s2, sequence[t + 1]]]
504 * beta[t + 1][s2];
505 }
506 beta[t][s1] = sum;
507 }
508 }
509
510 let log_likelihood: f64 = alpha[seq_len - 1].iter().sum::<f64>().ln();
512
513 Ok((alpha, beta, log_likelihood))
514 }
515
516 fn compute_gamma(
518 &self,
519 alpha: &[Vec<f64>],
520 beta: &[Vec<f64>],
521 t: usize,
522 s: usize,
523 log_likelihood: f64,
524 ) -> f64 {
525 (alpha[t][s] * beta[t][s]) / log_likelihood.exp()
526 }
527
528 #[allow(clippy::too_many_arguments)]
530 fn compute_xi(
531 &self,
532 hmm: &SimpleHMM,
533 alpha: &[Vec<f64>],
534 beta: &[Vec<f64>],
535 t: usize,
536 s1: usize,
537 s2: usize,
538 next_obs: usize,
539 log_likelihood: f64,
540 ) -> f64 {
541 let numerator = alpha[t][s1]
542 * hmm.transition_probabilities[[s1, s2]]
543 * hmm.emission_probabilities[[s2, next_obs]]
544 * beta[t + 1][s2];
545
546 numerator / log_likelihood.exp()
547 }
548
549 fn update_parameters(
551 &self,
552 hmm: &mut SimpleHMM,
553 initial_counts: &[f64],
554 transition_counts: &[Vec<f64>],
555 emission_counts: &[Vec<f64>],
556 ) -> Result<()> {
557 let num_states = hmm.num_states;
558 let num_observations = hmm.num_observations;
559
560 let initial_sum: f64 = initial_counts.iter().sum();
562 if initial_sum > 0.0 {
563 for (s, &count) in initial_counts.iter().enumerate().take(num_states) {
564 hmm.initial_distribution[[s]] = count / initial_sum;
565 }
566 }
567
568 for (s1, trans_counts) in transition_counts.iter().enumerate().take(num_states) {
570 let trans_sum: f64 = trans_counts.iter().sum();
571 if trans_sum > 0.0 {
572 for (s2, &count) in trans_counts.iter().enumerate().take(num_states) {
573 hmm.transition_probabilities[[s1, s2]] = count / trans_sum;
574 }
575 }
576 }
577
578 for (s, emis_counts) in emission_counts.iter().enumerate().take(num_states) {
580 let emission_sum: f64 = emis_counts.iter().sum();
581 if emission_sum > 0.0 {
582 for (o, &count) in emis_counts.iter().enumerate().take(num_observations) {
583 hmm.emission_probabilities[[s, o]] = count / emission_sum;
584 }
585 }
586 }
587
588 Ok(())
589 }
590}
591
592pub mod utils {
594 use super::*;
595
596 pub fn count_occurrences(variable: &str, data: &[Assignment]) -> HashMap<usize, usize> {
598 let mut counts = HashMap::new();
599
600 for assignment in data {
601 if let Some(&value) = assignment.get(variable) {
602 *counts.entry(value).or_insert(0) += 1;
603 }
604 }
605
606 counts
607 }
608
609 pub fn count_joint_occurrences(
611 var1: &str,
612 var2: &str,
613 data: &[Assignment],
614 ) -> HashMap<(usize, usize), usize> {
615 let mut counts = HashMap::new();
616
617 for assignment in data {
618 if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
619 *counts.entry((v1, v2)).or_insert(0) += 1;
620 }
621 }
622
623 counts
624 }
625
626 pub fn counts_to_distribution(counts: &HashMap<usize, usize>, cardinality: usize) -> Vec<f64> {
628 let total: usize = counts.values().sum();
629 let mut probs = vec![0.0; cardinality];
630
631 for (&value, &count) in counts {
632 if value < cardinality && total > 0 {
633 probs[value] = count as f64 / total as f64;
634 }
635 }
636
637 probs
638 }
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644
645 #[test]
646 fn test_mle_marginal() {
647 let estimator = MaximumLikelihoodEstimator::new();
648
649 let mut data = Vec::new();
650 for _ in 0..7 {
651 let mut assignment = HashMap::new();
652 assignment.insert("X".to_string(), 0);
653 data.push(assignment);
654 }
655 for _ in 0..3 {
656 let mut assignment = HashMap::new();
657 assignment.insert("X".to_string(), 1);
658 data.push(assignment);
659 }
660
661 let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
662
663 assert!((probs[[0]] - 0.7).abs() < 1e-6);
664 assert!((probs[[1]] - 0.3).abs() < 1e-6);
665 }
666
667 #[test]
668 fn test_mle_with_laplace() {
669 let estimator = MaximumLikelihoodEstimator::with_laplace(1.0);
670
671 let mut data = Vec::new();
672 for _ in 0..8 {
673 let mut assignment = HashMap::new();
674 assignment.insert("X".to_string(), 0);
675 data.push(assignment);
676 }
677 let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
680
681 assert!((probs[[0]] - 0.9).abs() < 1e-6);
683 assert!((probs[[1]] - 0.1).abs() < 1e-6);
684 }
685
686 #[test]
687 fn test_bayesian_estimator() {
688 let estimator = BayesianEstimator::new(2.0);
689
690 let mut data = Vec::new();
691 for _ in 0..8 {
692 let mut assignment = HashMap::new();
693 assignment.insert("X".to_string(), 0);
694 data.push(assignment);
695 }
696
697 let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
698
699 assert!((probs[[0]] - 0.9).abs() < 1e-6);
702 assert!((probs[[1]] - 0.1).abs() < 1e-6);
703 }
704
705 #[test]
706 fn test_count_occurrences() {
707 let mut data = Vec::new();
708 for i in 0..10 {
709 let mut assignment = HashMap::new();
710 assignment.insert("X".to_string(), i % 3);
711 data.push(assignment);
712 }
713
714 let counts = utils::count_occurrences("X", &data);
715
716 assert_eq!(counts.get(&0), Some(&4)); assert_eq!(counts.get(&1), Some(&3)); assert_eq!(counts.get(&2), Some(&3)); }
720
721 #[test]
722 fn test_counts_to_distribution() {
723 let mut counts = HashMap::new();
724 counts.insert(0, 7);
725 counts.insert(1, 3);
726
727 let probs = utils::counts_to_distribution(&counts, 2);
728
729 assert!((probs[0] - 0.7).abs() < 1e-6);
730 assert!((probs[1] - 0.3).abs() < 1e-6);
731 }
732}