1use anyhow::{anyhow, Result};
5use std::collections::{HashMap, HashSet, VecDeque};
6use tracing::{debug, info};
7
8use crate::sql::recursive_parser::{Parser, SelectStatement, SqlExpression, TableSource};
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct TableReferences {
13 pub reads: Vec<String>,
15 pub writes: Vec<String>,
17}
18
19impl TableReferences {
20 fn new() -> Self {
21 Self {
22 reads: Vec::new(),
23 writes: Vec::new(),
24 }
25 }
26
27 fn add_read(&mut self, table: String) {
28 if !self.reads.contains(&table) {
29 self.reads.push(table);
30 }
31 }
32
33 fn add_write(&mut self, table: String) {
34 if !self.writes.contains(&table) {
35 self.writes.push(table);
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct DependencyStatement {
43 pub number: usize,
45 pub sql: String,
47 pub references: TableReferences,
49 pub creates_temp_table: bool,
51}
52
53#[derive(Debug, Clone)]
55pub struct ExecutionPlan {
56 pub statements_to_execute: Vec<usize>,
58 pub statements_to_skip: Vec<usize>,
60 pub target_statement: usize,
62 pub dependency_graph: HashMap<usize, Vec<usize>>,
65}
66
67impl ExecutionPlan {
68 pub fn format_debug_trace(&self, statements: &[DependencyStatement]) -> String {
70 let mut output = Vec::new();
71
72 output.push("=== Execution Plan Debug Trace ===\n".to_string());
73 output.push(format!("Target Statement: #{}\n", self.target_statement));
74 output.push(format!(
75 "Statements to Execute: {:?}\n",
76 self.statements_to_execute
77 ));
78 output.push(format!(
79 "Statements to Skip: {:?}\n\n",
80 self.statements_to_skip
81 ));
82
83 output.push("--- Dependency Graph ---\n".to_string());
84 for stmt_num in &self.statements_to_execute {
85 if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
86 output.push(format!("\nStatement #{}: ", stmt_num));
87 if stmt.creates_temp_table {
88 output.push("[TEMP TABLE] ".to_string());
89 }
90 output.push(format!("\n Reads: {:?}", stmt.references.reads));
91 output.push(format!("\n Writes: {:?}", stmt.references.writes));
92
93 if let Some(deps) = self.dependency_graph.get(stmt_num) {
94 if !deps.is_empty() {
95 output.push(format!("\n Depends on: {:?}", deps));
96 }
97 }
98 output.push("\n SQL: ".to_string());
99 output.push(
100 stmt.sql
101 .lines()
102 .map(|line| format!(" {}", line))
103 .collect::<Vec<_>>()
104 .join("\n"),
105 );
106 }
107 }
108
109 output.push("\n\n--- Skipped Statements ---\n".to_string());
110 for stmt_num in &self.statements_to_skip {
111 if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
112 output.push(format!("\nStatement #{}: [SKIPPED]\n", stmt_num));
113 output.push(format!(" Reads: {:?}\n", stmt.references.reads));
114 output.push(format!(" Writes: {:?}\n", stmt.references.writes));
115 }
116 }
117
118 output.join("")
119 }
120}
121
122pub struct DependencyAnalyzer;
124
125impl DependencyAnalyzer {
126 pub fn analyze_statements(statements: &[String]) -> Result<Vec<DependencyStatement>> {
128 let mut analyzed = Vec::new();
129
130 for (idx, sql) in statements.iter().enumerate() {
131 let number = idx + 1; let mut parser = Parser::new(sql);
135 let ast = parser
136 .parse()
137 .map_err(|e| anyhow!("Failed to parse statement {}: {}", number, e))?;
138
139 let creates_temp_table = ast.into_table.is_some() || Self::is_create_temp_table(sql);
141
142 let references = Self::extract_table_references(&ast)?;
144
145 analyzed.push(DependencyStatement {
146 number,
147 sql: sql.clone(),
148 references,
149 creates_temp_table,
150 });
151 }
152
153 Ok(analyzed)
154 }
155
156 fn extract_table_references(ast: &SelectStatement) -> Result<TableReferences> {
158 let mut refs = TableReferences::new();
159
160 if let Some(ref into_table) = ast.into_table {
162 refs.add_write(into_table.name.clone());
163 }
164
165 if let Some(table) = &ast.from_table {
167 refs.add_read(table.clone());
168 }
169
170 if let Some(subquery) = &ast.from_subquery {
172 let subquery_refs = Self::extract_table_references(subquery)?;
173 for table in subquery_refs.reads {
174 refs.add_read(table);
175 }
176 }
177
178 if let Some(_function) = &ast.from_function {
180 }
182
183 for join in &ast.joins {
185 Self::extract_from_table_source(&join.table, &mut refs)?;
186 }
187
188 for cte in &ast.ctes {
190 match &cte.cte_type {
191 crate::sql::parser::ast::CTEType::Standard(stmt) => {
192 let cte_refs = Self::extract_table_references(stmt)?;
193 for table in cte_refs.reads {
194 refs.add_read(table);
195 }
196 }
197 _ => {} }
199 }
200
201 if let Some(where_clause) = &ast.where_clause {
203 for condition in &where_clause.conditions {
204 Self::extract_from_expression(&condition.expr, &mut refs)?;
205 }
206 }
207
208 Ok(refs)
209 }
210
211 fn extract_from_table_source(
213 table_source: &TableSource,
214 refs: &mut TableReferences,
215 ) -> Result<()> {
216 match table_source {
217 TableSource::Table(name) => {
218 refs.add_read(name.clone());
219 }
220 TableSource::DerivedTable { query, .. } => {
221 let subquery_refs = Self::extract_table_references(query)?;
222 for table in subquery_refs.reads {
223 refs.add_read(table);
224 }
225 }
226 }
227 Ok(())
228 }
229
230 fn extract_from_expression(expr: &SqlExpression, refs: &mut TableReferences) -> Result<()> {
232 match expr {
233 SqlExpression::ScalarSubquery { query } => {
234 let subquery_refs = Self::extract_table_references(query)?;
235 for table in subquery_refs.reads {
236 refs.add_read(table);
237 }
238 }
239 SqlExpression::InSubquery {
240 expr: inner_expr,
241 subquery,
242 } => {
243 Self::extract_from_expression(inner_expr, refs)?;
244 let subquery_refs = Self::extract_table_references(subquery)?;
245 for table in subquery_refs.reads {
246 refs.add_read(table);
247 }
248 }
249 SqlExpression::NotInSubquery {
250 expr: inner_expr,
251 subquery,
252 } => {
253 Self::extract_from_expression(inner_expr, refs)?;
254 let subquery_refs = Self::extract_table_references(subquery)?;
255 for table in subquery_refs.reads {
256 refs.add_read(table);
257 }
258 }
259 SqlExpression::BinaryOp { left, right, .. } => {
260 Self::extract_from_expression(left, refs)?;
261 Self::extract_from_expression(right, refs)?;
262 }
263 SqlExpression::FunctionCall { args, .. } => {
264 for arg in args {
265 Self::extract_from_expression(arg, refs)?;
266 }
267 }
268 SqlExpression::WindowFunction { args, .. } => {
269 for arg in args {
270 Self::extract_from_expression(arg, refs)?;
271 }
272 }
273 SqlExpression::MethodCall { args, .. } => {
274 for arg in args {
275 Self::extract_from_expression(arg, refs)?;
276 }
277 }
278 SqlExpression::ChainedMethodCall { base, args, .. } => {
279 Self::extract_from_expression(base, refs)?;
280 for arg in args {
281 Self::extract_from_expression(arg, refs)?;
282 }
283 }
284 _ => {} }
286 Ok(())
287 }
288
289 fn is_create_temp_table(sql: &str) -> bool {
291 let sql_lower = sql.to_lowercase();
292 sql_lower.contains("create temp table") || sql_lower.contains("create temporary table")
293 }
294
295 pub fn compute_execution_plan(
298 statements: &[DependencyStatement],
299 target_statement_number: usize,
300 ) -> Result<ExecutionPlan> {
301 if target_statement_number == 0 || target_statement_number > statements.len() {
302 return Err(anyhow!(
303 "Invalid target statement number: {}. Must be 1-{}",
304 target_statement_number,
305 statements.len()
306 ));
307 }
308
309 info!(
310 "Computing execution plan for statement #{}",
311 target_statement_number
312 );
313
314 let mut dependency_graph: HashMap<usize, Vec<usize>> = HashMap::new();
316
317 let mut table_creators: HashMap<String, usize> = HashMap::new();
319
320 for stmt in statements {
321 for table in &stmt.references.writes {
323 table_creators.insert(table.clone(), stmt.number);
324 }
325
326 let mut depends_on = Vec::new();
328 for table in &stmt.references.reads {
329 for candidate in statements {
331 if candidate.number >= stmt.number {
332 break; }
334 if candidate.references.writes.contains(table) {
335 if !depends_on.contains(&candidate.number) {
336 depends_on.push(candidate.number);
337 }
338 }
339 }
340 }
341
342 if !depends_on.is_empty() {
343 dependency_graph.insert(stmt.number, depends_on);
344 }
345 }
346
347 debug!("Dependency graph: {:?}", dependency_graph);
348
349 let mut to_execute = HashSet::new();
351 let mut queue = VecDeque::new();
352 queue.push_back(target_statement_number);
353
354 while let Some(stmt_num) = queue.pop_front() {
355 if to_execute.insert(stmt_num) {
356 if let Some(deps) = dependency_graph.get(&stmt_num) {
358 for &dep in deps {
359 queue.push_back(dep);
360 }
361 }
362 }
363 }
364
365 let mut statements_to_execute: Vec<usize> = to_execute.into_iter().collect();
367 statements_to_execute.sort_unstable();
368
369 let statements_to_skip: Vec<usize> = (1..=statements.len())
371 .filter(|n| !statements_to_execute.contains(n))
372 .collect();
373
374 info!(
375 "Execution plan: execute {:?}, skip {:?}",
376 statements_to_execute, statements_to_skip
377 );
378
379 Ok(ExecutionPlan {
380 statements_to_execute,
381 statements_to_skip,
382 target_statement: target_statement_number,
383 dependency_graph,
384 })
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_simple_dependency() {
394 let statements = vec![
395 "SELECT * FROM sales INTO #raw_data".to_string(),
396 "SELECT COUNT(*) FROM customers".to_string(),
397 "SELECT * FROM #raw_data WHERE amount > 100".to_string(),
398 ];
399
400 let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
401 assert_eq!(analyzed.len(), 3);
402
403 assert_eq!(analyzed[0].references.writes, vec!["#raw_data"]);
405 assert_eq!(analyzed[0].references.reads, vec!["sales"]);
406
407 assert_eq!(analyzed[1].references.reads, vec!["customers"]);
409 assert!(analyzed[1].references.writes.is_empty());
410
411 assert_eq!(analyzed[2].references.reads, vec!["#raw_data"]);
413 }
414
415 #[test]
416 fn test_execution_plan() {
417 let statements = vec![
418 "SELECT * FROM sales INTO #raw_data".to_string(),
419 "SELECT COUNT(*) FROM customers".to_string(),
420 "SELECT * FROM #raw_data WHERE amount > 100".to_string(),
421 ];
422
423 let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
424 let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 3).unwrap();
425
426 assert_eq!(plan.statements_to_execute, vec![1, 3]);
429 assert_eq!(plan.statements_to_skip, vec![2]);
430 assert_eq!(plan.target_statement, 3);
431 }
432
433 #[test]
434 fn test_transitive_dependencies() {
435 let statements = vec![
436 "SELECT * FROM base INTO #t1".to_string(),
437 "SELECT * FROM #t1 INTO #t2".to_string(),
438 "SELECT * FROM #t2 INTO #t3".to_string(),
439 "SELECT * FROM unrelated".to_string(),
440 "SELECT * FROM #t3".to_string(),
441 ];
442
443 let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
444 let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 5).unwrap();
445
446 assert_eq!(plan.statements_to_execute, vec![1, 2, 3, 5]);
449 assert_eq!(plan.statements_to_skip, vec![4]);
450 }
451
452 #[test]
453 fn test_invalid_statement_number() {
454 let statements = vec!["SELECT 1".to_string()];
455 let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
456
457 assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 0).is_err());
459
460 assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 5).is_err());
462 }
463}