tensorlogic_quantrs_hooks/
elimination_ordering.rs1use crate::error::{PgmError, Result};
8use crate::graph::FactorGraph;
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum EliminationStrategy {
14 #[default]
16 MinDegree,
17 MinFill,
19 WeightedMinFill,
21 MinWidth,
23 MaxCardinalitySearch,
25}
26
27pub struct EliminationOrdering {
29 strategy: EliminationStrategy,
30}
31
32impl Default for EliminationOrdering {
33 fn default() -> Self {
34 Self::new(EliminationStrategy::default())
35 }
36}
37
38impl EliminationOrdering {
39 pub fn new(strategy: EliminationStrategy) -> Self {
41 Self { strategy }
42 }
43
44 pub fn compute_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
46 match self.strategy {
47 EliminationStrategy::MinDegree => self.min_degree_order(graph, vars),
48 EliminationStrategy::MinFill => self.min_fill_order(graph, vars),
49 EliminationStrategy::WeightedMinFill => self.weighted_min_fill_order(graph, vars),
50 EliminationStrategy::MinWidth => self.min_width_order(graph, vars),
51 EliminationStrategy::MaxCardinalitySearch => self.max_cardinality_search(graph, vars),
52 }
53 }
54
55 fn min_degree_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
59 let mut remaining: HashSet<String> = vars.iter().cloned().collect();
60 let mut order = Vec::new();
61
62 let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
64
65 while !remaining.is_empty() {
66 let min_var = remaining
68 .iter()
69 .min_by_key(|v| adjacency.get(*v).map(|s| s.len()).unwrap_or(0))
70 .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
71 .clone();
72
73 order.push(min_var.clone());
74 remaining.remove(&min_var);
75
76 self.update_adjacency_after_elimination(&mut adjacency, &min_var);
78 }
79
80 Ok(order)
81 }
82
83 fn min_fill_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
88 let mut remaining: HashSet<String> = vars.iter().cloned().collect();
89 let mut order = Vec::new();
90
91 let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
93
94 while !remaining.is_empty() {
95 let min_var = remaining
97 .iter()
98 .min_by_key(|v| self.compute_fill(&adjacency, v))
99 .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
100 .clone();
101
102 order.push(min_var.clone());
103 remaining.remove(&min_var);
104
105 self.update_adjacency_after_elimination(&mut adjacency, &min_var);
107 }
108
109 Ok(order)
110 }
111
112 fn weighted_min_fill_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
117 let mut remaining: HashSet<String> = vars.iter().cloned().collect();
118 let mut order = Vec::new();
119
120 let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
122 let weights = self.compute_variable_weights(graph, vars)?;
123
124 while !remaining.is_empty() {
125 let min_var = remaining
127 .iter()
128 .min_by_key(|v| {
129 let fill = self.compute_fill(&adjacency, v);
130 let weight = weights.get(*v).copied().unwrap_or(1);
131 fill * weight
132 })
133 .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
134 .clone();
135
136 order.push(min_var.clone());
137 remaining.remove(&min_var);
138
139 self.update_adjacency_after_elimination(&mut adjacency, &min_var);
141 }
142
143 Ok(order)
144 }
145
146 fn min_width_order(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
150 let mut remaining: HashSet<String> = vars.iter().cloned().collect();
151 let mut order = Vec::new();
152
153 let mut adjacency = self.build_adjacency_graph(graph, &remaining)?;
155
156 while !remaining.is_empty() {
157 let min_var = remaining
159 .iter()
160 .min_by_key(|v| {
161 let neighbors = adjacency.get(*v).map(|s| s.len()).unwrap_or(0);
162 neighbors
163 })
164 .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
165 .clone();
166
167 order.push(min_var.clone());
168 remaining.remove(&min_var);
169
170 self.update_adjacency_after_elimination(&mut adjacency, &min_var);
172 }
173
174 Ok(order)
175 }
176
177 fn max_cardinality_search(&self, graph: &FactorGraph, vars: &[String]) -> Result<Vec<String>> {
182 let mut remaining: HashSet<String> = vars.iter().cloned().collect();
183 let mut order = Vec::new();
184 let mut cardinality: HashMap<String, usize> = HashMap::new();
185
186 for var in vars {
188 cardinality.insert(var.clone(), 0);
189 }
190
191 let adjacency = self.build_adjacency_graph(graph, &remaining)?;
193
194 while !remaining.is_empty() {
195 let max_var = remaining
197 .iter()
198 .max_by_key(|v| cardinality.get(*v).copied().unwrap_or(0))
199 .ok_or_else(|| PgmError::InvalidGraph("No variables to eliminate".to_string()))?
200 .clone();
201
202 order.push(max_var.clone());
203 remaining.remove(&max_var);
204
205 if let Some(neighbors) = adjacency.get(&max_var) {
207 for neighbor in neighbors {
208 if remaining.contains(neighbor) {
209 *cardinality.entry(neighbor.clone()).or_insert(0) += 1;
210 }
211 }
212 }
213 }
214
215 Ok(order)
216 }
217
218 fn build_adjacency_graph(
220 &self,
221 graph: &FactorGraph,
222 vars: &HashSet<String>,
223 ) -> Result<HashMap<String, HashSet<String>>> {
224 let mut adjacency: HashMap<String, HashSet<String>> = HashMap::new();
225
226 for var in vars {
228 adjacency.insert(var.clone(), HashSet::new());
229 }
230
231 for factor_id in graph.factor_ids() {
233 if let Some(factor) = graph.get_factor(factor_id) {
234 let factor_vars: Vec<String> = factor
235 .variables
236 .iter()
237 .filter(|v| vars.contains(*v))
238 .cloned()
239 .collect();
240
241 for i in 0..factor_vars.len() {
243 for j in (i + 1)..factor_vars.len() {
244 let v1 = &factor_vars[i];
245 let v2 = &factor_vars[j];
246
247 adjacency.entry(v1.clone()).or_default().insert(v2.clone());
248 adjacency.entry(v2.clone()).or_default().insert(v1.clone());
249 }
250 }
251 }
252 }
253
254 Ok(adjacency)
255 }
256
257 fn compute_fill(&self, adjacency: &HashMap<String, HashSet<String>>, var: &str) -> usize {
261 let neighbors = match adjacency.get(var) {
262 Some(n) => n,
263 None => return 0,
264 };
265
266 if neighbors.is_empty() {
267 return 0;
268 }
269
270 let mut fill = 0;
272 let neighbors_vec: Vec<_> = neighbors.iter().collect();
273
274 for i in 0..neighbors_vec.len() {
275 for j in (i + 1)..neighbors_vec.len() {
276 let v1 = neighbors_vec[i];
277 let v2 = neighbors_vec[j];
278
279 if let Some(adj_v1) = adjacency.get(v1) {
281 if !adj_v1.contains(v2) {
282 fill += 1;
283 }
284 }
285 }
286 }
287
288 fill
289 }
290
291 fn update_adjacency_after_elimination(
293 &self,
294 adjacency: &mut HashMap<String, HashSet<String>>,
295 var: &str,
296 ) {
297 let neighbors = match adjacency.remove(var) {
298 Some(n) => n,
299 None => return,
300 };
301
302 for neighbor in &neighbors {
304 if let Some(adj) = adjacency.get_mut(neighbor) {
305 adj.remove(var);
306 }
307 }
308
309 let neighbors_vec: Vec<_> = neighbors.iter().cloned().collect();
311 for i in 0..neighbors_vec.len() {
312 for j in (i + 1)..neighbors_vec.len() {
313 let v1 = &neighbors_vec[i];
314 let v2 = &neighbors_vec[j];
315
316 if let Some(adj_v1) = adjacency.get_mut(v1) {
318 adj_v1.insert(v2.clone());
319 }
320 if let Some(adj_v2) = adjacency.get_mut(v2) {
321 adj_v2.insert(v1.clone());
322 }
323 }
324 }
325 }
326
327 fn compute_variable_weights(
329 &self,
330 graph: &FactorGraph,
331 vars: &[String],
332 ) -> Result<HashMap<String, usize>> {
333 let mut weights = HashMap::new();
334
335 for var in vars {
336 let mut weight = 1;
337
338 if let Some(factors) = graph.get_adjacent_factors(var) {
339 for factor_id in factors {
340 if let Some(factor) = graph.get_factor(factor_id) {
341 for factor_var in &factor.variables {
343 if let Some(var_node) = graph.get_variable(factor_var) {
344 weight *= var_node.cardinality;
345 }
346 }
347 }
348 }
349 }
350
351 weights.insert(var.clone(), weight);
352 }
353
354 Ok(weights)
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::Factor;
362 use scirs2_core::ndarray::Array;
363
364 fn create_test_graph() -> FactorGraph {
365 let mut graph = FactorGraph::new();
366
367 graph.add_variable_with_card("X".to_string(), "Domain".to_string(), 2);
369 graph.add_variable_with_card("Y".to_string(), "Domain".to_string(), 2);
370 graph.add_variable_with_card("Z".to_string(), "Domain".to_string(), 2);
371
372 let f_xy = Factor::new(
373 "f_xy".to_string(),
374 vec!["X".to_string(), "Y".to_string()],
375 Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
376 .unwrap()
377 .into_dyn(),
378 )
379 .unwrap();
380
381 let f_yz = Factor::new(
382 "f_yz".to_string(),
383 vec!["Y".to_string(), "Z".to_string()],
384 Array::from_shape_vec(vec![2, 2], vec![0.5, 0.6, 0.7, 0.8])
385 .unwrap()
386 .into_dyn(),
387 )
388 .unwrap();
389
390 graph.add_factor(f_xy).unwrap();
391 graph.add_factor(f_yz).unwrap();
392
393 graph
394 }
395
396 #[test]
397 fn test_min_degree_ordering() {
398 let graph = create_test_graph();
399 let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
400
401 let ordering = EliminationOrdering::new(EliminationStrategy::MinDegree);
402 let order = ordering.compute_order(&graph, &vars).unwrap();
403
404 assert_eq!(order.len(), 3);
405 assert!(order[0] == "X" || order[0] == "Z");
407 }
408
409 #[test]
410 fn test_min_fill_ordering() {
411 let graph = create_test_graph();
412 let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
413
414 let ordering = EliminationOrdering::new(EliminationStrategy::MinFill);
415 let order = ordering.compute_order(&graph, &vars).unwrap();
416
417 assert_eq!(order.len(), 3);
418 }
419
420 #[test]
421 fn test_weighted_min_fill_ordering() {
422 let graph = create_test_graph();
423 let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
424
425 let ordering = EliminationOrdering::new(EliminationStrategy::WeightedMinFill);
426 let order = ordering.compute_order(&graph, &vars).unwrap();
427
428 assert_eq!(order.len(), 3);
429 }
430
431 #[test]
432 fn test_max_cardinality_search() {
433 let graph = create_test_graph();
434 let vars = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
435
436 let ordering = EliminationOrdering::new(EliminationStrategy::MaxCardinalitySearch);
437 let order = ordering.compute_order(&graph, &vars).unwrap();
438
439 assert_eq!(order.len(), 3);
440 }
441}