Skip to main content

ruvector_dag/dag/
query_dag.rs

1//! Core query DAG data structure
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use super::operator_node::OperatorNode;
6
7/// Error types for DAG operations
8#[derive(Debug, thiserror::Error)]
9pub enum DagError {
10    #[error("Node {0} not found")]
11    NodeNotFound(usize),
12    #[error("Adding edge would create cycle")]
13    CycleDetected,
14    #[error("Invalid operation: {0}")]
15    InvalidOperation(String),
16    #[error("DAG has cycles, cannot perform topological sort")]
17    HasCycles,
18}
19
20/// A Directed Acyclic Graph representing a query plan
21#[derive(Debug, Clone)]
22pub struct QueryDag {
23    pub(crate) nodes: HashMap<usize, OperatorNode>,
24    pub(crate) edges: HashMap<usize, Vec<usize>>, // parent -> children
25    pub(crate) reverse_edges: HashMap<usize, Vec<usize>>, // child -> parents
26    pub(crate) root: Option<usize>,
27    next_id: usize,
28}
29
30impl QueryDag {
31    /// Create a new empty DAG
32    pub fn new() -> Self {
33        Self {
34            nodes: HashMap::new(),
35            edges: HashMap::new(),
36            reverse_edges: HashMap::new(),
37            root: None,
38            next_id: 0,
39        }
40    }
41
42    /// Add a node to the DAG, returns the node ID
43    pub fn add_node(&mut self, mut node: OperatorNode) -> usize {
44        let id = self.next_id;
45        self.next_id += 1;
46        node.id = id;
47
48        self.nodes.insert(id, node);
49        self.edges.insert(id, Vec::new());
50        self.reverse_edges.insert(id, Vec::new());
51
52        // If this is the first node, set it as root
53        if self.nodes.len() == 1 {
54            self.root = Some(id);
55        }
56
57        id
58    }
59
60    /// Add an edge from parent to child
61    pub fn add_edge(&mut self, parent: usize, child: usize) -> Result<(), DagError> {
62        // Check both nodes exist
63        if !self.nodes.contains_key(&parent) {
64            return Err(DagError::NodeNotFound(parent));
65        }
66        if !self.nodes.contains_key(&child) {
67            return Err(DagError::NodeNotFound(child));
68        }
69
70        // Check if adding this edge would create a cycle
71        if self.would_create_cycle(parent, child) {
72            return Err(DagError::CycleDetected);
73        }
74
75        // Add edge
76        self.edges.get_mut(&parent).unwrap().push(child);
77        self.reverse_edges.get_mut(&child).unwrap().push(parent);
78
79        // Update root if child was previously root and now has parents
80        if self.root == Some(child) && !self.reverse_edges[&child].is_empty() {
81            // Find new root (node with no parents)
82            self.root = self
83                .nodes
84                .keys()
85                .find(|&&id| self.reverse_edges[&id].is_empty())
86                .copied();
87        }
88
89        Ok(())
90    }
91
92    /// Remove a node from the DAG
93    pub fn remove_node(&mut self, id: usize) -> Option<OperatorNode> {
94        let node = self.nodes.remove(&id)?;
95
96        // Remove all edges involving this node
97        if let Some(children) = self.edges.remove(&id) {
98            for child in children {
99                if let Some(parents) = self.reverse_edges.get_mut(&child) {
100                    parents.retain(|&p| p != id);
101                }
102            }
103        }
104
105        if let Some(parents) = self.reverse_edges.remove(&id) {
106            for parent in parents {
107                if let Some(children) = self.edges.get_mut(&parent) {
108                    children.retain(|&c| c != id);
109                }
110            }
111        }
112
113        // Update root if necessary
114        if self.root == Some(id) {
115            self.root = self
116                .nodes
117                .keys()
118                .find(|&&nid| self.reverse_edges[&nid].is_empty())
119                .copied();
120        }
121
122        Some(node)
123    }
124
125    /// Get a reference to a node
126    pub fn get_node(&self, id: usize) -> Option<&OperatorNode> {
127        self.nodes.get(&id)
128    }
129
130    /// Get a mutable reference to a node
131    pub fn get_node_mut(&mut self, id: usize) -> Option<&mut OperatorNode> {
132        self.nodes.get_mut(&id)
133    }
134
135    /// Get children of a node
136    pub fn children(&self, id: usize) -> &[usize] {
137        self.edges.get(&id).map(|v| v.as_slice()).unwrap_or(&[])
138    }
139
140    /// Get parents of a node
141    pub fn parents(&self, id: usize) -> &[usize] {
142        self.reverse_edges
143            .get(&id)
144            .map(|v| v.as_slice())
145            .unwrap_or(&[])
146    }
147
148    /// Get the root node ID
149    pub fn root(&self) -> Option<usize> {
150        self.root
151    }
152
153    /// Get all leaf nodes (nodes with no children)
154    pub fn leaves(&self) -> Vec<usize> {
155        self.nodes
156            .keys()
157            .filter(|&&id| self.edges[&id].is_empty())
158            .copied()
159            .collect()
160    }
161
162    /// Get the number of nodes
163    pub fn node_count(&self) -> usize {
164        self.nodes.len()
165    }
166
167    /// Get the number of edges
168    pub fn edge_count(&self) -> usize {
169        self.edges.values().map(|v| v.len()).sum()
170    }
171
172    /// Get iterator over node IDs
173    pub fn node_ids(&self) -> impl Iterator<Item = usize> + '_ {
174        self.nodes.keys().copied()
175    }
176
177    /// Get iterator over all nodes
178    pub fn nodes(&self) -> impl Iterator<Item = &OperatorNode> + '_ {
179        self.nodes.values()
180    }
181
182    /// Check if adding an edge would create a cycle
183    fn would_create_cycle(&self, from: usize, to: usize) -> bool {
184        // If 'to' can reach 'from', adding edge from->to would create cycle
185        self.can_reach(to, from)
186    }
187
188    /// Check if 'from' can reach 'to' through existing edges
189    fn can_reach(&self, from: usize, to: usize) -> bool {
190        if from == to {
191            return true;
192        }
193
194        let mut visited = HashSet::new();
195        let mut queue = VecDeque::new();
196        queue.push_back(from);
197        visited.insert(from);
198
199        while let Some(current) = queue.pop_front() {
200            if current == to {
201                return true;
202            }
203
204            if let Some(children) = self.edges.get(&current) {
205                for &child in children {
206                    if visited.insert(child) {
207                        queue.push_back(child);
208                    }
209                }
210            }
211        }
212
213        false
214    }
215
216    /// Compute depth of each node from leaves (leaves have depth 0)
217    pub fn compute_depths(&self) -> HashMap<usize, usize> {
218        let mut depths = HashMap::new();
219        let mut visited = HashSet::new();
220
221        // Start from leaves
222        let leaves = self.leaves();
223        let mut queue: VecDeque<(usize, usize)> = leaves.iter().map(|&id| (id, 0)).collect();
224
225        for &leaf in &leaves {
226            visited.insert(leaf);
227            depths.insert(leaf, 0);
228        }
229
230        while let Some((node, depth)) = queue.pop_front() {
231            depths.insert(node, depth);
232
233            // Process parents
234            if let Some(parents) = self.reverse_edges.get(&node) {
235                for &parent in parents {
236                    if visited.insert(parent) {
237                        queue.push_back((parent, depth + 1));
238                    } else {
239                        // Update depth if we found a longer path
240                        let current_depth = depths.get(&parent).copied().unwrap_or(0);
241                        if depth + 1 > current_depth {
242                            depths.insert(parent, depth + 1);
243                            queue.push_back((parent, depth + 1));
244                        }
245                    }
246                }
247            }
248        }
249
250        depths
251    }
252
253    /// Get all ancestors of a node
254    pub fn ancestors(&self, id: usize) -> HashSet<usize> {
255        let mut result = HashSet::new();
256        let mut queue = VecDeque::new();
257
258        if let Some(parents) = self.reverse_edges.get(&id) {
259            for &parent in parents {
260                queue.push_back(parent);
261                result.insert(parent);
262            }
263        }
264
265        while let Some(node) = queue.pop_front() {
266            if let Some(parents) = self.reverse_edges.get(&node) {
267                for &parent in parents {
268                    if result.insert(parent) {
269                        queue.push_back(parent);
270                    }
271                }
272            }
273        }
274
275        result
276    }
277
278    /// Get all descendants of a node
279    pub fn descendants(&self, id: usize) -> HashSet<usize> {
280        let mut result = HashSet::new();
281        let mut queue = VecDeque::new();
282
283        if let Some(children) = self.edges.get(&id) {
284            for &child in children {
285                queue.push_back(child);
286                result.insert(child);
287            }
288        }
289
290        while let Some(node) = queue.pop_front() {
291            if let Some(children) = self.edges.get(&node) {
292                for &child in children {
293                    if result.insert(child) {
294                        queue.push_back(child);
295                    }
296                }
297            }
298        }
299
300        result
301    }
302
303    /// Return nodes in topological order as Vec (dependencies first)
304    pub fn topological_sort(&self) -> Result<Vec<usize>, DagError> {
305        let mut result = Vec::new();
306        let mut in_degree: HashMap<usize, usize> = self
307            .nodes
308            .keys()
309            .map(|&id| (id, self.reverse_edges[&id].len()))
310            .collect();
311
312        let mut queue: VecDeque<usize> = in_degree
313            .iter()
314            .filter(|(_, &degree)| degree == 0)
315            .map(|(&id, _)| id)
316            .collect();
317
318        while let Some(node) = queue.pop_front() {
319            result.push(node);
320
321            if let Some(children) = self.edges.get(&node) {
322                for &child in children {
323                    let degree = in_degree.get_mut(&child).unwrap();
324                    *degree -= 1;
325                    if *degree == 0 {
326                        queue.push_back(child);
327                    }
328                }
329            }
330        }
331
332        if result.len() != self.nodes.len() {
333            return Err(DagError::HasCycles);
334        }
335
336        Ok(result)
337    }
338}
339
340impl Default for QueryDag {
341    fn default() -> Self {
342        Self::new()
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::OperatorNode;
350
351    #[test]
352    fn test_new_dag() {
353        let dag = QueryDag::new();
354        assert_eq!(dag.node_count(), 0);
355        assert_eq!(dag.edge_count(), 0);
356    }
357
358    #[test]
359    fn test_add_nodes() {
360        let mut dag = QueryDag::new();
361        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
362        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
363
364        assert_eq!(dag.node_count(), 2);
365        assert!(dag.get_node(id1).is_some());
366        assert!(dag.get_node(id2).is_some());
367    }
368
369    #[test]
370    fn test_add_edges() {
371        let mut dag = QueryDag::new();
372        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
373        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
374
375        assert!(dag.add_edge(id1, id2).is_ok());
376        assert_eq!(dag.edge_count(), 1);
377        assert_eq!(dag.children(id1), &[id2]);
378        assert_eq!(dag.parents(id2), &[id1]);
379    }
380
381    #[test]
382    fn test_cycle_detection() {
383        let mut dag = QueryDag::new();
384        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
385        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
386        let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
387
388        dag.add_edge(id1, id2).unwrap();
389        dag.add_edge(id2, id3).unwrap();
390
391        // This would create a cycle
392        assert!(matches!(
393            dag.add_edge(id3, id1),
394            Err(DagError::CycleDetected)
395        ));
396    }
397
398    #[test]
399    fn test_topological_sort() {
400        let mut dag = QueryDag::new();
401        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
402        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
403        let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
404
405        dag.add_edge(id1, id2).unwrap();
406        dag.add_edge(id2, id3).unwrap();
407
408        let sorted = dag.topological_sort().unwrap();
409        assert_eq!(sorted.len(), 3);
410
411        // id1 should come before id2, id2 before id3
412        let pos1 = sorted.iter().position(|&x| x == id1).unwrap();
413        let pos2 = sorted.iter().position(|&x| x == id2).unwrap();
414        let pos3 = sorted.iter().position(|&x| x == id3).unwrap();
415
416        assert!(pos1 < pos2);
417        assert!(pos2 < pos3);
418    }
419
420    #[test]
421    fn test_remove_node() {
422        let mut dag = QueryDag::new();
423        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
424        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
425
426        dag.add_edge(id1, id2).unwrap();
427
428        let removed = dag.remove_node(id1);
429        assert!(removed.is_some());
430        assert_eq!(dag.node_count(), 1);
431        assert_eq!(dag.edge_count(), 0);
432    }
433
434    #[test]
435    fn test_ancestors_descendants() {
436        let mut dag = QueryDag::new();
437        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
438        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
439        let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
440
441        dag.add_edge(id1, id2).unwrap();
442        dag.add_edge(id2, id3).unwrap();
443
444        let ancestors = dag.ancestors(id3);
445        assert!(ancestors.contains(&id1));
446        assert!(ancestors.contains(&id2));
447
448        let descendants = dag.descendants(id1);
449        assert!(descendants.contains(&id2));
450        assert!(descendants.contains(&id3));
451    }
452}