tensorlogic_quantrs_hooks/
variable_elimination.rs1use scirs2_core::ndarray::ArrayD;
8use std::collections::{HashMap, HashSet};
9
10use crate::error::{PgmError, Result};
11use crate::factor::Factor;
12use crate::graph::FactorGraph;
13
14pub struct VariableElimination {
18 pub elimination_order: Option<Vec<String>>,
20}
21
22impl Default for VariableElimination {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl VariableElimination {
29 pub fn new() -> Self {
31 Self {
32 elimination_order: None,
33 }
34 }
35
36 pub fn with_order(order: Vec<String>) -> Self {
38 Self {
39 elimination_order: Some(order),
40 }
41 }
42
43 pub fn marginalize(&self, graph: &FactorGraph, query_var: &str) -> Result<ArrayD<f64>> {
45 let query_node = graph
47 .get_variable(query_var)
48 .ok_or_else(|| PgmError::VariableNotFound(query_var.to_string()))?;
49
50 let mut factors: Vec<Factor> = graph
52 .factor_ids()
53 .filter_map(|id| graph.get_factor(id).cloned())
54 .collect();
55
56 if factors.is_empty() {
58 let uniform = ArrayD::from_elem(
59 vec![query_node.cardinality],
60 1.0 / query_node.cardinality as f64,
61 );
62 return Ok(uniform);
63 }
64
65 let all_vars: HashSet<String> = graph.variable_names().cloned().collect();
67 let vars_to_eliminate: Vec<String> =
68 all_vars.into_iter().filter(|v| v != query_var).collect();
69
70 let order = if let Some(ref custom_order) = self.elimination_order {
71 custom_order
72 .iter()
73 .filter(|v| vars_to_eliminate.contains(v))
74 .cloned()
75 .collect()
76 } else {
77 self.compute_elimination_order(graph, &vars_to_eliminate)?
78 };
79
80 for var in &order {
82 factors = self.eliminate_variable(&factors, var)?;
83 }
84
85 let mut result = self.multiply_all_factors(&factors)?;
87
88 let vars_to_remove: Vec<String> = result
90 .variables
91 .iter()
92 .filter(|v| *v != query_var)
93 .cloned()
94 .collect();
95
96 for var in vars_to_remove {
97 result = result.marginalize_out(&var)?;
98 }
99
100 result.normalize();
102
103 Ok(result.values)
104 }
105
106 pub fn marginalize_all(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
108 let mut marginals = HashMap::new();
109
110 for var_name in graph.variable_names() {
111 let marginal = self.marginalize(graph, var_name)?;
112 marginals.insert(var_name.clone(), marginal);
113 }
114
115 Ok(marginals)
116 }
117
118 fn eliminate_variable(&self, factors: &[Factor], var: &str) -> Result<Vec<Factor>> {
120 let (containing, not_containing): (Vec<Factor>, Vec<Factor>) = factors
122 .iter()
123 .cloned()
124 .partition(|f| f.variables.contains(&var.to_string()));
125
126 if containing.is_empty() {
127 return Ok(factors.to_vec());
129 }
130
131 let mut product = containing[0].clone();
133 for factor in &containing[1..] {
134 product = product.product(factor)?;
135 }
136
137 let marginalized = product.marginalize_out(var)?;
139
140 let mut result = not_containing;
142 if !marginalized.variables.is_empty() {
143 result.push(marginalized);
144 }
145
146 Ok(result)
147 }
148
149 fn multiply_all_factors(&self, factors: &[Factor]) -> Result<Factor> {
151 if factors.is_empty() {
152 return Err(PgmError::InvalidGraph("No factors to multiply".to_string()));
153 }
154
155 let mut result = factors[0].clone();
156 for factor in &factors[1..] {
157 result = result.product(factor)?;
158 }
159
160 Ok(result)
161 }
162
163 fn compute_elimination_order(
167 &self,
168 graph: &FactorGraph,
169 vars: &[String],
170 ) -> Result<Vec<String>> {
171 let mut order = vars.to_vec();
174
175 order.sort_by_key(|v| {
177 graph
178 .get_adjacent_factors(v)
179 .map(|factors| factors.len())
180 .unwrap_or(0)
181 });
182
183 Ok(order)
184 }
185
186 pub fn joint_probability(
188 &self,
189 graph: &FactorGraph,
190 assignment: &HashMap<String, usize>,
191 ) -> Result<f64> {
192 let mut prob = 1.0;
193
194 for factor_id in graph.factor_ids() {
195 if let Some(factor) = graph.get_factor(factor_id) {
196 let mut indices = Vec::new();
198 for var in &factor.variables {
199 if let Some(&value) = assignment.get(var) {
200 indices.push(value);
201 } else {
202 return Err(PgmError::VariableNotFound(var.clone()));
203 }
204 }
205
206 prob *= factor.values[indices.as_slice()];
207 }
208 }
209
210 Ok(prob)
211 }
212
213 pub fn map(&self, graph: &FactorGraph) -> Result<HashMap<String, usize>> {
215 let mut factors: Vec<Factor> = graph
217 .factor_ids()
218 .filter_map(|id| graph.get_factor(id).cloned())
219 .collect();
220
221 let all_vars: Vec<String> = graph.variable_names().cloned().collect();
223 let order = if let Some(ref custom_order) = self.elimination_order {
224 custom_order.clone()
225 } else {
226 self.compute_elimination_order(graph, &all_vars)?
227 };
228
229 let mut assignment = HashMap::new();
230
231 for var in order.iter().rev() {
233 let (containing, not_containing): (Vec<Factor>, Vec<Factor>) = factors
235 .iter()
236 .cloned()
237 .partition(|f| f.variables.contains(&var.to_string()));
238
239 if containing.is_empty() {
240 continue;
241 }
242
243 let mut product = containing[0].clone();
245 for factor in &containing[1..] {
246 product = product.product(factor)?;
247 }
248
249 let var_node = graph
251 .get_variable(var)
252 .ok_or_else(|| PgmError::VariableNotFound(var.clone()))?;
253
254 let mut max_val = f64::NEG_INFINITY;
255 let mut max_idx = 0;
256
257 for val in 0..var_node.cardinality {
258 let reduced = product.reduce(var, val)?;
259 let prob: f64 = reduced.values.iter().product();
260
261 if prob > max_val {
262 max_val = prob;
263 max_idx = val;
264 }
265 }
266
267 assignment.insert(var.clone(), max_idx);
268
269 let reduced = product.reduce(var, max_idx)?;
271 factors = not_containing;
272 if !reduced.variables.is_empty() {
273 factors.push(reduced);
274 }
275 }
276
277 Ok(assignment)
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use approx::assert_abs_diff_eq;
285 use scirs2_core::ndarray::Array;
286
287 #[test]
288 fn test_variable_elimination_single_variable() {
289 let mut graph = FactorGraph::new();
290 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
291
292 let factor = Factor::uniform("P(x)".to_string(), vec!["x".to_string()], 2);
294 graph.add_factor(factor).unwrap();
295
296 let ve = VariableElimination::new();
297 let marginal = ve.marginalize(&graph, "x").unwrap();
298
299 assert_eq!(marginal.len(), 2);
300 let sum: f64 = marginal.iter().sum();
301 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
302 }
303
304 #[test]
305 fn test_variable_elimination_chain() {
306 let mut graph = FactorGraph::new();
307 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
308 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
309
310 let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
312 .unwrap()
313 .into_dyn();
314 let px = Factor::new("P(x)".to_string(), vec!["x".to_string()], px_values).unwrap();
315 graph.add_factor(px).unwrap();
316
317 let pyx_values = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
319 .unwrap()
320 .into_dyn();
321 let pyx = Factor::new(
322 "P(y|x)".to_string(),
323 vec!["x".to_string(), "y".to_string()],
324 pyx_values,
325 )
326 .unwrap();
327 graph.add_factor(pyx).unwrap();
328
329 let ve = VariableElimination::new();
330 let marginal_y = ve.marginalize(&graph, "y").unwrap();
331
332 assert_eq!(marginal_y.len(), 2);
333 let sum: f64 = marginal_y.iter().sum();
334 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
335 }
336
337 #[test]
338 fn test_marginalize_all() {
339 let mut graph = FactorGraph::new();
340 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
341 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
342
343 let ve = VariableElimination::new();
344 let marginals = ve.marginalize_all(&graph).unwrap();
345
346 assert_eq!(marginals.len(), 2);
347 assert!(marginals.contains_key("x"));
348 assert!(marginals.contains_key("y"));
349 }
350
351 #[test]
352 fn test_joint_probability() {
353 let mut graph = FactorGraph::new();
354 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
355 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
356
357 let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
359 .unwrap()
360 .into_dyn();
361 let px = Factor::new("P(x)".to_string(), vec!["x".to_string()], px_values).unwrap();
362 graph.add_factor(px).unwrap();
363
364 let pyx_values = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
365 .unwrap()
366 .into_dyn();
367 let pyx = Factor::new(
368 "P(y|x)".to_string(),
369 vec!["x".to_string(), "y".to_string()],
370 pyx_values,
371 )
372 .unwrap();
373 graph.add_factor(pyx).unwrap();
374
375 let mut assignment = HashMap::new();
376 assignment.insert("x".to_string(), 0);
377 assignment.insert("y".to_string(), 1);
378
379 let ve = VariableElimination::new();
380 let prob = ve.joint_probability(&graph, &assignment).unwrap();
381
382 assert_abs_diff_eq!(prob, 0.06, epsilon = 1e-6);
384 }
385
386 #[test]
387 fn test_custom_elimination_order() {
388 let mut graph = FactorGraph::new();
389 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
390 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
391 graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
392
393 let order = vec!["x".to_string(), "y".to_string()];
394 let ve = VariableElimination::with_order(order);
395
396 let marginal = ve.marginalize(&graph, "z").unwrap();
397 assert_eq!(marginal.len(), 2);
398 }
399
400 #[test]
401 fn test_map_inference() {
402 let mut graph = FactorGraph::new();
403 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
404 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
405
406 let px_values = Array::from_shape_vec(vec![2], vec![0.3, 0.7])
408 .unwrap()
409 .into_dyn();
410 let px = Factor::new("P(x)".to_string(), vec!["x".to_string()], px_values).unwrap();
411 graph.add_factor(px).unwrap();
412
413 let pyx_values = Array::from_shape_vec(vec![2, 2], vec![0.8, 0.2, 0.1, 0.9])
414 .unwrap()
415 .into_dyn();
416 let pyx = Factor::new(
417 "P(y|x)".to_string(),
418 vec!["x".to_string(), "y".to_string()],
419 pyx_values,
420 )
421 .unwrap();
422 graph.add_factor(pyx).unwrap();
423
424 let ve = VariableElimination::new();
425 let map_assignment = ve.map(&graph).unwrap();
426
427 assert!(map_assignment.contains_key("x"));
428 assert!(map_assignment.contains_key("y"));
429 }
430}