Skip to main content

sql_cli/analysis/
statement_dependencies.rs

1// Statement dependency analysis for --execute-statement feature
2// Analyzes SQL scripts to compute minimal execution plans based on temp table dependencies
3
4use anyhow::{anyhow, Result};
5use std::collections::{HashMap, HashSet, VecDeque};
6use tracing::{debug, info};
7
8use crate::sql::recursive_parser::{Parser, SelectStatement, SqlExpression, TableSource};
9
10/// Information about tables referenced by a statement
11#[derive(Debug, Clone, PartialEq)]
12pub struct TableReferences {
13    /// Tables that this statement reads from (FROM, JOIN clauses)
14    pub reads: Vec<String>,
15    /// Tables that this statement writes to (CREATE TEMP TABLE, INSERT, etc.)
16    pub writes: Vec<String>,
17}
18
19impl TableReferences {
20    fn new() -> Self {
21        Self {
22            reads: Vec::new(),
23            writes: Vec::new(),
24        }
25    }
26
27    fn add_read(&mut self, table: String) {
28        if !self.reads.contains(&table) {
29            self.reads.push(table);
30        }
31    }
32
33    fn add_write(&mut self, table: String) {
34        if !self.writes.contains(&table) {
35            self.writes.push(table);
36        }
37    }
38}
39
40/// A parsed SQL statement with its dependencies
41#[derive(Debug, Clone)]
42pub struct DependencyStatement {
43    /// Statement number (1-based for user display)
44    pub number: usize,
45    /// The SQL text
46    pub sql: String,
47    /// Tables this statement references
48    pub references: TableReferences,
49    /// Whether this is a temp table creation
50    pub creates_temp_table: bool,
51}
52
53/// Execution plan for a target statement
54#[derive(Debug, Clone)]
55pub struct ExecutionPlan {
56    /// Statement numbers to execute in order (1-based)
57    pub statements_to_execute: Vec<usize>,
58    /// Statement numbers that will be skipped
59    pub statements_to_skip: Vec<usize>,
60    /// The target statement number
61    pub target_statement: usize,
62    /// Dependency graph as adjacency list (for debugging)
63    /// statement_number -> [dependent_statement_numbers]
64    pub dependency_graph: HashMap<usize, Vec<usize>>,
65}
66
67impl ExecutionPlan {
68    /// Create a formatted debug trace of the execution plan
69    pub fn format_debug_trace(&self, statements: &[DependencyStatement]) -> String {
70        let mut output = Vec::new();
71
72        output.push("=== Execution Plan Debug Trace ===\n".to_string());
73        output.push(format!("Target Statement: #{}\n", self.target_statement));
74        output.push(format!(
75            "Statements to Execute: {:?}\n",
76            self.statements_to_execute
77        ));
78        output.push(format!(
79            "Statements to Skip: {:?}\n\n",
80            self.statements_to_skip
81        ));
82
83        output.push("--- Dependency Graph ---\n".to_string());
84        for stmt_num in &self.statements_to_execute {
85            if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
86                output.push(format!("\nStatement #{}: ", stmt_num));
87                if stmt.creates_temp_table {
88                    output.push("[TEMP TABLE] ".to_string());
89                }
90                output.push(format!("\n  Reads: {:?}", stmt.references.reads));
91                output.push(format!("\n  Writes: {:?}", stmt.references.writes));
92
93                if let Some(deps) = self.dependency_graph.get(stmt_num) {
94                    if !deps.is_empty() {
95                        output.push(format!("\n  Depends on: {:?}", deps));
96                    }
97                }
98                output.push("\n  SQL: ".to_string());
99                output.push(
100                    stmt.sql
101                        .lines()
102                        .map(|line| format!("    {}", line))
103                        .collect::<Vec<_>>()
104                        .join("\n"),
105                );
106            }
107        }
108
109        output.push("\n\n--- Skipped Statements ---\n".to_string());
110        for stmt_num in &self.statements_to_skip {
111            if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
112                output.push(format!("\nStatement #{}: [SKIPPED]\n", stmt_num));
113                output.push(format!("  Reads: {:?}\n", stmt.references.reads));
114                output.push(format!("  Writes: {:?}\n", stmt.references.writes));
115            }
116        }
117
118        output.join("")
119    }
120}
121
122/// Dependency analyzer for SQL scripts
123pub struct DependencyAnalyzer;
124
125impl DependencyAnalyzer {
126    /// Analyze a list of SQL statements and extract their dependencies
127    pub fn analyze_statements(statements: &[String]) -> Result<Vec<DependencyStatement>> {
128        let mut analyzed = Vec::new();
129
130        for (idx, sql) in statements.iter().enumerate() {
131            let number = idx + 1; // 1-based numbering for user display
132
133            // Parse the SQL statement directly - parser handles all syntaxes
134            let mut parser = Parser::new(sql);
135            let ast = parser
136                .parse()
137                .map_err(|e| anyhow!("Failed to parse statement {}: {}", number, e))?;
138
139            // Check if this creates a temp table using AST
140            let creates_temp_table = ast.into_table.is_some() || Self::is_create_temp_table(sql);
141
142            // Extract table references from the AST
143            let references = Self::extract_table_references(&ast)?;
144
145            analyzed.push(DependencyStatement {
146                number,
147                sql: sql.clone(),
148                references,
149                creates_temp_table,
150            });
151        }
152
153        Ok(analyzed)
154    }
155
156    /// Extract table references from a parsed AST
157    fn extract_table_references(ast: &SelectStatement) -> Result<TableReferences> {
158        let mut refs = TableReferences::new();
159
160        // Check for SELECT INTO (write operation) - parser extracts this for us!
161        if let Some(ref into_table) = ast.into_table {
162            refs.add_write(into_table.name.clone());
163        }
164
165        // Extract FROM clause tables (read operations)
166        if let Some(table) = &ast.from_table {
167            refs.add_read(table.clone());
168        }
169
170        // Extract from subquery if present
171        if let Some(subquery) = &ast.from_subquery {
172            let subquery_refs = Self::extract_table_references(subquery)?;
173            for table in subquery_refs.reads {
174                refs.add_read(table);
175            }
176        }
177
178        // Extract from function if present (table functions like RANGE)
179        if let Some(_function) = &ast.from_function {
180            // Table functions don't reference existing tables, so nothing to extract
181        }
182
183        // Extract from JOINs
184        for join in &ast.joins {
185            Self::extract_from_table_source(&join.table, &mut refs)?;
186        }
187
188        // Extract from CTEs
189        for cte in &ast.ctes {
190            match &cte.cte_type {
191                crate::sql::parser::ast::CTEType::Standard(stmt) => {
192                    let cte_refs = Self::extract_table_references(stmt)?;
193                    for table in cte_refs.reads {
194                        refs.add_read(table);
195                    }
196                }
197                _ => {} // WEB CTEs don't have dependencies on local tables
198            }
199        }
200
201        // Extract from WHERE clause (for subqueries)
202        if let Some(where_clause) = &ast.where_clause {
203            for condition in &where_clause.conditions {
204                Self::extract_from_expression(&condition.expr, &mut refs)?;
205            }
206        }
207
208        Ok(refs)
209    }
210
211    /// Extract table references from a table source (handles subqueries, table functions, etc.)
212    fn extract_from_table_source(
213        table_source: &TableSource,
214        refs: &mut TableReferences,
215    ) -> Result<()> {
216        match table_source {
217            TableSource::Table(name) => {
218                refs.add_read(name.clone());
219            }
220            TableSource::DerivedTable { query, .. } => {
221                let subquery_refs = Self::extract_table_references(query)?;
222                for table in subquery_refs.reads {
223                    refs.add_read(table);
224                }
225            }
226            TableSource::Pivot { source, .. } => {
227                // Recursively extract from the pivot source
228                Self::extract_from_table_source(source, refs)?;
229            }
230        }
231        Ok(())
232    }
233
234    /// Extract table references from expressions (for subqueries in WHERE, etc.)
235    fn extract_from_expression(expr: &SqlExpression, refs: &mut TableReferences) -> Result<()> {
236        match expr {
237            SqlExpression::ScalarSubquery { query } => {
238                let subquery_refs = Self::extract_table_references(query)?;
239                for table in subquery_refs.reads {
240                    refs.add_read(table);
241                }
242            }
243            SqlExpression::InSubquery {
244                expr: inner_expr,
245                subquery,
246            } => {
247                Self::extract_from_expression(inner_expr, refs)?;
248                let subquery_refs = Self::extract_table_references(subquery)?;
249                for table in subquery_refs.reads {
250                    refs.add_read(table);
251                }
252            }
253            SqlExpression::NotInSubquery {
254                expr: inner_expr,
255                subquery,
256            } => {
257                Self::extract_from_expression(inner_expr, refs)?;
258                let subquery_refs = Self::extract_table_references(subquery)?;
259                for table in subquery_refs.reads {
260                    refs.add_read(table);
261                }
262            }
263            SqlExpression::BinaryOp { left, right, .. } => {
264                Self::extract_from_expression(left, refs)?;
265                Self::extract_from_expression(right, refs)?;
266            }
267            SqlExpression::FunctionCall { args, .. } => {
268                for arg in args {
269                    Self::extract_from_expression(arg, refs)?;
270                }
271            }
272            SqlExpression::WindowFunction { args, .. } => {
273                for arg in args {
274                    Self::extract_from_expression(arg, refs)?;
275                }
276            }
277            SqlExpression::MethodCall { args, .. } => {
278                for arg in args {
279                    Self::extract_from_expression(arg, refs)?;
280                }
281            }
282            SqlExpression::ChainedMethodCall { base, args, .. } => {
283                Self::extract_from_expression(base, refs)?;
284                for arg in args {
285                    Self::extract_from_expression(arg, refs)?;
286                }
287            }
288            _ => {} // Other expression types don't contain table references
289        }
290        Ok(())
291    }
292
293    /// Check if SQL uses CREATE TEMP TABLE syntax (not SELECT INTO, which is handled by AST)
294    fn is_create_temp_table(sql: &str) -> bool {
295        let sql_lower = sql.to_lowercase();
296        sql_lower.contains("create temp table") || sql_lower.contains("create temporary table")
297    }
298
299    /// Compute execution plan for a target statement
300    /// Returns the minimal set of statements needed to execute the target
301    pub fn compute_execution_plan(
302        statements: &[DependencyStatement],
303        target_statement_number: usize,
304    ) -> Result<ExecutionPlan> {
305        if target_statement_number == 0 || target_statement_number > statements.len() {
306            return Err(anyhow!(
307                "Invalid target statement number: {}. Must be 1-{}",
308                target_statement_number,
309                statements.len()
310            ));
311        }
312
313        info!(
314            "Computing execution plan for statement #{}",
315            target_statement_number
316        );
317
318        // Build dependency graph: statement_number -> [depends_on_statement_numbers]
319        let mut dependency_graph: HashMap<usize, Vec<usize>> = HashMap::new();
320
321        // Build a map of table_name -> statement_number that creates it
322        let mut table_creators: HashMap<String, usize> = HashMap::new();
323
324        for stmt in statements {
325            // Register temp tables this statement creates
326            for table in &stmt.references.writes {
327                table_creators.insert(table.clone(), stmt.number);
328            }
329
330            // Find dependencies for tables this statement reads
331            let mut depends_on = Vec::new();
332            for table in &stmt.references.reads {
333                // Look for the latest statement that creates this table (before current statement)
334                for candidate in statements {
335                    if candidate.number >= stmt.number {
336                        break; // Only look at earlier statements
337                    }
338                    if candidate.references.writes.contains(table) {
339                        if !depends_on.contains(&candidate.number) {
340                            depends_on.push(candidate.number);
341                        }
342                    }
343                }
344            }
345
346            if !depends_on.is_empty() {
347                dependency_graph.insert(stmt.number, depends_on);
348            }
349        }
350
351        debug!("Dependency graph: {:?}", dependency_graph);
352
353        // Compute transitive dependencies using BFS
354        let mut to_execute = HashSet::new();
355        let mut queue = VecDeque::new();
356        queue.push_back(target_statement_number);
357
358        while let Some(stmt_num) = queue.pop_front() {
359            if to_execute.insert(stmt_num) {
360                // First time seeing this statement, add its dependencies to queue
361                if let Some(deps) = dependency_graph.get(&stmt_num) {
362                    for &dep in deps {
363                        queue.push_back(dep);
364                    }
365                }
366            }
367        }
368
369        // Sort statements to execute in order
370        let mut statements_to_execute: Vec<usize> = to_execute.into_iter().collect();
371        statements_to_execute.sort_unstable();
372
373        // Compute skipped statements
374        let statements_to_skip: Vec<usize> = (1..=statements.len())
375            .filter(|n| !statements_to_execute.contains(n))
376            .collect();
377
378        info!(
379            "Execution plan: execute {:?}, skip {:?}",
380            statements_to_execute, statements_to_skip
381        );
382
383        Ok(ExecutionPlan {
384            statements_to_execute,
385            statements_to_skip,
386            target_statement: target_statement_number,
387            dependency_graph,
388        })
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_simple_dependency() {
398        let statements = vec![
399            "SELECT * FROM sales INTO #raw_data".to_string(),
400            "SELECT COUNT(*) FROM customers".to_string(),
401            "SELECT * FROM #raw_data WHERE amount > 100".to_string(),
402        ];
403
404        let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
405        assert_eq!(analyzed.len(), 3);
406
407        // Statement 1: creates #raw_data, reads sales
408        assert_eq!(analyzed[0].references.writes, vec!["#raw_data"]);
409        assert_eq!(analyzed[0].references.reads, vec!["sales"]);
410
411        // Statement 2: reads customers (independent)
412        assert_eq!(analyzed[1].references.reads, vec!["customers"]);
413        assert!(analyzed[1].references.writes.is_empty());
414
415        // Statement 3: reads #raw_data
416        assert_eq!(analyzed[2].references.reads, vec!["#raw_data"]);
417    }
418
419    #[test]
420    fn test_execution_plan() {
421        let statements = vec![
422            "SELECT * FROM sales INTO #raw_data".to_string(),
423            "SELECT COUNT(*) FROM customers".to_string(),
424            "SELECT * FROM #raw_data WHERE amount > 100".to_string(),
425        ];
426
427        let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
428        let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 3).unwrap();
429
430        // Should execute statement 1 (creates #raw_data) and 3 (target)
431        // Should skip statement 2 (independent)
432        assert_eq!(plan.statements_to_execute, vec![1, 3]);
433        assert_eq!(plan.statements_to_skip, vec![2]);
434        assert_eq!(plan.target_statement, 3);
435    }
436
437    #[test]
438    fn test_transitive_dependencies() {
439        let statements = vec![
440            "SELECT * FROM base INTO #t1".to_string(),
441            "SELECT * FROM #t1 INTO #t2".to_string(),
442            "SELECT * FROM #t2 INTO #t3".to_string(),
443            "SELECT * FROM unrelated".to_string(),
444            "SELECT * FROM #t3".to_string(),
445        ];
446
447        let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
448        let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 5).unwrap();
449
450        // Should execute 1 -> 2 -> 3 -> 5 (transitive chain)
451        // Should skip 4 (unrelated)
452        assert_eq!(plan.statements_to_execute, vec![1, 2, 3, 5]);
453        assert_eq!(plan.statements_to_skip, vec![4]);
454    }
455
456    #[test]
457    fn test_invalid_statement_number() {
458        let statements = vec!["SELECT 1".to_string()];
459        let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
460
461        // Test statement number 0
462        assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 0).is_err());
463
464        // Test statement number > len
465        assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 5).is_err());
466    }
467}