1use crate::token::{BoolOp, Source};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12pub struct Node {
13 pub id: u16,
15 pub op: BoolOp,
17 pub left: Source,
19 pub right: Source,
21}
22
23impl Node {
24 pub fn new(id: u16, op: BoolOp, left: Source, right: Source) -> Self {
26 Self {
27 id,
28 op,
29 left,
30 right,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Default, PartialEq, Eq)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct Graph {
47 pub inputs: Vec<u16>,
49 pub nodes: Vec<Node>,
51 pub outputs: Vec<u16>,
53}
54
55impl Graph {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn node_count(&self) -> usize {
63 self.nodes.len()
64 }
65
66 pub fn input_count(&self) -> usize {
68 self.inputs.len()
69 }
70
71 pub fn output_count(&self) -> usize {
73 self.outputs.len()
74 }
75
76 pub fn depth(&self) -> usize {
82 if self.nodes.is_empty() {
83 return 0;
84 }
85
86 let mut depths = std::collections::HashMap::new();
90
91 for &id in &self.inputs {
93 depths.insert(id, 0usize);
94 }
95
96 for node in &self.nodes {
98 let left_depth = self.source_depth(&node.left, &depths);
99 let right_depth = self.source_depth(&node.right, &depths);
100 let node_depth = 1 + left_depth.max(right_depth);
101 depths.insert(node.id, node_depth);
102 }
103
104 self.outputs
106 .iter()
107 .filter_map(|id| depths.get(id).copied())
108 .max()
109 .unwrap_or(0)
110 }
111
112 fn source_depth(
114 &self,
115 source: &Source,
116 depths: &std::collections::HashMap<u16, usize>,
117 ) -> usize {
118 match source {
119 Source::Id(id) => depths.get(id).copied().unwrap_or(0),
120 Source::True | Source::False => 0,
121 }
122 }
123
124 pub fn is_input(&self, id: u16) -> bool {
126 self.inputs.contains(&id)
127 }
128
129 pub fn get_node(&self, id: u16) -> Option<&Node> {
131 self.nodes.iter().find(|n| n.id == id)
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_empty_graph_depth() {
141 let graph = Graph::new();
142 assert_eq!(graph.depth(), 0);
143 }
144
145 #[test]
146 fn test_single_node_depth() {
147 let graph = Graph {
148 inputs: vec![0, 1],
149 nodes: vec![Node::new(2, BoolOp::Or, Source::Id(0), Source::Id(1))],
150 outputs: vec![2],
151 };
152 assert_eq!(graph.depth(), 1);
153 }
154
155 #[test]
156 fn test_chain_depth() {
157 let graph = Graph {
159 inputs: vec![0],
160 nodes: vec![
161 Node::new(1, BoolOp::Or, Source::Id(0), Source::True),
162 Node::new(2, BoolOp::Or, Source::Id(1), Source::False),
163 ],
164 outputs: vec![2],
165 };
166 assert_eq!(graph.depth(), 2);
167 }
168
169 #[test]
170 fn test_parallel_depth() {
171 let graph = Graph {
173 inputs: vec![0, 1],
174 nodes: vec![
175 Node::new(2, BoolOp::Or, Source::Id(0), Source::True),
176 Node::new(3, BoolOp::Or, Source::Id(1), Source::False),
177 ],
178 outputs: vec![2, 3],
179 };
180 assert_eq!(graph.depth(), 1);
181 }
182}