Skip to main content

tensorlogic_quantrs_hooks/
elimination_ordering.rs

1//! Elimination ordering heuristics for variable elimination.
2//!
3//! Different heuristics can produce significantly different elimination orders,
4//! which affects the computational cost of variable elimination. This module
5//! provides several classic ordering heuristics.
6
7use crate::error::{PgmError, Result};
8use crate::graph::FactorGraph;
9use std::collections::{HashMap, HashSet};
10
11/// Strategy for computing variable elimination ordering.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum EliminationStrategy {
14    /// Min-degree: Choose variable with fewest neighbors
15    #[default]
16    MinDegree,
17    /// Min-fill: Choose variable that introduces fewest new edges
18    MinFill,
19    /// Weighted min-fill: Min-fill weighted by factor sizes
20    WeightedMinFill,
21    /// Min-width: Minimize the width of the induced tree
22    MinWidth,
23    /// Max-cardinality search: Greedy algorithm that tends to produce good orderings
24    MaxCardinalitySearch,
25}
26
27/// Compute elimination ordering for variable elimination.
28pub struct EliminationOrdering {
29    strategy: EliminationStrategy,
30}
31
32impl Default for EliminationOrdering {
33    fn default() -> Self {
34        Self::new(EliminationStrategy::default())
35    }
36}
37
38impl EliminationOrdering {
39    /// Create with a specific strategy.
40    pub fn new(strategy: EliminationStrategy) -> Self {
41        Self { strategy }
42    }
43
44    /// Compute elimination order for the given variables.
45    pub fn compute_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
46        match self.strategy {
47            EliminationStrategy::MinDegree => self.min_degree_order(graph, vars),
48            EliminationStrategy::MinFill => self.min_fill_order(graph, vars),
49            EliminationStrategy::WeightedMinFill => self.weighted_min_fill_order(graph, vars),
50            EliminationStrategy::MinWidth => self.min_width_order(graph, vars),
51            EliminationStrategy::MaxCardinalitySearch => self.max_cardinality_search(graph, vars),
52        }
53    }
54
55    /// Min-degree heuristic: Choose variable with fewest neighbors.
56    ///
57    /// This is a simple and fast heuristic that works well in many cases.
58    fn min_degree_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
59        let mut remaining: HashSet<String> = vars.iter().cloned().collect();
60        let mut order = Vec::new();
61
62        // Build initial adjacency graph
63        let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
64
65        while !remaining.is_empty() {
66            // Find variable with minimum degree
67            let min_var = remaining
68                .iter()
69                .min_by_key(|v| adjacency.get(*v).map(|s| s.len()).unwrap_or(0))
70                .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
71                .clone();
72
73            order.push(min_var.clone());
74            remaining.remove(&min_var);
75
76            // Update adjacency after elimination
77            self.update_adjacency_after_elimination(&mut adjacency, &min_var);
78        }
79
80        Ok(order)
81    }
82
83    /// Min-fill heuristic: Choose variable that introduces fewest new edges.
84    ///
85    /// When a variable is eliminated, its neighbors become fully connected.
86    /// This heuristic minimizes the number of new edges (fill) created.
87    fn min_fill_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
88        let mut remaining: HashSet<String> = vars.iter().cloned().collect();
89        let mut order = Vec::new();
90
91        // Build initial adjacency graph
92        let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
93
94        while !remaining.is_empty() {
95            // Find variable that introduces minimum fill
96            let min_var = remaining
97                .iter()
98                .min_by_key(|v| self.compute_fill(&adjacency, v))
99                .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
100                .clone();
101
102            order.push(min_var.clone());
103            remaining.remove(&min_var);
104
105            // Update adjacency after elimination
106            self.update_adjacency_after_elimination(&mut adjacency, &min_var);
107        }
108
109        Ok(order)
110    }
111
112    /// Weighted min-fill: Min-fill weighted by factor sizes.
113    ///
114    /// Similar to min-fill, but weights the fill by the product of factor sizes.
115    /// This tries to minimize the computational cost more directly.
116    fn weighted_min_fill_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
117        let mut remaining: HashSet<String> = vars.iter().cloned().collect();
118        let mut order = Vec::new();
119
120        // Build initial adjacency graph with weights
121        let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
122        let weights = self.compute_variable_weights(graph, vars)?;
123
124        while !remaining.is_empty() {
125            // Find variable that introduces minimum weighted fill
126            let min_var = remaining
127                .iter()
128                .min_by_key(|v| {
129                    let fill = self.compute_fill(&adjacency, v);
130                    let weight = weights.get(*v).copied().unwrap_or(1);
131                    fill * weight
132                })
133                .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
134                .clone();
135
136            order.push(min_var.clone());
137            remaining.remove(&min_var);
138
139            // Update adjacency after elimination
140            self.update_adjacency_after_elimination(&mut adjacency, &min_var);
141        }
142
143        Ok(order)
144    }
145
146    /// Min-width heuristic: Minimize the width of the induced tree.
147    ///
148    /// Width is the size of the largest clique created during elimination.
149    fn min_width_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
150        let mut remaining: HashSet<String> = vars.iter().cloned().collect();
151        let mut order = Vec::new();
152
153        // Build initial adjacency graph
154        let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
155
156        while !remaining.is_empty() {
157            // Find variable that minimizes induced width
158            let min_var = remaining
159                .iter()
160                .min_by_key(|v| {
161                    let neighbors = adjacency.get(*v).map(|s| s.len()).unwrap_or(0);
162                    neighbors
163                })
164                .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
165                .clone();
166
167            order.push(min_var.clone());
168            remaining.remove(&min_var);
169
170            // Update adjacency after elimination
171            self.update_adjacency_after_elimination(&mut adjacency, &min_var);
172        }
173
174        Ok(order)
175    }
176
177    /// Max-cardinality search: Greedy algorithm that produces good orderings.
178    ///
179    /// This algorithm iteratively selects variables with maximum cardinality
180    /// (number of already-selected neighbors).
181    fn max_cardinality_search(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
182        let mut remaining: HashSet<String> = vars.iter().cloned().collect();
183        let mut order = Vec::new();
184        let mut cardinality: HashMap<String, usize> = HashMap::new();
185
186        // Initialize cardinality to 0
187        for var in vars {
188            cardinality.insert(var.clone(), 0);
189        }
190
191        // Build adjacency graph
192        let adjacency = self.build_adjacency_graph(graph, &remaining)?;
193
194        while !remaining.is_empty() {
195            // Find variable with maximum cardinality
196            let max_var = remaining
197                .iter()
198                .max_by_key(|v| cardinality.get(*v).copied().unwrap_or(0))
199                .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
200                .clone();
201
202            order.push(max_var.clone());
203            remaining.remove(&max_var);
204
205            // Update cardinality of neighbors
206            if let Some(neighbors) = adjacency.get(&max_var) {
207                for neighbor in neighbors {
208                    if remaining.contains(neighbor) {
209                        *cardinality.entry(neighbor.clone()).or_insert(0) += 1;
210                    }
211                }
212            }
213        }
214
215        Ok(order)
216    }
217
218    /// Build adjacency graph from factor graph.
219    fn build_adjacency_graph(
220        &self,
221        graph: &FactorGraph,
222        vars: &HashSet<String>,
223    ) -> Result<HashMap<String, HashSet<String>>> {
224        let mut adjacency: HashMap<String, HashSet<String>> = HashMap::new();
225
226        // Initialize empty sets
227        for var in vars {
228            adjacency.insert(var.clone(), HashSet::new());
229        }
230
231        // Add edges based on factors
232        for factor_id in graph.factor_ids() {
233            if let Some(factor) = graph.get_factor(factor_id) {
234                let factor_vars: Vec<String> = factor
235                    .variables
236                    .iter()
237                    .filter(|v| vars.contains(*v))
238                    .cloned()
239                    .collect();
240
241                // Connect all pairs of variables in the factor
242                for i in 0..factor_vars.len() {
243                    for j in (i + 1)..factor_vars.len() {
244                        let v1 = &factor_vars[i];
245                        let v2 = &factor_vars[j];
246
247                        adjacency.entry(v1.clone()).or_default().insert(v2.clone());
248                        adjacency.entry(v2.clone()).or_default().insert(v1.clone());
249                    }
250                }
251            }
252        }
253
254        Ok(adjacency)
255    }
256
257    /// Compute fill for eliminating a variable.
258    ///
259    /// Fill is the number of new edges that would be created.
260    fn compute_fill(&self, adjacency: &HashMap<String, HashSet<String>>, var: &str) -> usize {
261        let neighbors = match adjacency.get(var) {
262            Some(n) => n,
263            None => return 0,
264        };
265
266        if neighbors.is_empty() {
267            return 0;
268        }
269
270        // Count pairs of neighbors that are not already connected
271        let mut fill = 0;
272        let neighbors_vec: Vec<_> = neighbors.iter().collect();
273
274        for i in 0..neighbors_vec.len() {
275            for j in (i + 1)..neighbors_vec.len() {
276                let v1 = neighbors_vec[i];
277                let v2 = neighbors_vec[j];
278
279                // Check if edge exists
280                if let Some(adj_v1) = adjacency.get(v1) {
281                    if !adj_v1.contains(v2) {
282                        fill += 1;
283                    }
284                }
285            }
286        }
287
288        fill
289    }
290
291    /// Update adjacency graph after eliminating a variable.
292    fn update_adjacency_after_elimination(
293        &self,
294        adjacency: &mut HashMap<String, HashSet<String>>,
295        var: &str,
296    ) {
297        let neighbors = match adjacency.remove(var) {
298            Some(n) => n,
299            None => return,
300        };
301
302        // Remove var from all neighbor lists
303        for neighbor in &neighbors {
304            if let Some(adj) = adjacency.get_mut(neighbor) {
305                adj.remove(var);
306            }
307        }
308
309        // Connect all pairs of neighbors (create fill edges)
310        let neighbors_vec: Vec<_> = neighbors.iter().cloned().collect();
311        for i in 0..neighbors_vec.len() {
312            for j in (i + 1)..neighbors_vec.len() {
313                let v1 = &neighbors_vec[i];
314                let v2 = &neighbors_vec[j];
315
316                // Add edge v1 <-> v2
317                if let Some(adj_v1) = adjacency.get_mut(v1) {
318                    adj_v1.insert(v2.clone());
319                }
320                if let Some(adj_v2) = adjacency.get_mut(v2) {
321                    adj_v2.insert(v1.clone());
322                }
323            }
324        }
325    }
326
327    /// Compute weights for variables based on factor sizes.
328    fn compute_variable_weights(
329        &self,
330        graph: &FactorGraph,
331        vars: &[String],
332    ) -> Result<HashMap<String, usize>> {
333        let mut weights = HashMap::new();
334
335        for var in vars {
336            let mut weight = 1;
337
338            if let Some(factors) = graph.get_adjacent_factors(var) {
339                for factor_id in factors {
340                    if let Some(factor) = graph.get_factor(factor_id) {
341                        // Weight by product of variable cardinalities
342                        for factor_var in &factor.variables {
343                            if let Some(var_node) = graph.get_variable(factor_var) {
344                                weight *= var_node.cardinality;
345                            }
346                        }
347                    }
348                }
349            }
350
351            weights.insert(var.clone(), weight);
352        }
353
354        Ok(weights)
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::Factor;
362    use scirs2_core::ndarray::Array;
363
364    fn create_test_graph() -> FactorGraph {
365        let mut graph = FactorGraph::new();
366
367        // Create a simple chain: X - Y - Z
368        graph.add_variable_with_card("X".to_string(), "Domain".to_string(), 2);
369        graph.add_variable_with_card("Y".to_string(), "Domain".to_string(), 2);
370        graph.add_variable_with_card("Z".to_string(), "Domain".to_string(), 2);
371
372        let f_xy = Factor::new(
373            "f_xy".to_string(),
374            vec!["X".to_string(), "Y".to_string()],
375            Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
376                .unwrap()
377                .into_dyn(),
378        )
379        .unwrap();
380
381        let f_yz = Factor::new(
382            "f_yz".to_string(),
383            vec!["Y".to_string(), "Z".to_string()],
384            Array::from_shape_vec(vec![2, 2], vec![0.5, 0.6, 0.7, 0.8])
385                .unwrap()
386                .into_dyn(),
387        )
388        .unwrap();
389
390        graph.add_factor(f_xy).unwrap();
391        graph.add_factor(f_yz).unwrap();
392
393        graph
394    }
395
396    #[test]
397    fn test_min_degree_ordering() {
398        let graph = create_test_graph();
399        let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
400
401        let ordering = EliminationOrdering::new(EliminationStrategy::MinDegree);
402        let order = ordering.compute_order(&graph, &vars).unwrap();
403
404        assert_eq!(order.len(), 3);
405        // X and Z have degree 1, Y has degree 2
406        assert!(order[0] == "X" || order[0] == "Z");
407    }
408
409    #[test]
410    fn test_min_fill_ordering() {
411        let graph = create_test_graph();
412        let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
413
414        let ordering = EliminationOrdering::new(EliminationStrategy::MinFill);
415        let order = ordering.compute_order(&graph, &vars).unwrap();
416
417        assert_eq!(order.len(), 3);
418    }
419
420    #[test]
421    fn test_weighted_min_fill_ordering() {
422        let graph = create_test_graph();
423        let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
424
425        let ordering = EliminationOrdering::new(EliminationStrategy::WeightedMinFill);
426        let order = ordering.compute_order(&graph, &vars).unwrap();
427
428        assert_eq!(order.len(), 3);
429    }
430
431    #[test]
432    fn test_max_cardinality_search() {
433        let graph = create_test_graph();
434        let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
435
436        let ordering = EliminationOrdering::new(EliminationStrategy::MaxCardinalitySearch);
437        let order = ordering.compute_order(&graph, &vars).unwrap();
438
439        assert_eq!(order.len(), 3);
440    }
441}