Skip to main content

ruvector_dag/dag/
traversal.rs

1//! DAG traversal algorithms and iterators
2
3use std::collections::{HashSet, VecDeque};
4
5use super::query_dag::{DagError, QueryDag};
6
7/// Iterator for topological order traversal (dependencies first)
8pub struct TopologicalIterator<'a> {
9    #[allow(dead_code)]
10    dag: &'a QueryDag,
11    sorted: Vec<usize>,
12    index: usize,
13}
14
15impl<'a> TopologicalIterator<'a> {
16    pub(crate) fn new(dag: &'a QueryDag) -> Result<Self, DagError> {
17        let sorted = dag.topological_sort()?;
18        Ok(Self {
19            dag,
20            sorted,
21            index: 0,
22        })
23    }
24}
25
26impl<'a> Iterator for TopologicalIterator<'a> {
27    type Item = usize;
28
29    fn next(&mut self) -> Option<Self::Item> {
30        if self.index < self.sorted.len() {
31            let id = self.sorted[self.index];
32            self.index += 1;
33            Some(id)
34        } else {
35            None
36        }
37    }
38}
39
40/// Iterator for depth-first search traversal
41pub struct DfsIterator<'a> {
42    dag: &'a QueryDag,
43    stack: Vec<usize>,
44    visited: HashSet<usize>,
45}
46
47impl<'a> DfsIterator<'a> {
48    pub(crate) fn new(dag: &'a QueryDag, start: usize) -> Self {
49        let mut stack = Vec::new();
50        let visited = HashSet::new();
51
52        if dag.get_node(start).is_some() {
53            stack.push(start);
54        }
55
56        Self {
57            dag,
58            stack,
59            visited,
60        }
61    }
62}
63
64impl<'a> Iterator for DfsIterator<'a> {
65    type Item = usize;
66
67    fn next(&mut self) -> Option<Self::Item> {
68        while let Some(node) = self.stack.pop() {
69            if self.visited.insert(node) {
70                // Add children to stack (in reverse order so they're processed in order)
71                if let Some(children) = self.dag.edges.get(&node) {
72                    for &child in children.iter().rev() {
73                        if !self.visited.contains(&child) {
74                            self.stack.push(child);
75                        }
76                    }
77                }
78                return Some(node);
79            }
80        }
81        None
82    }
83}
84
85/// Iterator for breadth-first search traversal
86pub struct BfsIterator<'a> {
87    dag: &'a QueryDag,
88    queue: VecDeque<usize>,
89    visited: HashSet<usize>,
90}
91
92impl<'a> BfsIterator<'a> {
93    pub(crate) fn new(dag: &'a QueryDag, start: usize) -> Self {
94        let mut queue = VecDeque::new();
95        let visited = HashSet::new();
96
97        if dag.get_node(start).is_some() {
98            queue.push_back(start);
99        }
100
101        Self {
102            dag,
103            queue,
104            visited,
105        }
106    }
107}
108
109impl<'a> Iterator for BfsIterator<'a> {
110    type Item = usize;
111
112    fn next(&mut self) -> Option<Self::Item> {
113        while let Some(node) = self.queue.pop_front() {
114            if self.visited.insert(node) {
115                // Add children to queue
116                if let Some(children) = self.dag.edges.get(&node) {
117                    for &child in children {
118                        if !self.visited.contains(&child) {
119                            self.queue.push_back(child);
120                        }
121                    }
122                }
123                return Some(node);
124            }
125        }
126        None
127    }
128}
129
130impl QueryDag {
131    /// Create an iterator for topological order traversal
132    pub fn topological_iter(&self) -> Result<TopologicalIterator<'_>, DagError> {
133        TopologicalIterator::new(self)
134    }
135
136    /// Create an iterator for depth-first search starting from a node
137    pub fn dfs_iter(&self, start: usize) -> DfsIterator<'_> {
138        DfsIterator::new(self, start)
139    }
140
141    /// Create an iterator for breadth-first search starting from a node
142    pub fn bfs_iter(&self, start: usize) -> BfsIterator<'_> {
143        BfsIterator::new(self, start)
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::OperatorNode;
151
152    fn create_test_dag() -> QueryDag {
153        let mut dag = QueryDag::new();
154        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
155        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
156        let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
157        let id4 = dag.add_node(OperatorNode::limit(0, 10));
158
159        dag.add_edge(id1, id2).unwrap();
160        dag.add_edge(id2, id3).unwrap();
161        dag.add_edge(id3, id4).unwrap();
162
163        dag
164    }
165
166    #[test]
167    fn test_topological_iterator() {
168        let dag = create_test_dag();
169        let nodes: Vec<usize> = dag.topological_iter().unwrap().collect();
170
171        assert_eq!(nodes.len(), 4);
172
173        // Check ordering constraints
174        let pos: Vec<usize> = (0..4)
175            .map(|i| nodes.iter().position(|&x| x == i).unwrap())
176            .collect();
177
178        assert!(pos[0] < pos[1]); // 0 before 1
179        assert!(pos[1] < pos[2]); // 1 before 2
180        assert!(pos[2] < pos[3]); // 2 before 3
181    }
182
183    #[test]
184    fn test_dfs_iterator() {
185        let dag = create_test_dag();
186        let nodes: Vec<usize> = dag.dfs_iter(0).collect();
187
188        assert_eq!(nodes.len(), 4);
189        assert_eq!(nodes[0], 0); // Should start from node 0
190    }
191
192    #[test]
193    fn test_bfs_iterator() {
194        let dag = create_test_dag();
195        let nodes: Vec<usize> = dag.bfs_iter(0).collect();
196
197        assert_eq!(nodes.len(), 4);
198        assert_eq!(nodes[0], 0); // Should start from node 0
199    }
200
201    #[test]
202    fn test_branching_dag() {
203        let mut dag = QueryDag::new();
204        let root = dag.add_node(OperatorNode::seq_scan(0, "users"));
205        let left1 = dag.add_node(OperatorNode::filter(0, "age > 18"));
206        let left2 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
207        let right1 = dag.add_node(OperatorNode::filter(0, "active = true"));
208        let join = dag.add_node(OperatorNode::hash_join(0, "id"));
209
210        dag.add_edge(root, left1).unwrap();
211        dag.add_edge(left1, left2).unwrap();
212        dag.add_edge(root, right1).unwrap();
213        dag.add_edge(left2, join).unwrap();
214        dag.add_edge(right1, join).unwrap();
215
216        // BFS should visit level by level
217        let bfs_nodes: Vec<usize> = dag.bfs_iter(root).collect();
218        assert_eq!(bfs_nodes.len(), 5);
219
220        // Topological sort should respect dependencies
221        let topo_nodes = dag.topological_sort().unwrap();
222        assert_eq!(topo_nodes.len(), 5);
223
224        let pos_root = topo_nodes.iter().position(|&x| x == root).unwrap();
225        let pos_join = topo_nodes.iter().position(|&x| x == join).unwrap();
226        assert!(pos_root < pos_join);
227    }
228}