tensorlogic_quantrs_hooks/loopy_bp/
cycle.rs1use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet, VecDeque};
5
6use crate::graph::FactorGraph;
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct CycleAnalysis {
11 pub has_cycles: bool,
13 pub girth: Option<usize>,
15 pub cycle_rank: usize,
17 pub is_tree: bool,
19 pub num_components: usize,
21}
22
23pub struct CycleDetector<'a> {
26 graph: &'a FactorGraph,
27}
28
29impl<'a> CycleDetector<'a> {
30 pub fn new(graph: &'a FactorGraph) -> Self {
32 Self { graph }
33 }
34
35 pub fn analyse(&self) -> CycleAnalysis {
37 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
40
41 for var_name in self.graph.variable_names() {
42 let v_node = format!("v:{}", var_name);
43 if let Some(factors) = self.graph.get_adjacent_factors(var_name) {
44 for f_id in factors {
45 let f_node = format!("f:{}", f_id);
46 adj.entry(v_node.clone()).or_default().push(f_node.clone());
47 adj.entry(f_node).or_default().push(v_node.clone());
48 }
49 }
50 }
51
52 if adj.is_empty() {
53 return CycleAnalysis {
54 has_cycles: false,
55 girth: None,
56 cycle_rank: 0,
57 is_tree: false,
58 num_components: 0,
59 };
60 }
61
62 let all_nodes: Vec<String> = adj.keys().cloned().collect();
63 let num_nodes = all_nodes.len();
64
65 let num_edges: usize = adj.values().map(|v| v.len()).sum::<usize>() / 2;
67
68 let mut visited: HashSet<String> = HashSet::new();
69 let mut num_components = 0usize;
70 let mut has_cycles = false;
71 let mut min_girth: Option<usize> = None;
72
73 for start in &all_nodes {
74 if visited.contains(start) {
75 continue;
76 }
77 num_components += 1;
78
79 let mut depth: HashMap<String, usize> = HashMap::new();
82 let mut parent: HashMap<String, Option<String>> = HashMap::new();
83 let mut queue: VecDeque<String> = VecDeque::new();
84
85 depth.insert(start.clone(), 0);
86 parent.insert(start.clone(), None);
87 queue.push_back(start.clone());
88 visited.insert(start.clone());
89
90 while let Some(cur) = queue.pop_front() {
91 let cur_depth = depth.get(&cur).copied().unwrap_or(0);
92 if let Some(neighbours) = adj.get(&cur) {
93 for nb in neighbours {
94 if !visited.contains(nb) {
95 visited.insert(nb.clone());
96 depth.insert(nb.clone(), cur_depth + 1);
97 parent.insert(nb.clone(), Some(cur.clone()));
98 queue.push_back(nb.clone());
99 } else {
100 let is_parent = parent
102 .get(&cur)
103 .and_then(|p| p.as_ref())
104 .map(|p| p == nb)
105 .unwrap_or(false);
106 if !is_parent {
107 has_cycles = true;
108 let cycle_len = cur_depth + depth.get(nb).copied().unwrap_or(0) + 1;
110 min_girth =
111 Some(min_girth.map(|g| g.min(cycle_len)).unwrap_or(cycle_len));
112 }
113 }
114 }
115 }
116 }
117 }
118
119 let cycle_rank = ((num_edges as isize) - (num_nodes as isize) + (num_components as isize))
123 .max(0) as usize;
124 let is_tree = num_components == 1 && !has_cycles;
125
126 CycleAnalysis {
127 has_cycles,
128 girth: min_girth,
129 cycle_rank,
130 is_tree,
131 num_components,
132 }
133 }
134}