1use anyhow::{anyhow, Result};
5use std::collections::{HashMap, HashSet, VecDeque};
6use tracing::{debug, info};
7
8use crate::sql::recursive_parser::{Parser, SelectStatement, SqlExpression, TableSource};
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct TableReferences {
13 pub reads: Vec<String>,
15 pub writes: Vec<String>,
17}
18
19impl TableReferences {
20 fn new() -> Self {
21 Self {
22 reads: Vec::new(),
23 writes: Vec::new(),
24 }
25 }
26
27 fn add_read(&mut self, table: String) {
28 if !self.reads.contains(&table) {
29 self.reads.push(table);
30 }
31 }
32
33 fn add_write(&mut self, table: String) {
34 if !self.writes.contains(&table) {
35 self.writes.push(table);
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct DependencyStatement {
43 pub number: usize,
45 pub sql: String,
47 pub references: TableReferences,
49 pub creates_temp_table: bool,
51}
52
53#[derive(Debug, Clone)]
55pub struct ExecutionPlan {
56 pub statements_to_execute: Vec<usize>,
58 pub statements_to_skip: Vec<usize>,
60 pub target_statement: usize,
62 pub dependency_graph: HashMap<usize, Vec<usize>>,
65}
66
67impl ExecutionPlan {
68 pub fn format_debug_trace(&self, statements: &[DependencyStatement]) -> String {
70 let mut output = Vec::new();
71
72 output.push("=== Execution Plan Debug Trace ===\n".to_string());
73 output.push(format!("Target Statement: #{}\n", self.target_statement));
74 output.push(format!(
75 "Statements to Execute: {:?}\n",
76 self.statements_to_execute
77 ));
78 output.push(format!(
79 "Statements to Skip: {:?}\n\n",
80 self.statements_to_skip
81 ));
82
83 output.push("--- Dependency Graph ---\n".to_string());
84 for stmt_num in &self.statements_to_execute {
85 if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
86 output.push(format!("\nStatement #{}: ", stmt_num));
87 if stmt.creates_temp_table {
88 output.push("[TEMP TABLE] ".to_string());
89 }
90 output.push(format!("\n Reads: {:?}", stmt.references.reads));
91 output.push(format!("\n Writes: {:?}", stmt.references.writes));
92
93 if let Some(deps) = self.dependency_graph.get(stmt_num) {
94 if !deps.is_empty() {
95 output.push(format!("\n Depends on: {:?}", deps));
96 }
97 }
98 output.push("\n SQL: ".to_string());
99 output.push(
100 stmt.sql
101 .lines()
102 .map(|line| format!(" {}", line))
103 .collect::<Vec<_>>()
104 .join("\n"),
105 );
106 }
107 }
108
109 output.push("\n\n--- Skipped Statements ---\n".to_string());
110 for stmt_num in &self.statements_to_skip {
111 if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
112 output.push(format!("\nStatement #{}: [SKIPPED]\n", stmt_num));
113 output.push(format!(" Reads: {:?}\n", stmt.references.reads));
114 output.push(format!(" Writes: {:?}\n", stmt.references.writes));
115 }
116 }
117
118 output.join("")
119 }
120}
121
122pub struct DependencyAnalyzer;
124
125impl DependencyAnalyzer {
126 pub fn analyze_statements(statements: &[String]) -> Result<Vec<DependencyStatement>> {
128 let mut analyzed = Vec::new();
129
130 for (idx, sql) in statements.iter().enumerate() {
131 let number = idx + 1; let mut parser = Parser::new(sql);
135 let ast = parser
136 .parse()
137 .map_err(|e| anyhow!("Failed to parse statement {}: {}", number, e))?;
138
139 let creates_temp_table = ast.into_table.is_some() || Self::is_create_temp_table(sql);
141
142 let references = Self::extract_table_references(&ast)?;
144
145 analyzed.push(DependencyStatement {
146 number,
147 sql: sql.clone(),
148 references,
149 creates_temp_table,
150 });
151 }
152
153 Ok(analyzed)
154 }
155
156 fn extract_table_references(ast: &SelectStatement) -> Result<TableReferences> {
158 let mut refs = TableReferences::new();
159
160 if let Some(ref into_table) = ast.into_table {
162 refs.add_write(into_table.name.clone());
163 }
164
165 if let Some(table) = &ast.from_table {
167 refs.add_read(table.clone());
168 }
169
170 if let Some(subquery) = &ast.from_subquery {
172 let subquery_refs = Self::extract_table_references(subquery)?;
173 for table in subquery_refs.reads {
174 refs.add_read(table);
175 }
176 }
177
178 if let Some(_function) = &ast.from_function {
180 }
182
183 for join in &ast.joins {
185 Self::extract_from_table_source(&join.table, &mut refs)?;
186 }
187
188 for cte in &ast.ctes {
190 match &cte.cte_type {
191 crate::sql::parser::ast::CTEType::Standard(stmt) => {
192 let cte_refs = Self::extract_table_references(stmt)?;
193 for table in cte_refs.reads {
194 refs.add_read(table);
195 }
196 }
197 _ => {} }
199 }
200
201 if let Some(where_clause) = &ast.where_clause {
203 for condition in &where_clause.conditions {
204 Self::extract_from_expression(&condition.expr, &mut refs)?;
205 }
206 }
207
208 Ok(refs)
209 }
210
211 fn extract_from_table_source(
213 table_source: &TableSource,
214 refs: &mut TableReferences,
215 ) -> Result<()> {
216 match table_source {
217 TableSource::Table(name) => {
218 refs.add_read(name.clone());
219 }
220 TableSource::DerivedTable { query, .. } => {
221 let subquery_refs = Self::extract_table_references(query)?;
222 for table in subquery_refs.reads {
223 refs.add_read(table);
224 }
225 }
226 TableSource::Pivot { source, .. } => {
227 Self::extract_from_table_source(source, refs)?;
229 }
230 }
231 Ok(())
232 }
233
234 fn extract_from_expression(expr: &SqlExpression, refs: &mut TableReferences) -> Result<()> {
236 match expr {
237 SqlExpression::ScalarSubquery { query } => {
238 let subquery_refs = Self::extract_table_references(query)?;
239 for table in subquery_refs.reads {
240 refs.add_read(table);
241 }
242 }
243 SqlExpression::InSubquery {
244 expr: inner_expr,
245 subquery,
246 } => {
247 Self::extract_from_expression(inner_expr, refs)?;
248 let subquery_refs = Self::extract_table_references(subquery)?;
249 for table in subquery_refs.reads {
250 refs.add_read(table);
251 }
252 }
253 SqlExpression::NotInSubquery {
254 expr: inner_expr,
255 subquery,
256 } => {
257 Self::extract_from_expression(inner_expr, refs)?;
258 let subquery_refs = Self::extract_table_references(subquery)?;
259 for table in subquery_refs.reads {
260 refs.add_read(table);
261 }
262 }
263 SqlExpression::BinaryOp { left, right, .. } => {
264 Self::extract_from_expression(left, refs)?;
265 Self::extract_from_expression(right, refs)?;
266 }
267 SqlExpression::FunctionCall { args, .. } => {
268 for arg in args {
269 Self::extract_from_expression(arg, refs)?;
270 }
271 }
272 SqlExpression::WindowFunction { args, .. } => {
273 for arg in args {
274 Self::extract_from_expression(arg, refs)?;
275 }
276 }
277 SqlExpression::MethodCall { args, .. } => {
278 for arg in args {
279 Self::extract_from_expression(arg, refs)?;
280 }
281 }
282 SqlExpression::ChainedMethodCall { base, args, .. } => {
283 Self::extract_from_expression(base, refs)?;
284 for arg in args {
285 Self::extract_from_expression(arg, refs)?;
286 }
287 }
288 _ => {} }
290 Ok(())
291 }
292
293 fn is_create_temp_table(sql: &str) -> bool {
295 let sql_lower = sql.to_lowercase();
296 sql_lower.contains("create temp table") || sql_lower.contains("create temporary table")
297 }
298
299 pub fn compute_execution_plan(
302 statements: &[DependencyStatement],
303 target_statement_number: usize,
304 ) -> Result<ExecutionPlan> {
305 if target_statement_number == 0 || target_statement_number > statements.len() {
306 return Err(anyhow!(
307 "Invalid target statement number: {}. Must be 1-{}",
308 target_statement_number,
309 statements.len()
310 ));
311 }
312
313 info!(
314 "Computing execution plan for statement #{}",
315 target_statement_number
316 );
317
318 let mut dependency_graph: HashMap<usize, Vec<usize>> = HashMap::new();
320
321 let mut table_creators: HashMap<String, usize> = HashMap::new();
323
324 for stmt in statements {
325 for table in &stmt.references.writes {
327 table_creators.insert(table.clone(), stmt.number);
328 }
329
330 let mut depends_on = Vec::new();
332 for table in &stmt.references.reads {
333 for candidate in statements {
335 if candidate.number >= stmt.number {
336 break; }
338 if candidate.references.writes.contains(table) {
339 if !depends_on.contains(&candidate.number) {
340 depends_on.push(candidate.number);
341 }
342 }
343 }
344 }
345
346 if !depends_on.is_empty() {
347 dependency_graph.insert(stmt.number, depends_on);
348 }
349 }
350
351 debug!("Dependency graph: {:?}", dependency_graph);
352
353 let mut to_execute = HashSet::new();
355 let mut queue = VecDeque::new();
356 queue.push_back(target_statement_number);
357
358 while let Some(stmt_num) = queue.pop_front() {
359 if to_execute.insert(stmt_num) {
360 if let Some(deps) = dependency_graph.get(&stmt_num) {
362 for &dep in deps {
363 queue.push_back(dep);
364 }
365 }
366 }
367 }
368
369 let mut statements_to_execute: Vec<usize> = to_execute.into_iter().collect();
371 statements_to_execute.sort_unstable();
372
373 let statements_to_skip: Vec<usize> = (1..=statements.len())
375 .filter(|n| !statements_to_execute.contains(n))
376 .collect();
377
378 info!(
379 "Execution plan: execute {:?}, skip {:?}",
380 statements_to_execute, statements_to_skip
381 );
382
383 Ok(ExecutionPlan {
384 statements_to_execute,
385 statements_to_skip,
386 target_statement: target_statement_number,
387 dependency_graph,
388 })
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn test_simple_dependency() {
398 let statements = vec![
399 "SELECT * FROM sales INTO #raw_data".to_string(),
400 "SELECT COUNT(*) FROM customers".to_string(),
401 "SELECT * FROM #raw_data WHERE amount > 100".to_string(),
402 ];
403
404 let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
405 assert_eq!(analyzed.len(), 3);
406
407 assert_eq!(analyzed[0].references.writes, vec!["#raw_data"]);
409 assert_eq!(analyzed[0].references.reads, vec!["sales"]);
410
411 assert_eq!(analyzed[1].references.reads, vec!["customers"]);
413 assert!(analyzed[1].references.writes.is_empty());
414
415 assert_eq!(analyzed[2].references.reads, vec!["#raw_data"]);
417 }
418
419 #[test]
420 fn test_execution_plan() {
421 let statements = vec![
422 "SELECT * FROM sales INTO #raw_data".to_string(),
423 "SELECT COUNT(*) FROM customers".to_string(),
424 "SELECT * FROM #raw_data WHERE amount > 100".to_string(),
425 ];
426
427 let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
428 let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 3).unwrap();
429
430 assert_eq!(plan.statements_to_execute, vec![1, 3]);
433 assert_eq!(plan.statements_to_skip, vec![2]);
434 assert_eq!(plan.target_statement, 3);
435 }
436
437 #[test]
438 fn test_transitive_dependencies() {
439 let statements = vec![
440 "SELECT * FROM base INTO #t1".to_string(),
441 "SELECT * FROM #t1 INTO #t2".to_string(),
442 "SELECT * FROM #t2 INTO #t3".to_string(),
443 "SELECT * FROM unrelated".to_string(),
444 "SELECT * FROM #t3".to_string(),
445 ];
446
447 let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
448 let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 5).unwrap();
449
450 assert_eq!(plan.statements_to_execute, vec![1, 2, 3, 5]);
453 assert_eq!(plan.statements_to_skip, vec![4]);
454 }
455
456 #[test]
457 fn test_invalid_statement_number() {
458 let statements = vec!["SELECT 1".to_string()];
459 let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
460
461 assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 0).is_err());
463
464 assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 5).is_err());
466 }
467}