1use crate::sql::parser::ast::{SelectStatement, SqlExpression, WhereClause};
2use std::collections::{HashMap, HashSet};
3
4#[derive(Debug, Clone)]
6pub struct WorkUnit {
7 pub id: String,
9
10 pub work_type: WorkUnitType,
12
13 pub expression: WorkUnitExpression,
15
16 pub dependencies: Vec<String>,
18
19 pub parallelizable: bool,
21
22 pub cost_estimate: Option<f64>,
24}
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum WorkUnitType {
29 TableScan,
31
32 CTE,
34
35 Filter,
37
38 Aggregate,
40
41 Sort,
43
44 Join,
46
47 Window,
49
50 Expression,
52
53 Projection,
55}
56
57#[derive(Debug, Clone)]
59pub enum WorkUnitExpression {
60 Select(SelectStatement),
62
63 Expression(SqlExpression),
65
66 WhereClause(WhereClause),
68
69 TableName(String),
71
72 Custom(String),
74}
75
76#[derive(Debug)]
78pub struct QueryPlan {
79 pub units: Vec<WorkUnit>,
81
82 pub dependency_graph: DependencyGraph,
84
85 pub total_cost: Option<f64>,
87
88 pub original_query: String,
90
91 pub metadata: PlanMetadata,
93}
94
95impl QueryPlan {
96 pub fn new(original_query: String) -> Self {
98 QueryPlan {
99 units: Vec::new(),
100 dependency_graph: DependencyGraph::new(),
101 original_query,
102 total_cost: None,
103 metadata: PlanMetadata::default(),
104 }
105 }
106
107 pub fn add_unit(&mut self, unit: WorkUnit) {
109 for dep in &unit.dependencies {
111 self.dependency_graph.add_edge(dep.clone(), unit.id.clone());
112 }
113
114 self.units.push(unit);
116 }
117
118 pub fn get_execution_order(&self) -> Result<Vec<String>, String> {
120 self.dependency_graph.topological_sort()
121 }
122
123 pub fn get_parallel_groups(&self) -> Vec<Vec<String>> {
125 self.dependency_graph.get_parallel_groups()
126 }
127
128 pub fn optimize(&mut self) -> Result<(), String> {
130 Ok(())
135 }
136
137 pub fn explain(&self) -> String {
139 let mut output = String::new();
140 output.push_str("Query Execution Plan:\n");
141 output.push_str("====================\n\n");
142
143 match self.get_execution_order() {
145 Ok(order) => {
146 output.push_str("Execution Order:\n");
147 for (i, unit_id) in order.iter().enumerate() {
148 if let Some(unit) = self.units.iter().find(|u| u.id == *unit_id) {
149 output.push_str(&format!(
150 " {}. {} ({:?})\n",
151 i + 1,
152 unit.id,
153 unit.work_type
154 ));
155
156 if !unit.dependencies.is_empty() {
157 output.push_str(&format!(
158 " Dependencies: {}\n",
159 unit.dependencies.join(", ")
160 ));
161 }
162
163 if unit.parallelizable {
164 output.push_str(" [Parallelizable]\n");
165 }
166 }
167 }
168 }
169 Err(e) => {
170 output.push_str(&format!("Error determining execution order: {}\n", e));
171 }
172 }
173
174 output.push_str("\nParallel Execution Groups:\n");
176 for (i, group) in self.get_parallel_groups().iter().enumerate() {
177 output.push_str(&format!(" Group {}: {}\n", i + 1, group.join(", ")));
178 }
179
180 output
181 }
182}
183
184#[derive(Debug, Default)]
186pub struct PlanMetadata {
187 pub has_lifted_expressions: bool,
189
190 pub parallel_opportunities: usize,
192
193 pub estimated_rows: Option<usize>,
195
196 pub planning_time_ms: Option<u64>,
198}
199
200#[derive(Debug)]
202pub struct DependencyGraph {
203 edges: HashMap<String, HashSet<String>>,
205
206 nodes: HashSet<String>,
208}
209
210impl DependencyGraph {
211 pub fn new() -> Self {
213 DependencyGraph {
214 edges: HashMap::new(),
215 nodes: HashSet::new(),
216 }
217 }
218
219 pub fn add_edge(&mut self, source: String, target: String) {
221 self.nodes.insert(source.clone());
222 self.nodes.insert(target.clone());
223
224 self.edges
225 .entry(source)
226 .or_insert_with(HashSet::new)
227 .insert(target);
228 }
229
230 pub fn topological_sort(&self) -> Result<Vec<String>, String> {
232 let mut in_degree: HashMap<String, usize> = HashMap::new();
233 let mut result = Vec::new();
234
235 for node in &self.nodes {
237 in_degree.insert(node.clone(), 0);
238 }
239
240 for (_, targets) in &self.edges {
242 for target in targets {
243 *in_degree.get_mut(target).unwrap() += 1;
244 }
245 }
246
247 let mut queue: Vec<String> = in_degree
249 .iter()
250 .filter(|(_, °ree)| degree == 0)
251 .map(|(node, _)| node.clone())
252 .collect();
253
254 while !queue.is_empty() {
256 let node = queue.remove(0);
257 result.push(node.clone());
258
259 if let Some(targets) = self.edges.get(&node) {
261 for target in targets {
262 let degree = in_degree.get_mut(target).unwrap();
263 *degree -= 1;
264 if *degree == 0 {
265 queue.push(target.clone());
266 }
267 }
268 }
269 }
270
271 if result.len() != self.nodes.len() {
273 return Err("Dependency cycle detected in query plan".to_string());
274 }
275
276 Ok(result)
277 }
278
279 pub fn get_parallel_groups(&self) -> Vec<Vec<String>> {
281 let mut groups = Vec::new();
282 let mut remaining = self.nodes.clone();
283 let mut completed = HashSet::new();
284
285 while !remaining.is_empty() {
286 let mut current_group = Vec::new();
287
288 for node in &remaining {
290 let deps_satisfied = self
291 .edges
292 .iter()
293 .filter(|(_, targets)| targets.contains(node))
294 .all(|(source, _)| completed.contains(source));
295
296 if deps_satisfied {
297 current_group.push(node.clone());
298 }
299 }
300
301 if current_group.is_empty() && !remaining.is_empty() {
303 break;
305 }
306
307 for node in ¤t_group {
309 completed.insert(node.clone());
310 remaining.remove(node);
311 }
312
313 if !current_group.is_empty() {
314 groups.push(current_group);
315 }
316 }
317
318 groups
319 }
320
321 pub fn has_cycles(&self) -> bool {
323 self.topological_sort().is_err()
324 }
325}
326
327pub struct QueryAnalyzer {
329 unit_counter: usize,
331}
332
333impl QueryAnalyzer {
334 pub fn new() -> Self {
336 QueryAnalyzer { unit_counter: 0 }
337 }
338
339 fn next_unit_id(&mut self, prefix: &str) -> String {
341 self.unit_counter += 1;
342 format!("{}_{}", prefix, self.unit_counter)
343 }
344
345 pub fn analyze(&mut self, stmt: &SelectStatement, query: String) -> Result<QueryPlan, String> {
347 let mut plan = QueryPlan::new(query);
348
349 let table_unit = WorkUnit {
351 id: self.next_unit_id("scan"),
352 work_type: WorkUnitType::TableScan,
353 expression: WorkUnitExpression::TableName(
354 stmt.from_table
355 .clone()
356 .unwrap_or_else(|| "unknown".to_string()),
357 ),
358 dependencies: Vec::new(),
359 parallelizable: false,
360 cost_estimate: None,
361 };
362 let table_id = table_unit.id.clone();
363 plan.add_unit(table_unit);
364
365 let mut filter_id = None;
367 if let Some(ref where_clause) = stmt.where_clause {
368 let filter_unit = WorkUnit {
371 id: self.next_unit_id("filter"),
372 work_type: WorkUnitType::Filter,
373 expression: WorkUnitExpression::WhereClause(where_clause.clone()),
374 dependencies: vec![table_id.clone()],
375 parallelizable: false,
376 cost_estimate: None,
377 };
378 filter_id = Some(filter_unit.id.clone());
379 plan.add_unit(filter_unit);
380 }
381
382 let mut group_id = None;
384 if stmt.group_by.as_ref().map_or(false, |g| !g.is_empty()) {
385 let dependencies = vec![filter_id.clone().unwrap_or(table_id.clone())];
386 let group_unit = WorkUnit {
387 id: self.next_unit_id("group"),
388 work_type: WorkUnitType::Aggregate,
389 expression: WorkUnitExpression::Custom("GROUP BY".to_string()),
390 dependencies,
391 parallelizable: false,
392 cost_estimate: None,
393 };
394 group_id = Some(group_unit.id.clone());
395 plan.add_unit(group_unit);
396 }
397
398 let mut sort_id = None;
400 if stmt.order_by.as_ref().map_or(false, |o| !o.is_empty()) {
401 let dependencies = vec![group_id
402 .clone()
403 .or(filter_id.clone())
404 .unwrap_or(table_id.clone())];
405 let sort_unit = WorkUnit {
406 id: self.next_unit_id("sort"),
407 work_type: WorkUnitType::Sort,
408 expression: WorkUnitExpression::Custom("ORDER BY".to_string()),
409 dependencies,
410 parallelizable: false,
411 cost_estimate: None,
412 };
413 sort_id = Some(sort_unit.id.clone());
414 plan.add_unit(sort_unit);
415 }
416
417 let dependencies = vec![sort_id.or(group_id).or(filter_id).unwrap_or(table_id)];
419 let projection_unit = WorkUnit {
420 id: self.next_unit_id("project"),
421 work_type: WorkUnitType::Projection,
422 expression: WorkUnitExpression::Custom("SELECT".to_string()),
423 dependencies,
424 parallelizable: false,
425 cost_estimate: None,
426 };
427 plan.add_unit(projection_unit);
428
429 plan.optimize()?;
431
432 Ok(plan)
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_dependency_graph() {
442 let mut graph = DependencyGraph::new();
443
444 graph.add_edge("A".to_string(), "B".to_string());
446 graph.add_edge("A".to_string(), "C".to_string());
447 graph.add_edge("B".to_string(), "D".to_string());
448 graph.add_edge("C".to_string(), "D".to_string());
449
450 let order = graph.topological_sort().unwrap();
452 assert_eq!(order.len(), 4);
453
454 let a_pos = order.iter().position(|x| x == "A").unwrap();
456 let b_pos = order.iter().position(|x| x == "B").unwrap();
457 let c_pos = order.iter().position(|x| x == "C").unwrap();
458 let d_pos = order.iter().position(|x| x == "D").unwrap();
459
460 assert!(a_pos < b_pos);
461 assert!(a_pos < c_pos);
462 assert!(b_pos < d_pos);
463 assert!(c_pos < d_pos);
464 }
465
466 #[test]
467 fn test_cycle_detection() {
468 let mut graph = DependencyGraph::new();
469
470 graph.add_edge("A".to_string(), "B".to_string());
472 graph.add_edge("B".to_string(), "C".to_string());
473 graph.add_edge("C".to_string(), "A".to_string());
474
475 assert!(graph.has_cycles());
476 }
477
478 #[test]
479 fn test_parallel_groups() {
480 let mut graph = DependencyGraph::new();
481
482 graph.add_edge("A".to_string(), "B".to_string());
484 graph.add_edge("A".to_string(), "C".to_string());
485 graph.add_edge("B".to_string(), "D".to_string());
486 graph.add_edge("C".to_string(), "E".to_string());
487
488 let groups = graph.get_parallel_groups();
489
490 assert!(groups.len() >= 3);
492 }
493}