sql_cli/query_plan/
query_plan.rs

1use crate::sql::parser::ast::{SelectStatement, SqlExpression, WhereClause};
2use std::collections::{HashMap, HashSet};
3
4/// Represents a unit of work in the query execution pipeline
5#[derive(Debug, Clone)]
6pub struct WorkUnit {
7    /// Unique identifier for this work unit
8    pub id: String,
9
10    /// Type of work this unit performs
11    pub work_type: WorkUnitType,
12
13    /// SQL expression or statement to execute
14    pub expression: WorkUnitExpression,
15
16    /// Dependencies - IDs of work units that must complete before this one
17    pub dependencies: Vec<String>,
18
19    /// Whether this unit can be executed in parallel with siblings
20    pub parallelizable: bool,
21
22    /// Cost estimate for query optimization
23    pub cost_estimate: Option<f64>,
24}
25
26/// Types of work units in the execution pipeline
27#[derive(Debug, Clone, PartialEq)]
28pub enum WorkUnitType {
29    /// Base table scan
30    TableScan,
31
32    /// CTE definition
33    CTE,
34
35    /// Filter operation (WHERE clause)
36    Filter,
37
38    /// Aggregation (GROUP BY)
39    Aggregate,
40
41    /// Sorting (ORDER BY)
42    Sort,
43
44    /// Join operation
45    Join,
46
47    /// Window function computation
48    Window,
49
50    /// Expression evaluation
51    Expression,
52
53    /// Final projection (SELECT)
54    Projection,
55}
56
57/// Expression or statement in a work unit
58#[derive(Debug, Clone)]
59pub enum WorkUnitExpression {
60    /// Full SELECT statement (for CTEs)
61    Select(SelectStatement),
62
63    /// Single expression (for filters, projections)
64    Expression(SqlExpression),
65
66    /// WHERE clause
67    WhereClause(WhereClause),
68
69    /// Table name for base scans
70    TableName(String),
71
72    /// Custom operation
73    Custom(String),
74}
75
76/// Complete query execution plan
77#[derive(Debug)]
78pub struct QueryPlan {
79    /// All work units in the plan
80    pub units: Vec<WorkUnit>,
81
82    /// Dependency graph for determining execution order
83    pub dependency_graph: DependencyGraph,
84
85    /// Estimated total cost
86    pub total_cost: Option<f64>,
87
88    /// Original query for reference
89    pub original_query: String,
90
91    /// Metadata about the plan
92    pub metadata: PlanMetadata,
93}
94
95impl QueryPlan {
96    /// Create a new empty query plan
97    pub fn new(original_query: String) -> Self {
98        QueryPlan {
99            units: Vec::new(),
100            dependency_graph: DependencyGraph::new(),
101            original_query,
102            total_cost: None,
103            metadata: PlanMetadata::default(),
104        }
105    }
106
107    /// Add a work unit to the plan
108    pub fn add_unit(&mut self, unit: WorkUnit) {
109        // Add to dependency graph
110        for dep in &unit.dependencies {
111            self.dependency_graph.add_edge(dep.clone(), unit.id.clone());
112        }
113
114        // Store the unit
115        self.units.push(unit);
116    }
117
118    /// Get execution order respecting dependencies
119    pub fn get_execution_order(&self) -> Result<Vec<String>, String> {
120        self.dependency_graph.topological_sort()
121    }
122
123    /// Get units that can be executed in parallel
124    pub fn get_parallel_groups(&self) -> Vec<Vec<String>> {
125        self.dependency_graph.get_parallel_groups()
126    }
127
128    /// Optimize the plan (placeholder for future optimization logic)
129    pub fn optimize(&mut self) -> Result<(), String> {
130        // Future: implement cost-based optimization
131        // - Reorder operations when possible
132        // - Push down filters
133        // - Merge adjacent operations
134        Ok(())
135    }
136
137    /// Generate a human-readable representation of the plan
138    pub fn explain(&self) -> String {
139        let mut output = String::new();
140        output.push_str("Query Execution Plan:\n");
141        output.push_str("====================\n\n");
142
143        // Show execution order
144        match self.get_execution_order() {
145            Ok(order) => {
146                output.push_str("Execution Order:\n");
147                for (i, unit_id) in order.iter().enumerate() {
148                    if let Some(unit) = self.units.iter().find(|u| u.id == *unit_id) {
149                        output.push_str(&format!(
150                            "  {}. {} ({:?})\n",
151                            i + 1,
152                            unit.id,
153                            unit.work_type
154                        ));
155
156                        if !unit.dependencies.is_empty() {
157                            output.push_str(&format!(
158                                "     Dependencies: {}\n",
159                                unit.dependencies.join(", ")
160                            ));
161                        }
162
163                        if unit.parallelizable {
164                            output.push_str("     [Parallelizable]\n");
165                        }
166                    }
167                }
168            }
169            Err(e) => {
170                output.push_str(&format!("Error determining execution order: {}\n", e));
171            }
172        }
173
174        // Show parallel groups
175        output.push_str("\nParallel Execution Groups:\n");
176        for (i, group) in self.get_parallel_groups().iter().enumerate() {
177            output.push_str(&format!("  Group {}: {}\n", i + 1, group.join(", ")));
178        }
179
180        output
181    }
182}
183
184/// Metadata about the query plan
185#[derive(Debug, Default)]
186pub struct PlanMetadata {
187    /// Whether CTEs were lifted from WHERE clause
188    pub has_lifted_expressions: bool,
189
190    /// Number of parallel execution opportunities
191    pub parallel_opportunities: usize,
192
193    /// Estimated row count
194    pub estimated_rows: Option<usize>,
195
196    /// Planning time in milliseconds
197    pub planning_time_ms: Option<u64>,
198}
199
200/// Dependency graph for work units
201#[derive(Debug)]
202pub struct DependencyGraph {
203    /// Adjacency list representation
204    edges: HashMap<String, HashSet<String>>,
205
206    /// All nodes in the graph
207    nodes: HashSet<String>,
208}
209
210impl DependencyGraph {
211    /// Create a new empty dependency graph
212    pub fn new() -> Self {
213        DependencyGraph {
214            edges: HashMap::new(),
215            nodes: HashSet::new(),
216        }
217    }
218
219    /// Add an edge from source to target (source must complete before target)
220    pub fn add_edge(&mut self, source: String, target: String) {
221        self.nodes.insert(source.clone());
222        self.nodes.insert(target.clone());
223
224        self.edges
225            .entry(source)
226            .or_insert_with(HashSet::new)
227            .insert(target);
228    }
229
230    /// Perform topological sort to get valid execution order
231    pub fn topological_sort(&self) -> Result<Vec<String>, String> {
232        let mut in_degree: HashMap<String, usize> = HashMap::new();
233        let mut result = Vec::new();
234
235        // Initialize in-degrees
236        for node in &self.nodes {
237            in_degree.insert(node.clone(), 0);
238        }
239
240        // Calculate in-degrees
241        for (_, targets) in &self.edges {
242            for target in targets {
243                *in_degree.get_mut(target).unwrap() += 1;
244            }
245        }
246
247        // Find nodes with no dependencies
248        let mut queue: Vec<String> = in_degree
249            .iter()
250            .filter(|(_, &degree)| degree == 0)
251            .map(|(node, _)| node.clone())
252            .collect();
253
254        // Process nodes
255        while !queue.is_empty() {
256            let node = queue.remove(0);
257            result.push(node.clone());
258
259            // Update in-degrees of dependent nodes
260            if let Some(targets) = self.edges.get(&node) {
261                for target in targets {
262                    let degree = in_degree.get_mut(target).unwrap();
263                    *degree -= 1;
264                    if *degree == 0 {
265                        queue.push(target.clone());
266                    }
267                }
268            }
269        }
270
271        // Check for cycles
272        if result.len() != self.nodes.len() {
273            return Err("Dependency cycle detected in query plan".to_string());
274        }
275
276        Ok(result)
277    }
278
279    /// Get groups of units that can be executed in parallel
280    pub fn get_parallel_groups(&self) -> Vec<Vec<String>> {
281        let mut groups = Vec::new();
282        let mut remaining = self.nodes.clone();
283        let mut completed = HashSet::new();
284
285        while !remaining.is_empty() {
286            let mut current_group = Vec::new();
287
288            // Find all nodes whose dependencies are satisfied
289            for node in &remaining {
290                let deps_satisfied = self
291                    .edges
292                    .iter()
293                    .filter(|(_, targets)| targets.contains(node))
294                    .all(|(source, _)| completed.contains(source));
295
296                if deps_satisfied {
297                    current_group.push(node.clone());
298                }
299            }
300
301            // If no nodes can be executed, we have a problem
302            if current_group.is_empty() && !remaining.is_empty() {
303                // This shouldn't happen if topological sort succeeds
304                break;
305            }
306
307            // Mark these nodes as completed
308            for node in &current_group {
309                completed.insert(node.clone());
310                remaining.remove(node);
311            }
312
313            if !current_group.is_empty() {
314                groups.push(current_group);
315            }
316        }
317
318        groups
319    }
320
321    /// Check if the graph has cycles
322    pub fn has_cycles(&self) -> bool {
323        self.topological_sort().is_err()
324    }
325}
326
327/// Query analyzer that builds execution plans
328pub struct QueryAnalyzer {
329    /// Counter for generating unique work unit IDs
330    unit_counter: usize,
331}
332
333impl QueryAnalyzer {
334    /// Create a new query analyzer
335    pub fn new() -> Self {
336        QueryAnalyzer { unit_counter: 0 }
337    }
338
339    /// Generate a unique ID for a work unit
340    fn next_unit_id(&mut self, prefix: &str) -> String {
341        self.unit_counter += 1;
342        format!("{}_{}", prefix, self.unit_counter)
343    }
344
345    /// Analyze a SELECT statement and build an execution plan
346    pub fn analyze(&mut self, stmt: &SelectStatement, query: String) -> Result<QueryPlan, String> {
347        let mut plan = QueryPlan::new(query);
348
349        // Phase 1: Add base table scan
350        let table_unit = WorkUnit {
351            id: self.next_unit_id("scan"),
352            work_type: WorkUnitType::TableScan,
353            expression: WorkUnitExpression::TableName(
354                stmt.from_table
355                    .clone()
356                    .unwrap_or_else(|| "unknown".to_string()),
357            ),
358            dependencies: Vec::new(),
359            parallelizable: false,
360            cost_estimate: None,
361        };
362        let table_id = table_unit.id.clone();
363        plan.add_unit(table_unit);
364
365        // Phase 2: Analyze WHERE clause for liftable expressions
366        let mut filter_id = None;
367        if let Some(ref where_clause) = stmt.where_clause {
368            // TODO: Implement expression lifting logic here
369            // For now, just add as a simple filter
370            let filter_unit = WorkUnit {
371                id: self.next_unit_id("filter"),
372                work_type: WorkUnitType::Filter,
373                expression: WorkUnitExpression::WhereClause(where_clause.clone()),
374                dependencies: vec![table_id.clone()],
375                parallelizable: false,
376                cost_estimate: None,
377            };
378            filter_id = Some(filter_unit.id.clone());
379            plan.add_unit(filter_unit);
380        }
381
382        // Phase 3: Handle GROUP BY
383        let mut group_id = None;
384        if stmt.group_by.as_ref().map_or(false, |g| !g.is_empty()) {
385            let dependencies = vec![filter_id.clone().unwrap_or(table_id.clone())];
386            let group_unit = WorkUnit {
387                id: self.next_unit_id("group"),
388                work_type: WorkUnitType::Aggregate,
389                expression: WorkUnitExpression::Custom("GROUP BY".to_string()),
390                dependencies,
391                parallelizable: false,
392                cost_estimate: None,
393            };
394            group_id = Some(group_unit.id.clone());
395            plan.add_unit(group_unit);
396        }
397
398        // Phase 4: Handle ORDER BY
399        let mut sort_id = None;
400        if stmt.order_by.as_ref().map_or(false, |o| !o.is_empty()) {
401            let dependencies = vec![group_id
402                .clone()
403                .or(filter_id.clone())
404                .unwrap_or(table_id.clone())];
405            let sort_unit = WorkUnit {
406                id: self.next_unit_id("sort"),
407                work_type: WorkUnitType::Sort,
408                expression: WorkUnitExpression::Custom("ORDER BY".to_string()),
409                dependencies,
410                parallelizable: false,
411                cost_estimate: None,
412            };
413            sort_id = Some(sort_unit.id.clone());
414            plan.add_unit(sort_unit);
415        }
416
417        // Phase 5: Final projection
418        let dependencies = vec![sort_id.or(group_id).or(filter_id).unwrap_or(table_id)];
419        let projection_unit = WorkUnit {
420            id: self.next_unit_id("project"),
421            work_type: WorkUnitType::Projection,
422            expression: WorkUnitExpression::Custom("SELECT".to_string()),
423            dependencies,
424            parallelizable: false,
425            cost_estimate: None,
426        };
427        plan.add_unit(projection_unit);
428
429        // Optimize the plan
430        plan.optimize()?;
431
432        Ok(plan)
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn test_dependency_graph() {
442        let mut graph = DependencyGraph::new();
443
444        // Create a simple DAG
445        graph.add_edge("A".to_string(), "B".to_string());
446        graph.add_edge("A".to_string(), "C".to_string());
447        graph.add_edge("B".to_string(), "D".to_string());
448        graph.add_edge("C".to_string(), "D".to_string());
449
450        // Test topological sort
451        let order = graph.topological_sort().unwrap();
452        assert_eq!(order.len(), 4);
453
454        // A should come before B and C
455        let a_pos = order.iter().position(|x| x == "A").unwrap();
456        let b_pos = order.iter().position(|x| x == "B").unwrap();
457        let c_pos = order.iter().position(|x| x == "C").unwrap();
458        let d_pos = order.iter().position(|x| x == "D").unwrap();
459
460        assert!(a_pos < b_pos);
461        assert!(a_pos < c_pos);
462        assert!(b_pos < d_pos);
463        assert!(c_pos < d_pos);
464    }
465
466    #[test]
467    fn test_cycle_detection() {
468        let mut graph = DependencyGraph::new();
469
470        // Create a cycle
471        graph.add_edge("A".to_string(), "B".to_string());
472        graph.add_edge("B".to_string(), "C".to_string());
473        graph.add_edge("C".to_string(), "A".to_string());
474
475        assert!(graph.has_cycles());
476    }
477
478    #[test]
479    fn test_parallel_groups() {
480        let mut graph = DependencyGraph::new();
481
482        // Create independent branches
483        graph.add_edge("A".to_string(), "B".to_string());
484        graph.add_edge("A".to_string(), "C".to_string());
485        graph.add_edge("B".to_string(), "D".to_string());
486        graph.add_edge("C".to_string(), "E".to_string());
487
488        let groups = graph.get_parallel_groups();
489
490        // A should be alone, B and C can be parallel, D and E can be parallel
491        assert!(groups.len() >= 3);
492    }
493}