Skip to main content

tensorlogic_quantrs_hooks/
sampling.rs

1//! Sampling-based inference methods for PGM.
2//!
3//! This module provides MCMC, importance sampling, and particle filtering algorithms
4//! for approximate inference in probabilistic graphical models.
5//!
6//! # Algorithms
7//!
8//! - **Gibbs Sampling**: MCMC method for sampling from joint distributions
9//! - **Importance Sampling**: Weighted sampling with proposal distributions
10//! - **Particle Filter**: Sequential Monte Carlo for temporal models
11
12use scirs2_core::ndarray::ArrayD;
13use scirs2_core::random::{thread_rng, Rng};
14use std::collections::HashMap;
15
16use crate::error::{PgmError, Result};
17use crate::graph::FactorGraph;
18
19/// Assignment of values to variables.
20pub type Assignment = HashMap<String, usize>;
21
22/// Gibbs sampling for approximate inference.
23///
24/// Uses Markov Chain Monte Carlo to sample from the joint distribution.
25pub struct GibbsSampler {
26    /// Number of burn-in samples to discard
27    pub burn_in: usize,
28    /// Number of samples to collect
29    pub num_samples: usize,
30    /// Thinning interval (keep every N-th sample)
31    pub thinning: usize,
32}
33
34impl Default for GibbsSampler {
35    fn default() -> Self {
36        Self {
37            burn_in: 100,
38            num_samples: 1000,
39            thinning: 1,
40        }
41    }
42}
43
44impl GibbsSampler {
45    /// Create with custom parameters.
46    pub fn new(burn_in: usize, num_samples: usize, thinning: usize) -> Self {
47        Self {
48            burn_in,
49            num_samples,
50            thinning,
51        }
52    }
53
54    /// Run Gibbs sampling to approximate marginals.
55    pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
56        // Initialize random assignment
57        let mut current_assignment = self.initialize_assignment(graph)?;
58
59        // Burn-in phase
60        for _ in 0..self.burn_in {
61            self.gibbs_step(graph, &mut current_assignment)?;
62        }
63
64        // Collect samples
65        let mut samples = Vec::new();
66        for i in 0..self.num_samples * self.thinning {
67            self.gibbs_step(graph, &mut current_assignment)?;
68
69            // Keep sample if it's at thinning interval
70            if i % self.thinning == 0 {
71                samples.push(current_assignment.clone());
72            }
73        }
74
75        // Compute empirical marginals from samples
76        self.compute_empirical_marginals(graph, &samples)
77    }
78
79    /// Initialize random assignment for all variables.
80    fn initialize_assignment(&self, graph: &FactorGraph) -> Result<Assignment> {
81        let mut rng = thread_rng();
82        let mut assignment = Assignment::new();
83
84        for var_name in graph.variable_names() {
85            if let Some(var_node) = graph.get_variable(var_name) {
86                let random_value = rng.gen_range(0..var_node.cardinality);
87                assignment.insert(var_name.clone(), random_value);
88            }
89        }
90
91        Ok(assignment)
92    }
93
94    /// Perform one Gibbs sampling step (resample all variables).
95    fn gibbs_step(&self, graph: &FactorGraph, assignment: &mut Assignment) -> Result<()> {
96        // Resample each variable conditioned on others
97        for var_name in graph.variable_names() {
98            self.resample_variable(graph, var_name, assignment)?;
99        }
100
101        Ok(())
102    }
103
104    /// Resample a single variable given current assignment of others.
105    fn resample_variable(
106        &self,
107        graph: &FactorGraph,
108        var_name: &str,
109        assignment: &mut Assignment,
110    ) -> Result<()> {
111        let var_node = graph
112            .get_variable(var_name)
113            .ok_or_else(|| PgmError::VariableNotFound(var_name.to_string()))?;
114
115        // Compute conditional distribution P(X | others)
116        let mut conditional_probs = vec![0.0; var_node.cardinality];
117
118        for (value, prob) in conditional_probs
119            .iter_mut()
120            .enumerate()
121            .take(var_node.cardinality)
122        {
123            assignment.insert(var_name.to_string(), value);
124            *prob = self.compute_joint_probability(graph, assignment)?;
125        }
126
127        // Normalize
128        let sum: f64 = conditional_probs.iter().sum();
129        if sum > 0.0 {
130            for prob in &mut conditional_probs {
131                *prob /= sum;
132            }
133        } else {
134            // Fallback to uniform if all zero
135            let uniform_prob = 1.0 / var_node.cardinality as f64;
136            conditional_probs = vec![uniform_prob; var_node.cardinality];
137        }
138
139        // Sample from conditional distribution
140        let sampled_value = self.sample_from_distribution(&conditional_probs);
141        assignment.insert(var_name.to_string(), sampled_value);
142
143        Ok(())
144    }
145
146    /// Compute joint probability for a full assignment.
147    fn compute_joint_probability(
148        &self,
149        graph: &FactorGraph,
150        assignment: &Assignment,
151    ) -> Result<f64> {
152        let mut prob = 1.0;
153
154        for factor_id in graph.factor_ids() {
155            if let Some(factor) = graph.get_factor(factor_id) {
156                // Build index for this factor
157                let mut indices = Vec::new();
158                for var in &factor.variables {
159                    if let Some(&value) = assignment.get(var) {
160                        indices.push(value);
161                    } else {
162                        return Err(PgmError::VariableNotFound(var.clone()));
163                    }
164                }
165
166                prob *= factor.values[indices.as_slice()];
167            }
168        }
169
170        Ok(prob)
171    }
172
173    /// Sample from a discrete probability distribution.
174    fn sample_from_distribution(&self, probs: &[f64]) -> usize {
175        let mut rng = thread_rng();
176        let u: f64 = rng.random();
177
178        let mut cumulative = 0.0;
179        for (idx, &prob) in probs.iter().enumerate() {
180            cumulative += prob;
181            if u < cumulative {
182                return idx;
183            }
184        }
185
186        // Fallback to last index
187        probs.len() - 1
188    }
189
190    /// Compute empirical marginals from collected samples.
191    fn compute_empirical_marginals(
192        &self,
193        graph: &FactorGraph,
194        samples: &[Assignment],
195    ) -> Result<HashMap<String, ArrayD<f64>>> {
196        let mut marginals = HashMap::new();
197
198        for var_name in graph.variable_names() {
199            if let Some(var_node) = graph.get_variable(var_name) {
200                let mut counts = vec![0; var_node.cardinality];
201
202                // Count occurrences
203                for sample in samples {
204                    if let Some(&value) = sample.get(var_name) {
205                        counts[value] += 1;
206                    }
207                }
208
209                // Normalize to probabilities
210                let total = samples.len() as f64;
211                let probs: Vec<f64> = counts.iter().map(|&c| c as f64 / total).collect();
212
213                marginals.insert(
214                    var_name.clone(),
215                    ArrayD::from_shape_vec(vec![var_node.cardinality], probs)?,
216                );
217            }
218        }
219
220        Ok(marginals)
221    }
222
223    /// Get all samples (for analysis).
224    pub fn get_samples(&self, graph: &FactorGraph) -> Result<Vec<Assignment>> {
225        let mut current_assignment = self.initialize_assignment(graph)?;
226
227        // Burn-in
228        for _ in 0..self.burn_in {
229            self.gibbs_step(graph, &mut current_assignment)?;
230        }
231
232        // Collect samples
233        let mut samples = Vec::new();
234        for i in 0..self.num_samples * self.thinning {
235            self.gibbs_step(graph, &mut current_assignment)?;
236
237            if i % self.thinning == 0 {
238                samples.push(current_assignment.clone());
239            }
240        }
241
242        Ok(samples)
243    }
244}
245
246impl From<scirs2_core::ndarray::ShapeError> for PgmError {
247    fn from(err: scirs2_core::ndarray::ShapeError) -> Self {
248        PgmError::InvalidDistribution(format!("Shape error: {}", err))
249    }
250}
251
252/// Weighted sample for importance sampling.
253#[derive(Debug, Clone)]
254pub struct WeightedSample {
255    /// The assignment of values to variables
256    pub assignment: Assignment,
257    /// The unnormalized importance weight
258    pub weight: f64,
259    /// Log weight for numerical stability
260    pub log_weight: f64,
261}
262
263/// Importance sampling for approximate inference.
264///
265/// Importance sampling draws samples from a proposal distribution q(x)
266/// and weights them by p(x)/q(x) to estimate expectations under p(x).
267///
268/// # Example
269///
270/// ```
271/// use tensorlogic_quantrs_hooks::{FactorGraph, ImportanceSampler, ProposalDistribution};
272///
273/// let mut graph = FactorGraph::new();
274/// graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
275///
276/// let sampler = ImportanceSampler::new(1000);
277/// let result = sampler.run(&graph, ProposalDistribution::Uniform);
278/// ```
279pub struct ImportanceSampler {
280    /// Number of samples to draw
281    pub num_samples: usize,
282    /// Whether to use self-normalized importance sampling
283    pub self_normalize: bool,
284}
285
286/// Proposal distribution types for importance sampling.
287#[derive(Debug, Clone)]
288pub enum ProposalDistribution {
289    /// Uniform distribution over all states
290    Uniform,
291    /// Custom proposal weights (not normalized)
292    Custom(HashMap<String, Vec<f64>>),
293    /// Prior distribution from the model
294    Prior,
295}
296
297impl Default for ImportanceSampler {
298    fn default() -> Self {
299        Self {
300            num_samples: 1000,
301            self_normalize: true,
302        }
303    }
304}
305
306impl ImportanceSampler {
307    /// Create a new importance sampler with specified number of samples.
308    pub fn new(num_samples: usize) -> Self {
309        Self {
310            num_samples,
311            self_normalize: true,
312        }
313    }
314
315    /// Set whether to use self-normalized importance sampling.
316    pub fn with_self_normalize(mut self, self_normalize: bool) -> Self {
317        self.self_normalize = self_normalize;
318        self
319    }
320
321    /// Run importance sampling to approximate marginals.
322    pub fn run(
323        &self,
324        graph: &FactorGraph,
325        proposal: ProposalDistribution,
326    ) -> Result<HashMap<String, ArrayD<f64>>> {
327        let samples = self.draw_weighted_samples(graph, &proposal)?;
328        self.compute_weighted_marginals(graph, &samples)
329    }
330
331    /// Draw weighted samples from the proposal distribution.
332    pub fn draw_weighted_samples(
333        &self,
334        graph: &FactorGraph,
335        proposal: &ProposalDistribution,
336    ) -> Result<Vec<WeightedSample>> {
337        let mut samples = Vec::with_capacity(self.num_samples);
338        let mut rng = thread_rng();
339
340        for _ in 0..self.num_samples {
341            // Sample from proposal
342            let (assignment, proposal_prob) =
343                self.sample_from_proposal(graph, proposal, &mut rng)?;
344
345            // Compute target probability
346            let target_prob = self.compute_target_probability(graph, &assignment)?;
347
348            // Compute importance weight
349            let weight = if proposal_prob > 0.0 {
350                target_prob / proposal_prob
351            } else {
352                0.0
353            };
354
355            let log_weight = if proposal_prob > 0.0 && target_prob > 0.0 {
356                target_prob.ln() - proposal_prob.ln()
357            } else {
358                f64::NEG_INFINITY
359            };
360
361            samples.push(WeightedSample {
362                assignment,
363                weight,
364                log_weight,
365            });
366        }
367
368        Ok(samples)
369    }
370
371    /// Sample from the proposal distribution.
372    fn sample_from_proposal(
373        &self,
374        graph: &FactorGraph,
375        proposal: &ProposalDistribution,
376        rng: &mut impl Rng,
377    ) -> Result<(Assignment, f64)> {
378        let mut assignment = Assignment::new();
379        let mut proposal_prob = 1.0;
380
381        for var_name in graph.variable_names() {
382            if let Some(var_node) = graph.get_variable(var_name) {
383                let (value, prob) = match proposal {
384                    ProposalDistribution::Uniform => {
385                        let value = rng.random_range(0..var_node.cardinality);
386                        let prob = 1.0 / var_node.cardinality as f64;
387                        (value, prob)
388                    }
389                    ProposalDistribution::Custom(weights) => {
390                        if let Some(var_weights) = weights.get(var_name) {
391                            let (value, prob) = self.sample_categorical(var_weights, rng);
392                            (value, prob)
393                        } else {
394                            let value = rng.random_range(0..var_node.cardinality);
395                            let prob = 1.0 / var_node.cardinality as f64;
396                            (value, prob)
397                        }
398                    }
399                    ProposalDistribution::Prior => {
400                        // Use uniform for now; could be extended to use prior factors
401                        let value = rng.random_range(0..var_node.cardinality);
402                        let prob = 1.0 / var_node.cardinality as f64;
403                        (value, prob)
404                    }
405                };
406
407                assignment.insert(var_name.clone(), value);
408                proposal_prob *= prob;
409            }
410        }
411
412        Ok((assignment, proposal_prob))
413    }
414
415    /// Sample from a categorical distribution given weights.
416    fn sample_categorical(&self, weights: &[f64], rng: &mut impl Rng) -> (usize, f64) {
417        let total: f64 = weights.iter().sum();
418        if total <= 0.0 {
419            return (0, 1.0 / weights.len() as f64);
420        }
421
422        let normalized: Vec<f64> = weights.iter().map(|w| w / total).collect();
423        let u: f64 = rng.random();
424
425        let mut cumulative = 0.0;
426        for (idx, &prob) in normalized.iter().enumerate() {
427            cumulative += prob;
428            if u < cumulative {
429                return (idx, prob);
430            }
431        }
432
433        (weights.len() - 1, *normalized.last().unwrap_or(&0.0))
434    }
435
436    /// Compute target probability (unnormalized) for an assignment.
437    fn compute_target_probability(
438        &self,
439        graph: &FactorGraph,
440        assignment: &Assignment,
441    ) -> Result<f64> {
442        let mut prob = 1.0;
443
444        for factor_id in graph.factor_ids() {
445            if let Some(factor) = graph.get_factor(factor_id) {
446                let mut indices = Vec::new();
447                for var in &factor.variables {
448                    if let Some(&value) = assignment.get(var) {
449                        indices.push(value);
450                    } else {
451                        return Err(PgmError::VariableNotFound(var.clone()));
452                    }
453                }
454                prob *= factor.values[indices.as_slice()];
455            }
456        }
457
458        Ok(prob)
459    }
460
461    /// Compute weighted marginals from importance samples.
462    fn compute_weighted_marginals(
463        &self,
464        graph: &FactorGraph,
465        samples: &[WeightedSample],
466    ) -> Result<HashMap<String, ArrayD<f64>>> {
467        let mut marginals = HashMap::new();
468
469        // Compute total weight for self-normalization
470        let total_weight: f64 = samples.iter().map(|s| s.weight).sum();
471
472        for var_name in graph.variable_names() {
473            if let Some(var_node) = graph.get_variable(var_name) {
474                let mut weighted_counts = vec![0.0; var_node.cardinality];
475
476                // Accumulate weighted counts
477                for sample in samples {
478                    if let Some(&value) = sample.assignment.get(var_name) {
479                        weighted_counts[value] += sample.weight;
480                    }
481                }
482
483                // Normalize
484                let probs: Vec<f64> = if self.self_normalize && total_weight > 0.0 {
485                    weighted_counts.iter().map(|&c| c / total_weight).collect()
486                } else {
487                    let sum: f64 = weighted_counts.iter().sum();
488                    if sum > 0.0 {
489                        weighted_counts.iter().map(|&c| c / sum).collect()
490                    } else {
491                        vec![1.0 / var_node.cardinality as f64; var_node.cardinality]
492                    }
493                };
494
495                marginals.insert(
496                    var_name.clone(),
497                    ArrayD::from_shape_vec(vec![var_node.cardinality], probs)?,
498                );
499            }
500        }
501
502        Ok(marginals)
503    }
504
505    /// Get all weighted samples for analysis.
506    pub fn get_weighted_samples(
507        &self,
508        graph: &FactorGraph,
509        proposal: &ProposalDistribution,
510    ) -> Result<Vec<WeightedSample>> {
511        self.draw_weighted_samples(graph, proposal)
512    }
513
514    /// Compute the effective sample size (ESS).
515    ///
516    /// ESS measures the efficiency of importance sampling.
517    /// Higher ESS indicates better proposal distribution.
518    pub fn effective_sample_size(samples: &[WeightedSample]) -> f64 {
519        let weights: Vec<f64> = samples.iter().map(|s| s.weight).collect();
520        let sum_w: f64 = weights.iter().sum();
521        let sum_w2: f64 = weights.iter().map(|w| w * w).sum();
522
523        if sum_w2 > 0.0 {
524            (sum_w * sum_w) / sum_w2
525        } else {
526            0.0
527        }
528    }
529
530    /// Compute the coefficient of variation of weights.
531    pub fn weight_coefficient_of_variation(samples: &[WeightedSample]) -> f64 {
532        let n = samples.len() as f64;
533        let weights: Vec<f64> = samples.iter().map(|s| s.weight).collect();
534        let mean = weights.iter().sum::<f64>() / n;
535        let variance = weights.iter().map(|w| (w - mean).powi(2)).sum::<f64>() / n;
536        let std_dev = variance.sqrt();
537
538        if mean > 0.0 {
539            std_dev / mean
540        } else {
541            0.0
542        }
543    }
544
545    /// Resample particles based on their weights (for particle filtering).
546    pub fn resample(samples: &[WeightedSample]) -> Vec<WeightedSample> {
547        let n = samples.len();
548        if n == 0 {
549            return Vec::new();
550        }
551
552        let mut rng = thread_rng();
553        let total_weight: f64 = samples.iter().map(|s| s.weight).sum();
554
555        if total_weight <= 0.0 {
556            return samples.to_vec();
557        }
558
559        let normalized_weights: Vec<f64> =
560            samples.iter().map(|s| s.weight / total_weight).collect();
561
562        // Systematic resampling
563        let mut resampled = Vec::with_capacity(n);
564        let u0: f64 = rng.random::<f64>() / n as f64;
565
566        let mut cumulative = 0.0;
567        let mut j = 0;
568
569        for i in 0..n {
570            let u = u0 + (i as f64) / (n as f64);
571            while cumulative + normalized_weights[j] < u && j < n - 1 {
572                cumulative += normalized_weights[j];
573                j += 1;
574            }
575
576            resampled.push(WeightedSample {
577                assignment: samples[j].assignment.clone(),
578                weight: 1.0,
579                log_weight: 0.0,
580            });
581        }
582
583        resampled
584    }
585}
586
587/// Particle for particle filtering.
588#[derive(Debug, Clone)]
589pub struct Particle {
590    /// Current state assignment
591    pub state: Assignment,
592    /// Particle weight
593    pub weight: f64,
594    /// Log weight for numerical stability
595    pub log_weight: f64,
596    /// History of states (optional)
597    pub history: Vec<Assignment>,
598}
599
600/// Particle filter (Sequential Monte Carlo) for temporal inference.
601///
602/// Particle filtering is used for inference in dynamic systems
603/// where the state evolves over time.
604///
605/// # Example
606///
607/// ```no_run
608/// use tensorlogic_quantrs_hooks::{ParticleFilter, HiddenMarkovModel, Assignment};
609/// use std::collections::HashMap;
610///
611/// // Create HMM with 2 states, 3 observations, 10 time steps
612/// let hmm = HiddenMarkovModel::new(2, 3, 10);
613///
614/// // Create particle filter
615/// let mut pf = ParticleFilter::new(100, vec!["state".to_string()]);
616///
617/// // Initialize particles
618/// let cardinalities: HashMap<String, usize> = [("state".to_string(), 2)].into_iter().collect();
619/// pf.initialize(&cardinalities);
620/// ```
621pub struct ParticleFilter {
622    /// Number of particles
623    pub num_particles: usize,
624    /// Current particles
625    pub particles: Vec<Particle>,
626    /// State variable names
627    pub state_variables: Vec<String>,
628    /// Effective sample size threshold for resampling
629    pub ess_threshold: f64,
630    /// Whether to track history
631    pub track_history: bool,
632}
633
634impl ParticleFilter {
635    /// Create a new particle filter.
636    pub fn new(num_particles: usize, state_variables: Vec<String>) -> Self {
637        Self {
638            num_particles,
639            particles: Vec::new(),
640            state_variables,
641            ess_threshold: 0.5,
642            track_history: false,
643        }
644    }
645
646    /// Set the ESS threshold for resampling (as fraction of num_particles).
647    pub fn with_ess_threshold(mut self, threshold: f64) -> Self {
648        self.ess_threshold = threshold;
649        self
650    }
651
652    /// Enable history tracking.
653    pub fn with_history(mut self, track: bool) -> Self {
654        self.track_history = track;
655        self
656    }
657
658    /// Initialize particles uniformly.
659    pub fn initialize(&mut self, cardinalities: &HashMap<String, usize>) {
660        let mut rng = thread_rng();
661        self.particles = Vec::with_capacity(self.num_particles);
662
663        for _ in 0..self.num_particles {
664            let mut state = Assignment::new();
665
666            for var_name in &self.state_variables {
667                if let Some(&card) = cardinalities.get(var_name) {
668                    let value = rng.gen_range(0..card);
669                    state.insert(var_name.clone(), value);
670                }
671            }
672
673            self.particles.push(Particle {
674                state,
675                weight: 1.0 / self.num_particles as f64,
676                log_weight: -(self.num_particles as f64).ln(),
677                history: Vec::new(),
678            });
679        }
680    }
681
682    /// Initialize particles from a prior distribution.
683    pub fn initialize_from_prior(&mut self, prior: &[f64], cardinalities: &HashMap<String, usize>) {
684        let mut rng = thread_rng();
685        self.particles = Vec::with_capacity(self.num_particles);
686
687        let total: f64 = prior.iter().sum();
688        let normalized: Vec<f64> = prior.iter().map(|p| p / total).collect();
689
690        for _ in 0..self.num_particles {
691            let mut state = Assignment::new();
692
693            // Sample state from prior (assuming single state variable)
694            if let Some(var_name) = self.state_variables.first() {
695                let u: f64 = rng.random();
696                let mut cumulative = 0.0;
697                let mut value = 0;
698
699                for (idx, &prob) in normalized.iter().enumerate() {
700                    cumulative += prob;
701                    if u < cumulative {
702                        value = idx;
703                        break;
704                    }
705                }
706
707                state.insert(var_name.clone(), value);
708            }
709
710            // Initialize other variables uniformly
711            for var_name in self.state_variables.iter().skip(1) {
712                if let Some(&card) = cardinalities.get(var_name) {
713                    let value = rng.gen_range(0..card);
714                    state.insert(var_name.clone(), value);
715                }
716            }
717
718            self.particles.push(Particle {
719                state,
720                weight: 1.0 / self.num_particles as f64,
721                log_weight: -(self.num_particles as f64).ln(),
722                history: Vec::new(),
723            });
724        }
725    }
726
727    /// Predict step: propagate particles through transition model.
728    ///
729    /// The transition function takes a state and a random seed, returning the next state.
730    pub fn predict(
731        &mut self,
732        transition: &dyn Fn(&Assignment, u64) -> Assignment,
733        cardinalities: &HashMap<String, usize>,
734    ) {
735        let mut rng = thread_rng();
736
737        for particle in &mut self.particles {
738            if self.track_history {
739                particle.history.push(particle.state.clone());
740            }
741
742            // Generate a random seed for the transition
743            let seed: u64 = rng.random();
744            particle.state = transition(&particle.state, seed);
745
746            // Ensure state values are within bounds
747            for var_name in &self.state_variables {
748                if let Some(&card) = cardinalities.get(var_name) {
749                    if let Some(value) = particle.state.get_mut(var_name) {
750                        *value = (*value).min(card.saturating_sub(1));
751                    }
752                }
753            }
754        }
755    }
756
757    /// Update step: weight particles based on observation likelihood.
758    pub fn update<F>(&mut self, observation: &Assignment, likelihood: F)
759    where
760        F: Fn(&Assignment, &Assignment) -> f64,
761    {
762        // Compute likelihood for each particle
763        for particle in &mut self.particles {
764            let lik = likelihood(&particle.state, observation);
765            particle.weight *= lik;
766            if lik > 0.0 {
767                particle.log_weight += lik.ln();
768            } else {
769                particle.log_weight = f64::NEG_INFINITY;
770            }
771        }
772
773        // Normalize weights
774        self.normalize_weights();
775
776        // Resample if ESS is too low
777        let ess = self.effective_sample_size();
778        if ess < self.ess_threshold * self.num_particles as f64 {
779            self.resample();
780        }
781    }
782
783    /// Normalize particle weights.
784    fn normalize_weights(&mut self) {
785        let total: f64 = self.particles.iter().map(|p| p.weight).sum();
786        if total > 0.0 {
787            for particle in &mut self.particles {
788                particle.weight /= total;
789            }
790        }
791    }
792
793    /// Compute effective sample size.
794    pub fn effective_sample_size(&self) -> f64 {
795        let sum_w2: f64 = self.particles.iter().map(|p| p.weight * p.weight).sum();
796        if sum_w2 > 0.0 {
797            1.0 / sum_w2
798        } else {
799            0.0
800        }
801    }
802
803    /// Resample particles using systematic resampling.
804    pub fn resample(&mut self) {
805        let n = self.num_particles;
806        let mut rng = thread_rng();
807
808        // Build CDF
809        let mut cdf = Vec::with_capacity(n);
810        let mut cumulative = 0.0;
811        for particle in &self.particles {
812            cumulative += particle.weight;
813            cdf.push(cumulative);
814        }
815
816        // Systematic resampling
817        let u0: f64 = rng.random::<f64>() / n as f64;
818        let mut new_particles = Vec::with_capacity(n);
819
820        let mut j = 0;
821        for i in 0..n {
822            let u = u0 + (i as f64) / (n as f64);
823            while j < n - 1 && cdf[j] < u {
824                j += 1;
825            }
826
827            new_particles.push(Particle {
828                state: self.particles[j].state.clone(),
829                weight: 1.0 / n as f64,
830                log_weight: -(n as f64).ln(),
831                history: if self.track_history {
832                    self.particles[j].history.clone()
833                } else {
834                    Vec::new()
835                },
836            });
837        }
838
839        self.particles = new_particles;
840    }
841
842    /// Estimate marginal distribution from particles.
843    pub fn estimate_marginal(&self, var_name: &str, cardinality: usize) -> Vec<f64> {
844        let mut counts = vec![0.0; cardinality];
845
846        for particle in &self.particles {
847            if let Some(&value) = particle.state.get(var_name) {
848                if value < cardinality {
849                    counts[value] += particle.weight;
850                }
851            }
852        }
853
854        // Normalize
855        let total: f64 = counts.iter().sum();
856        if total > 0.0 {
857            counts.iter().map(|c| c / total).collect()
858        } else {
859            vec![1.0 / cardinality as f64; cardinality]
860        }
861    }
862
863    /// Estimate expected value of a function over the particle distribution.
864    pub fn estimate_expectation<F>(&self, func: F) -> f64
865    where
866        F: Fn(&Assignment) -> f64,
867    {
868        self.particles
869            .iter()
870            .map(|p| p.weight * func(&p.state))
871            .sum()
872    }
873
874    /// Get the MAP (most likely) state.
875    pub fn map_estimate(&self) -> Option<&Assignment> {
876        self.particles
877            .iter()
878            .max_by(|a, b| {
879                a.weight
880                    .partial_cmp(&b.weight)
881                    .unwrap_or(std::cmp::Ordering::Equal)
882            })
883            .map(|p| &p.state)
884    }
885
886    /// Get entropy of the particle distribution.
887    pub fn entropy(&self) -> f64 {
888        self.particles
889            .iter()
890            .filter(|p| p.weight > 0.0)
891            .map(|p| -p.weight * p.weight.ln())
892            .sum()
893    }
894
895    /// Run particle filter on a sequence of observations.
896    ///
897    /// The transition function takes a state and a random seed.
898    /// The likelihood function computes P(observation | state).
899    pub fn run_sequence(
900        &mut self,
901        observations: &[Assignment],
902        transition: &dyn Fn(&Assignment, u64) -> Assignment,
903        likelihood: &dyn Fn(&Assignment, &Assignment) -> f64,
904        cardinalities: &HashMap<String, usize>,
905    ) -> Vec<Vec<f64>> {
906        let mut marginals = Vec::with_capacity(observations.len());
907
908        for obs in observations {
909            // Predict
910            self.predict(transition, cardinalities);
911
912            // Update
913            self.update(obs, likelihood);
914
915            // Record marginal for first state variable
916            if let Some(var_name) = self.state_variables.first() {
917                if let Some(&card) = cardinalities.get(var_name) {
918                    marginals.push(self.estimate_marginal(var_name, card));
919                }
920            }
921        }
922
923        marginals
924    }
925}
926
927/// Likelihood weighting for Bayesian networks.
928///
929/// A specialized form of importance sampling where:
930/// - Sample non-evidence variables from prior
931/// - Weight by likelihood of evidence
932pub struct LikelihoodWeighting {
933    /// Number of samples
934    pub num_samples: usize,
935}
936
937impl Default for LikelihoodWeighting {
938    fn default() -> Self {
939        Self { num_samples: 1000 }
940    }
941}
942
943impl LikelihoodWeighting {
944    /// Create a new likelihood weighting sampler.
945    pub fn new(num_samples: usize) -> Self {
946        Self { num_samples }
947    }
948
949    /// Run likelihood weighting with evidence.
950    pub fn run(
951        &self,
952        graph: &FactorGraph,
953        evidence: &Assignment,
954    ) -> Result<HashMap<String, ArrayD<f64>>> {
955        let mut weighted_samples = Vec::with_capacity(self.num_samples);
956        let mut rng = thread_rng();
957
958        for _ in 0..self.num_samples {
959            let (assignment, weight) = self.sample_with_evidence(graph, evidence, &mut rng)?;
960
961            weighted_samples.push(WeightedSample {
962                assignment,
963                weight,
964                log_weight: if weight > 0.0 {
965                    weight.ln()
966                } else {
967                    f64::NEG_INFINITY
968                },
969            });
970        }
971
972        // Compute marginals
973        let sampler = ImportanceSampler::new(self.num_samples);
974        sampler.compute_weighted_marginals(graph, &weighted_samples)
975    }
976
977    /// Sample non-evidence variables and compute weight from evidence.
978    fn sample_with_evidence(
979        &self,
980        graph: &FactorGraph,
981        evidence: &Assignment,
982        rng: &mut impl Rng,
983    ) -> Result<(Assignment, f64)> {
984        let mut assignment = Assignment::new();
985        let mut weight = 1.0;
986
987        // Set evidence variables
988        for (var, value) in evidence {
989            assignment.insert(var.clone(), *value);
990        }
991
992        // Sample non-evidence variables uniformly
993        for var_name in graph.variable_names() {
994            if !evidence.contains_key(var_name) {
995                if let Some(var_node) = graph.get_variable(var_name) {
996                    let value = rng.random_range(0..var_node.cardinality);
997                    assignment.insert(var_name.clone(), value);
998                }
999            }
1000        }
1001
1002        // Compute weight as product of factors
1003        for factor_id in graph.factor_ids() {
1004            if let Some(factor) = graph.get_factor(factor_id) {
1005                let mut indices = Vec::new();
1006                for var in &factor.variables {
1007                    if let Some(&value) = assignment.get(var) {
1008                        indices.push(value);
1009                    } else {
1010                        return Err(PgmError::VariableNotFound(var.clone()));
1011                    }
1012                }
1013                weight *= factor.values[indices.as_slice()];
1014            }
1015        }
1016
1017        Ok((assignment, weight))
1018    }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024    use approx::assert_abs_diff_eq;
1025
1026    #[test]
1027    fn test_gibbs_sampler_single_variable() {
1028        let mut graph = FactorGraph::new();
1029        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1030
1031        let sampler = GibbsSampler::new(10, 100, 1);
1032        let result = sampler.run(&graph);
1033        assert!(result.is_ok());
1034
1035        let marginals = result.unwrap();
1036        assert!(marginals.contains_key("x"));
1037
1038        // Should be approximately uniform
1039        let dist = &marginals["x"];
1040        let sum: f64 = dist.iter().sum();
1041        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
1042    }
1043
1044    #[test]
1045    fn test_gibbs_sampler_multiple_variables() {
1046        let mut graph = FactorGraph::new();
1047        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1048        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
1049
1050        let sampler = GibbsSampler::new(20, 100, 1);
1051        let result = sampler.run(&graph);
1052        assert!(result.is_ok());
1053
1054        let marginals = result.unwrap();
1055        assert_eq!(marginals.len(), 2);
1056    }
1057
1058    #[test]
1059    fn test_sample_collection() {
1060        let mut graph = FactorGraph::new();
1061        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1062
1063        let sampler = GibbsSampler::new(10, 50, 1);
1064        let samples = sampler.get_samples(&graph);
1065        assert!(samples.is_ok());
1066        assert_eq!(samples.unwrap().len(), 50);
1067    }
1068
1069    #[test]
1070    fn test_gibbs_with_thinning() {
1071        let mut graph = FactorGraph::new();
1072        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1073
1074        let sampler = GibbsSampler::new(10, 50, 5);
1075        let samples = sampler.get_samples(&graph);
1076        assert!(samples.is_ok());
1077        assert_eq!(samples.unwrap().len(), 50);
1078    }
1079
1080    #[test]
1081    fn test_importance_sampler_uniform() {
1082        let mut graph = FactorGraph::new();
1083        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1084
1085        let sampler = ImportanceSampler::new(100);
1086        let result = sampler.run(&graph, ProposalDistribution::Uniform);
1087        assert!(result.is_ok());
1088
1089        let marginals = result.unwrap();
1090        assert!(marginals.contains_key("x"));
1091
1092        let dist = &marginals["x"];
1093        let sum: f64 = dist.iter().sum();
1094        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
1095    }
1096
1097    #[test]
1098    fn test_importance_sampler_custom_proposal() {
1099        let mut graph = FactorGraph::new();
1100        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1101
1102        let mut custom_weights = HashMap::new();
1103        custom_weights.insert("x".to_string(), vec![0.8, 0.2]);
1104
1105        let sampler = ImportanceSampler::new(100);
1106        let result = sampler.run(&graph, ProposalDistribution::Custom(custom_weights));
1107        assert!(result.is_ok());
1108
1109        let marginals = result.unwrap();
1110        let sum: f64 = marginals["x"].iter().sum();
1111        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
1112    }
1113
1114    #[test]
1115    fn test_effective_sample_size() {
1116        let samples = vec![
1117            WeightedSample {
1118                assignment: HashMap::new(),
1119                weight: 0.5,
1120                log_weight: 0.5_f64.ln(),
1121            },
1122            WeightedSample {
1123                assignment: HashMap::new(),
1124                weight: 0.5,
1125                log_weight: 0.5_f64.ln(),
1126            },
1127        ];
1128
1129        let ess = ImportanceSampler::effective_sample_size(&samples);
1130        // Equal weights should give ESS = N
1131        assert_abs_diff_eq!(ess, 2.0, epsilon = 1e-6);
1132    }
1133
1134    #[test]
1135    fn test_particle_filter_initialization() {
1136        let mut pf = ParticleFilter::new(10, vec!["state".to_string()]);
1137        let cardinalities: HashMap<String, usize> =
1138            [("state".to_string(), 3)].into_iter().collect();
1139        pf.initialize(&cardinalities);
1140
1141        assert_eq!(pf.particles.len(), 10);
1142
1143        // All particles should have equal weight
1144        for particle in &pf.particles {
1145            assert_abs_diff_eq!(particle.weight, 0.1, epsilon = 1e-6);
1146        }
1147    }
1148
1149    #[test]
1150    fn test_particle_filter_estimate_marginal() {
1151        let mut pf = ParticleFilter::new(100, vec!["state".to_string()]);
1152        let cardinalities: HashMap<String, usize> =
1153            [("state".to_string(), 2)].into_iter().collect();
1154        pf.initialize(&cardinalities);
1155
1156        let marginal = pf.estimate_marginal("state", 2);
1157        assert_eq!(marginal.len(), 2);
1158
1159        // Should sum to 1
1160        let sum: f64 = marginal.iter().sum();
1161        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
1162    }
1163
1164    #[test]
1165    fn test_particle_filter_ess() {
1166        let mut pf = ParticleFilter::new(100, vec!["state".to_string()]);
1167        let cardinalities: HashMap<String, usize> =
1168            [("state".to_string(), 2)].into_iter().collect();
1169        pf.initialize(&cardinalities);
1170
1171        let ess = pf.effective_sample_size();
1172        // Uniform weights should give ESS close to N
1173        assert!(ess > 90.0);
1174    }
1175
1176    #[test]
1177    fn test_particle_filter_resample() {
1178        let mut pf = ParticleFilter::new(10, vec!["state".to_string()]);
1179        let cardinalities: HashMap<String, usize> =
1180            [("state".to_string(), 2)].into_iter().collect();
1181        pf.initialize(&cardinalities);
1182
1183        // Manually set unequal weights
1184        for (i, particle) in pf.particles.iter_mut().enumerate() {
1185            particle.weight = if i == 0 { 1.0 } else { 0.0 };
1186        }
1187
1188        pf.resample();
1189
1190        // After resampling, weights should be equal
1191        for particle in &pf.particles {
1192            assert_abs_diff_eq!(particle.weight, 0.1, epsilon = 1e-6);
1193        }
1194    }
1195
1196    #[test]
1197    fn test_likelihood_weighting() {
1198        let mut graph = FactorGraph::new();
1199        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1200        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
1201
1202        let mut evidence = Assignment::new();
1203        evidence.insert("y".to_string(), 1);
1204
1205        let lw = LikelihoodWeighting::new(100);
1206        let result = lw.run(&graph, &evidence);
1207        assert!(result.is_ok());
1208
1209        let marginals = result.unwrap();
1210        assert!(marginals.contains_key("x"));
1211    }
1212
1213    #[test]
1214    fn test_importance_sampler_weighted_samples() {
1215        let mut graph = FactorGraph::new();
1216        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1217
1218        let sampler = ImportanceSampler::new(50);
1219        let samples = sampler
1220            .get_weighted_samples(&graph, &ProposalDistribution::Uniform)
1221            .unwrap();
1222
1223        assert_eq!(samples.len(), 50);
1224
1225        // All samples should have valid assignments
1226        for sample in &samples {
1227            assert!(sample.assignment.contains_key("x"));
1228        }
1229    }
1230
1231    #[test]
1232    fn test_weight_coefficient_of_variation() {
1233        let samples = vec![
1234            WeightedSample {
1235                assignment: HashMap::new(),
1236                weight: 1.0,
1237                log_weight: 0.0,
1238            },
1239            WeightedSample {
1240                assignment: HashMap::new(),
1241                weight: 1.0,
1242                log_weight: 0.0,
1243            },
1244        ];
1245
1246        let cv = ImportanceSampler::weight_coefficient_of_variation(&samples);
1247        // Equal weights should give CV = 0
1248        assert_abs_diff_eq!(cv, 0.0, epsilon = 1e-6);
1249    }
1250
1251    #[test]
1252    fn test_particle_filter_with_history() {
1253        let pf = ParticleFilter::new(5, vec!["state".to_string()])
1254            .with_history(true)
1255            .with_ess_threshold(0.3);
1256
1257        assert!(pf.track_history);
1258        assert_abs_diff_eq!(pf.ess_threshold, 0.3, epsilon = 1e-6);
1259    }
1260}