sql_cli/sql/
script_parser.rs1use anyhow::Result;
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum ScriptDirective {
9 Skip,
11}
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum ScriptStatementType {
16 Query(String),
18 Exit(Option<i32>),
21}
22
23#[derive(Debug, Clone)]
25pub struct ScriptStatement {
26 pub statement_type: ScriptStatementType,
28 pub directives: Vec<ScriptDirective>,
30}
31
32impl ScriptStatement {
33 pub fn should_skip(&self) -> bool {
35 self.directives.contains(&ScriptDirective::Skip)
36 }
37
38 pub fn is_exit(&self) -> bool {
40 matches!(self.statement_type, ScriptStatementType::Exit(_))
41 }
42
43 pub fn get_exit_code(&self) -> Option<i32> {
45 match &self.statement_type {
46 ScriptStatementType::Exit(code) => Some(code.unwrap_or(0)),
47 _ => None,
48 }
49 }
50
51 pub fn get_query(&self) -> Option<&str> {
53 match &self.statement_type {
54 ScriptStatementType::Query(sql) => Some(sql),
55 ScriptStatementType::Exit(_) => None,
56 }
57 }
58}
59
60pub struct ScriptParser {
62 content: String,
63 data_file_hint: Option<String>,
64}
65
66impl ScriptParser {
67 pub fn new(content: &str) -> Self {
69 let data_file_hint = Self::extract_data_file_hint(content);
70 Self {
71 content: content.to_string(),
72 data_file_hint,
73 }
74 }
75
76 fn extract_data_file_hint(content: &str) -> Option<String> {
82 for line in content.lines() {
83 let trimmed = line.trim();
84
85 if !trimmed.starts_with("--") {
87 continue;
88 }
89
90 let comment_content = trimmed.strip_prefix("--").unwrap().trim();
92
93 if let Some(path) = comment_content.strip_prefix("#!data:") {
95 return Some(path.trim().to_string());
96 }
97 if let Some(path) = comment_content.strip_prefix("#!datafile:") {
98 return Some(path.trim().to_string());
99 }
100 if let Some(path) = comment_content.strip_prefix("#!") {
101 let path = path.trim();
102 if path.contains('.') || path.contains('/') || path.contains('\\') {
104 return Some(path.to_string());
105 }
106 }
107 }
108 None
109 }
110
111 pub fn data_file_hint(&self) -> Option<&str> {
113 self.data_file_hint.as_deref()
114 }
115
116 fn parse_directives(comment_lines: &[String]) -> Vec<ScriptDirective> {
119 let mut directives = Vec::new();
120
121 for line in comment_lines {
122 let trimmed = line.trim();
123 if !trimmed.starts_with("--") {
124 continue;
125 }
126
127 let comment_content = trimmed.strip_prefix("--").unwrap().trim();
128
129 if comment_content.eq_ignore_ascii_case("[skip]")
131 || comment_content.eq_ignore_ascii_case("[ignore]")
132 {
133 directives.push(ScriptDirective::Skip);
134 }
135 }
136
137 directives
138 }
139
140 pub fn parse_script_statements(&self) -> Vec<ScriptStatement> {
143 let mut statements = Vec::new();
144 let mut current_statement = String::new();
145 let mut pending_comments = Vec::new();
146
147 for line in self.content.lines() {
148 let trimmed = line.trim();
149
150 if trimmed.eq_ignore_ascii_case("go") {
152 let statement = current_statement.trim().to_string();
154 if !statement.is_empty() && !Self::is_comment_only(&statement) {
155 let directives = Self::parse_directives(&pending_comments);
157
158 let statement_type = Self::parse_exit_statement(&statement)
160 .unwrap_or_else(|| ScriptStatementType::Query(statement));
161
162 statements.push(ScriptStatement {
163 statement_type,
164 directives,
165 });
166 }
167 current_statement.clear();
168 pending_comments.clear();
169 } else if trimmed.starts_with("--") {
170 pending_comments.push(line.to_string());
172 if !current_statement.is_empty() {
174 current_statement.push('\n');
175 }
176 current_statement.push_str(line);
177 } else {
178 if !current_statement.is_empty() {
180 current_statement.push('\n');
181 }
182 current_statement.push_str(line);
183 }
184 }
185
186 let statement = current_statement.trim().to_string();
188 if !statement.is_empty() && !Self::is_comment_only(&statement) {
189 let directives = Self::parse_directives(&pending_comments);
190
191 let statement_type = Self::parse_exit_statement(&statement)
192 .unwrap_or_else(|| ScriptStatementType::Query(statement));
193
194 statements.push(ScriptStatement {
195 statement_type,
196 directives,
197 });
198 }
199
200 statements
201 }
202
203 fn parse_exit_statement(statement: &str) -> Option<ScriptStatementType> {
207 let mut non_comment_lines = Vec::new();
209 for line in statement.lines() {
210 let trimmed = line.trim();
211 if !trimmed.is_empty() && !trimmed.starts_with("--") {
212 non_comment_lines.push(trimmed);
213 }
214 }
215
216 if non_comment_lines.is_empty() {
217 return None;
218 }
219
220 let content = non_comment_lines.join(" ");
222 let trimmed = content.trim().trim_end_matches(';').trim();
223
224 if trimmed.eq_ignore_ascii_case("exit") {
225 return Some(ScriptStatementType::Exit(None));
226 }
227
228 let parts: Vec<&str> = trimmed.split_whitespace().collect();
230 if parts.len() == 2 && parts[0].eq_ignore_ascii_case("exit") {
231 if let Ok(code) = parts[1].parse::<i32>() {
232 return Some(ScriptStatementType::Exit(Some(code)));
233 }
234 }
235
236 None
237 }
238
239 pub fn parse_statements(&self) -> Vec<String> {
243 self.parse_script_statements()
244 .into_iter()
245 .filter_map(|stmt| match stmt.statement_type {
246 ScriptStatementType::Query(sql) => Some(sql),
247 ScriptStatementType::Exit(_) => None,
248 })
249 .collect()
250 }
251
252 fn is_comment_only(statement: &str) -> bool {
254 for line in statement.lines() {
255 let trimmed = line.trim();
256 if trimmed.is_empty() || trimmed.starts_with("--") {
258 continue;
259 }
260 return false;
262 }
263 true
265 }
266
267 pub fn parse_and_validate(&self) -> Result<Vec<String>> {
270 let statements = self.parse_statements();
271
272 if statements.is_empty() {
273 anyhow::bail!("No SQL statements found in script");
274 }
275
276 for (i, stmt) in statements.iter().enumerate() {
278 if stmt.trim().is_empty() {
279 anyhow::bail!("Empty statement at position {}", i + 1);
280 }
281 }
282
283 Ok(statements)
284 }
285}
286
287#[derive(Debug)]
289pub struct StatementResult {
290 pub statement_number: usize,
291 pub sql: String,
292 pub success: bool,
293 pub rows_affected: usize,
294 pub error_message: Option<String>,
295 pub execution_time_ms: f64,
296}
297
298#[derive(Debug)]
300pub struct ScriptResult {
301 pub total_statements: usize,
302 pub successful_statements: usize,
303 pub failed_statements: usize,
304 pub total_execution_time_ms: f64,
305 pub statement_results: Vec<StatementResult>,
306}
307
308impl ScriptResult {
309 pub fn new() -> Self {
310 Self {
311 total_statements: 0,
312 successful_statements: 0,
313 failed_statements: 0,
314 total_execution_time_ms: 0.0,
315 statement_results: Vec::new(),
316 }
317 }
318
319 pub fn add_success(&mut self, statement_number: usize, sql: String, rows: usize, time_ms: f64) {
320 self.total_statements += 1;
321 self.successful_statements += 1;
322 self.total_execution_time_ms += time_ms;
323
324 self.statement_results.push(StatementResult {
325 statement_number,
326 sql,
327 success: true,
328 rows_affected: rows,
329 error_message: None,
330 execution_time_ms: time_ms,
331 });
332 }
333
334 pub fn add_failure(
335 &mut self,
336 statement_number: usize,
337 sql: String,
338 error: String,
339 time_ms: f64,
340 ) {
341 self.total_statements += 1;
342 self.failed_statements += 1;
343 self.total_execution_time_ms += time_ms;
344
345 self.statement_results.push(StatementResult {
346 statement_number,
347 sql,
348 success: false,
349 rows_affected: 0,
350 error_message: Some(error),
351 execution_time_ms: time_ms,
352 });
353 }
354
355 pub fn all_successful(&self) -> bool {
356 self.failed_statements == 0
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_parse_single_statement() {
366 let script = "SELECT * FROM users";
367 let parser = ScriptParser::new(script);
368 let statements = parser.parse_statements();
369
370 assert_eq!(statements.len(), 1);
371 assert_eq!(statements[0], "SELECT * FROM users");
372 }
373
374 #[test]
375 fn test_parse_multiple_statements_with_go() {
376 let script = r"
377SELECT * FROM users
378GO
379SELECT * FROM orders
380GO
381SELECT * FROM products
382";
383 let parser = ScriptParser::new(script);
384 let statements = parser.parse_statements();
385
386 assert_eq!(statements.len(), 3);
387 assert_eq!(statements[0].trim(), "SELECT * FROM users");
388 assert_eq!(statements[1].trim(), "SELECT * FROM orders");
389 assert_eq!(statements[2].trim(), "SELECT * FROM products");
390 }
391
392 #[test]
393 fn test_go_case_insensitive() {
394 let script = r"
395SELECT 1
396go
397SELECT 2
398Go
399SELECT 3
400GO
401";
402 let parser = ScriptParser::new(script);
403 let statements = parser.parse_statements();
404
405 assert_eq!(statements.len(), 3);
406 }
407
408 #[test]
409 fn test_go_in_string_not_separator() {
410 let script = r"
411SELECT 'This string contains GO but should not split' as test
412GO
413SELECT 'Another statement' as test2
414";
415 let parser = ScriptParser::new(script);
416 let statements = parser.parse_statements();
417
418 assert_eq!(statements.len(), 2);
419 assert!(statements[0].contains("GO but should not split"));
420 }
421
422 #[test]
423 fn test_multiline_statements() {
424 let script = r"
425SELECT
426 id,
427 name,
428 email
429FROM users
430WHERE active = true
431GO
432SELECT COUNT(*)
433FROM orders
434";
435 let parser = ScriptParser::new(script);
436 let statements = parser.parse_statements();
437
438 assert_eq!(statements.len(), 2);
439 assert!(statements[0].contains("WHERE active = true"));
440 }
441
442 #[test]
443 fn test_empty_statements_filtered() {
444 let script = r"
445GO
446SELECT 1
447GO
448GO
449SELECT 2
450GO
451";
452 let parser = ScriptParser::new(script);
453 let statements = parser.parse_statements();
454
455 assert_eq!(statements.len(), 2);
456 assert_eq!(statements[0].trim(), "SELECT 1");
457 assert_eq!(statements[1].trim(), "SELECT 2");
458 }
459}