1use crate::error::{PgmError, Result};
26use crate::sampling::Assignment;
27use scirs2_core::ndarray::{Array1, Array2, ArrayD};
28use std::collections::HashMap;
29
30#[derive(Debug, Clone)]
35pub struct SimpleHMM {
36 pub num_states: usize,
38 pub num_observations: usize,
40 pub initial_distribution: Array1<f64>,
42 pub transition_probabilities: Array2<f64>,
44 pub emission_probabilities: Array2<f64>,
46}
47
48impl SimpleHMM {
49 pub fn new(num_states: usize, num_observations: usize) -> Self {
51 let initial_distribution = Array1::from_elem(num_states, 1.0 / num_states as f64);
52
53 let transition_probabilities =
54 Array2::from_elem((num_states, num_states), 1.0 / num_states as f64);
55
56 let emission_probabilities = Array2::from_elem(
57 (num_states, num_observations),
58 1.0 / num_observations as f64,
59 );
60
61 Self {
62 num_states,
63 num_observations,
64 initial_distribution,
65 transition_probabilities,
66 emission_probabilities,
67 }
68 }
69
70 pub fn new_random(num_states: usize, num_observations: usize) -> Self {
72 use scirs2_core::random::{thread_rng, Rng};
73
74 let mut rng = thread_rng();
75 let mut hmm = Self::new(num_states, num_observations);
76
77 let mut init_sum = 0.0;
79 for i in 0..num_states {
80 hmm.initial_distribution[i] = rng.random::<f64>();
81 init_sum += hmm.initial_distribution[i];
82 }
83 hmm.initial_distribution /= init_sum;
84
85 for i in 0..num_states {
87 let mut trans_sum = 0.0;
88 for j in 0..num_states {
89 hmm.transition_probabilities[[i, j]] = rng.random::<f64>();
90 trans_sum += hmm.transition_probabilities[[i, j]];
91 }
92 for j in 0..num_states {
93 hmm.transition_probabilities[[i, j]] /= trans_sum;
94 }
95 }
96
97 for i in 0..num_states {
99 let mut emission_sum = 0.0;
100 for j in 0..num_observations {
101 hmm.emission_probabilities[[i, j]] = rng.random::<f64>();
102 emission_sum += hmm.emission_probabilities[[i, j]];
103 }
104 for j in 0..num_observations {
105 hmm.emission_probabilities[[i, j]] /= emission_sum;
106 }
107 }
108
109 hmm
110 }
111}
112
113#[derive(Debug, Clone)]
117pub struct MaximumLikelihoodEstimator {
118 pub use_laplace: bool,
120 pub pseudocount: f64,
122}
123
124impl MaximumLikelihoodEstimator {
125 pub fn new() -> Self {
127 Self {
128 use_laplace: false,
129 pseudocount: 1.0,
130 }
131 }
132
133 pub fn with_laplace(pseudocount: f64) -> Self {
135 Self {
136 use_laplace: true,
137 pseudocount,
138 }
139 }
140
141 pub fn estimate_marginal(
153 &self,
154 variable: &str,
155 cardinality: usize,
156 data: &[Assignment],
157 ) -> Result<ArrayD<f64>> {
158 let pseudocount = if self.use_laplace {
159 self.pseudocount
160 } else {
161 0.0
162 };
163 let mut counts = vec![pseudocount; cardinality];
164
165 for assignment in data {
167 if let Some(&value) = assignment.get(variable) {
168 if value < cardinality {
169 counts[value] += 1.0;
170 }
171 }
172 }
173
174 let total: f64 = counts.iter().sum();
176 if total == 0.0 {
177 return Err(PgmError::InvalidDistribution(
178 "No data for variable".to_string(),
179 ));
180 }
181
182 let probs: Vec<f64> = counts.iter().map(|&c| c / total).collect();
183
184 ArrayD::from_shape_vec(vec![cardinality], probs)
185 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
186 }
187
188 pub fn estimate_conditional(
197 &self,
198 child: &str,
199 parents: &[String],
200 cardinalities: &[usize],
201 data: &[Assignment],
202 ) -> Result<ArrayD<f64>> {
203 if cardinalities.is_empty() {
204 return Err(PgmError::InvalidGraph(
205 "Cardinalities must not be empty".to_string(),
206 ));
207 }
208
209 let pseudocount = if self.use_laplace {
210 self.pseudocount
211 } else {
212 0.0
213 };
214
215 let child_card = cardinalities[0];
216 let parent_cards = &cardinalities[1..];
217
218 let num_parent_configs: usize = parent_cards.iter().product();
220
221 let mut counts = vec![vec![pseudocount; child_card]; num_parent_configs];
223
224 for assignment in data {
226 if let Some(&child_val) = assignment.get(child) {
227 let mut parent_config = 0;
229 let mut multiplier = 1;
230
231 for (i, parent) in parents.iter().enumerate() {
232 if let Some(&parent_val) = assignment.get(parent) {
233 parent_config += parent_val * multiplier;
234 multiplier *= parent_cards[i];
235 } else {
236 continue; }
238 }
239
240 if parent_config < num_parent_configs && child_val < child_card {
241 counts[parent_config][child_val] += 1.0;
242 }
243 }
244 }
245
246 let mut probs = Vec::new();
248 for config_counts in counts {
249 let total: f64 = config_counts.iter().sum();
250 if total > 0.0 {
251 for count in config_counts {
252 probs.push(count / total);
253 }
254 } else {
255 for _ in 0..child_card {
257 probs.push(1.0 / child_card as f64);
258 }
259 }
260 }
261
262 ArrayD::from_shape_vec(cardinalities.to_vec(), probs)
264 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
265 }
266}
267
268impl Default for MaximumLikelihoodEstimator {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[derive(Debug, Clone)]
278pub struct BayesianEstimator {
279 pub prior_strength: f64,
281}
282
283impl BayesianEstimator {
284 pub fn new(prior_strength: f64) -> Self {
290 Self { prior_strength }
291 }
292
293 pub fn estimate_marginal(
295 &self,
296 variable: &str,
297 cardinality: usize,
298 data: &[Assignment],
299 ) -> Result<ArrayD<f64>> {
300 let alpha = self.prior_strength / cardinality as f64;
302 let mut counts = vec![alpha; cardinality];
303
304 for assignment in data {
306 if let Some(&value) = assignment.get(variable) {
307 if value < cardinality {
308 counts[value] += 1.0;
309 }
310 }
311 }
312
313 let total: f64 = counts.iter().sum();
315 let probs: Vec<f64> = counts.iter().map(|&c| c / total).collect();
316
317 ArrayD::from_shape_vec(vec![cardinality], probs)
318 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))
319 }
320}
321
322#[derive(Debug, Clone)]
331pub struct BaumWelchLearner {
332 pub max_iterations: usize,
334 pub tolerance: f64,
336 pub verbose: bool,
338}
339
340impl BaumWelchLearner {
341 pub fn new(max_iterations: usize, tolerance: f64) -> Self {
343 Self {
344 max_iterations,
345 tolerance,
346 verbose: false,
347 }
348 }
349
350 pub fn with_verbose(max_iterations: usize, tolerance: f64) -> Self {
352 Self {
353 max_iterations,
354 tolerance,
355 verbose: true,
356 }
357 }
358
359 pub fn learn(&self, hmm: &mut SimpleHMM, observation_sequences: &[Vec<usize>]) -> Result<f64> {
370 let num_states = hmm.num_states;
371 let num_observations = hmm.num_observations;
372
373 let mut prev_log_likelihood = f64::NEG_INFINITY;
374
375 for iteration in 0..self.max_iterations {
376 let mut initial_counts = vec![0.0; num_states];
378 let mut transition_counts = vec![vec![0.0; num_states]; num_states];
379 let mut emission_counts = vec![vec![0.0; num_observations]; num_states];
380
381 let mut total_log_likelihood = 0.0;
382
383 for sequence in observation_sequences {
384 let (alpha, beta, log_likelihood) = self.forward_backward(hmm, sequence)?;
385 total_log_likelihood += log_likelihood;
386
387 let seq_len = sequence.len();
388
389 for (s, count) in initial_counts.iter_mut().enumerate().take(num_states) {
391 let gamma_0 = self.compute_gamma(&alpha, &beta, 0, s, log_likelihood);
392 *count += gamma_0;
393 }
394
395 for t in 0..(seq_len - 1) {
397 for s1 in 0..num_states {
398 let gamma_t = self.compute_gamma(&alpha, &beta, t, s1, log_likelihood);
399
400 emission_counts[s1][sequence[t]] += gamma_t;
402
403 for s2 in 0..num_states {
405 let xi = self.compute_xi(
406 hmm,
407 &alpha,
408 &beta,
409 t,
410 s1,
411 s2,
412 sequence[t + 1],
413 log_likelihood,
414 );
415 transition_counts[s1][s2] += xi;
416 }
417 }
418 }
419
420 for (s, counts) in emission_counts.iter_mut().enumerate().take(num_states) {
422 let gamma_last =
423 self.compute_gamma(&alpha, &beta, seq_len - 1, s, log_likelihood);
424 counts[sequence[seq_len - 1]] += gamma_last;
425 }
426 }
427
428 self.update_parameters(hmm, &initial_counts, &transition_counts, &emission_counts)?;
430
431 let avg_log_likelihood = total_log_likelihood / observation_sequences.len() as f64;
433
434 if self.verbose {
435 println!(
436 "Iteration {}: log-likelihood = {:.4}",
437 iteration, avg_log_likelihood
438 );
439 }
440
441 if (avg_log_likelihood - prev_log_likelihood).abs() < self.tolerance {
442 if self.verbose {
443 println!("Converged after {} iterations", iteration + 1);
444 }
445 return Ok(avg_log_likelihood);
446 }
447
448 prev_log_likelihood = avg_log_likelihood;
449 }
450
451 if self.verbose {
452 println!("Maximum iterations reached");
453 }
454
455 Ok(prev_log_likelihood)
456 }
457
458 #[allow(clippy::type_complexity)]
460 fn forward_backward(
461 &self,
462 hmm: &SimpleHMM,
463 sequence: &[usize],
464 ) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>, f64)> {
465 let num_states = hmm.num_states;
466 let seq_len = sequence.len();
467
468 let mut alpha = vec![vec![0.0; num_states]; seq_len];
470
471 for s in 0..num_states {
473 alpha[0][s] =
474 hmm.initial_distribution[[s]] * hmm.emission_probabilities[[s, sequence[0]]];
475 }
476
477 for t in 1..seq_len {
479 for s2 in 0..num_states {
480 let mut sum = 0.0;
481 for s1 in 0..num_states {
482 sum += alpha[t - 1][s1] * hmm.transition_probabilities[[s1, s2]];
483 }
484 alpha[t][s2] = sum * hmm.emission_probabilities[[s2, sequence[t]]];
485 }
486 }
487
488 let mut beta = vec![vec![0.0; num_states]; seq_len];
490
491 for s in 0..num_states {
493 beta[seq_len - 1][s] = 1.0;
494 }
495
496 for t in (0..(seq_len - 1)).rev() {
498 for s1 in 0..num_states {
499 let mut sum = 0.0;
500 for s2 in 0..num_states {
501 sum += hmm.transition_probabilities[[s1, s2]]
502 * hmm.emission_probabilities[[s2, sequence[t + 1]]]
503 * beta[t + 1][s2];
504 }
505 beta[t][s1] = sum;
506 }
507 }
508
509 let log_likelihood: f64 = alpha[seq_len - 1].iter().sum::<f64>().ln();
511
512 Ok((alpha, beta, log_likelihood))
513 }
514
515 fn compute_gamma(
517 &self,
518 alpha: &[Vec<f64>],
519 beta: &[Vec<f64>],
520 t: usize,
521 s: usize,
522 log_likelihood: f64,
523 ) -> f64 {
524 (alpha[t][s] * beta[t][s]) / log_likelihood.exp()
525 }
526
527 #[allow(clippy::too_many_arguments)]
529 fn compute_xi(
530 &self,
531 hmm: &SimpleHMM,
532 alpha: &[Vec<f64>],
533 beta: &[Vec<f64>],
534 t: usize,
535 s1: usize,
536 s2: usize,
537 next_obs: usize,
538 log_likelihood: f64,
539 ) -> f64 {
540 let numerator = alpha[t][s1]
541 * hmm.transition_probabilities[[s1, s2]]
542 * hmm.emission_probabilities[[s2, next_obs]]
543 * beta[t + 1][s2];
544
545 numerator / log_likelihood.exp()
546 }
547
548 fn update_parameters(
550 &self,
551 hmm: &mut SimpleHMM,
552 initial_counts: &[f64],
553 transition_counts: &[Vec<f64>],
554 emission_counts: &[Vec<f64>],
555 ) -> Result<()> {
556 let num_states = hmm.num_states;
557 let num_observations = hmm.num_observations;
558
559 let initial_sum: f64 = initial_counts.iter().sum();
561 if initial_sum > 0.0 {
562 for (s, &count) in initial_counts.iter().enumerate().take(num_states) {
563 hmm.initial_distribution[[s]] = count / initial_sum;
564 }
565 }
566
567 for (s1, trans_counts) in transition_counts.iter().enumerate().take(num_states) {
569 let trans_sum: f64 = trans_counts.iter().sum();
570 if trans_sum > 0.0 {
571 for (s2, &count) in trans_counts.iter().enumerate().take(num_states) {
572 hmm.transition_probabilities[[s1, s2]] = count / trans_sum;
573 }
574 }
575 }
576
577 for (s, emis_counts) in emission_counts.iter().enumerate().take(num_states) {
579 let emission_sum: f64 = emis_counts.iter().sum();
580 if emission_sum > 0.0 {
581 for (o, &count) in emis_counts.iter().enumerate().take(num_observations) {
582 hmm.emission_probabilities[[s, o]] = count / emission_sum;
583 }
584 }
585 }
586
587 Ok(())
588 }
589}
590
591pub mod utils {
593 use super::*;
594
595 pub fn count_occurrences(variable: &str, data: &[Assignment]) -> HashMap<usize, usize> {
597 let mut counts = HashMap::new();
598
599 for assignment in data {
600 if let Some(&value) = assignment.get(variable) {
601 *counts.entry(value).or_insert(0) += 1;
602 }
603 }
604
605 counts
606 }
607
608 pub fn count_joint_occurrences(
610 var1: &str,
611 var2: &str,
612 data: &[Assignment],
613 ) -> HashMap<(usize, usize), usize> {
614 let mut counts = HashMap::new();
615
616 for assignment in data {
617 if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
618 *counts.entry((v1, v2)).or_insert(0) += 1;
619 }
620 }
621
622 counts
623 }
624
625 pub fn counts_to_distribution(counts: &HashMap<usize, usize>, cardinality: usize) -> Vec<f64> {
627 let total: usize = counts.values().sum();
628 let mut probs = vec![0.0; cardinality];
629
630 for (&value, &count) in counts {
631 if value < cardinality && total > 0 {
632 probs[value] = count as f64 / total as f64;
633 }
634 }
635
636 probs
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
645 fn test_mle_marginal() {
646 let estimator = MaximumLikelihoodEstimator::new();
647
648 let mut data = Vec::new();
649 for _ in 0..7 {
650 let mut assignment = HashMap::new();
651 assignment.insert("X".to_string(), 0);
652 data.push(assignment);
653 }
654 for _ in 0..3 {
655 let mut assignment = HashMap::new();
656 assignment.insert("X".to_string(), 1);
657 data.push(assignment);
658 }
659
660 let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
661
662 assert!((probs[[0]] - 0.7).abs() < 1e-6);
663 assert!((probs[[1]] - 0.3).abs() < 1e-6);
664 }
665
666 #[test]
667 fn test_mle_with_laplace() {
668 let estimator = MaximumLikelihoodEstimator::with_laplace(1.0);
669
670 let mut data = Vec::new();
671 for _ in 0..8 {
672 let mut assignment = HashMap::new();
673 assignment.insert("X".to_string(), 0);
674 data.push(assignment);
675 }
676 let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
679
680 assert!((probs[[0]] - 0.9).abs() < 1e-6);
682 assert!((probs[[1]] - 0.1).abs() < 1e-6);
683 }
684
685 #[test]
686 fn test_bayesian_estimator() {
687 let estimator = BayesianEstimator::new(2.0);
688
689 let mut data = Vec::new();
690 for _ in 0..8 {
691 let mut assignment = HashMap::new();
692 assignment.insert("X".to_string(), 0);
693 data.push(assignment);
694 }
695
696 let probs = estimator.estimate_marginal("X", 2, &data).unwrap();
697
698 assert!((probs[[0]] - 0.9).abs() < 1e-6);
701 assert!((probs[[1]] - 0.1).abs() < 1e-6);
702 }
703
704 #[test]
705 fn test_count_occurrences() {
706 let mut data = Vec::new();
707 for i in 0..10 {
708 let mut assignment = HashMap::new();
709 assignment.insert("X".to_string(), i % 3);
710 data.push(assignment);
711 }
712
713 let counts = utils::count_occurrences("X", &data);
714
715 assert_eq!(counts.get(&0), Some(&4)); assert_eq!(counts.get(&1), Some(&3)); assert_eq!(counts.get(&2), Some(&3)); }
719
720 #[test]
721 fn test_counts_to_distribution() {
722 let mut counts = HashMap::new();
723 counts.insert(0, 7);
724 counts.insert(1, 3);
725
726 let probs = utils::counts_to_distribution(&counts, 2);
727
728 assert!((probs[0] - 0.7).abs() < 1e-6);
729 assert!((probs[1] - 0.3).abs() < 1e-6);
730 }
731}