1use scirs2_core::ndarray::ArrayD;
13use scirs2_core::random::{thread_rng, Rng};
14use std::collections::HashMap;
15
16use crate::error::{PgmError, Result};
17use crate::graph::FactorGraph;
18
19pub type Assignment = HashMap<String, usize>;
21
22pub struct GibbsSampler {
26 pub burn_in: usize,
28 pub num_samples: usize,
30 pub thinning: usize,
32}
33
34impl Default for GibbsSampler {
35 fn default() -> Self {
36 Self {
37 burn_in: 100,
38 num_samples: 1000,
39 thinning: 1,
40 }
41 }
42}
43
44impl GibbsSampler {
45 pub fn new(burn_in: usize, num_samples: usize, thinning: usize) -> Self {
47 Self {
48 burn_in,
49 num_samples,
50 thinning,
51 }
52 }
53
54 pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
56 let mut current_assignment = self.initialize_assignment(graph)?;
58
59 for _ in 0..self.burn_in {
61 self.gibbs_step(graph, &mut current_assignment)?;
62 }
63
64 let mut samples = Vec::new();
66 for i in 0..self.num_samples * self.thinning {
67 self.gibbs_step(graph, &mut current_assignment)?;
68
69 if i % self.thinning == 0 {
71 samples.push(current_assignment.clone());
72 }
73 }
74
75 self.compute_empirical_marginals(graph, &samples)
77 }
78
79 fn initialize_assignment(&self, graph: &FactorGraph) -> Result<Assignment> {
81 let mut rng = thread_rng();
82 let mut assignment = Assignment::new();
83
84 for var_name in graph.variable_names() {
85 if let Some(var_node) = graph.get_variable(var_name) {
86 let random_value = rng.gen_range(0..var_node.cardinality);
87 assignment.insert(var_name.clone(), random_value);
88 }
89 }
90
91 Ok(assignment)
92 }
93
94 fn gibbs_step(&self, graph: &FactorGraph, assignment: &mut Assignment) -> Result<()> {
96 for var_name in graph.variable_names() {
98 self.resample_variable(graph, var_name, assignment)?;
99 }
100
101 Ok(())
102 }
103
104 fn resample_variable(
106 &self,
107 graph: &FactorGraph,
108 var_name: &str,
109 assignment: &mut Assignment,
110 ) -> Result<()> {
111 let var_node = graph
112 .get_variable(var_name)
113 .ok_or_else(|| PgmError::VariableNotFound(var_name.to_string()))?;
114
115 let mut conditional_probs = vec![0.0; var_node.cardinality];
117
118 for (value, prob) in conditional_probs
119 .iter_mut()
120 .enumerate()
121 .take(var_node.cardinality)
122 {
123 assignment.insert(var_name.to_string(), value);
124 *prob = self.compute_joint_probability(graph, assignment)?;
125 }
126
127 let sum: f64 = conditional_probs.iter().sum();
129 if sum > 0.0 {
130 for prob in &mut conditional_probs {
131 *prob /= sum;
132 }
133 } else {
134 let uniform_prob = 1.0 / var_node.cardinality as f64;
136 conditional_probs = vec![uniform_prob; var_node.cardinality];
137 }
138
139 let sampled_value = self.sample_from_distribution(&conditional_probs);
141 assignment.insert(var_name.to_string(), sampled_value);
142
143 Ok(())
144 }
145
146 fn compute_joint_probability(
148 &self,
149 graph: &FactorGraph,
150 assignment: &Assignment,
151 ) -> Result<f64> {
152 let mut prob = 1.0;
153
154 for factor_id in graph.factor_ids() {
155 if let Some(factor) = graph.get_factor(factor_id) {
156 let mut indices = Vec::new();
158 for var in &factor.variables {
159 if let Some(&value) = assignment.get(var) {
160 indices.push(value);
161 } else {
162 return Err(PgmError::VariableNotFound(var.clone()));
163 }
164 }
165
166 prob *= factor.values[indices.as_slice()];
167 }
168 }
169
170 Ok(prob)
171 }
172
173 fn sample_from_distribution(&self, probs: &[f64]) -> usize {
175 let mut rng = thread_rng();
176 let u: f64 = rng.random();
177
178 let mut cumulative = 0.0;
179 for (idx, &prob) in probs.iter().enumerate() {
180 cumulative += prob;
181 if u < cumulative {
182 return idx;
183 }
184 }
185
186 probs.len() - 1
188 }
189
190 fn compute_empirical_marginals(
192 &self,
193 graph: &FactorGraph,
194 samples: &[Assignment],
195 ) -> Result<HashMap<String, ArrayD<f64>>> {
196 let mut marginals = HashMap::new();
197
198 for var_name in graph.variable_names() {
199 if let Some(var_node) = graph.get_variable(var_name) {
200 let mut counts = vec![0; var_node.cardinality];
201
202 for sample in samples {
204 if let Some(&value) = sample.get(var_name) {
205 counts[value] += 1;
206 }
207 }
208
209 let total = samples.len() as f64;
211 let probs: Vec<f64> = counts.iter().map(|&c| c as f64 / total).collect();
212
213 marginals.insert(
214 var_name.clone(),
215 ArrayD::from_shape_vec(vec![var_node.cardinality], probs)?,
216 );
217 }
218 }
219
220 Ok(marginals)
221 }
222
223 pub fn get_samples(&self, graph: &FactorGraph) -> Result<Vec<Assignment>> {
225 let mut current_assignment = self.initialize_assignment(graph)?;
226
227 for _ in 0..self.burn_in {
229 self.gibbs_step(graph, &mut current_assignment)?;
230 }
231
232 let mut samples = Vec::new();
234 for i in 0..self.num_samples * self.thinning {
235 self.gibbs_step(graph, &mut current_assignment)?;
236
237 if i % self.thinning == 0 {
238 samples.push(current_assignment.clone());
239 }
240 }
241
242 Ok(samples)
243 }
244}
245
246impl From<scirs2_core::ndarray::ShapeError> for PgmError {
247 fn from(err: scirs2_core::ndarray::ShapeError) -> Self {
248 PgmError::InvalidDistribution(format!("Shape error: {}", err))
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct WeightedSample {
255 pub assignment: Assignment,
257 pub weight: f64,
259 pub log_weight: f64,
261}
262
263pub struct ImportanceSampler {
280 pub num_samples: usize,
282 pub self_normalize: bool,
284}
285
286#[derive(Debug, Clone)]
288pub enum ProposalDistribution {
289 Uniform,
291 Custom(HashMap<String, Vec<f64>>),
293 Prior,
295}
296
297impl Default for ImportanceSampler {
298 fn default() -> Self {
299 Self {
300 num_samples: 1000,
301 self_normalize: true,
302 }
303 }
304}
305
306impl ImportanceSampler {
307 pub fn new(num_samples: usize) -> Self {
309 Self {
310 num_samples,
311 self_normalize: true,
312 }
313 }
314
315 pub fn with_self_normalize(mut self, self_normalize: bool) -> Self {
317 self.self_normalize = self_normalize;
318 self
319 }
320
321 pub fn run(
323 &self,
324 graph: &FactorGraph,
325 proposal: ProposalDistribution,
326 ) -> Result<HashMap<String, ArrayD<f64>>> {
327 let samples = self.draw_weighted_samples(graph, &proposal)?;
328 self.compute_weighted_marginals(graph, &samples)
329 }
330
331 pub fn draw_weighted_samples(
333 &self,
334 graph: &FactorGraph,
335 proposal: &ProposalDistribution,
336 ) -> Result<Vec<WeightedSample>> {
337 let mut samples = Vec::with_capacity(self.num_samples);
338 let mut rng = thread_rng();
339
340 for _ in 0..self.num_samples {
341 let (assignment, proposal_prob) =
343 self.sample_from_proposal(graph, proposal, &mut rng)?;
344
345 let target_prob = self.compute_target_probability(graph, &assignment)?;
347
348 let weight = if proposal_prob > 0.0 {
350 target_prob / proposal_prob
351 } else {
352 0.0
353 };
354
355 let log_weight = if proposal_prob > 0.0 && target_prob > 0.0 {
356 target_prob.ln() - proposal_prob.ln()
357 } else {
358 f64::NEG_INFINITY
359 };
360
361 samples.push(WeightedSample {
362 assignment,
363 weight,
364 log_weight,
365 });
366 }
367
368 Ok(samples)
369 }
370
371 fn sample_from_proposal(
373 &self,
374 graph: &FactorGraph,
375 proposal: &ProposalDistribution,
376 rng: &mut impl Rng,
377 ) -> Result<(Assignment, f64)> {
378 let mut assignment = Assignment::new();
379 let mut proposal_prob = 1.0;
380
381 for var_name in graph.variable_names() {
382 if let Some(var_node) = graph.get_variable(var_name) {
383 let (value, prob) = match proposal {
384 ProposalDistribution::Uniform => {
385 let value = rng.random_range(0..var_node.cardinality);
386 let prob = 1.0 / var_node.cardinality as f64;
387 (value, prob)
388 }
389 ProposalDistribution::Custom(weights) => {
390 if let Some(var_weights) = weights.get(var_name) {
391 let (value, prob) = self.sample_categorical(var_weights, rng);
392 (value, prob)
393 } else {
394 let value = rng.random_range(0..var_node.cardinality);
395 let prob = 1.0 / var_node.cardinality as f64;
396 (value, prob)
397 }
398 }
399 ProposalDistribution::Prior => {
400 let value = rng.random_range(0..var_node.cardinality);
402 let prob = 1.0 / var_node.cardinality as f64;
403 (value, prob)
404 }
405 };
406
407 assignment.insert(var_name.clone(), value);
408 proposal_prob *= prob;
409 }
410 }
411
412 Ok((assignment, proposal_prob))
413 }
414
415 fn sample_categorical(&self, weights: &[f64], rng: &mut impl Rng) -> (usize, f64) {
417 let total: f64 = weights.iter().sum();
418 if total <= 0.0 {
419 return (0, 1.0 / weights.len() as f64);
420 }
421
422 let normalized: Vec<f64> = weights.iter().map(|w| w / total).collect();
423 let u: f64 = rng.random();
424
425 let mut cumulative = 0.0;
426 for (idx, &prob) in normalized.iter().enumerate() {
427 cumulative += prob;
428 if u < cumulative {
429 return (idx, prob);
430 }
431 }
432
433 (weights.len() - 1, *normalized.last().unwrap_or(&0.0))
434 }
435
436 fn compute_target_probability(
438 &self,
439 graph: &FactorGraph,
440 assignment: &Assignment,
441 ) -> Result<f64> {
442 let mut prob = 1.0;
443
444 for factor_id in graph.factor_ids() {
445 if let Some(factor) = graph.get_factor(factor_id) {
446 let mut indices = Vec::new();
447 for var in &factor.variables {
448 if let Some(&value) = assignment.get(var) {
449 indices.push(value);
450 } else {
451 return Err(PgmError::VariableNotFound(var.clone()));
452 }
453 }
454 prob *= factor.values[indices.as_slice()];
455 }
456 }
457
458 Ok(prob)
459 }
460
461 fn compute_weighted_marginals(
463 &self,
464 graph: &FactorGraph,
465 samples: &[WeightedSample],
466 ) -> Result<HashMap<String, ArrayD<f64>>> {
467 let mut marginals = HashMap::new();
468
469 let total_weight: f64 = samples.iter().map(|s| s.weight).sum();
471
472 for var_name in graph.variable_names() {
473 if let Some(var_node) = graph.get_variable(var_name) {
474 let mut weighted_counts = vec![0.0; var_node.cardinality];
475
476 for sample in samples {
478 if let Some(&value) = sample.assignment.get(var_name) {
479 weighted_counts[value] += sample.weight;
480 }
481 }
482
483 let probs: Vec<f64> = if self.self_normalize && total_weight > 0.0 {
485 weighted_counts.iter().map(|&c| c / total_weight).collect()
486 } else {
487 let sum: f64 = weighted_counts.iter().sum();
488 if sum > 0.0 {
489 weighted_counts.iter().map(|&c| c / sum).collect()
490 } else {
491 vec![1.0 / var_node.cardinality as f64; var_node.cardinality]
492 }
493 };
494
495 marginals.insert(
496 var_name.clone(),
497 ArrayD::from_shape_vec(vec![var_node.cardinality], probs)?,
498 );
499 }
500 }
501
502 Ok(marginals)
503 }
504
505 pub fn get_weighted_samples(
507 &self,
508 graph: &FactorGraph,
509 proposal: &ProposalDistribution,
510 ) -> Result<Vec<WeightedSample>> {
511 self.draw_weighted_samples(graph, proposal)
512 }
513
514 pub fn effective_sample_size(samples: &[WeightedSample]) -> f64 {
519 let weights: Vec<f64> = samples.iter().map(|s| s.weight).collect();
520 let sum_w: f64 = weights.iter().sum();
521 let sum_w2: f64 = weights.iter().map(|w| w * w).sum();
522
523 if sum_w2 > 0.0 {
524 (sum_w * sum_w) / sum_w2
525 } else {
526 0.0
527 }
528 }
529
530 pub fn weight_coefficient_of_variation(samples: &[WeightedSample]) -> f64 {
532 let n = samples.len() as f64;
533 let weights: Vec<f64> = samples.iter().map(|s| s.weight).collect();
534 let mean = weights.iter().sum::<f64>() / n;
535 let variance = weights.iter().map(|w| (w - mean).powi(2)).sum::<f64>() / n;
536 let std_dev = variance.sqrt();
537
538 if mean > 0.0 {
539 std_dev / mean
540 } else {
541 0.0
542 }
543 }
544
545 pub fn resample(samples: &[WeightedSample]) -> Vec<WeightedSample> {
547 let n = samples.len();
548 if n == 0 {
549 return Vec::new();
550 }
551
552 let mut rng = thread_rng();
553 let total_weight: f64 = samples.iter().map(|s| s.weight).sum();
554
555 if total_weight <= 0.0 {
556 return samples.to_vec();
557 }
558
559 let normalized_weights: Vec<f64> =
560 samples.iter().map(|s| s.weight / total_weight).collect();
561
562 let mut resampled = Vec::with_capacity(n);
564 let u0: f64 = rng.random::<f64>() / n as f64;
565
566 let mut cumulative = 0.0;
567 let mut j = 0;
568
569 for i in 0..n {
570 let u = u0 + (i as f64) / (n as f64);
571 while cumulative + normalized_weights[j] < u && j < n - 1 {
572 cumulative += normalized_weights[j];
573 j += 1;
574 }
575
576 resampled.push(WeightedSample {
577 assignment: samples[j].assignment.clone(),
578 weight: 1.0,
579 log_weight: 0.0,
580 });
581 }
582
583 resampled
584 }
585}
586
587#[derive(Debug, Clone)]
589pub struct Particle {
590 pub state: Assignment,
592 pub weight: f64,
594 pub log_weight: f64,
596 pub history: Vec<Assignment>,
598}
599
600pub struct ParticleFilter {
622 pub num_particles: usize,
624 pub particles: Vec<Particle>,
626 pub state_variables: Vec<String>,
628 pub ess_threshold: f64,
630 pub track_history: bool,
632}
633
634impl ParticleFilter {
635 pub fn new(num_particles: usize, state_variables: Vec<String>) -> Self {
637 Self {
638 num_particles,
639 particles: Vec::new(),
640 state_variables,
641 ess_threshold: 0.5,
642 track_history: false,
643 }
644 }
645
646 pub fn with_ess_threshold(mut self, threshold: f64) -> Self {
648 self.ess_threshold = threshold;
649 self
650 }
651
652 pub fn with_history(mut self, track: bool) -> Self {
654 self.track_history = track;
655 self
656 }
657
658 pub fn initialize(&mut self, cardinalities: &HashMap<String, usize>) {
660 let mut rng = thread_rng();
661 self.particles = Vec::with_capacity(self.num_particles);
662
663 for _ in 0..self.num_particles {
664 let mut state = Assignment::new();
665
666 for var_name in &self.state_variables {
667 if let Some(&card) = cardinalities.get(var_name) {
668 let value = rng.gen_range(0..card);
669 state.insert(var_name.clone(), value);
670 }
671 }
672
673 self.particles.push(Particle {
674 state,
675 weight: 1.0 / self.num_particles as f64,
676 log_weight: -(self.num_particles as f64).ln(),
677 history: Vec::new(),
678 });
679 }
680 }
681
682 pub fn initialize_from_prior(&mut self, prior: &[f64], cardinalities: &HashMap<String, usize>) {
684 let mut rng = thread_rng();
685 self.particles = Vec::with_capacity(self.num_particles);
686
687 let total: f64 = prior.iter().sum();
688 let normalized: Vec<f64> = prior.iter().map(|p| p / total).collect();
689
690 for _ in 0..self.num_particles {
691 let mut state = Assignment::new();
692
693 if let Some(var_name) = self.state_variables.first() {
695 let u: f64 = rng.random();
696 let mut cumulative = 0.0;
697 let mut value = 0;
698
699 for (idx, &prob) in normalized.iter().enumerate() {
700 cumulative += prob;
701 if u < cumulative {
702 value = idx;
703 break;
704 }
705 }
706
707 state.insert(var_name.clone(), value);
708 }
709
710 for var_name in self.state_variables.iter().skip(1) {
712 if let Some(&card) = cardinalities.get(var_name) {
713 let value = rng.gen_range(0..card);
714 state.insert(var_name.clone(), value);
715 }
716 }
717
718 self.particles.push(Particle {
719 state,
720 weight: 1.0 / self.num_particles as f64,
721 log_weight: -(self.num_particles as f64).ln(),
722 history: Vec::new(),
723 });
724 }
725 }
726
727 pub fn predict(
731 &mut self,
732 transition: &dyn Fn(&Assignment, u64) -> Assignment,
733 cardinalities: &HashMap<String, usize>,
734 ) {
735 let mut rng = thread_rng();
736
737 for particle in &mut self.particles {
738 if self.track_history {
739 particle.history.push(particle.state.clone());
740 }
741
742 let seed: u64 = rng.random();
744 particle.state = transition(&particle.state, seed);
745
746 for var_name in &self.state_variables {
748 if let Some(&card) = cardinalities.get(var_name) {
749 if let Some(value) = particle.state.get_mut(var_name) {
750 *value = (*value).min(card.saturating_sub(1));
751 }
752 }
753 }
754 }
755 }
756
757 pub fn update<F>(&mut self, observation: &Assignment, likelihood: F)
759 where
760 F: Fn(&Assignment, &Assignment) -> f64,
761 {
762 for particle in &mut self.particles {
764 let lik = likelihood(&particle.state, observation);
765 particle.weight *= lik;
766 if lik > 0.0 {
767 particle.log_weight += lik.ln();
768 } else {
769 particle.log_weight = f64::NEG_INFINITY;
770 }
771 }
772
773 self.normalize_weights();
775
776 let ess = self.effective_sample_size();
778 if ess < self.ess_threshold * self.num_particles as f64 {
779 self.resample();
780 }
781 }
782
783 fn normalize_weights(&mut self) {
785 let total: f64 = self.particles.iter().map(|p| p.weight).sum();
786 if total > 0.0 {
787 for particle in &mut self.particles {
788 particle.weight /= total;
789 }
790 }
791 }
792
793 pub fn effective_sample_size(&self) -> f64 {
795 let sum_w2: f64 = self.particles.iter().map(|p| p.weight * p.weight).sum();
796 if sum_w2 > 0.0 {
797 1.0 / sum_w2
798 } else {
799 0.0
800 }
801 }
802
803 pub fn resample(&mut self) {
805 let n = self.num_particles;
806 let mut rng = thread_rng();
807
808 let mut cdf = Vec::with_capacity(n);
810 let mut cumulative = 0.0;
811 for particle in &self.particles {
812 cumulative += particle.weight;
813 cdf.push(cumulative);
814 }
815
816 let u0: f64 = rng.random::<f64>() / n as f64;
818 let mut new_particles = Vec::with_capacity(n);
819
820 let mut j = 0;
821 for i in 0..n {
822 let u = u0 + (i as f64) / (n as f64);
823 while j < n - 1 && cdf[j] < u {
824 j += 1;
825 }
826
827 new_particles.push(Particle {
828 state: self.particles[j].state.clone(),
829 weight: 1.0 / n as f64,
830 log_weight: -(n as f64).ln(),
831 history: if self.track_history {
832 self.particles[j].history.clone()
833 } else {
834 Vec::new()
835 },
836 });
837 }
838
839 self.particles = new_particles;
840 }
841
842 pub fn estimate_marginal(&self, var_name: &str, cardinality: usize) -> Vec<f64> {
844 let mut counts = vec![0.0; cardinality];
845
846 for particle in &self.particles {
847 if let Some(&value) = particle.state.get(var_name) {
848 if value < cardinality {
849 counts[value] += particle.weight;
850 }
851 }
852 }
853
854 let total: f64 = counts.iter().sum();
856 if total > 0.0 {
857 counts.iter().map(|c| c / total).collect()
858 } else {
859 vec![1.0 / cardinality as f64; cardinality]
860 }
861 }
862
863 pub fn estimate_expectation<F>(&self, func: F) -> f64
865 where
866 F: Fn(&Assignment) -> f64,
867 {
868 self.particles
869 .iter()
870 .map(|p| p.weight * func(&p.state))
871 .sum()
872 }
873
874 pub fn map_estimate(&self) -> Option<&Assignment> {
876 self.particles
877 .iter()
878 .max_by(|a, b| {
879 a.weight
880 .partial_cmp(&b.weight)
881 .unwrap_or(std::cmp::Ordering::Equal)
882 })
883 .map(|p| &p.state)
884 }
885
886 pub fn entropy(&self) -> f64 {
888 self.particles
889 .iter()
890 .filter(|p| p.weight > 0.0)
891 .map(|p| -p.weight * p.weight.ln())
892 .sum()
893 }
894
895 pub fn run_sequence(
900 &mut self,
901 observations: &[Assignment],
902 transition: &dyn Fn(&Assignment, u64) -> Assignment,
903 likelihood: &dyn Fn(&Assignment, &Assignment) -> f64,
904 cardinalities: &HashMap<String, usize>,
905 ) -> Vec<Vec<f64>> {
906 let mut marginals = Vec::with_capacity(observations.len());
907
908 for obs in observations {
909 self.predict(transition, cardinalities);
911
912 self.update(obs, likelihood);
914
915 if let Some(var_name) = self.state_variables.first() {
917 if let Some(&card) = cardinalities.get(var_name) {
918 marginals.push(self.estimate_marginal(var_name, card));
919 }
920 }
921 }
922
923 marginals
924 }
925}
926
927pub struct LikelihoodWeighting {
933 pub num_samples: usize,
935}
936
937impl Default for LikelihoodWeighting {
938 fn default() -> Self {
939 Self { num_samples: 1000 }
940 }
941}
942
943impl LikelihoodWeighting {
944 pub fn new(num_samples: usize) -> Self {
946 Self { num_samples }
947 }
948
949 pub fn run(
951 &self,
952 graph: &FactorGraph,
953 evidence: &Assignment,
954 ) -> Result<HashMap<String, ArrayD<f64>>> {
955 let mut weighted_samples = Vec::with_capacity(self.num_samples);
956 let mut rng = thread_rng();
957
958 for _ in 0..self.num_samples {
959 let (assignment, weight) = self.sample_with_evidence(graph, evidence, &mut rng)?;
960
961 weighted_samples.push(WeightedSample {
962 assignment,
963 weight,
964 log_weight: if weight > 0.0 {
965 weight.ln()
966 } else {
967 f64::NEG_INFINITY
968 },
969 });
970 }
971
972 let sampler = ImportanceSampler::new(self.num_samples);
974 sampler.compute_weighted_marginals(graph, &weighted_samples)
975 }
976
977 fn sample_with_evidence(
979 &self,
980 graph: &FactorGraph,
981 evidence: &Assignment,
982 rng: &mut impl Rng,
983 ) -> Result<(Assignment, f64)> {
984 let mut assignment = Assignment::new();
985 let mut weight = 1.0;
986
987 for (var, value) in evidence {
989 assignment.insert(var.clone(), *value);
990 }
991
992 for var_name in graph.variable_names() {
994 if !evidence.contains_key(var_name) {
995 if let Some(var_node) = graph.get_variable(var_name) {
996 let value = rng.random_range(0..var_node.cardinality);
997 assignment.insert(var_name.clone(), value);
998 }
999 }
1000 }
1001
1002 for factor_id in graph.factor_ids() {
1004 if let Some(factor) = graph.get_factor(factor_id) {
1005 let mut indices = Vec::new();
1006 for var in &factor.variables {
1007 if let Some(&value) = assignment.get(var) {
1008 indices.push(value);
1009 } else {
1010 return Err(PgmError::VariableNotFound(var.clone()));
1011 }
1012 }
1013 weight *= factor.values[indices.as_slice()];
1014 }
1015 }
1016
1017 Ok((assignment, weight))
1018 }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024 use approx::assert_abs_diff_eq;
1025
1026 #[test]
1027 fn test_gibbs_sampler_single_variable() {
1028 let mut graph = FactorGraph::new();
1029 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1030
1031 let sampler = GibbsSampler::new(10, 100, 1);
1032 let result = sampler.run(&graph);
1033 assert!(result.is_ok());
1034
1035 let marginals = result.unwrap();
1036 assert!(marginals.contains_key("x"));
1037
1038 let dist = &marginals["x"];
1040 let sum: f64 = dist.iter().sum();
1041 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
1042 }
1043
1044 #[test]
1045 fn test_gibbs_sampler_multiple_variables() {
1046 let mut graph = FactorGraph::new();
1047 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1048 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
1049
1050 let sampler = GibbsSampler::new(20, 100, 1);
1051 let result = sampler.run(&graph);
1052 assert!(result.is_ok());
1053
1054 let marginals = result.unwrap();
1055 assert_eq!(marginals.len(), 2);
1056 }
1057
1058 #[test]
1059 fn test_sample_collection() {
1060 let mut graph = FactorGraph::new();
1061 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1062
1063 let sampler = GibbsSampler::new(10, 50, 1);
1064 let samples = sampler.get_samples(&graph);
1065 assert!(samples.is_ok());
1066 assert_eq!(samples.unwrap().len(), 50);
1067 }
1068
1069 #[test]
1070 fn test_gibbs_with_thinning() {
1071 let mut graph = FactorGraph::new();
1072 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1073
1074 let sampler = GibbsSampler::new(10, 50, 5);
1075 let samples = sampler.get_samples(&graph);
1076 assert!(samples.is_ok());
1077 assert_eq!(samples.unwrap().len(), 50);
1078 }
1079
1080 #[test]
1081 fn test_importance_sampler_uniform() {
1082 let mut graph = FactorGraph::new();
1083 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1084
1085 let sampler = ImportanceSampler::new(100);
1086 let result = sampler.run(&graph, ProposalDistribution::Uniform);
1087 assert!(result.is_ok());
1088
1089 let marginals = result.unwrap();
1090 assert!(marginals.contains_key("x"));
1091
1092 let dist = &marginals["x"];
1093 let sum: f64 = dist.iter().sum();
1094 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
1095 }
1096
1097 #[test]
1098 fn test_importance_sampler_custom_proposal() {
1099 let mut graph = FactorGraph::new();
1100 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1101
1102 let mut custom_weights = HashMap::new();
1103 custom_weights.insert("x".to_string(), vec![0.8, 0.2]);
1104
1105 let sampler = ImportanceSampler::new(100);
1106 let result = sampler.run(&graph, ProposalDistribution::Custom(custom_weights));
1107 assert!(result.is_ok());
1108
1109 let marginals = result.unwrap();
1110 let sum: f64 = marginals["x"].iter().sum();
1111 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
1112 }
1113
1114 #[test]
1115 fn test_effective_sample_size() {
1116 let samples = vec![
1117 WeightedSample {
1118 assignment: HashMap::new(),
1119 weight: 0.5,
1120 log_weight: 0.5_f64.ln(),
1121 },
1122 WeightedSample {
1123 assignment: HashMap::new(),
1124 weight: 0.5,
1125 log_weight: 0.5_f64.ln(),
1126 },
1127 ];
1128
1129 let ess = ImportanceSampler::effective_sample_size(&samples);
1130 assert_abs_diff_eq!(ess, 2.0, epsilon = 1e-6);
1132 }
1133
1134 #[test]
1135 fn test_particle_filter_initialization() {
1136 let mut pf = ParticleFilter::new(10, vec!["state".to_string()]);
1137 let cardinalities: HashMap<String, usize> =
1138 [("state".to_string(), 3)].into_iter().collect();
1139 pf.initialize(&cardinalities);
1140
1141 assert_eq!(pf.particles.len(), 10);
1142
1143 for particle in &pf.particles {
1145 assert_abs_diff_eq!(particle.weight, 0.1, epsilon = 1e-6);
1146 }
1147 }
1148
1149 #[test]
1150 fn test_particle_filter_estimate_marginal() {
1151 let mut pf = ParticleFilter::new(100, vec!["state".to_string()]);
1152 let cardinalities: HashMap<String, usize> =
1153 [("state".to_string(), 2)].into_iter().collect();
1154 pf.initialize(&cardinalities);
1155
1156 let marginal = pf.estimate_marginal("state", 2);
1157 assert_eq!(marginal.len(), 2);
1158
1159 let sum: f64 = marginal.iter().sum();
1161 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
1162 }
1163
1164 #[test]
1165 fn test_particle_filter_ess() {
1166 let mut pf = ParticleFilter::new(100, vec!["state".to_string()]);
1167 let cardinalities: HashMap<String, usize> =
1168 [("state".to_string(), 2)].into_iter().collect();
1169 pf.initialize(&cardinalities);
1170
1171 let ess = pf.effective_sample_size();
1172 assert!(ess > 90.0);
1174 }
1175
1176 #[test]
1177 fn test_particle_filter_resample() {
1178 let mut pf = ParticleFilter::new(10, vec!["state".to_string()]);
1179 let cardinalities: HashMap<String, usize> =
1180 [("state".to_string(), 2)].into_iter().collect();
1181 pf.initialize(&cardinalities);
1182
1183 for (i, particle) in pf.particles.iter_mut().enumerate() {
1185 particle.weight = if i == 0 { 1.0 } else { 0.0 };
1186 }
1187
1188 pf.resample();
1189
1190 for particle in &pf.particles {
1192 assert_abs_diff_eq!(particle.weight, 0.1, epsilon = 1e-6);
1193 }
1194 }
1195
1196 #[test]
1197 fn test_likelihood_weighting() {
1198 let mut graph = FactorGraph::new();
1199 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1200 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
1201
1202 let mut evidence = Assignment::new();
1203 evidence.insert("y".to_string(), 1);
1204
1205 let lw = LikelihoodWeighting::new(100);
1206 let result = lw.run(&graph, &evidence);
1207 assert!(result.is_ok());
1208
1209 let marginals = result.unwrap();
1210 assert!(marginals.contains_key("x"));
1211 }
1212
1213 #[test]
1214 fn test_importance_sampler_weighted_samples() {
1215 let mut graph = FactorGraph::new();
1216 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1217
1218 let sampler = ImportanceSampler::new(50);
1219 let samples = sampler
1220 .get_weighted_samples(&graph, &ProposalDistribution::Uniform)
1221 .unwrap();
1222
1223 assert_eq!(samples.len(), 50);
1224
1225 for sample in &samples {
1227 assert!(sample.assignment.contains_key("x"));
1228 }
1229 }
1230
1231 #[test]
1232 fn test_weight_coefficient_of_variation() {
1233 let samples = vec![
1234 WeightedSample {
1235 assignment: HashMap::new(),
1236 weight: 1.0,
1237 log_weight: 0.0,
1238 },
1239 WeightedSample {
1240 assignment: HashMap::new(),
1241 weight: 1.0,
1242 log_weight: 0.0,
1243 },
1244 ];
1245
1246 let cv = ImportanceSampler::weight_coefficient_of_variation(&samples);
1247 assert_abs_diff_eq!(cv, 0.0, epsilon = 1e-6);
1249 }
1250
1251 #[test]
1252 fn test_particle_filter_with_history() {
1253 let pf = ParticleFilter::new(5, vec!["state".to_string()])
1254 .with_history(true)
1255 .with_ess_threshold(0.3);
1256
1257 assert!(pf.track_history);
1258 assert_abs_diff_eq!(pf.ess_threshold, 0.3, epsilon = 1e-6);
1259 }
1260}