Skip to main content

scythe_core/parser/
mod.rs

1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum QueryCommand {
8    One,
9    Opt,
10    Many,
11    Exec,
12    ExecResult,
13    ExecRows,
14    Batch,
15    Grouped,
16}
17
18impl std::fmt::Display for QueryCommand {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            QueryCommand::One => write!(f, "one"),
22            QueryCommand::Opt => write!(f, "opt"),
23            QueryCommand::Many => write!(f, "many"),
24            QueryCommand::Exec => write!(f, "exec"),
25            QueryCommand::ExecResult => write!(f, "exec_result"),
26            QueryCommand::ExecRows => write!(f, "exec_rows"),
27            QueryCommand::Batch => write!(f, "batch"),
28            QueryCommand::Grouped => write!(f, "grouped"),
29        }
30    }
31}
32
33impl QueryCommand {
34    fn from_str(s: &str) -> Result<Self, ScytheError> {
35        match s {
36            "one" => Ok(QueryCommand::One),
37            "opt" => Ok(QueryCommand::Opt),
38            "many" => Ok(QueryCommand::Many),
39            "exec" => Ok(QueryCommand::Exec),
40            "exec_result" => Ok(QueryCommand::ExecResult),
41            "exec_rows" => Ok(QueryCommand::ExecRows),
42            "batch" => Ok(QueryCommand::Batch),
43            "grouped" => Ok(QueryCommand::Grouped),
44            other => Err(ScytheError::invalid_annotation(format!(
45                "invalid @returns value: {other}"
46            ))),
47        }
48    }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct ParamDoc {
53    pub name: String,
54    pub description: String,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct JsonMapping {
59    pub column: String,
60    pub rust_type: String,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct Annotations {
65    pub name: String,
66    pub command: QueryCommand,
67    pub param_docs: Vec<ParamDoc>,
68    pub nullable_overrides: Vec<String>,
69    pub nonnull_overrides: Vec<String>,
70    pub json_mappings: Vec<JsonMapping>,
71    pub deprecated: Option<String>,
72    pub optional_params: Vec<String>,
73    pub group_by: Option<String>,
74}
75
76#[derive(Debug)]
77pub struct Query {
78    pub name: String,
79    pub command: QueryCommand,
80    pub sql: String,
81    pub stmt: sqlparser::ast::Statement,
82    pub annotations: Annotations,
83}
84
85/// Parse a single annotated SQL query into a `Query` using the PostgreSQL dialect.
86pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
87    parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
88}
89
90/// Parse a single annotated SQL query into a `Query` using the specified dialect.
91pub fn parse_query_with_dialect(
92    query_sql: &str,
93    dialect: &SqlDialect,
94) -> Result<Query, ScytheError> {
95    let mut name: Option<String> = None;
96    let mut command: Option<QueryCommand> = None;
97    let mut param_docs = Vec::new();
98    let mut nullable_overrides = Vec::new();
99    let mut nonnull_overrides = Vec::new();
100    let mut json_mappings = Vec::new();
101    let mut deprecated: Option<String> = None;
102    let mut optional_params = Vec::new();
103    let mut group_by: Option<String> = None;
104
105    let mut sql_lines = Vec::new();
106
107    for line in query_sql.lines() {
108        let trimmed = line.trim();
109
110        // Check for annotation: "-- @..." or "--@..."
111        let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
112            let rest = rest.trim_start();
113            rest.strip_prefix('@')
114        } else {
115            None
116        };
117
118        if let Some(body) = annotation_body {
119            // Parse the annotation keyword and value
120            let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
121                Some(pos) => (&body[..pos], body[pos..].trim()),
122                None => (body, ""),
123            };
124
125            match keyword.to_ascii_lowercase().as_str() {
126                "name" => {
127                    name = Some(value.to_string());
128                }
129                "returns" => {
130                    let cmd_str = value.strip_prefix(':').unwrap_or(value);
131                    command = Some(QueryCommand::from_str(cmd_str)?);
132                }
133                "param" => {
134                    // format: "<name>: <description>" or "<name>:<description>"
135                    if let Some(colon_pos) = value.find(':') {
136                        let param_name = value[..colon_pos].trim().to_string();
137                        let description = value[colon_pos + 1..].trim().to_string();
138                        param_docs.push(ParamDoc {
139                            name: param_name,
140                            description,
141                        });
142                    } else {
143                        param_docs.push(ParamDoc {
144                            name: value.to_string(),
145                            description: String::new(),
146                        });
147                    }
148                }
149                "nullable" => {
150                    for col in value.split(',') {
151                        let col = col.trim();
152                        if !col.is_empty() {
153                            nullable_overrides.push(col.to_string());
154                        }
155                    }
156                }
157                "nonnull" => {
158                    for col in value.split(',') {
159                        let col = col.trim();
160                        if !col.is_empty() {
161                            nonnull_overrides.push(col.to_string());
162                        }
163                    }
164                }
165                "json" => {
166                    // format: "<col> = <Type>"
167                    if let Some(eq_pos) = value.find('=') {
168                        let column = value[..eq_pos].trim().to_string();
169                        let rust_type = value[eq_pos + 1..].trim().to_string();
170                        json_mappings.push(JsonMapping { column, rust_type });
171                    }
172                }
173                "deprecated" => {
174                    deprecated = Some(value.to_string());
175                }
176                "group_by" => {
177                    group_by = Some(value.to_string());
178                }
179                "optional" => {
180                    for param in value.split(',') {
181                        let param = param.trim();
182                        if !param.is_empty() {
183                            optional_params.push(param.to_string());
184                        }
185                    }
186                }
187                _ => {
188                    // Unknown annotation — ignore or could error
189                }
190            }
191        } else {
192            sql_lines.push(line);
193        }
194    }
195
196    let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
197    let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
198
199    if command == QueryCommand::Grouped && group_by.is_none() {
200        return Err(ScytheError::invalid_annotation(
201            "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
202        ));
203    }
204
205    let sql = sql_lines.join("\n").trim().to_string();
206
207    if sql.is_empty() {
208        return Err(ScytheError::syntax("empty SQL body"));
209    }
210
211    // Preprocess dialect-specific syntax before parsing:
212    // - Oracle: strip `RETURNING ... INTO` output binds, convert `:N` → `?`
213    // - MSSQL: convert `OUTPUT INSERTED.*` → `RETURNING` for parsing,
214    //          convert `@pN` → `?` for parsing; keep original SQL for codegen
215    let (sql, parse_sql) = if *dialect == SqlDialect::Oracle {
216        let processed = preprocess_oracle_sql(&sql);
217        (processed.clone(), processed)
218    } else if *dialect == SqlDialect::MsSql {
219        // For codegen: only convert @pN → ? placeholders (keep OUTPUT syntax)
220        let codegen_sql = convert_mssql_placeholders(&sql);
221        // For parsing: also convert OUTPUT INSERTED → RETURNING
222        let parse_sql = preprocess_mssql_sql(&sql);
223        (codegen_sql, parse_sql)
224    } else {
225        (sql.clone(), sql)
226    };
227
228    let parser_dialect = dialect.to_sqlparser_dialect();
229    let statements = Parser::parse_sql(parser_dialect.as_ref(), &parse_sql)
230        .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
231
232    if statements.len() != 1 {
233        // sqlparser may produce an extra empty statement from a trailing semicolon —
234        // filter those out by checking for exactly one non-empty statement.
235        let non_empty: Vec<_> = statements
236            .into_iter()
237            .filter(|s| {
238                !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
239            })
240            .collect();
241        if non_empty.len() != 1 {
242            return Err(ScytheError::syntax("expected exactly one SQL statement"));
243        }
244        let stmt = non_empty
245            .into_iter()
246            .next()
247            .expect("filtered to exactly one statement");
248        let annotations = Annotations {
249            name: name.clone(),
250            command: command.clone(),
251            param_docs,
252            nullable_overrides,
253            nonnull_overrides,
254            json_mappings,
255            deprecated,
256            optional_params,
257            group_by: group_by.clone(),
258        };
259        return Ok(Query {
260            name,
261            command,
262            sql,
263            stmt,
264            annotations,
265        });
266    }
267
268    let stmt = statements
269        .into_iter()
270        .next()
271        .expect("filtered to exactly one statement");
272
273    let annotations = Annotations {
274        name: name.clone(),
275        command: command.clone(),
276        param_docs,
277        nullable_overrides,
278        nonnull_overrides,
279        json_mappings,
280        deprecated,
281        optional_params,
282        group_by,
283    };
284
285    Ok(Query {
286        name,
287        command,
288        sql,
289        stmt,
290        annotations,
291    })
292}
293
294/// Preprocess Oracle SQL before parsing:
295/// 1. Strip `INTO :N, :N, ...` suffix from `RETURNING ... INTO` clauses
296/// 2. Convert `:N` positional placeholders to `?` (universally supported)
297fn preprocess_oracle_sql(sql: &str) -> String {
298    // Strip Oracle RETURNING ... INTO clause (output bind variables)
299    // e.g. "INSERT ... RETURNING id, name INTO :4, :5" → "INSERT ... RETURNING id, name"
300    let sql = strip_returning_into(sql);
301
302    // Convert :N → ? (outside string literals)
303    let mut result = String::with_capacity(sql.len());
304    let mut chars = sql.chars().peekable();
305    while let Some(ch) = chars.next() {
306        if ch == '\'' {
307            // Skip string literals
308            result.push(ch);
309            while let Some(inner) = chars.next() {
310                result.push(inner);
311                if inner == '\'' {
312                    if chars.peek() == Some(&'\'') {
313                        result.push(chars.next().unwrap());
314                    } else {
315                        break;
316                    }
317                }
318            }
319        } else if ch == ':' && chars.peek().is_some_and(|c| c.is_ascii_digit()) {
320            // Convert :N → ?
321            result.push('?');
322            while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
323                chars.next();
324            }
325        } else {
326            result.push(ch);
327        }
328    }
329    result
330}
331
332/// Convert MSSQL `@pN` positional placeholders to `?` (outside string literals).
333/// MsSqlDialect treats `@` as an identifier start, so `@p1` becomes an identifier
334/// rather than a `Placeholder` token — preprocessing normalises it to `?`.
335fn convert_mssql_placeholders(sql: &str) -> String {
336    let mut result = String::with_capacity(sql.len());
337    let mut chars = sql.chars().peekable();
338    while let Some(ch) = chars.next() {
339        if ch == '\'' {
340            // Skip string literals verbatim
341            result.push(ch);
342            while let Some(inner) = chars.next() {
343                result.push(inner);
344                if inner == '\'' {
345                    if chars.peek() == Some(&'\'') {
346                        // Escaped quote inside string literal
347                        result.push(chars.next().unwrap());
348                    } else {
349                        break;
350                    }
351                }
352            }
353        } else if ch == '@' && chars.peek().is_some_and(|c| *c == 'p' || *c == 'P') {
354            // Peek ahead: must be `@p` followed by at least one digit
355            let mut lookahead = chars.clone();
356            lookahead.next(); // consume the 'p'/'P'
357            if lookahead.peek().is_some_and(|c| c.is_ascii_digit()) {
358                // It is an `@pN` placeholder — consume `p` and all digits
359                chars.next(); // consume 'p'/'P'
360                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
361                    chars.next();
362                }
363                result.push('?');
364            } else {
365                result.push(ch);
366            }
367        } else {
368            result.push(ch);
369        }
370    }
371    result
372}
373
374/// Preprocess MSSQL SQL before parsing:
375/// 1. Strip `OUTPUT INSERTED.col, ...` clauses and convert to RETURNING
376/// 2. Convert `@pN` positional placeholders to `?`
377fn preprocess_mssql_sql(sql: &str) -> String {
378    // First pass: convert OUTPUT INSERTED.col to RETURNING col
379    let sql = strip_and_convert_mssql_output(sql);
380    // Second pass: convert @pN to ?
381    convert_mssql_placeholders(&sql)
382}
383
384/// Strip MSSQL `OUTPUT INSERTED.col1, INSERTED.col2, ...` from INSERT statements
385/// and convert it to a `RETURNING col1, col2, ...` clause.
386/// The OUTPUT clause appears between the column list and VALUES clause:
387///   INSERT INTO table (cols) OUTPUT INSERTED.col1, INSERTED.col2, ... VALUES (...)
388/// becomes:
389///   INSERT INTO table (cols) VALUES (...) RETURNING col1, col2, ...
390fn strip_and_convert_mssql_output(sql: &str) -> String {
391    // Case-insensitive search for OUTPUT keyword in INSERT statements
392    let upper = sql.to_uppercase();
393
394    // Only process INSERT statements with OUTPUT
395    if !upper.contains("INSERT") || !upper.contains("OUTPUT") {
396        return sql.to_string();
397    }
398
399    // Find the OUTPUT keyword
400    if let Some(output_pos) = find_word_position(&upper, "OUTPUT") {
401        // Check if this is actually part of an INSERT statement by finding INSERT before it
402        let before_output = &upper[..output_pos];
403        if !before_output.contains("INSERT") {
404            return sql.to_string();
405        }
406
407        // Look for the VALUES keyword after OUTPUT
408        let after_output = &upper[output_pos + "OUTPUT".len()..];
409        if let Some(values_offset) = find_word_position(after_output, "VALUES") {
410            let values_pos = output_pos + "OUTPUT".len() + values_offset;
411
412            // Extract the OUTPUT column list (between OUTPUT and VALUES)
413            let output_cols_str = &sql[output_pos + "OUTPUT".len()..values_pos];
414
415            // Parse column names: strip "INSERTED." prefix from each column name
416            let cols = parse_inserted_columns(output_cols_str);
417
418            if !cols.is_empty() {
419                // Build result: keep everything before OUTPUT, then VALUES clause,
420                // then RETURNING clause (before any trailing semicolon)
421                let before_output_sql = sql[..output_pos].trim_end();
422                let after_values = sql[values_pos..].trim_end();
423                let (values_body, trailing) = if let Some(stripped) = after_values.strip_suffix(';')
424                {
425                    (stripped, ";")
426                } else {
427                    (after_values, "")
428                };
429
430                return format!(
431                    "{}\n{} RETURNING {}{}",
432                    before_output_sql, values_body, cols, trailing
433                );
434            }
435        }
436    }
437
438    sql.to_string()
439}
440
441/// Find the position of a word (case-insensitive) in the text.
442/// The word must be a separate word, not part of another identifier.
443fn find_word_position(text: &str, word: &str) -> Option<usize> {
444    let mut pos = 0;
445    let word_len = word.len();
446    while let Some(idx) = text[pos..].find(word) {
447        let abs_idx = pos + idx;
448
449        // Check character before
450        let before_ok = abs_idx == 0
451            || !text
452                .as_bytes()
453                .get(abs_idx - 1)
454                .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
455
456        // Check character after
457        let after_idx = abs_idx + word_len;
458        let after_ok = after_idx >= text.len()
459            || !text
460                .as_bytes()
461                .get(after_idx)
462                .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
463
464        if before_ok && after_ok {
465            return Some(abs_idx);
466        }
467        pos = abs_idx + 1;
468    }
469    None
470}
471
472/// Parse INSERTED.col1, INSERTED.col2, ... and extract column names as "col1, col2, ..."
473fn parse_inserted_columns(output_str: &str) -> String {
474    let mut cols = Vec::new();
475
476    for part in output_str.split(',') {
477        let trimmed = part.trim();
478
479        // Try to extract column name after INSERTED.
480        if let Some(after_inserted) = trimmed
481            .strip_prefix("INSERTED.")
482            .or_else(|| trimmed.strip_prefix("inserted."))
483            .or_else(|| trimmed.strip_prefix("INSERTED"))
484            .or_else(|| trimmed.strip_prefix("inserted"))
485        {
486            let col_name = after_inserted.trim().to_string();
487            if !col_name.is_empty() {
488                cols.push(col_name);
489            }
490        }
491    }
492
493    cols.join(", ")
494}
495
496/// Strip the `INTO :N, :N, ...` suffix from an Oracle `RETURNING ... INTO` clause.
497fn strip_returning_into(sql: &str) -> String {
498    // Case-insensitive search for "INTO" after "RETURNING" at the end of the statement
499    let upper = sql.to_uppercase();
500    if let Some(ret_pos) = upper.rfind("RETURNING") {
501        let after_returning = &upper[ret_pos + "RETURNING".len()..];
502        if let Some(into_offset) = after_returning.find("INTO") {
503            let into_pos = ret_pos + "RETURNING".len() + into_offset;
504            // Keep everything before INTO, trim trailing whitespace/semicolons
505            let trimmed = sql[..into_pos].trim_end();
506            return trimmed.to_string();
507        }
508    }
509    sql.to_string()
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::errors::ErrorCode;
516
517    fn parse(sql: &str) -> Result<Query, ScytheError> {
518        parse_query(sql)
519    }
520
521    #[test]
522    fn test_basic_parse() {
523        let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
524        let q = parse(input).unwrap();
525        assert_eq!(q.name, "GetUsers");
526        assert_eq!(q.command, QueryCommand::Many);
527        assert!(q.sql.contains("SELECT"));
528    }
529
530    #[test]
531    fn test_all_command_types() {
532        let cases = vec![
533            (":one", QueryCommand::One),
534            (":many", QueryCommand::Many),
535            (":exec", QueryCommand::Exec),
536            (":exec_result", QueryCommand::ExecResult),
537            (":exec_rows", QueryCommand::ExecRows),
538        ];
539        for (tag, expected) in cases {
540            let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
541            let q = parse(&input).unwrap();
542            assert_eq!(q.command, expected, "failed for {}", tag);
543        }
544    }
545
546    #[test]
547    fn test_case_insensitive_keywords() {
548        let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
549        let q = parse(input).unwrap();
550        assert_eq!(q.name, "GetUsers");
551        assert_eq!(q.command, QueryCommand::Many);
552    }
553
554    #[test]
555    fn test_missing_name_errors() {
556        let input = "-- @returns :many\nSELECT 1";
557        let err = parse(input).unwrap_err();
558        assert_eq!(err.code, ErrorCode::MissingAnnotation);
559        assert!(err.message.contains("name"));
560    }
561
562    #[test]
563    fn test_missing_returns_errors() {
564        let input = "-- @name Foo\nSELECT 1";
565        let err = parse(input).unwrap_err();
566        assert_eq!(err.code, ErrorCode::MissingAnnotation);
567        assert!(err.message.contains("returns"));
568    }
569
570    #[test]
571    fn test_invalid_returns_value() {
572        let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
573        let err = parse(input).unwrap_err();
574        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
575    }
576
577    #[test]
578    fn test_empty_name_value() {
579        // An empty name is accepted by the parser (it stores "")
580        let input = "-- @name\n-- @returns :one\nSELECT 1";
581        let q = parse(input).unwrap();
582        assert_eq!(q.name, "");
583    }
584
585    #[test]
586    fn test_param_annotation() {
587        let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
588        let q = parse(input).unwrap();
589        assert_eq!(q.annotations.param_docs.len(), 1);
590        assert_eq!(q.annotations.param_docs[0].name, "id");
591        assert_eq!(q.annotations.param_docs[0].description, "the user ID");
592    }
593
594    #[test]
595    fn test_param_no_description() {
596        let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
597        let q = parse(input).unwrap();
598        assert_eq!(q.annotations.param_docs.len(), 1);
599        assert_eq!(q.annotations.param_docs[0].name, "id");
600        assert_eq!(q.annotations.param_docs[0].description, "");
601    }
602
603    #[test]
604    fn test_nullable_annotation() {
605        let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
606        let q = parse(input).unwrap();
607        assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
608    }
609
610    #[test]
611    fn test_nonnull_annotation() {
612        let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
613        let q = parse(input).unwrap();
614        assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
615    }
616
617    #[test]
618    fn test_json_annotation() {
619        let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
620        let q = parse(input).unwrap();
621        assert_eq!(q.annotations.json_mappings.len(), 1);
622        assert_eq!(q.annotations.json_mappings[0].column, "data");
623        assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
624    }
625
626    #[test]
627    fn test_deprecated_annotation() {
628        let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
629        let q = parse(input).unwrap();
630        assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
631    }
632
633    #[test]
634    fn test_sql_syntax_error() {
635        let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
636        let err = parse(input).unwrap_err();
637        assert_eq!(err.code, ErrorCode::SyntaxError);
638    }
639
640    #[test]
641    fn test_trailing_semicolon() {
642        let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
643        let q = parse(input).unwrap();
644        assert_eq!(q.name, "Foo");
645    }
646
647    #[test]
648    fn test_multiple_statements_error() {
649        let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
650        let err = parse(input).unwrap_err();
651        assert_eq!(err.code, ErrorCode::SyntaxError);
652    }
653
654    #[test]
655    fn test_sql_preserved_without_annotations() {
656        let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
657        let q = parse(input).unwrap();
658        assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
659    }
660
661    #[test]
662    fn test_returns_without_colon_prefix() {
663        let input = "-- @name Foo\n-- @returns many\nSELECT 1";
664        let q = parse(input).unwrap();
665        assert_eq!(q.command, QueryCommand::Many);
666    }
667
668    #[test]
669    fn test_batch_command() {
670        let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
671        let q = parse(input).unwrap();
672        assert_eq!(q.command, QueryCommand::Batch);
673    }
674
675    #[test]
676    fn test_grouped_command_with_group_by() {
677        let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
678        let q = parse(input).unwrap();
679        assert_eq!(q.command, QueryCommand::Grouped);
680        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
681    }
682
683    #[test]
684    fn test_grouped_command_without_group_by_errors() {
685        let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
686        let err = parse(input).unwrap_err();
687        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
688        assert!(err.message.contains("@group_by"));
689    }
690
691    #[test]
692    fn test_group_by_without_grouped_is_ignored() {
693        let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
694        let q = parse(input).unwrap();
695        assert_eq!(q.command, QueryCommand::Many);
696        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
697    }
698
699    #[test]
700    fn test_preprocess_oracle_colon_placeholders() {
701        assert_eq!(
702            preprocess_oracle_sql("SELECT * FROM users WHERE id = :1"),
703            "SELECT * FROM users WHERE id = ?"
704        );
705        assert_eq!(
706            preprocess_oracle_sql("INSERT INTO users (name, email) VALUES (:1, :2)"),
707            "INSERT INTO users (name, email) VALUES (?, ?)"
708        );
709    }
710
711    #[test]
712    fn test_preprocess_oracle_preserves_string_literals() {
713        assert_eq!(
714            preprocess_oracle_sql("SELECT * FROM users WHERE name = ':1' AND id = :1"),
715            "SELECT * FROM users WHERE name = ':1' AND id = ?"
716        );
717    }
718
719    #[test]
720    fn test_preprocess_oracle_strips_returning_into() {
721        assert_eq!(
722            preprocess_oracle_sql(
723                "INSERT INTO users (name) VALUES (:1) RETURNING id, name INTO :2, :3"
724            ),
725            "INSERT INTO users (name) VALUES (?) RETURNING id, name"
726        );
727    }
728
729    #[test]
730    fn test_preprocess_oracle_full_insert_returning_into() {
731        let sql = "INSERT INTO users (name, email, active) VALUES (:1, :2, :3) RETURNING id, name, email, active, created_at INTO :4, :5, :6, :7, :8";
732        let result = preprocess_oracle_sql(sql);
733        assert_eq!(
734            result,
735            "INSERT INTO users (name, email, active) VALUES (?, ?, ?) RETURNING id, name, email, active, created_at"
736        );
737    }
738
739    #[test]
740    fn test_preprocess_oracle_no_returning_into_unchanged() {
741        assert_eq!(
742            preprocess_oracle_sql("DELETE FROM users WHERE id = :1"),
743            "DELETE FROM users WHERE id = ?"
744        );
745    }
746
747    #[test]
748    fn test_preprocess_mssql_single_placeholder() {
749        assert_eq!(
750            preprocess_mssql_sql("SELECT * FROM users WHERE id = @p1"),
751            "SELECT * FROM users WHERE id = ?"
752        );
753    }
754
755    #[test]
756    fn test_preprocess_mssql_multiple_placeholders() {
757        assert_eq!(
758            preprocess_mssql_sql("INSERT INTO users (name, email) VALUES (@p1, @p2)"),
759            "INSERT INTO users (name, email) VALUES (?, ?)"
760        );
761    }
762
763    #[test]
764    fn test_preprocess_mssql_preserves_string_literals() {
765        assert_eq!(
766            preprocess_mssql_sql("SELECT * FROM users WHERE name = '@p1' AND id = @p1"),
767            "SELECT * FROM users WHERE name = '@p1' AND id = ?"
768        );
769    }
770
771    #[test]
772    fn test_preprocess_mssql_case_insensitive_p() {
773        assert_eq!(
774            preprocess_mssql_sql("SELECT * FROM users WHERE id = @P1"),
775            "SELECT * FROM users WHERE id = ?"
776        );
777    }
778
779    #[test]
780    fn test_preprocess_mssql_non_placeholder_at_variable_unchanged() {
781        // @variable (not @pN pattern) must not be touched
782        assert_eq!(preprocess_mssql_sql("SELECT @myvar"), "SELECT @myvar");
783    }
784
785    #[test]
786    fn test_preprocess_mssql_multi_digit_placeholder() {
787        assert_eq!(preprocess_mssql_sql("SELECT @p10, @p2"), "SELECT ?, ?");
788    }
789
790    #[test]
791    fn test_preprocess_mssql_output_inserted_simple() {
792        let sql =
793            "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, @p2)";
794        let result = preprocess_mssql_sql(sql);
795        // Should convert OUTPUT INSERTED.col to RETURNING col and @pN to ?
796        assert!(result.contains("RETURNING id, name"), "got: {}", result);
797        assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
798        assert!(!result.contains("OUTPUT"), "got: {}", result);
799    }
800
801    #[test]
802    fn test_preprocess_mssql_output_inserted_full_example() {
803        let sql = "INSERT INTO users (id, name, email, active) OUTPUT INSERTED.id, INSERTED.name, INSERTED.email, INSERTED.active, INSERTED.created_at VALUES (@p1, @p2, @p3, @p4)";
804        let result = preprocess_mssql_sql(sql);
805        assert!(
806            result.contains("RETURNING id, name, email, active, created_at"),
807            "got: {}",
808            result
809        );
810        assert!(result.contains("VALUES (?, ?, ?, ?)"), "got: {}", result);
811    }
812
813    #[test]
814    fn test_preprocess_mssql_output_case_insensitive() {
815        let sql = "INSERT INTO users (id) output inserted.id values (@p1)";
816        let result = preprocess_mssql_sql(sql);
817        assert!(result.contains("RETURNING id"), "got: {}", result);
818        // The original lowercase "values" is preserved, then @p1 becomes ?
819        assert!(
820            result.contains("values (?)") || result.contains("VALUES (?)"),
821            "got: {}",
822            result
823        );
824    }
825
826    #[test]
827    fn test_preprocess_mssql_no_output_unchanged() {
828        let sql = "INSERT INTO users (id, name) VALUES (@p1, @p2)";
829        let result = preprocess_mssql_sql(sql);
830        assert_eq!(result, "INSERT INTO users (id, name) VALUES (?, ?)");
831    }
832
833    #[test]
834    fn test_preprocess_mssql_output_with_string_literal() {
835        // @p1 inside a string should be preserved by placeholder conversion
836        let sql =
837            "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, '@p2')";
838        let result = preprocess_mssql_sql(sql);
839        assert!(result.contains("RETURNING id, name"), "got: {}", result);
840        assert!(result.contains("(?, '@p2')"), "got: {}", result);
841    }
842
843    #[test]
844    fn test_preprocess_mssql_output_with_whitespace() {
845        let sql =
846            "INSERT INTO users (id, name)\nOUTPUT INSERTED.id,\n  INSERTED.name\nVALUES (@p1, @p2)";
847        let result = preprocess_mssql_sql(sql);
848        assert!(result.contains("RETURNING id, name"), "got: {}", result);
849        assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
850    }
851}