Skip to main content

tensorlogic_quantrs_hooks/
variable_elimination.rs

1//! Variable Elimination algorithm for exact inference.
2//!
3//! Variable Elimination is a classic exact inference algorithm that eliminates
4//! variables one by one from the factor graph. The complexity depends on the
5//! elimination order.
6
7use scirs2_core::ndarray::ArrayD;
8use std::collections::{HashMap, HashSet};
9
10use crate::error::{PgmError, Result};
11use crate::factor::Factor;
12use crate::graph::FactorGraph;
13
14/// Variable elimination algorithm for exact inference.
15///
16/// Computes marginal probabilities by eliminating variables in a specific order.
17pub struct VariableElimination {
18    /// Elimination order (if None, uses min-degree heuristic)
19    pub elimination_order: Option<Vec<String>>,
20}
21
22impl Default for VariableElimination {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl VariableElimination {
29    /// Create a new variable elimination algorithm.
30    pub fn new() -> Self {
31        Self {
32            elimination_order: None,
33        }
34    }
35
36    /// Create with a specific elimination order.
37    pub fn with_order(order: Vec<String>) -> Self {
38        Self {
39            elimination_order: Some(order),
40        }
41    }
42
43    /// Compute marginal for a single query variable.
44    pub fn marginalize(&self, graph: &FactorGraph, query_var: &str) -> Result<ArrayD<f64>> {
45        // Check if query variable exists
46        let query_node = graph
47            .get_variable(query_var)
48            .ok_or_else(|| PgmError::VariableNotFound(query_var.to_string()))?;
49
50        // Get all factors as a working set
51        let mut factors: Vec<Factor> = graph
52            .factor_ids()
53            .filter_map(|id| graph.get_factor(id).cloned())
54            .collect();
55
56        // If no factors, return uniform distribution
57        if factors.is_empty() {
58            let uniform = ArrayD::from_elem(
59                vec![query_node.cardinality],
60                1.0 / query_node.cardinality as f64,
61            );
62            return Ok(uniform);
63        }
64
65        // Determine elimination order
66        let all_vars: HashSet<String> = graph.variable_names().cloned().collect();
67        let vars_to_eliminate: Vec<String> =
68            all_vars.into_iter().filter(|v| v != query_var).collect();
69
70        let order = if let Some(ref custom_order) = self.elimination_order {
71            custom_order
72                .iter()
73                .filter(|v| vars_to_eliminate.contains(v))
74                .cloned()
75                .collect()
76        } else {
77            self.compute_elimination_order(graph, &vars_to_eliminate)?
78        };
79
80        // Eliminate variables one by one
81        for var in &order {
82            factors = self.eliminate_variable(&factors, var)?;
83        }
84
85        // Multiply remaining factors and marginalize to query variable
86        let mut result = self.multiply_all_factors(&factors)?;
87
88        // If result contains more than just the query variable, marginalize others
89        let vars_to_remove: Vec<String> = result
90            .variables
91            .iter()
92            .filter(|v| *v != query_var)
93            .cloned()
94            .collect();
95
96        for var in vars_to_remove {
97            result = result.marginalize_out(&var)?;
98        }
99
100        // Normalize
101        result.normalize();
102
103        Ok(result.values)
104    }
105
106    /// Compute marginals for all variables.
107    pub fn marginalize_all(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
108        let mut marginals = HashMap::new();
109
110        for var_name in graph.variable_names() {
111            let marginal = self.marginalize(graph, var_name)?;
112            marginals.insert(var_name.clone(), marginal);
113        }
114
115        Ok(marginals)
116    }
117
118    /// Eliminate a single variable from a set of factors.
119    fn eliminate_variable(&self, factors: &[Factor], var: &str) -> Result<Vec<Factor>> {
120        // Find all factors containing this variable
121        let (containing, not_containing): (Vec<Factor>, Vec<Factor>) = factors
122            .iter()
123            .cloned()
124            .partition(|f| f.variables.contains(&var.to_string()));
125
126        if containing.is_empty() {
127            // Variable not in any factor, nothing to eliminate
128            return Ok(factors.to_vec());
129        }
130
131        // Multiply all factors containing the variable
132        let mut product = containing[0].clone();
133        for factor in &containing[1..] {
134            product = product.product(factor)?;
135        }
136
137        // Marginalize out the variable
138        let marginalized = product.marginalize_out(var)?;
139
140        // Combine with factors that didn't contain the variable
141        let mut result = not_containing;
142        if !marginalized.variables.is_empty() {
143            result.push(marginalized);
144        }
145
146        Ok(result)
147    }
148
149    /// Multiply all factors together.
150    fn multiply_all_factors(&self, factors: &[Factor]) -> Result<Factor> {
151        if factors.is_empty() {
152            return Err(PgmError::InvalidGraph("No factors to multiply".to_string()));
153        }
154
155        let mut result = factors[0].clone();
156        for factor in &factors[1..] {
157            result = result.product(factor)?;
158        }
159
160        Ok(result)
161    }
162
163    /// Compute elimination order using min-degree heuristic.
164    ///
165    /// Chooses variables in order of fewest connections (smallest induced clique size).
166    fn compute_elimination_order(
167        &self,
168        graph: &FactorGraph,
169        vars: &[String],
170    ) -> Result<Vec<String>> {
171        // Simple heuristic: eliminate variables in the order they appear
172        // More sophisticated: use min-degree, min-fill, or max-cardinality search
173        let mut order = vars.to_vec();
174
175        // Sort by number of factors containing each variable
176        order.sort_by_key(|v| {
177            graph
178                .get_adjacent_factors(v)
179                .map(|factors| factors.len())
180                .unwrap_or(0)
181        });
182
183        Ok(order)
184    }
185
186    /// Compute joint probability for a specific assignment.
187    pub fn joint_probability(
188        &self,
189        graph: &FactorGraph,
190        assignment: &HashMap<String, usize>,
191    ) -> Result<f64> {
192        let mut prob = 1.0;
193
194        for factor_id in graph.factor_ids() {
195            if let Some(factor) = graph.get_factor(factor_id) {
196                // Build index for this factor
197                let mut indices = Vec::new();
198                for var in &factor.variables {
199                    if let Some(&value) = assignment.get(var) {
200                        indices.push(value);
201                    } else {
202                        return Err(PgmError::VariableNotFound(var.clone()));
203                    }
204                }
205
206                prob *= factor.values[indices.as_slice()];
207            }
208        }
209
210        Ok(prob)
211    }
212
213    /// Compute MAP (Maximum A Posteriori) assignment using variable elimination.
214    pub fn map(&self, graph: &FactorGraph) -> Result<HashMap<String, usize>> {
215        // Get all factors
216        let mut factors: Vec<Factor> = graph
217            .factor_ids()
218            .filter_map(|id| graph.get_factor(id).cloned())
219            .collect();
220
221        // Get elimination order (all variables)
222        let all_vars: Vec<String> = graph.variable_names().cloned().collect();
223        let order = if let Some(ref custom_order) = self.elimination_order {
224            custom_order.clone()
225        } else {
226            self.compute_elimination_order(graph, &all_vars)?
227        };
228
229        let mut assignment = HashMap::new();
230
231        // Eliminate variables using MAX instead of SUM
232        for var in order.iter().rev() {
233            // Find factors containing this variable
234            let (containing, not_containing): (Vec<Factor>, Vec<Factor>) = factors
235                .iter()
236                .cloned()
237                .partition(|f| f.variables.contains(&var.to_string()));
238
239            if containing.is_empty() {
240                continue;
241            }
242
243            // Multiply factors
244            let mut product = containing[0].clone();
245            for factor in &containing[1..] {
246                product = product.product(factor)?;
247            }
248
249            // Find max value for this variable
250            let var_node = graph
251                .get_variable(var)
252                .ok_or_else(|| PgmError::VariableNotFound(var.clone()))?;
253
254            let mut max_val = f64::NEG_INFINITY;
255            let mut max_idx = 0;
256
257            for val in 0..var_node.cardinality {
258                let reduced = product.reduce(var, val)?;
259                let prob: f64 = reduced.values.iter().product();
260
261                if prob > max_val {
262                    max_val = prob;
263                    max_idx = val;
264                }
265            }
266
267            assignment.insert(var.clone(), max_idx);
268
269            // Reduce factor to this assignment
270            let reduced = product.reduce(var, max_idx)?;
271            factors = not_containing;
272            if !reduced.variables.is_empty() {
273                factors.push(reduced);
274            }
275        }
276
277        Ok(assignment)
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use approx::assert_abs_diff_eq;
285    use scirs2_core::ndarray::Array;
286
287    #[test]
288    fn test_variable_elimination_single_variable() {
289        let mut graph = FactorGraph::new();
290        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
291
292        // Add uniform factor
293        let factor = Factor::uniform("P(x)".to_string(), vec!["x".to_string()], 2);
294        graph.add_factor(factor).unwrap();
295
296        let ve = VariableElimination::new();
297        let marginal = ve.marginalize(&graph, "x").unwrap();
298
299        assert_eq!(marginal.len(), 2);
300        let sum: f64 = marginal.iter().sum();
301        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
302    }
303
304    #[test]
305    fn test_variable_elimination_chain() {
306        let mut graph = FactorGraph::new();
307        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
308        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
309
310        // P(x)
311        let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
312            .unwrap()
313            .into_dyn();
314        let px = Factor::new("P(x)".to_string(), vec!["x".to_string()], px_values).unwrap();
315        graph.add_factor(px).unwrap();
316
317        // P(y|x)
318        let pyx_values = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
319            .unwrap()
320            .into_dyn();
321        let pyx = Factor::new(
322            "P(y|x)".to_string(),
323            vec!["x".to_string(), "y".to_string()],
324            pyx_values,
325        )
326        .unwrap();
327        graph.add_factor(pyx).unwrap();
328
329        let ve = VariableElimination::new();
330        let marginal_y = ve.marginalize(&graph, "y").unwrap();
331
332        assert_eq!(marginal_y.len(), 2);
333        let sum: f64 = marginal_y.iter().sum();
334        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
335    }
336
337    #[test]
338    fn test_marginalize_all() {
339        let mut graph = FactorGraph::new();
340        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
341        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
342
343        let ve = VariableElimination::new();
344        let marginals = ve.marginalize_all(&graph).unwrap();
345
346        assert_eq!(marginals.len(), 2);
347        assert!(marginals.contains_key("x"));
348        assert!(marginals.contains_key("y"));
349    }
350
351    #[test]
352    fn test_joint_probability() {
353        let mut graph = FactorGraph::new();
354        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
355        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
356
357        // Add factors
358        let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
359            .unwrap()
360            .into_dyn();
361        let px = Factor::new("P(x)".to_string(), vec!["x".to_string()], px_values).unwrap();
362        graph.add_factor(px).unwrap();
363
364        let pyx_values = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
365            .unwrap()
366            .into_dyn();
367        let pyx = Factor::new(
368            "P(y|x)".to_string(),
369            vec!["x".to_string(), "y".to_string()],
370            pyx_values,
371        )
372        .unwrap();
373        graph.add_factor(pyx).unwrap();
374
375        let mut assignment = HashMap::new();
376        assignment.insert("x".to_string(), 0);
377        assignment.insert("y".to_string(), 1);
378
379        let ve = VariableElimination::new();
380        let prob = ve.joint_probability(&graph, &assignment).unwrap();
381
382        // P(x=0, y=1) = P(x=0) * P(y=1|x=0) = 0.6 * 0.1 = 0.06
383        assert_abs_diff_eq!(prob, 0.06, epsilon = 1e-6);
384    }
385
386    #[test]
387    fn test_custom_elimination_order() {
388        let mut graph = FactorGraph::new();
389        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
390        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
391        graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
392
393        let order = vec!["x".to_string(), "y".to_string()];
394        let ve = VariableElimination::with_order(order);
395
396        let marginal = ve.marginalize(&graph, "z").unwrap();
397        assert_eq!(marginal.len(), 2);
398    }
399
400    #[test]
401    fn test_map_inference() {
402        let mut graph = FactorGraph::new();
403        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
404        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
405
406        // Add biased factors
407        let px_values = Array::from_shape_vec(vec![2], vec![0.3, 0.7])
408            .unwrap()
409            .into_dyn();
410        let px = Factor::new("P(x)".to_string(), vec!["x".to_string()], px_values).unwrap();
411        graph.add_factor(px).unwrap();
412
413        let pyx_values = Array::from_shape_vec(vec![2, 2], vec![0.8, 0.2, 0.1, 0.9])
414            .unwrap()
415            .into_dyn();
416        let pyx = Factor::new(
417            "P(y|x)".to_string(),
418            vec!["x".to_string(), "y".to_string()],
419            pyx_values,
420        )
421        .unwrap();
422        graph.add_factor(pyx).unwrap();
423
424        let ve = VariableElimination::new();
425        let map_assignment = ve.map(&graph).unwrap();
426
427        assert!(map_assignment.contains_key("x"));
428        assert!(map_assignment.contains_key("y"));
429    }
430}