sql_cli/analysis/
statement_dependencies.rs

1// Statement dependency analysis for --execute-statement feature
2// Analyzes SQL scripts to compute minimal execution plans based on temp table dependencies
3
4use anyhow::{anyhow, Result};
5use std::collections::{HashMap, HashSet, VecDeque};
6use tracing::{debug, info};
7
8use crate::sql::recursive_parser::{Parser, SelectStatement, SqlExpression, TableSource};
9
10/// Information about tables referenced by a statement
11#[derive(Debug, Clone, PartialEq)]
12pub struct TableReferences {
13    /// Tables that this statement reads from (FROM, JOIN clauses)
14    pub reads: Vec<String>,
15    /// Tables that this statement writes to (CREATE TEMP TABLE, INSERT, etc.)
16    pub writes: Vec<String>,
17}
18
19impl TableReferences {
20    fn new() -> Self {
21        Self {
22            reads: Vec::new(),
23            writes: Vec::new(),
24        }
25    }
26
27    fn add_read(&mut self, table: String) {
28        if !self.reads.contains(&table) {
29            self.reads.push(table);
30        }
31    }
32
33    fn add_write(&mut self, table: String) {
34        if !self.writes.contains(&table) {
35            self.writes.push(table);
36        }
37    }
38}
39
40/// A parsed SQL statement with its dependencies
41#[derive(Debug, Clone)]
42pub struct DependencyStatement {
43    /// Statement number (1-based for user display)
44    pub number: usize,
45    /// The SQL text
46    pub sql: String,
47    /// Tables this statement references
48    pub references: TableReferences,
49    /// Whether this is a temp table creation
50    pub creates_temp_table: bool,
51}
52
53/// Execution plan for a target statement
54#[derive(Debug, Clone)]
55pub struct ExecutionPlan {
56    /// Statement numbers to execute in order (1-based)
57    pub statements_to_execute: Vec<usize>,
58    /// Statement numbers that will be skipped
59    pub statements_to_skip: Vec<usize>,
60    /// The target statement number
61    pub target_statement: usize,
62    /// Dependency graph as adjacency list (for debugging)
63    /// statement_number -> [dependent_statement_numbers]
64    pub dependency_graph: HashMap<usize, Vec<usize>>,
65}
66
67impl ExecutionPlan {
68    /// Create a formatted debug trace of the execution plan
69    pub fn format_debug_trace(&self, statements: &[DependencyStatement]) -> String {
70        let mut output = Vec::new();
71
72        output.push("=== Execution Plan Debug Trace ===\n".to_string());
73        output.push(format!("Target Statement: #{}\n", self.target_statement));
74        output.push(format!(
75            "Statements to Execute: {:?}\n",
76            self.statements_to_execute
77        ));
78        output.push(format!(
79            "Statements to Skip: {:?}\n\n",
80            self.statements_to_skip
81        ));
82
83        output.push("--- Dependency Graph ---\n".to_string());
84        for stmt_num in &self.statements_to_execute {
85            if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
86                output.push(format!("\nStatement #{}: ", stmt_num));
87                if stmt.creates_temp_table {
88                    output.push("[TEMP TABLE] ".to_string());
89                }
90                output.push(format!("\n  Reads: {:?}", stmt.references.reads));
91                output.push(format!("\n  Writes: {:?}", stmt.references.writes));
92
93                if let Some(deps) = self.dependency_graph.get(stmt_num) {
94                    if !deps.is_empty() {
95                        output.push(format!("\n  Depends on: {:?}", deps));
96                    }
97                }
98                output.push("\n  SQL: ".to_string());
99                output.push(
100                    stmt.sql
101                        .lines()
102                        .map(|line| format!("    {}", line))
103                        .collect::<Vec<_>>()
104                        .join("\n"),
105                );
106            }
107        }
108
109        output.push("\n\n--- Skipped Statements ---\n".to_string());
110        for stmt_num in &self.statements_to_skip {
111            if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
112                output.push(format!("\nStatement #{}: [SKIPPED]\n", stmt_num));
113                output.push(format!("  Reads: {:?}\n", stmt.references.reads));
114                output.push(format!("  Writes: {:?}\n", stmt.references.writes));
115            }
116        }
117
118        output.join("")
119    }
120}
121
122/// Dependency analyzer for SQL scripts
123pub struct DependencyAnalyzer;
124
125impl DependencyAnalyzer {
126    /// Analyze a list of SQL statements and extract their dependencies
127    pub fn analyze_statements(statements: &[String]) -> Result<Vec<DependencyStatement>> {
128        let mut analyzed = Vec::new();
129
130        for (idx, sql) in statements.iter().enumerate() {
131            let number = idx + 1; // 1-based numbering for user display
132
133            // Parse the SQL statement directly - parser handles all syntaxes
134            let mut parser = Parser::new(sql);
135            let ast = parser
136                .parse()
137                .map_err(|e| anyhow!("Failed to parse statement {}: {}", number, e))?;
138
139            // Check if this creates a temp table using AST
140            let creates_temp_table = ast.into_table.is_some() || Self::is_create_temp_table(sql);
141
142            // Extract table references from the AST
143            let references = Self::extract_table_references(&ast)?;
144
145            analyzed.push(DependencyStatement {
146                number,
147                sql: sql.clone(),
148                references,
149                creates_temp_table,
150            });
151        }
152
153        Ok(analyzed)
154    }
155
156    /// Extract table references from a parsed AST
157    fn extract_table_references(ast: &SelectStatement) -> Result<TableReferences> {
158        let mut refs = TableReferences::new();
159
160        // Check for SELECT INTO (write operation) - parser extracts this for us!
161        if let Some(ref into_table) = ast.into_table {
162            refs.add_write(into_table.name.clone());
163        }
164
165        // Extract FROM clause tables (read operations)
166        if let Some(table) = &ast.from_table {
167            refs.add_read(table.clone());
168        }
169
170        // Extract from subquery if present
171        if let Some(subquery) = &ast.from_subquery {
172            let subquery_refs = Self::extract_table_references(subquery)?;
173            for table in subquery_refs.reads {
174                refs.add_read(table);
175            }
176        }
177
178        // Extract from function if present (table functions like RANGE)
179        if let Some(_function) = &ast.from_function {
180            // Table functions don't reference existing tables, so nothing to extract
181        }
182
183        // Extract from JOINs
184        for join in &ast.joins {
185            Self::extract_from_table_source(&join.table, &mut refs)?;
186        }
187
188        // Extract from CTEs
189        for cte in &ast.ctes {
190            match &cte.cte_type {
191                crate::sql::parser::ast::CTEType::Standard(stmt) => {
192                    let cte_refs = Self::extract_table_references(stmt)?;
193                    for table in cte_refs.reads {
194                        refs.add_read(table);
195                    }
196                }
197                _ => {} // WEB CTEs don't have dependencies on local tables
198            }
199        }
200
201        // Extract from WHERE clause (for subqueries)
202        if let Some(where_clause) = &ast.where_clause {
203            for condition in &where_clause.conditions {
204                Self::extract_from_expression(&condition.expr, &mut refs)?;
205            }
206        }
207
208        Ok(refs)
209    }
210
211    /// Extract table references from a table source (handles subqueries, table functions, etc.)
212    fn extract_from_table_source(
213        table_source: &TableSource,
214        refs: &mut TableReferences,
215    ) -> Result<()> {
216        match table_source {
217            TableSource::Table(name) => {
218                refs.add_read(name.clone());
219            }
220            TableSource::DerivedTable { query, .. } => {
221                let subquery_refs = Self::extract_table_references(query)?;
222                for table in subquery_refs.reads {
223                    refs.add_read(table);
224                }
225            }
226        }
227        Ok(())
228    }
229
230    /// Extract table references from expressions (for subqueries in WHERE, etc.)
231    fn extract_from_expression(expr: &SqlExpression, refs: &mut TableReferences) -> Result<()> {
232        match expr {
233            SqlExpression::ScalarSubquery { query } => {
234                let subquery_refs = Self::extract_table_references(query)?;
235                for table in subquery_refs.reads {
236                    refs.add_read(table);
237                }
238            }
239            SqlExpression::InSubquery {
240                expr: inner_expr,
241                subquery,
242            } => {
243                Self::extract_from_expression(inner_expr, refs)?;
244                let subquery_refs = Self::extract_table_references(subquery)?;
245                for table in subquery_refs.reads {
246                    refs.add_read(table);
247                }
248            }
249            SqlExpression::NotInSubquery {
250                expr: inner_expr,
251                subquery,
252            } => {
253                Self::extract_from_expression(inner_expr, refs)?;
254                let subquery_refs = Self::extract_table_references(subquery)?;
255                for table in subquery_refs.reads {
256                    refs.add_read(table);
257                }
258            }
259            SqlExpression::BinaryOp { left, right, .. } => {
260                Self::extract_from_expression(left, refs)?;
261                Self::extract_from_expression(right, refs)?;
262            }
263            SqlExpression::FunctionCall { args, .. } => {
264                for arg in args {
265                    Self::extract_from_expression(arg, refs)?;
266                }
267            }
268            SqlExpression::WindowFunction { args, .. } => {
269                for arg in args {
270                    Self::extract_from_expression(arg, refs)?;
271                }
272            }
273            SqlExpression::MethodCall { args, .. } => {
274                for arg in args {
275                    Self::extract_from_expression(arg, refs)?;
276                }
277            }
278            SqlExpression::ChainedMethodCall { base, args, .. } => {
279                Self::extract_from_expression(base, refs)?;
280                for arg in args {
281                    Self::extract_from_expression(arg, refs)?;
282                }
283            }
284            _ => {} // Other expression types don't contain table references
285        }
286        Ok(())
287    }
288
289    /// Check if SQL uses CREATE TEMP TABLE syntax (not SELECT INTO, which is handled by AST)
290    fn is_create_temp_table(sql: &str) -> bool {
291        let sql_lower = sql.to_lowercase();
292        sql_lower.contains("create temp table") || sql_lower.contains("create temporary table")
293    }
294
295    /// Compute execution plan for a target statement
296    /// Returns the minimal set of statements needed to execute the target
297    pub fn compute_execution_plan(
298        statements: &[DependencyStatement],
299        target_statement_number: usize,
300    ) -> Result<ExecutionPlan> {
301        if target_statement_number == 0 || target_statement_number > statements.len() {
302            return Err(anyhow!(
303                "Invalid target statement number: {}. Must be 1-{}",
304                target_statement_number,
305                statements.len()
306            ));
307        }
308
309        info!(
310            "Computing execution plan for statement #{}",
311            target_statement_number
312        );
313
314        // Build dependency graph: statement_number -> [depends_on_statement_numbers]
315        let mut dependency_graph: HashMap<usize, Vec<usize>> = HashMap::new();
316
317        // Build a map of table_name -> statement_number that creates it
318        let mut table_creators: HashMap<String, usize> = HashMap::new();
319
320        for stmt in statements {
321            // Register temp tables this statement creates
322            for table in &stmt.references.writes {
323                table_creators.insert(table.clone(), stmt.number);
324            }
325
326            // Find dependencies for tables this statement reads
327            let mut depends_on = Vec::new();
328            for table in &stmt.references.reads {
329                // Look for the latest statement that creates this table (before current statement)
330                for candidate in statements {
331                    if candidate.number >= stmt.number {
332                        break; // Only look at earlier statements
333                    }
334                    if candidate.references.writes.contains(table) {
335                        if !depends_on.contains(&candidate.number) {
336                            depends_on.push(candidate.number);
337                        }
338                    }
339                }
340            }
341
342            if !depends_on.is_empty() {
343                dependency_graph.insert(stmt.number, depends_on);
344            }
345        }
346
347        debug!("Dependency graph: {:?}", dependency_graph);
348
349        // Compute transitive dependencies using BFS
350        let mut to_execute = HashSet::new();
351        let mut queue = VecDeque::new();
352        queue.push_back(target_statement_number);
353
354        while let Some(stmt_num) = queue.pop_front() {
355            if to_execute.insert(stmt_num) {
356                // First time seeing this statement, add its dependencies to queue
357                if let Some(deps) = dependency_graph.get(&stmt_num) {
358                    for &dep in deps {
359                        queue.push_back(dep);
360                    }
361                }
362            }
363        }
364
365        // Sort statements to execute in order
366        let mut statements_to_execute: Vec<usize> = to_execute.into_iter().collect();
367        statements_to_execute.sort_unstable();
368
369        // Compute skipped statements
370        let statements_to_skip: Vec<usize> = (1..=statements.len())
371            .filter(|n| !statements_to_execute.contains(n))
372            .collect();
373
374        info!(
375            "Execution plan: execute {:?}, skip {:?}",
376            statements_to_execute, statements_to_skip
377        );
378
379        Ok(ExecutionPlan {
380            statements_to_execute,
381            statements_to_skip,
382            target_statement: target_statement_number,
383            dependency_graph,
384        })
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_simple_dependency() {
394        let statements = vec![
395            "SELECT * FROM sales INTO #raw_data".to_string(),
396            "SELECT COUNT(*) FROM customers".to_string(),
397            "SELECT * FROM #raw_data WHERE amount > 100".to_string(),
398        ];
399
400        let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
401        assert_eq!(analyzed.len(), 3);
402
403        // Statement 1: creates #raw_data, reads sales
404        assert_eq!(analyzed[0].references.writes, vec!["#raw_data"]);
405        assert_eq!(analyzed[0].references.reads, vec!["sales"]);
406
407        // Statement 2: reads customers (independent)
408        assert_eq!(analyzed[1].references.reads, vec!["customers"]);
409        assert!(analyzed[1].references.writes.is_empty());
410
411        // Statement 3: reads #raw_data
412        assert_eq!(analyzed[2].references.reads, vec!["#raw_data"]);
413    }
414
415    #[test]
416    fn test_execution_plan() {
417        let statements = vec![
418            "SELECT * FROM sales INTO #raw_data".to_string(),
419            "SELECT COUNT(*) FROM customers".to_string(),
420            "SELECT * FROM #raw_data WHERE amount > 100".to_string(),
421        ];
422
423        let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
424        let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 3).unwrap();
425
426        // Should execute statement 1 (creates #raw_data) and 3 (target)
427        // Should skip statement 2 (independent)
428        assert_eq!(plan.statements_to_execute, vec![1, 3]);
429        assert_eq!(plan.statements_to_skip, vec![2]);
430        assert_eq!(plan.target_statement, 3);
431    }
432
433    #[test]
434    fn test_transitive_dependencies() {
435        let statements = vec![
436            "SELECT * FROM base INTO #t1".to_string(),
437            "SELECT * FROM #t1 INTO #t2".to_string(),
438            "SELECT * FROM #t2 INTO #t3".to_string(),
439            "SELECT * FROM unrelated".to_string(),
440            "SELECT * FROM #t3".to_string(),
441        ];
442
443        let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
444        let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 5).unwrap();
445
446        // Should execute 1 -> 2 -> 3 -> 5 (transitive chain)
447        // Should skip 4 (unrelated)
448        assert_eq!(plan.statements_to_execute, vec![1, 2, 3, 5]);
449        assert_eq!(plan.statements_to_skip, vec![4]);
450    }
451
452    #[test]
453    fn test_invalid_statement_number() {
454        let statements = vec!["SELECT 1".to_string()];
455        let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
456
457        // Test statement number 0
458        assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 0).is_err());
459
460        // Test statement number > len
461        assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 5).is_err());
462    }
463}