sql_cli/sql/
script_parser.rs

1// Script parser for handling multi-statement SQL scripts with GO separator
2// Similar to SQL Server's batch execution model
3
4use anyhow::Result;
5
6/// Parses SQL scripts into individual statements using GO as separator
7pub struct ScriptParser {
8    content: String,
9    data_file_hint: Option<String>,
10}
11
12impl ScriptParser {
13    /// Create a new script parser with the given content
14    pub fn new(content: &str) -> Self {
15        let data_file_hint = Self::extract_data_file_hint(content);
16        Self {
17            content: content.to_string(),
18            data_file_hint,
19        }
20    }
21
22    /// Extract data file hint from script comments
23    /// Looks for patterns like:
24    /// -- #!data: path/to/file.csv
25    /// -- #!datafile: path/to/file.csv  
26    /// -- #! /path/to/file.csv
27    fn extract_data_file_hint(content: &str) -> Option<String> {
28        for line in content.lines() {
29            let trimmed = line.trim();
30
31            // Skip non-comment lines
32            if !trimmed.starts_with("--") {
33                continue;
34            }
35
36            // Remove the comment prefix
37            let comment_content = trimmed.strip_prefix("--").unwrap().trim();
38
39            // Check for data file hint patterns
40            if let Some(path) = comment_content.strip_prefix("#!data:") {
41                return Some(path.trim().to_string());
42            }
43            if let Some(path) = comment_content.strip_prefix("#!datafile:") {
44                return Some(path.trim().to_string());
45            }
46            if let Some(path) = comment_content.strip_prefix("#!") {
47                let path = path.trim();
48                // Check if it looks like a file path
49                if path.contains('.') || path.contains('/') || path.contains('\\') {
50                    return Some(path.to_string());
51                }
52            }
53        }
54        None
55    }
56
57    /// Get the data file hint if present
58    pub fn data_file_hint(&self) -> Option<&str> {
59        self.data_file_hint.as_deref()
60    }
61
62    /// Parse the script into individual SQL statements
63    /// GO must be on its own line (case-insensitive)
64    /// Returns a vector of SQL statements to execute
65    pub fn parse_statements(&self) -> Vec<String> {
66        let mut statements = Vec::new();
67        let mut current_statement = String::new();
68
69        for line in self.content.lines() {
70            let trimmed = line.trim();
71
72            // Check if this line is just "GO" (case-insensitive)
73            if trimmed.eq_ignore_ascii_case("go") {
74                // Add the current statement if it's not empty or just comments
75                let statement = current_statement.trim().to_string();
76                if !statement.is_empty() && !Self::is_comment_only(&statement) {
77                    statements.push(statement);
78                }
79                current_statement.clear();
80            } else {
81                // Add this line to the current statement
82                if !current_statement.is_empty() {
83                    current_statement.push('\n');
84                }
85                current_statement.push_str(line);
86            }
87        }
88
89        // Don't forget the last statement if there's no trailing GO
90        let statement = current_statement.trim().to_string();
91        if !statement.is_empty() && !Self::is_comment_only(&statement) {
92            statements.push(statement);
93        }
94
95        statements
96    }
97
98    /// Check if a statement contains only comments (no actual SQL)
99    fn is_comment_only(statement: &str) -> bool {
100        for line in statement.lines() {
101            let trimmed = line.trim();
102            // Skip empty lines and comments
103            if trimmed.is_empty() || trimmed.starts_with("--") {
104                continue;
105            }
106            // If we find any non-comment content, it's not comment-only
107            return false;
108        }
109        // All lines were comments or empty
110        true
111    }
112
113    /// Parse and validate that all statements are valid SQL
114    /// Returns the statements or an error if any are invalid
115    pub fn parse_and_validate(&self) -> Result<Vec<String>> {
116        let statements = self.parse_statements();
117
118        if statements.is_empty() {
119            anyhow::bail!("No SQL statements found in script");
120        }
121
122        // Basic validation - ensure no statement is just whitespace
123        for (i, stmt) in statements.iter().enumerate() {
124            if stmt.trim().is_empty() {
125                anyhow::bail!("Empty statement at position {}", i + 1);
126            }
127        }
128
129        Ok(statements)
130    }
131}
132
133/// Result of executing a single statement in a script
134#[derive(Debug)]
135pub struct StatementResult {
136    pub statement_number: usize,
137    pub sql: String,
138    pub success: bool,
139    pub rows_affected: usize,
140    pub error_message: Option<String>,
141    pub execution_time_ms: f64,
142}
143
144/// Result of executing an entire script
145#[derive(Debug)]
146pub struct ScriptResult {
147    pub total_statements: usize,
148    pub successful_statements: usize,
149    pub failed_statements: usize,
150    pub total_execution_time_ms: f64,
151    pub statement_results: Vec<StatementResult>,
152}
153
154impl ScriptResult {
155    pub fn new() -> Self {
156        Self {
157            total_statements: 0,
158            successful_statements: 0,
159            failed_statements: 0,
160            total_execution_time_ms: 0.0,
161            statement_results: Vec::new(),
162        }
163    }
164
165    pub fn add_success(&mut self, statement_number: usize, sql: String, rows: usize, time_ms: f64) {
166        self.total_statements += 1;
167        self.successful_statements += 1;
168        self.total_execution_time_ms += time_ms;
169
170        self.statement_results.push(StatementResult {
171            statement_number,
172            sql,
173            success: true,
174            rows_affected: rows,
175            error_message: None,
176            execution_time_ms: time_ms,
177        });
178    }
179
180    pub fn add_failure(
181        &mut self,
182        statement_number: usize,
183        sql: String,
184        error: String,
185        time_ms: f64,
186    ) {
187        self.total_statements += 1;
188        self.failed_statements += 1;
189        self.total_execution_time_ms += time_ms;
190
191        self.statement_results.push(StatementResult {
192            statement_number,
193            sql,
194            success: false,
195            rows_affected: 0,
196            error_message: Some(error),
197            execution_time_ms: time_ms,
198        });
199    }
200
201    pub fn all_successful(&self) -> bool {
202        self.failed_statements == 0
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_parse_single_statement() {
212        let script = "SELECT * FROM users";
213        let parser = ScriptParser::new(script);
214        let statements = parser.parse_statements();
215
216        assert_eq!(statements.len(), 1);
217        assert_eq!(statements[0], "SELECT * FROM users");
218    }
219
220    #[test]
221    fn test_parse_multiple_statements_with_go() {
222        let script = r"
223SELECT * FROM users
224GO
225SELECT * FROM orders
226GO
227SELECT * FROM products
228";
229        let parser = ScriptParser::new(script);
230        let statements = parser.parse_statements();
231
232        assert_eq!(statements.len(), 3);
233        assert_eq!(statements[0].trim(), "SELECT * FROM users");
234        assert_eq!(statements[1].trim(), "SELECT * FROM orders");
235        assert_eq!(statements[2].trim(), "SELECT * FROM products");
236    }
237
238    #[test]
239    fn test_go_case_insensitive() {
240        let script = r"
241SELECT 1
242go
243SELECT 2
244Go
245SELECT 3
246GO
247";
248        let parser = ScriptParser::new(script);
249        let statements = parser.parse_statements();
250
251        assert_eq!(statements.len(), 3);
252    }
253
254    #[test]
255    fn test_go_in_string_not_separator() {
256        let script = r"
257SELECT 'This string contains GO but should not split' as test
258GO
259SELECT 'Another statement' as test2
260";
261        let parser = ScriptParser::new(script);
262        let statements = parser.parse_statements();
263
264        assert_eq!(statements.len(), 2);
265        assert!(statements[0].contains("GO but should not split"));
266    }
267
268    #[test]
269    fn test_multiline_statements() {
270        let script = r"
271SELECT 
272    id,
273    name,
274    email
275FROM users
276WHERE active = true
277GO
278SELECT COUNT(*) 
279FROM orders
280";
281        let parser = ScriptParser::new(script);
282        let statements = parser.parse_statements();
283
284        assert_eq!(statements.len(), 2);
285        assert!(statements[0].contains("WHERE active = true"));
286    }
287
288    #[test]
289    fn test_empty_statements_filtered() {
290        let script = r"
291GO
292SELECT 1
293GO
294GO
295SELECT 2
296GO
297";
298        let parser = ScriptParser::new(script);
299        let statements = parser.parse_statements();
300
301        assert_eq!(statements.len(), 2);
302        assert_eq!(statements[0].trim(), "SELECT 1");
303        assert_eq!(statements[1].trim(), "SELECT 2");
304    }
305}