tensorlogic_quantrs_hooks/
sampling.rs

1//! Sampling-based inference methods for PGM.
2//!
3//! This module provides MCMC and other sampling algorithms for approximate inference.
4
5use scirs2_core::ndarray::ArrayD;
6use scirs2_core::random::{thread_rng, Rng};
7use std::collections::HashMap;
8
9use crate::error::{PgmError, Result};
10use crate::graph::FactorGraph;
11
12/// Assignment of values to variables.
13pub type Assignment = HashMap<String, usize>;
14
15/// Gibbs sampling for approximate inference.
16///
17/// Uses Markov Chain Monte Carlo to sample from the joint distribution.
18pub struct GibbsSampler {
19    /// Number of burn-in samples to discard
20    pub burn_in: usize,
21    /// Number of samples to collect
22    pub num_samples: usize,
23    /// Thinning interval (keep every N-th sample)
24    pub thinning: usize,
25}
26
27impl Default for GibbsSampler {
28    fn default() -> Self {
29        Self {
30            burn_in: 100,
31            num_samples: 1000,
32            thinning: 1,
33        }
34    }
35}
36
37impl GibbsSampler {
38    /// Create with custom parameters.
39    pub fn new(burn_in: usize, num_samples: usize, thinning: usize) -> Self {
40        Self {
41            burn_in,
42            num_samples,
43            thinning,
44        }
45    }
46
47    /// Run Gibbs sampling to approximate marginals.
48    pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
49        // Initialize random assignment
50        let mut current_assignment = self.initialize_assignment(graph)?;
51
52        // Burn-in phase
53        for _ in 0..self.burn_in {
54            self.gibbs_step(graph, &mut current_assignment)?;
55        }
56
57        // Collect samples
58        let mut samples = Vec::new();
59        for i in 0..self.num_samples * self.thinning {
60            self.gibbs_step(graph, &mut current_assignment)?;
61
62            // Keep sample if it's at thinning interval
63            if i % self.thinning == 0 {
64                samples.push(current_assignment.clone());
65            }
66        }
67
68        // Compute empirical marginals from samples
69        self.compute_empirical_marginals(graph, &samples)
70    }
71
72    /// Initialize random assignment for all variables.
73    fn initialize_assignment(&self, graph: &FactorGraph) -> Result<Assignment> {
74        let mut rng = thread_rng();
75        let mut assignment = Assignment::new();
76
77        for var_name in graph.variable_names() {
78            if let Some(var_node) = graph.get_variable(var_name) {
79                let random_value = rng.gen_range(0..var_node.cardinality);
80                assignment.insert(var_name.clone(), random_value);
81            }
82        }
83
84        Ok(assignment)
85    }
86
87    /// Perform one Gibbs sampling step (resample all variables).
88    fn gibbs_step(&self, graph: &FactorGraph, assignment: &mut Assignment) -> Result<()> {
89        // Resample each variable conditioned on others
90        for var_name in graph.variable_names() {
91            self.resample_variable(graph, var_name, assignment)?;
92        }
93
94        Ok(())
95    }
96
97    /// Resample a single variable given current assignment of others.
98    fn resample_variable(
99        &self,
100        graph: &FactorGraph,
101        var_name: &str,
102        assignment: &mut Assignment,
103    ) -> Result<()> {
104        let var_node = graph
105            .get_variable(var_name)
106            .ok_or_else(|| PgmError::VariableNotFound(var_name.to_string()))?;
107
108        // Compute conditional distribution P(X | others)
109        let mut conditional_probs = vec![0.0; var_node.cardinality];
110
111        for (value, prob) in conditional_probs
112            .iter_mut()
113            .enumerate()
114            .take(var_node.cardinality)
115        {
116            assignment.insert(var_name.to_string(), value);
117            *prob = self.compute_joint_probability(graph, assignment)?;
118        }
119
120        // Normalize
121        let sum: f64 = conditional_probs.iter().sum();
122        if sum > 0.0 {
123            for prob in &mut conditional_probs {
124                *prob /= sum;
125            }
126        } else {
127            // Fallback to uniform if all zero
128            let uniform_prob = 1.0 / var_node.cardinality as f64;
129            conditional_probs = vec![uniform_prob; var_node.cardinality];
130        }
131
132        // Sample from conditional distribution
133        let sampled_value = self.sample_from_distribution(&conditional_probs);
134        assignment.insert(var_name.to_string(), sampled_value);
135
136        Ok(())
137    }
138
139    /// Compute joint probability for a full assignment.
140    fn compute_joint_probability(
141        &self,
142        graph: &FactorGraph,
143        assignment: &Assignment,
144    ) -> Result<f64> {
145        let mut prob = 1.0;
146
147        for factor_id in graph.factor_ids() {
148            if let Some(factor) = graph.get_factor(factor_id) {
149                // Build index for this factor
150                let mut indices = Vec::new();
151                for var in &factor.variables {
152                    if let Some(&value) = assignment.get(var) {
153                        indices.push(value);
154                    } else {
155                        return Err(PgmError::VariableNotFound(var.clone()));
156                    }
157                }
158
159                prob *= factor.values[indices.as_slice()];
160            }
161        }
162
163        Ok(prob)
164    }
165
166    /// Sample from a discrete probability distribution.
167    fn sample_from_distribution(&self, probs: &[f64]) -> usize {
168        let mut rng = thread_rng();
169        let u: f64 = rng.random();
170
171        let mut cumulative = 0.0;
172        for (idx, &prob) in probs.iter().enumerate() {
173            cumulative += prob;
174            if u < cumulative {
175                return idx;
176            }
177        }
178
179        // Fallback to last index
180        probs.len() - 1
181    }
182
183    /// Compute empirical marginals from collected samples.
184    fn compute_empirical_marginals(
185        &self,
186        graph: &FactorGraph,
187        samples: &[Assignment],
188    ) -> Result<HashMap<String, ArrayD<f64>>> {
189        let mut marginals = HashMap::new();
190
191        for var_name in graph.variable_names() {
192            if let Some(var_node) = graph.get_variable(var_name) {
193                let mut counts = vec![0; var_node.cardinality];
194
195                // Count occurrences
196                for sample in samples {
197                    if let Some(&value) = sample.get(var_name) {
198                        counts[value] += 1;
199                    }
200                }
201
202                // Normalize to probabilities
203                let total = samples.len() as f64;
204                let probs: Vec<f64> = counts.iter().map(|&c| c as f64 / total).collect();
205
206                marginals.insert(
207                    var_name.clone(),
208                    ArrayD::from_shape_vec(vec![var_node.cardinality], probs)?,
209                );
210            }
211        }
212
213        Ok(marginals)
214    }
215
216    /// Get all samples (for analysis).
217    pub fn get_samples(&self, graph: &FactorGraph) -> Result<Vec<Assignment>> {
218        let mut current_assignment = self.initialize_assignment(graph)?;
219
220        // Burn-in
221        for _ in 0..self.burn_in {
222            self.gibbs_step(graph, &mut current_assignment)?;
223        }
224
225        // Collect samples
226        let mut samples = Vec::new();
227        for i in 0..self.num_samples * self.thinning {
228            self.gibbs_step(graph, &mut current_assignment)?;
229
230            if i % self.thinning == 0 {
231                samples.push(current_assignment.clone());
232            }
233        }
234
235        Ok(samples)
236    }
237}
238
239impl From<scirs2_core::ndarray::ShapeError> for PgmError {
240    fn from(err: scirs2_core::ndarray::ShapeError) -> Self {
241        PgmError::InvalidDistribution(format!("Shape error: {}", err))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use approx::assert_abs_diff_eq;
249
250    #[test]
251    fn test_gibbs_sampler_single_variable() {
252        let mut graph = FactorGraph::new();
253        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
254
255        let sampler = GibbsSampler::new(10, 100, 1);
256        let result = sampler.run(&graph);
257        assert!(result.is_ok());
258
259        let marginals = result.unwrap();
260        assert!(marginals.contains_key("x"));
261
262        // Should be approximately uniform
263        let dist = &marginals["x"];
264        let sum: f64 = dist.iter().sum();
265        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
266    }
267
268    #[test]
269    fn test_gibbs_sampler_multiple_variables() {
270        let mut graph = FactorGraph::new();
271        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
272        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
273
274        let sampler = GibbsSampler::new(20, 100, 1);
275        let result = sampler.run(&graph);
276        assert!(result.is_ok());
277
278        let marginals = result.unwrap();
279        assert_eq!(marginals.len(), 2);
280    }
281
282    #[test]
283    fn test_sample_collection() {
284        let mut graph = FactorGraph::new();
285        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
286
287        let sampler = GibbsSampler::new(10, 50, 1);
288        let samples = sampler.get_samples(&graph);
289        assert!(samples.is_ok());
290        assert_eq!(samples.unwrap().len(), 50);
291    }
292
293    #[test]
294    fn test_gibbs_with_thinning() {
295        let mut graph = FactorGraph::new();
296        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
297
298        let sampler = GibbsSampler::new(10, 50, 5);
299        let samples = sampler.get_samples(&graph);
300        assert!(samples.is_ok());
301        assert_eq!(samples.unwrap().len(), 50);
302    }
303}