sql_cli/query_plan/
dependency_analyzer.rs

1use anyhow::Result;
2use std::collections::{HashMap, HashSet};
3
4use crate::sql::parser::ast::SelectStatement;
5use crate::sql::recursive_parser::Parser;
6use crate::sql::script_parser::ScriptParser;
7
8/// Represents a single statement in the script with its dependencies
9#[derive(Debug, Clone)]
10pub struct StatementNode {
11    /// 1-based index in the script
12    pub index: usize,
13    /// The SQL text of this statement
14    pub sql: String,
15    /// Temporary tables this statement creates (e.g., ["#temp", "#summary"])
16    pub creates_tables: Vec<String>,
17    /// Tables this statement depends on (both temp tables and base tables)
18    pub depends_on_tables: Vec<String>,
19    /// The parsed AST (if parsing succeeded)
20    pub ast: Option<SelectStatement>,
21}
22
23/// Dependency graph for a SQL script (temp table dependencies between statements)
24#[derive(Debug)]
25pub struct ScriptDependencyGraph {
26    /// All statements in the script
27    pub statements: Vec<StatementNode>,
28    /// Map of table name -> statement index that creates it
29    pub table_creators: HashMap<String, usize>,
30}
31
32impl ScriptDependencyGraph {
33    /// Analyze a SQL script and build the dependency graph
34    ///
35    /// This parses the entire script, identifies what tables each statement
36    /// creates and depends on, and builds a graph of dependencies.
37    pub fn analyze(script_content: &str) -> Result<Self> {
38        let script_parser = ScriptParser::new(script_content);
39        let script_statements = script_parser.parse_script_statements();
40
41        let mut statements = Vec::new();
42        let mut table_creators = HashMap::new();
43
44        for (idx, script_stmt) in script_statements.iter().enumerate() {
45            let statement_num = idx + 1;
46
47            // Skip EXIT and [SKIP] directives
48            if script_stmt.is_exit() || script_stmt.should_skip() {
49                continue;
50            }
51
52            // Get the SQL query
53            let sql = match script_stmt.get_query() {
54                Some(s) => s.to_string(),
55                None => continue,
56            };
57
58            // Parse the statement to extract dependencies
59            let mut parser = Parser::new(&sql);
60            let ast = parser.parse().ok();
61
62            let mut creates_tables = Vec::new();
63            let mut depends_on_tables = Vec::new();
64
65            if let Some(ref stmt) = ast {
66                // Check if this statement creates a temp table (INTO clause)
67                if let Some(ref into_table) = stmt.into_table {
68                    creates_tables.push(into_table.name.clone());
69                    table_creators.insert(into_table.name.clone(), statement_num);
70                }
71
72                // Check what tables this statement depends on
73                if let Some(ref from_table) = stmt.from_table {
74                    depends_on_tables.push(from_table.clone());
75                }
76
77                // Check JOIN clauses for dependencies
78                for join in &stmt.joins {
79                    if let crate::sql::parser::ast::TableSource::Table(table_name) = &join.table {
80                        depends_on_tables.push(table_name.clone());
81                    }
82                }
83
84                // TODO: Could also check subqueries, CTEs, etc. for more complete analysis
85            }
86
87            statements.push(StatementNode {
88                index: statement_num,
89                sql,
90                creates_tables,
91                depends_on_tables,
92                ast,
93            });
94        }
95
96        Ok(Self {
97            statements,
98            table_creators,
99        })
100    }
101
102    /// Get the minimal set of statement indices needed to execute a target statement
103    ///
104    /// Returns statements in execution order (dependencies first, target last)
105    pub fn get_dependencies(&self, target_index: usize) -> Result<Vec<usize>> {
106        if target_index == 0 || target_index > self.statements.len() {
107            anyhow::bail!(
108                "Invalid statement index: {}. Script has {} statements.",
109                target_index,
110                self.statements.len()
111            );
112        }
113
114        let mut required = HashSet::new();
115        let mut to_process = vec![target_index];
116
117        // Traverse dependency graph backwards
118        while let Some(stmt_idx) = to_process.pop() {
119            if required.contains(&stmt_idx) {
120                continue; // Already processed
121            }
122
123            required.insert(stmt_idx);
124
125            // Find the statement (0-indexed in Vec, but stmt_idx is 1-indexed)
126            if let Some(stmt) = self.statements.iter().find(|s| s.index == stmt_idx) {
127                // For each table this statement depends on
128                for table in &stmt.depends_on_tables {
129                    // If it's a temp table, find who creates it
130                    if table.starts_with('#') {
131                        if let Some(&creator_idx) = self.table_creators.get(table) {
132                            to_process.push(creator_idx);
133                        }
134                        // If temp table is not found, it will error during execution
135                    }
136                    // Base tables don't need dependency tracking
137                }
138            }
139        }
140
141        // Convert to sorted vector (execution order)
142        let mut result: Vec<usize> = required.into_iter().collect();
143        result.sort();
144
145        Ok(result)
146    }
147
148    /// Generate a debug report showing the dependency analysis
149    pub fn explain_dependencies(&self, target_index: usize) -> Result<String> {
150        let deps = self.get_dependencies(target_index)?;
151
152        let mut output = String::new();
153        output.push_str("\n=== Dependency Analysis ===\n");
154        output.push_str(&format!(
155            "Script has {} statements total\n",
156            self.statements.len()
157        ));
158        output.push_str(&format!("Target: Statement {}\n\n", target_index));
159
160        // Show all relevant statements
161        for &stmt_idx in &deps {
162            if let Some(stmt) = self.statements.iter().find(|s| s.index == stmt_idx) {
163                let is_target = stmt_idx == target_index;
164                let marker = if is_target { " [TARGET]" } else { "" };
165
166                output.push_str(&format!("Statement {}{}\n", stmt_idx, marker));
167
168                // Show abbreviated SQL (first 60 chars)
169                let sql_preview = if stmt.sql.len() > 60 {
170                    format!("{}...", &stmt.sql[..60])
171                } else {
172                    stmt.sql.clone()
173                };
174                output.push_str(&format!("  SQL: {}\n", sql_preview.replace('\n', " ")));
175
176                if !stmt.creates_tables.is_empty() {
177                    output.push_str(&format!("  Creates: {}\n", stmt.creates_tables.join(", ")));
178                }
179
180                if !stmt.depends_on_tables.is_empty() {
181                    output.push_str(&format!(
182                        "  Depends on: {}\n",
183                        stmt.depends_on_tables.join(", ")
184                    ));
185                }
186
187                output.push('\n');
188            }
189        }
190
191        output.push_str("Execution Plan:\n");
192        for &stmt_idx in &deps {
193            let marker = if stmt_idx == target_index {
194                " ← target"
195            } else {
196                ""
197            };
198            output.push_str(&format!("  → Statement {}{}\n", stmt_idx, marker));
199        }
200
201        output.push_str(&format!(
202            "\nExecuting {} of {} statements...\n",
203            deps.len(),
204            self.statements.len()
205        ));
206
207        Ok(output)
208    }
209
210    /// Get a statement by its index
211    pub fn get_statement(&self, index: usize) -> Option<&StatementNode> {
212        self.statements.iter().find(|s| s.index == index)
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_simple_dependency_chain() {
222        let script = r#"
223SELECT * INTO #temp FROM data WHERE value > 100;
224GO
225
226SELECT * INTO #summary FROM #temp GROUP BY category;
227GO
228
229SELECT * FROM #summary WHERE total > 500;
230GO
231"#;
232
233        let graph = ScriptDependencyGraph::analyze(script).unwrap();
234
235        // Should have 3 statements
236        assert_eq!(graph.statements.len(), 3);
237
238        // Statement 1 creates #temp
239        assert_eq!(graph.statements[0].creates_tables, vec!["#temp"]);
240
241        // Statement 2 depends on #temp and creates #summary
242        assert!(graph.statements[1]
243            .depends_on_tables
244            .contains(&"#temp".to_string()));
245        assert_eq!(graph.statements[1].creates_tables, vec!["#summary"]);
246
247        // Statement 3 depends on #summary
248        assert!(graph.statements[2]
249            .depends_on_tables
250            .contains(&"#summary".to_string()));
251
252        // Get dependencies for statement 3
253        let deps = graph.get_dependencies(3).unwrap();
254        assert_eq!(deps, vec![1, 2, 3]); // All three statements needed
255    }
256
257    #[test]
258    fn test_independent_statements() {
259        let script = r#"
260SELECT * FROM data1;
261GO
262
263SELECT * FROM data2;
264GO
265
266SELECT * FROM data3;
267GO
268"#;
269
270        let graph = ScriptDependencyGraph::analyze(script).unwrap();
271        assert_eq!(graph.statements.len(), 3);
272
273        // Statement 3 only needs itself (no temp tables)
274        let deps = graph.get_dependencies(3).unwrap();
275        assert_eq!(deps, vec![3]);
276    }
277
278    #[test]
279    fn test_partial_dependency() {
280        let script = r#"
281SELECT * INTO #temp1 FROM data;
282GO
283
284SELECT * INTO #temp2 FROM data;
285GO
286
287SELECT * FROM #temp2;
288GO
289"#;
290
291        let graph = ScriptDependencyGraph::analyze(script).unwrap();
292
293        // Statement 3 only needs statement 2 (creates #temp2)
294        let deps = graph.get_dependencies(3).unwrap();
295        assert_eq!(deps, vec![2, 3]);
296    }
297
298    #[test]
299    fn test_explain_output() {
300        let script = r#"
301SELECT * INTO #temp FROM data;
302GO
303
304SELECT * FROM #temp;
305GO
306"#;
307
308        let graph = ScriptDependencyGraph::analyze(script).unwrap();
309        let explanation = graph.explain_dependencies(2).unwrap();
310
311        // Should mention both statements
312        assert!(explanation.contains("Statement 1"));
313        assert!(explanation.contains("Statement 2"));
314        assert!(explanation.contains("[TARGET]"));
315        assert!(explanation.contains("Creates: #temp"));
316        assert!(explanation.contains("Depends on: #temp"));
317    }
318}