ruvector_dag/dag/
traversal.rs1use std::collections::{HashSet, VecDeque};
4
5use super::query_dag::{DagError, QueryDag};
6
7pub struct TopologicalIterator<'a> {
9 #[allow(dead_code)]
10 dag: &'a QueryDag,
11 sorted: Vec<usize>,
12 index: usize,
13}
14
15impl<'a> TopologicalIterator<'a> {
16 pub(crate) fn new(dag: &'a QueryDag) -> Result<Self, DagError> {
17 let sorted = dag.topological_sort()?;
18 Ok(Self {
19 dag,
20 sorted,
21 index: 0,
22 })
23 }
24}
25
26impl<'a> Iterator for TopologicalIterator<'a> {
27 type Item = usize;
28
29 fn next(&mut self) -> Option<Self::Item> {
30 if self.index < self.sorted.len() {
31 let id = self.sorted[self.index];
32 self.index += 1;
33 Some(id)
34 } else {
35 None
36 }
37 }
38}
39
40pub struct DfsIterator<'a> {
42 dag: &'a QueryDag,
43 stack: Vec<usize>,
44 visited: HashSet<usize>,
45}
46
47impl<'a> DfsIterator<'a> {
48 pub(crate) fn new(dag: &'a QueryDag, start: usize) -> Self {
49 let mut stack = Vec::new();
50 let visited = HashSet::new();
51
52 if dag.get_node(start).is_some() {
53 stack.push(start);
54 }
55
56 Self {
57 dag,
58 stack,
59 visited,
60 }
61 }
62}
63
64impl<'a> Iterator for DfsIterator<'a> {
65 type Item = usize;
66
67 fn next(&mut self) -> Option<Self::Item> {
68 while let Some(node) = self.stack.pop() {
69 if self.visited.insert(node) {
70 if let Some(children) = self.dag.edges.get(&node) {
72 for &child in children.iter().rev() {
73 if !self.visited.contains(&child) {
74 self.stack.push(child);
75 }
76 }
77 }
78 return Some(node);
79 }
80 }
81 None
82 }
83}
84
85pub struct BfsIterator<'a> {
87 dag: &'a QueryDag,
88 queue: VecDeque<usize>,
89 visited: HashSet<usize>,
90}
91
92impl<'a> BfsIterator<'a> {
93 pub(crate) fn new(dag: &'a QueryDag, start: usize) -> Self {
94 let mut queue = VecDeque::new();
95 let visited = HashSet::new();
96
97 if dag.get_node(start).is_some() {
98 queue.push_back(start);
99 }
100
101 Self {
102 dag,
103 queue,
104 visited,
105 }
106 }
107}
108
109impl<'a> Iterator for BfsIterator<'a> {
110 type Item = usize;
111
112 fn next(&mut self) -> Option<Self::Item> {
113 while let Some(node) = self.queue.pop_front() {
114 if self.visited.insert(node) {
115 if let Some(children) = self.dag.edges.get(&node) {
117 for &child in children {
118 if !self.visited.contains(&child) {
119 self.queue.push_back(child);
120 }
121 }
122 }
123 return Some(node);
124 }
125 }
126 None
127 }
128}
129
130impl QueryDag {
131 pub fn topological_iter(&self) -> Result<TopologicalIterator<'_>, DagError> {
133 TopologicalIterator::new(self)
134 }
135
136 pub fn dfs_iter(&self, start: usize) -> DfsIterator<'_> {
138 DfsIterator::new(self, start)
139 }
140
141 pub fn bfs_iter(&self, start: usize) -> BfsIterator<'_> {
143 BfsIterator::new(self, start)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::OperatorNode;
151
152 fn create_test_dag() -> QueryDag {
153 let mut dag = QueryDag::new();
154 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
155 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
156 let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
157 let id4 = dag.add_node(OperatorNode::limit(0, 10));
158
159 dag.add_edge(id1, id2).unwrap();
160 dag.add_edge(id2, id3).unwrap();
161 dag.add_edge(id3, id4).unwrap();
162
163 dag
164 }
165
166 #[test]
167 fn test_topological_iterator() {
168 let dag = create_test_dag();
169 let nodes: Vec<usize> = dag.topological_iter().unwrap().collect();
170
171 assert_eq!(nodes.len(), 4);
172
173 let pos: Vec<usize> = (0..4)
175 .map(|i| nodes.iter().position(|&x| x == i).unwrap())
176 .collect();
177
178 assert!(pos[0] < pos[1]); assert!(pos[1] < pos[2]); assert!(pos[2] < pos[3]); }
182
183 #[test]
184 fn test_dfs_iterator() {
185 let dag = create_test_dag();
186 let nodes: Vec<usize> = dag.dfs_iter(0).collect();
187
188 assert_eq!(nodes.len(), 4);
189 assert_eq!(nodes[0], 0); }
191
192 #[test]
193 fn test_bfs_iterator() {
194 let dag = create_test_dag();
195 let nodes: Vec<usize> = dag.bfs_iter(0).collect();
196
197 assert_eq!(nodes.len(), 4);
198 assert_eq!(nodes[0], 0); }
200
201 #[test]
202 fn test_branching_dag() {
203 let mut dag = QueryDag::new();
204 let root = dag.add_node(OperatorNode::seq_scan(0, "users"));
205 let left1 = dag.add_node(OperatorNode::filter(0, "age > 18"));
206 let left2 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
207 let right1 = dag.add_node(OperatorNode::filter(0, "active = true"));
208 let join = dag.add_node(OperatorNode::hash_join(0, "id"));
209
210 dag.add_edge(root, left1).unwrap();
211 dag.add_edge(left1, left2).unwrap();
212 dag.add_edge(root, right1).unwrap();
213 dag.add_edge(left2, join).unwrap();
214 dag.add_edge(right1, join).unwrap();
215
216 let bfs_nodes: Vec<usize> = dag.bfs_iter(root).collect();
218 assert_eq!(bfs_nodes.len(), 5);
219
220 let topo_nodes = dag.topological_sort().unwrap();
222 assert_eq!(topo_nodes.len(), 5);
223
224 let pos_root = topo_nodes.iter().position(|&x| x == root).unwrap();
225 let pos_join = topo_nodes.iter().position(|&x| x == join).unwrap();
226 assert!(pos_root < pos_join);
227 }
228}