1use std::collections::{HashMap, HashSet, VecDeque};
4
5use super::operator_node::OperatorNode;
6
7#[derive(Debug, thiserror::Error)]
9pub enum DagError {
10 #[error("Node {0} not found")]
11 NodeNotFound(usize),
12 #[error("Adding edge would create cycle")]
13 CycleDetected,
14 #[error("Invalid operation: {0}")]
15 InvalidOperation(String),
16 #[error("DAG has cycles, cannot perform topological sort")]
17 HasCycles,
18}
19
20#[derive(Debug, Clone)]
22pub struct QueryDag {
23 pub(crate) nodes: HashMap<usize, OperatorNode>,
24 pub(crate) edges: HashMap<usize, Vec<usize>>, pub(crate) reverse_edges: HashMap<usize, Vec<usize>>, pub(crate) root: Option<usize>,
27 next_id: usize,
28}
29
30impl QueryDag {
31 pub fn new() -> Self {
33 Self {
34 nodes: HashMap::new(),
35 edges: HashMap::new(),
36 reverse_edges: HashMap::new(),
37 root: None,
38 next_id: 0,
39 }
40 }
41
42 pub fn add_node(&mut self, mut node: OperatorNode) -> usize {
44 let id = self.next_id;
45 self.next_id += 1;
46 node.id = id;
47
48 self.nodes.insert(id, node);
49 self.edges.insert(id, Vec::new());
50 self.reverse_edges.insert(id, Vec::new());
51
52 if self.nodes.len() == 1 {
54 self.root = Some(id);
55 }
56
57 id
58 }
59
60 pub fn add_edge(&mut self, parent: usize, child: usize) -> Result<(), DagError> {
62 if !self.nodes.contains_key(&parent) {
64 return Err(DagError::NodeNotFound(parent));
65 }
66 if !self.nodes.contains_key(&child) {
67 return Err(DagError::NodeNotFound(child));
68 }
69
70 if self.would_create_cycle(parent, child) {
72 return Err(DagError::CycleDetected);
73 }
74
75 self.edges.get_mut(&parent).unwrap().push(child);
77 self.reverse_edges.get_mut(&child).unwrap().push(parent);
78
79 if self.root == Some(child) && !self.reverse_edges[&child].is_empty() {
81 self.root = self
83 .nodes
84 .keys()
85 .find(|&&id| self.reverse_edges[&id].is_empty())
86 .copied();
87 }
88
89 Ok(())
90 }
91
92 pub fn remove_node(&mut self, id: usize) -> Option<OperatorNode> {
94 let node = self.nodes.remove(&id)?;
95
96 if let Some(children) = self.edges.remove(&id) {
98 for child in children {
99 if let Some(parents) = self.reverse_edges.get_mut(&child) {
100 parents.retain(|&p| p != id);
101 }
102 }
103 }
104
105 if let Some(parents) = self.reverse_edges.remove(&id) {
106 for parent in parents {
107 if let Some(children) = self.edges.get_mut(&parent) {
108 children.retain(|&c| c != id);
109 }
110 }
111 }
112
113 if self.root == Some(id) {
115 self.root = self
116 .nodes
117 .keys()
118 .find(|&&nid| self.reverse_edges[&nid].is_empty())
119 .copied();
120 }
121
122 Some(node)
123 }
124
125 pub fn get_node(&self, id: usize) -> Option<&OperatorNode> {
127 self.nodes.get(&id)
128 }
129
130 pub fn get_node_mut(&mut self, id: usize) -> Option<&mut OperatorNode> {
132 self.nodes.get_mut(&id)
133 }
134
135 pub fn children(&self, id: usize) -> &[usize] {
137 self.edges.get(&id).map(|v| v.as_slice()).unwrap_or(&[])
138 }
139
140 pub fn parents(&self, id: usize) -> &[usize] {
142 self.reverse_edges
143 .get(&id)
144 .map(|v| v.as_slice())
145 .unwrap_or(&[])
146 }
147
148 pub fn root(&self) -> Option<usize> {
150 self.root
151 }
152
153 pub fn leaves(&self) -> Vec<usize> {
155 self.nodes
156 .keys()
157 .filter(|&&id| self.edges[&id].is_empty())
158 .copied()
159 .collect()
160 }
161
162 pub fn node_count(&self) -> usize {
164 self.nodes.len()
165 }
166
167 pub fn edge_count(&self) -> usize {
169 self.edges.values().map(|v| v.len()).sum()
170 }
171
172 pub fn node_ids(&self) -> impl Iterator<Item = usize> + '_ {
174 self.nodes.keys().copied()
175 }
176
177 pub fn nodes(&self) -> impl Iterator<Item = &OperatorNode> + '_ {
179 self.nodes.values()
180 }
181
182 fn would_create_cycle(&self, from: usize, to: usize) -> bool {
184 self.can_reach(to, from)
186 }
187
188 fn can_reach(&self, from: usize, to: usize) -> bool {
190 if from == to {
191 return true;
192 }
193
194 let mut visited = HashSet::new();
195 let mut queue = VecDeque::new();
196 queue.push_back(from);
197 visited.insert(from);
198
199 while let Some(current) = queue.pop_front() {
200 if current == to {
201 return true;
202 }
203
204 if let Some(children) = self.edges.get(¤t) {
205 for &child in children {
206 if visited.insert(child) {
207 queue.push_back(child);
208 }
209 }
210 }
211 }
212
213 false
214 }
215
216 pub fn compute_depths(&self) -> HashMap<usize, usize> {
218 let mut depths = HashMap::new();
219 let mut visited = HashSet::new();
220
221 let leaves = self.leaves();
223 let mut queue: VecDeque<(usize, usize)> = leaves.iter().map(|&id| (id, 0)).collect();
224
225 for &leaf in &leaves {
226 visited.insert(leaf);
227 depths.insert(leaf, 0);
228 }
229
230 while let Some((node, depth)) = queue.pop_front() {
231 depths.insert(node, depth);
232
233 if let Some(parents) = self.reverse_edges.get(&node) {
235 for &parent in parents {
236 if visited.insert(parent) {
237 queue.push_back((parent, depth + 1));
238 } else {
239 let current_depth = depths.get(&parent).copied().unwrap_or(0);
241 if depth + 1 > current_depth {
242 depths.insert(parent, depth + 1);
243 queue.push_back((parent, depth + 1));
244 }
245 }
246 }
247 }
248 }
249
250 depths
251 }
252
253 pub fn ancestors(&self, id: usize) -> HashSet<usize> {
255 let mut result = HashSet::new();
256 let mut queue = VecDeque::new();
257
258 if let Some(parents) = self.reverse_edges.get(&id) {
259 for &parent in parents {
260 queue.push_back(parent);
261 result.insert(parent);
262 }
263 }
264
265 while let Some(node) = queue.pop_front() {
266 if let Some(parents) = self.reverse_edges.get(&node) {
267 for &parent in parents {
268 if result.insert(parent) {
269 queue.push_back(parent);
270 }
271 }
272 }
273 }
274
275 result
276 }
277
278 pub fn descendants(&self, id: usize) -> HashSet<usize> {
280 let mut result = HashSet::new();
281 let mut queue = VecDeque::new();
282
283 if let Some(children) = self.edges.get(&id) {
284 for &child in children {
285 queue.push_back(child);
286 result.insert(child);
287 }
288 }
289
290 while let Some(node) = queue.pop_front() {
291 if let Some(children) = self.edges.get(&node) {
292 for &child in children {
293 if result.insert(child) {
294 queue.push_back(child);
295 }
296 }
297 }
298 }
299
300 result
301 }
302
303 pub fn topological_sort(&self) -> Result<Vec<usize>, DagError> {
305 let mut result = Vec::new();
306 let mut in_degree: HashMap<usize, usize> = self
307 .nodes
308 .keys()
309 .map(|&id| (id, self.reverse_edges[&id].len()))
310 .collect();
311
312 let mut queue: VecDeque<usize> = in_degree
313 .iter()
314 .filter(|(_, °ree)| degree == 0)
315 .map(|(&id, _)| id)
316 .collect();
317
318 while let Some(node) = queue.pop_front() {
319 result.push(node);
320
321 if let Some(children) = self.edges.get(&node) {
322 for &child in children {
323 let degree = in_degree.get_mut(&child).unwrap();
324 *degree -= 1;
325 if *degree == 0 {
326 queue.push_back(child);
327 }
328 }
329 }
330 }
331
332 if result.len() != self.nodes.len() {
333 return Err(DagError::HasCycles);
334 }
335
336 Ok(result)
337 }
338}
339
340impl Default for QueryDag {
341 fn default() -> Self {
342 Self::new()
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::OperatorNode;
350
351 #[test]
352 fn test_new_dag() {
353 let dag = QueryDag::new();
354 assert_eq!(dag.node_count(), 0);
355 assert_eq!(dag.edge_count(), 0);
356 }
357
358 #[test]
359 fn test_add_nodes() {
360 let mut dag = QueryDag::new();
361 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
362 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
363
364 assert_eq!(dag.node_count(), 2);
365 assert!(dag.get_node(id1).is_some());
366 assert!(dag.get_node(id2).is_some());
367 }
368
369 #[test]
370 fn test_add_edges() {
371 let mut dag = QueryDag::new();
372 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
373 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
374
375 assert!(dag.add_edge(id1, id2).is_ok());
376 assert_eq!(dag.edge_count(), 1);
377 assert_eq!(dag.children(id1), &[id2]);
378 assert_eq!(dag.parents(id2), &[id1]);
379 }
380
381 #[test]
382 fn test_cycle_detection() {
383 let mut dag = QueryDag::new();
384 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
385 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
386 let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
387
388 dag.add_edge(id1, id2).unwrap();
389 dag.add_edge(id2, id3).unwrap();
390
391 assert!(matches!(
393 dag.add_edge(id3, id1),
394 Err(DagError::CycleDetected)
395 ));
396 }
397
398 #[test]
399 fn test_topological_sort() {
400 let mut dag = QueryDag::new();
401 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
402 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
403 let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
404
405 dag.add_edge(id1, id2).unwrap();
406 dag.add_edge(id2, id3).unwrap();
407
408 let sorted = dag.topological_sort().unwrap();
409 assert_eq!(sorted.len(), 3);
410
411 let pos1 = sorted.iter().position(|&x| x == id1).unwrap();
413 let pos2 = sorted.iter().position(|&x| x == id2).unwrap();
414 let pos3 = sorted.iter().position(|&x| x == id3).unwrap();
415
416 assert!(pos1 < pos2);
417 assert!(pos2 < pos3);
418 }
419
420 #[test]
421 fn test_remove_node() {
422 let mut dag = QueryDag::new();
423 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
424 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
425
426 dag.add_edge(id1, id2).unwrap();
427
428 let removed = dag.remove_node(id1);
429 assert!(removed.is_some());
430 assert_eq!(dag.node_count(), 1);
431 assert_eq!(dag.edge_count(), 0);
432 }
433
434 #[test]
435 fn test_ancestors_descendants() {
436 let mut dag = QueryDag::new();
437 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
438 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
439 let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
440
441 dag.add_edge(id1, id2).unwrap();
442 dag.add_edge(id2, id3).unwrap();
443
444 let ancestors = dag.ancestors(id3);
445 assert!(ancestors.contains(&id1));
446 assert!(ancestors.contains(&id2));
447
448 let descendants = dag.descendants(id1);
449 assert!(descendants.contains(&id2));
450 assert!(descendants.contains(&id3));
451 }
452}