tensorlogic_quantrs_hooks/
sampling.rs1use scirs2_core::ndarray::ArrayD;
6use scirs2_core::random::{thread_rng, Rng};
7use std::collections::HashMap;
8
9use crate::error::{PgmError, Result};
10use crate::graph::FactorGraph;
11
12pub type Assignment = HashMap<String, usize>;
14
15pub struct GibbsSampler {
19 pub burn_in: usize,
21 pub num_samples: usize,
23 pub thinning: usize,
25}
26
27impl Default for GibbsSampler {
28 fn default() -> Self {
29 Self {
30 burn_in: 100,
31 num_samples: 1000,
32 thinning: 1,
33 }
34 }
35}
36
37impl GibbsSampler {
38 pub fn new(burn_in: usize, num_samples: usize, thinning: usize) -> Self {
40 Self {
41 burn_in,
42 num_samples,
43 thinning,
44 }
45 }
46
47 pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
49 let mut current_assignment = self.initialize_assignment(graph)?;
51
52 for _ in 0..self.burn_in {
54 self.gibbs_step(graph, &mut current_assignment)?;
55 }
56
57 let mut samples = Vec::new();
59 for i in 0..self.num_samples * self.thinning {
60 self.gibbs_step(graph, &mut current_assignment)?;
61
62 if i % self.thinning == 0 {
64 samples.push(current_assignment.clone());
65 }
66 }
67
68 self.compute_empirical_marginals(graph, &samples)
70 }
71
72 fn initialize_assignment(&self, graph: &FactorGraph) -> Result<Assignment> {
74 let mut rng = thread_rng();
75 let mut assignment = Assignment::new();
76
77 for var_name in graph.variable_names() {
78 if let Some(var_node) = graph.get_variable(var_name) {
79 let random_value = rng.gen_range(0..var_node.cardinality);
80 assignment.insert(var_name.clone(), random_value);
81 }
82 }
83
84 Ok(assignment)
85 }
86
87 fn gibbs_step(&self, graph: &FactorGraph, assignment: &mut Assignment) -> Result<()> {
89 for var_name in graph.variable_names() {
91 self.resample_variable(graph, var_name, assignment)?;
92 }
93
94 Ok(())
95 }
96
97 fn resample_variable(
99 &self,
100 graph: &FactorGraph,
101 var_name: &str,
102 assignment: &mut Assignment,
103 ) -> Result<()> {
104 let var_node = graph
105 .get_variable(var_name)
106 .ok_or_else(|| PgmError::VariableNotFound(var_name.to_string()))?;
107
108 let mut conditional_probs = vec![0.0; var_node.cardinality];
110
111 for (value, prob) in conditional_probs
112 .iter_mut()
113 .enumerate()
114 .take(var_node.cardinality)
115 {
116 assignment.insert(var_name.to_string(), value);
117 *prob = self.compute_joint_probability(graph, assignment)?;
118 }
119
120 let sum: f64 = conditional_probs.iter().sum();
122 if sum > 0.0 {
123 for prob in &mut conditional_probs {
124 *prob /= sum;
125 }
126 } else {
127 let uniform_prob = 1.0 / var_node.cardinality as f64;
129 conditional_probs = vec![uniform_prob; var_node.cardinality];
130 }
131
132 let sampled_value = self.sample_from_distribution(&conditional_probs);
134 assignment.insert(var_name.to_string(), sampled_value);
135
136 Ok(())
137 }
138
139 fn compute_joint_probability(
141 &self,
142 graph: &FactorGraph,
143 assignment: &Assignment,
144 ) -> Result<f64> {
145 let mut prob = 1.0;
146
147 for factor_id in graph.factor_ids() {
148 if let Some(factor) = graph.get_factor(factor_id) {
149 let mut indices = Vec::new();
151 for var in &factor.variables {
152 if let Some(&value) = assignment.get(var) {
153 indices.push(value);
154 } else {
155 return Err(PgmError::VariableNotFound(var.clone()));
156 }
157 }
158
159 prob *= factor.values[indices.as_slice()];
160 }
161 }
162
163 Ok(prob)
164 }
165
166 fn sample_from_distribution(&self, probs: &[f64]) -> usize {
168 let mut rng = thread_rng();
169 let u: f64 = rng.random();
170
171 let mut cumulative = 0.0;
172 for (idx, &prob) in probs.iter().enumerate() {
173 cumulative += prob;
174 if u < cumulative {
175 return idx;
176 }
177 }
178
179 probs.len() - 1
181 }
182
183 fn compute_empirical_marginals(
185 &self,
186 graph: &FactorGraph,
187 samples: &[Assignment],
188 ) -> Result<HashMap<String, ArrayD<f64>>> {
189 let mut marginals = HashMap::new();
190
191 for var_name in graph.variable_names() {
192 if let Some(var_node) = graph.get_variable(var_name) {
193 let mut counts = vec![0; var_node.cardinality];
194
195 for sample in samples {
197 if let Some(&value) = sample.get(var_name) {
198 counts[value] += 1;
199 }
200 }
201
202 let total = samples.len() as f64;
204 let probs: Vec<f64> = counts.iter().map(|&c| c as f64 / total).collect();
205
206 marginals.insert(
207 var_name.clone(),
208 ArrayD::from_shape_vec(vec![var_node.cardinality], probs)?,
209 );
210 }
211 }
212
213 Ok(marginals)
214 }
215
216 pub fn get_samples(&self, graph: &FactorGraph) -> Result<Vec<Assignment>> {
218 let mut current_assignment = self.initialize_assignment(graph)?;
219
220 for _ in 0..self.burn_in {
222 self.gibbs_step(graph, &mut current_assignment)?;
223 }
224
225 let mut samples = Vec::new();
227 for i in 0..self.num_samples * self.thinning {
228 self.gibbs_step(graph, &mut current_assignment)?;
229
230 if i % self.thinning == 0 {
231 samples.push(current_assignment.clone());
232 }
233 }
234
235 Ok(samples)
236 }
237}
238
239impl From<scirs2_core::ndarray::ShapeError> for PgmError {
240 fn from(err: scirs2_core::ndarray::ShapeError) -> Self {
241 PgmError::InvalidDistribution(format!("Shape error: {}", err))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use approx::assert_abs_diff_eq;
249
250 #[test]
251 fn test_gibbs_sampler_single_variable() {
252 let mut graph = FactorGraph::new();
253 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
254
255 let sampler = GibbsSampler::new(10, 100, 1);
256 let result = sampler.run(&graph);
257 assert!(result.is_ok());
258
259 let marginals = result.unwrap();
260 assert!(marginals.contains_key("x"));
261
262 let dist = &marginals["x"];
264 let sum: f64 = dist.iter().sum();
265 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
266 }
267
268 #[test]
269 fn test_gibbs_sampler_multiple_variables() {
270 let mut graph = FactorGraph::new();
271 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
272 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
273
274 let sampler = GibbsSampler::new(20, 100, 1);
275 let result = sampler.run(&graph);
276 assert!(result.is_ok());
277
278 let marginals = result.unwrap();
279 assert_eq!(marginals.len(), 2);
280 }
281
282 #[test]
283 fn test_sample_collection() {
284 let mut graph = FactorGraph::new();
285 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
286
287 let sampler = GibbsSampler::new(10, 50, 1);
288 let samples = sampler.get_samples(&graph);
289 assert!(samples.is_ok());
290 assert_eq!(samples.unwrap().len(), 50);
291 }
292
293 #[test]
294 fn test_gibbs_with_thinning() {
295 let mut graph = FactorGraph::new();
296 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
297
298 let sampler = GibbsSampler::new(10, 50, 5);
299 let samples = sampler.get_samples(&graph);
300 assert!(samples.is_ok());
301 assert_eq!(samples.unwrap().len(), 50);
302 }
303}