sql_cli/sql/
script_parser.rs

1// Script parser for handling multi-statement SQL scripts with GO separator
2// Similar to SQL Server's batch execution model
3
4use anyhow::Result;
5
6/// Directives that can be attached to a script statement
7#[derive(Debug, Clone, PartialEq)]
8pub enum ScriptDirective {
9    /// Skip execution of this statement
10    Skip,
11}
12
13/// Type of script statement
14#[derive(Debug, Clone, PartialEq)]
15pub enum ScriptStatementType {
16    /// Regular SQL query
17    Query(String),
18    /// EXIT statement - stops script execution
19    /// Optional exit code (defaults to 0 for success)
20    Exit(Option<i32>),
21}
22
23/// A parsed script statement with optional directives
24#[derive(Debug, Clone)]
25pub struct ScriptStatement {
26    /// The type of statement (Query or Exit)
27    pub statement_type: ScriptStatementType,
28    /// Directives attached to this statement (from comments above it)
29    pub directives: Vec<ScriptDirective>,
30}
31
32impl ScriptStatement {
33    /// Check if this statement should be skipped
34    pub fn should_skip(&self) -> bool {
35        self.directives.contains(&ScriptDirective::Skip)
36    }
37
38    /// Check if this is an EXIT statement
39    pub fn is_exit(&self) -> bool {
40        matches!(self.statement_type, ScriptStatementType::Exit(_))
41    }
42
43    /// Get exit code if this is an EXIT statement
44    pub fn get_exit_code(&self) -> Option<i32> {
45        match &self.statement_type {
46            ScriptStatementType::Exit(code) => Some(code.unwrap_or(0)),
47            _ => None,
48        }
49    }
50
51    /// Get the SQL query if this is a query statement
52    pub fn get_query(&self) -> Option<&str> {
53        match &self.statement_type {
54            ScriptStatementType::Query(sql) => Some(sql),
55            ScriptStatementType::Exit(_) => None,
56        }
57    }
58}
59
60/// Parses SQL scripts into individual statements using GO as separator
61pub struct ScriptParser {
62    content: String,
63    data_file_hint: Option<String>,
64}
65
66impl ScriptParser {
67    /// Create a new script parser with the given content
68    pub fn new(content: &str) -> Self {
69        let data_file_hint = Self::extract_data_file_hint(content);
70        Self {
71            content: content.to_string(),
72            data_file_hint,
73        }
74    }
75
76    /// Extract data file hint from script comments
77    /// Looks for patterns like:
78    /// -- #!data: path/to/file.csv
79    /// -- #!datafile: path/to/file.csv  
80    /// -- #! /path/to/file.csv
81    fn extract_data_file_hint(content: &str) -> Option<String> {
82        for line in content.lines() {
83            let trimmed = line.trim();
84
85            // Skip non-comment lines
86            if !trimmed.starts_with("--") {
87                continue;
88            }
89
90            // Remove the comment prefix
91            let comment_content = trimmed.strip_prefix("--").unwrap().trim();
92
93            // Check for data file hint patterns
94            if let Some(path) = comment_content.strip_prefix("#!data:") {
95                return Some(path.trim().to_string());
96            }
97            if let Some(path) = comment_content.strip_prefix("#!datafile:") {
98                return Some(path.trim().to_string());
99            }
100            if let Some(path) = comment_content.strip_prefix("#!") {
101                let path = path.trim();
102                // Check if it looks like a file path
103                if path.contains('.') || path.contains('/') || path.contains('\\') {
104                    return Some(path.to_string());
105                }
106            }
107        }
108        None
109    }
110
111    /// Get the data file hint if present
112    pub fn data_file_hint(&self) -> Option<&str> {
113        self.data_file_hint.as_deref()
114    }
115
116    /// Parse directives from comment lines
117    /// Looks for patterns like: -- [SKIP], -- [TODO], etc.
118    fn parse_directives(comment_lines: &[String]) -> Vec<ScriptDirective> {
119        let mut directives = Vec::new();
120
121        for line in comment_lines {
122            let trimmed = line.trim();
123            if !trimmed.starts_with("--") {
124                continue;
125            }
126
127            let comment_content = trimmed.strip_prefix("--").unwrap().trim();
128
129            // Check for directive patterns: [SKIP], [IGNORE]
130            if comment_content.eq_ignore_ascii_case("[skip]")
131                || comment_content.eq_ignore_ascii_case("[ignore]")
132            {
133                directives.push(ScriptDirective::Skip);
134            }
135        }
136
137        directives
138    }
139
140    /// Parse the script into ScriptStatements with directives
141    /// GO must be on its own line (case-insensitive)
142    pub fn parse_script_statements(&self) -> Vec<ScriptStatement> {
143        let mut statements = Vec::new();
144        let mut current_statement = String::new();
145        let mut pending_comments = Vec::new();
146
147        for line in self.content.lines() {
148            let trimmed = line.trim();
149
150            // Check if this line is just "GO" (case-insensitive)
151            if trimmed.eq_ignore_ascii_case("go") {
152                // Add the current statement if it's not empty
153                let statement = current_statement.trim().to_string();
154                if !statement.is_empty() && !Self::is_comment_only(&statement) {
155                    // Parse directives from pending comments
156                    let directives = Self::parse_directives(&pending_comments);
157
158                    // Check if this is an EXIT statement
159                    let statement_type = Self::parse_exit_statement(&statement)
160                        .unwrap_or_else(|| ScriptStatementType::Query(statement));
161
162                    statements.push(ScriptStatement {
163                        statement_type,
164                        directives,
165                    });
166                }
167                current_statement.clear();
168                pending_comments.clear();
169            } else if trimmed.starts_with("--") {
170                // This is a comment line - save it for directive parsing
171                pending_comments.push(line.to_string());
172                // Also add to current statement
173                if !current_statement.is_empty() {
174                    current_statement.push('\n');
175                }
176                current_statement.push_str(line);
177            } else {
178                // Regular line - add to current statement
179                if !current_statement.is_empty() {
180                    current_statement.push('\n');
181                }
182                current_statement.push_str(line);
183            }
184        }
185
186        // Don't forget the last statement if there's no trailing GO
187        let statement = current_statement.trim().to_string();
188        if !statement.is_empty() && !Self::is_comment_only(&statement) {
189            let directives = Self::parse_directives(&pending_comments);
190
191            let statement_type = Self::parse_exit_statement(&statement)
192                .unwrap_or_else(|| ScriptStatementType::Query(statement));
193
194            statements.push(ScriptStatement {
195                statement_type,
196                directives,
197            });
198        }
199
200        statements
201    }
202
203    /// Try to parse an EXIT statement with optional exit code
204    /// Supports: EXIT, EXIT;, EXIT 0, EXIT 1;, etc.
205    /// Strips comments before checking
206    fn parse_exit_statement(statement: &str) -> Option<ScriptStatementType> {
207        // Extract non-comment content
208        let mut non_comment_lines = Vec::new();
209        for line in statement.lines() {
210            let trimmed = line.trim();
211            if !trimmed.is_empty() && !trimmed.starts_with("--") {
212                non_comment_lines.push(trimmed);
213            }
214        }
215
216        if non_comment_lines.is_empty() {
217            return None;
218        }
219
220        // Join non-comment lines and check if it's EXIT
221        let content = non_comment_lines.join(" ");
222        let trimmed = content.trim().trim_end_matches(';').trim();
223
224        if trimmed.eq_ignore_ascii_case("exit") {
225            return Some(ScriptStatementType::Exit(None));
226        }
227
228        // Check for EXIT with a number: EXIT 0, EXIT 1, etc.
229        let parts: Vec<&str> = trimmed.split_whitespace().collect();
230        if parts.len() == 2 && parts[0].eq_ignore_ascii_case("exit") {
231            if let Ok(code) = parts[1].parse::<i32>() {
232                return Some(ScriptStatementType::Exit(Some(code)));
233            }
234        }
235
236        None
237    }
238
239    /// Parse the script into individual SQL statements (legacy method)
240    /// GO must be on its own line (case-insensitive)
241    /// Returns a vector of SQL statements to execute
242    pub fn parse_statements(&self) -> Vec<String> {
243        self.parse_script_statements()
244            .into_iter()
245            .filter_map(|stmt| match stmt.statement_type {
246                ScriptStatementType::Query(sql) => Some(sql),
247                ScriptStatementType::Exit(_) => None,
248            })
249            .collect()
250    }
251
252    /// Check if a statement contains only comments (no actual SQL)
253    fn is_comment_only(statement: &str) -> bool {
254        for line in statement.lines() {
255            let trimmed = line.trim();
256            // Skip empty lines and comments
257            if trimmed.is_empty() || trimmed.starts_with("--") {
258                continue;
259            }
260            // If we find any non-comment content, it's not comment-only
261            return false;
262        }
263        // All lines were comments or empty
264        true
265    }
266
267    /// Parse and validate that all statements are valid SQL
268    /// Returns the statements or an error if any are invalid
269    pub fn parse_and_validate(&self) -> Result<Vec<String>> {
270        let statements = self.parse_statements();
271
272        if statements.is_empty() {
273            anyhow::bail!("No SQL statements found in script");
274        }
275
276        // Basic validation - ensure no statement is just whitespace
277        for (i, stmt) in statements.iter().enumerate() {
278            if stmt.trim().is_empty() {
279                anyhow::bail!("Empty statement at position {}", i + 1);
280            }
281        }
282
283        Ok(statements)
284    }
285}
286
287/// Result of executing a single statement in a script
288#[derive(Debug)]
289pub struct StatementResult {
290    pub statement_number: usize,
291    pub sql: String,
292    pub success: bool,
293    pub rows_affected: usize,
294    pub error_message: Option<String>,
295    pub execution_time_ms: f64,
296}
297
298/// Result of executing an entire script
299#[derive(Debug)]
300pub struct ScriptResult {
301    pub total_statements: usize,
302    pub successful_statements: usize,
303    pub failed_statements: usize,
304    pub total_execution_time_ms: f64,
305    pub statement_results: Vec<StatementResult>,
306}
307
308impl ScriptResult {
309    pub fn new() -> Self {
310        Self {
311            total_statements: 0,
312            successful_statements: 0,
313            failed_statements: 0,
314            total_execution_time_ms: 0.0,
315            statement_results: Vec::new(),
316        }
317    }
318
319    pub fn add_success(&mut self, statement_number: usize, sql: String, rows: usize, time_ms: f64) {
320        self.total_statements += 1;
321        self.successful_statements += 1;
322        self.total_execution_time_ms += time_ms;
323
324        self.statement_results.push(StatementResult {
325            statement_number,
326            sql,
327            success: true,
328            rows_affected: rows,
329            error_message: None,
330            execution_time_ms: time_ms,
331        });
332    }
333
334    pub fn add_failure(
335        &mut self,
336        statement_number: usize,
337        sql: String,
338        error: String,
339        time_ms: f64,
340    ) {
341        self.total_statements += 1;
342        self.failed_statements += 1;
343        self.total_execution_time_ms += time_ms;
344
345        self.statement_results.push(StatementResult {
346            statement_number,
347            sql,
348            success: false,
349            rows_affected: 0,
350            error_message: Some(error),
351            execution_time_ms: time_ms,
352        });
353    }
354
355    pub fn all_successful(&self) -> bool {
356        self.failed_statements == 0
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_parse_single_statement() {
366        let script = "SELECT * FROM users";
367        let parser = ScriptParser::new(script);
368        let statements = parser.parse_statements();
369
370        assert_eq!(statements.len(), 1);
371        assert_eq!(statements[0], "SELECT * FROM users");
372    }
373
374    #[test]
375    fn test_parse_multiple_statements_with_go() {
376        let script = r"
377SELECT * FROM users
378GO
379SELECT * FROM orders
380GO
381SELECT * FROM products
382";
383        let parser = ScriptParser::new(script);
384        let statements = parser.parse_statements();
385
386        assert_eq!(statements.len(), 3);
387        assert_eq!(statements[0].trim(), "SELECT * FROM users");
388        assert_eq!(statements[1].trim(), "SELECT * FROM orders");
389        assert_eq!(statements[2].trim(), "SELECT * FROM products");
390    }
391
392    #[test]
393    fn test_go_case_insensitive() {
394        let script = r"
395SELECT 1
396go
397SELECT 2
398Go
399SELECT 3
400GO
401";
402        let parser = ScriptParser::new(script);
403        let statements = parser.parse_statements();
404
405        assert_eq!(statements.len(), 3);
406    }
407
408    #[test]
409    fn test_go_in_string_not_separator() {
410        let script = r"
411SELECT 'This string contains GO but should not split' as test
412GO
413SELECT 'Another statement' as test2
414";
415        let parser = ScriptParser::new(script);
416        let statements = parser.parse_statements();
417
418        assert_eq!(statements.len(), 2);
419        assert!(statements[0].contains("GO but should not split"));
420    }
421
422    #[test]
423    fn test_multiline_statements() {
424        let script = r"
425SELECT 
426    id,
427    name,
428    email
429FROM users
430WHERE active = true
431GO
432SELECT COUNT(*) 
433FROM orders
434";
435        let parser = ScriptParser::new(script);
436        let statements = parser.parse_statements();
437
438        assert_eq!(statements.len(), 2);
439        assert!(statements[0].contains("WHERE active = true"));
440    }
441
442    #[test]
443    fn test_empty_statements_filtered() {
444        let script = r"
445GO
446SELECT 1
447GO
448GO
449SELECT 2
450GO
451";
452        let parser = ScriptParser::new(script);
453        let statements = parser.parse_statements();
454
455        assert_eq!(statements.len(), 2);
456        assert_eq!(statements[0].trim(), "SELECT 1");
457        assert_eq!(statements[1].trim(), "SELECT 2");
458    }
459}