Skip to main content

scythe_core/parser/
mod.rs

1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
8pub enum QueryCommand {
9    One,
10    Opt,
11    Many,
12    Exec,
13    ExecResult,
14    ExecRows,
15    Batch,
16    Grouped,
17}
18
19impl std::fmt::Display for QueryCommand {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            QueryCommand::One => write!(f, "one"),
23            QueryCommand::Opt => write!(f, "opt"),
24            QueryCommand::Many => write!(f, "many"),
25            QueryCommand::Exec => write!(f, "exec"),
26            QueryCommand::ExecResult => write!(f, "exec_result"),
27            QueryCommand::ExecRows => write!(f, "exec_rows"),
28            QueryCommand::Batch => write!(f, "batch"),
29            QueryCommand::Grouped => write!(f, "grouped"),
30        }
31    }
32}
33
34impl QueryCommand {
35    fn from_str(s: &str) -> Result<Self, ScytheError> {
36        match s {
37            "one" => Ok(QueryCommand::One),
38            "opt" => Ok(QueryCommand::Opt),
39            "many" => Ok(QueryCommand::Many),
40            "exec" => Ok(QueryCommand::Exec),
41            "exec_result" => Ok(QueryCommand::ExecResult),
42            "exec_rows" => Ok(QueryCommand::ExecRows),
43            "batch" => Ok(QueryCommand::Batch),
44            "grouped" => Ok(QueryCommand::Grouped),
45            other => Err(ScytheError::invalid_annotation(format!(
46                "invalid @returns value: {other}"
47            ))),
48        }
49    }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54pub struct ParamDoc {
55    pub name: String,
56    pub description: String,
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61pub struct JsonMapping {
62    pub column: String,
63    pub rust_type: String,
64}
65
66/// A custom (non-native) annotation captured verbatim from the SQL source.
67///
68/// Scythe parses its known annotations (`@name`, `@returns`, `@param`, `@nullable`,
69/// `@nonnull`, `@json`, `@optional`, `@group_by`, `@deprecated`) into typed fields.
70/// Any other `-- @<name> <value>` line is captured here as an opaque triple and
71/// exposed to crate consumers, who can layer their own annotation vocabulary on
72/// top of scythe without coupling scythe to their domain.
73#[derive(Debug, Clone, PartialEq, Eq)]
74#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
75pub struct CustomAnnotation {
76    /// Annotation name, lowercased, without the leading `@` (e.g. `http`, `http_param`).
77    pub name: String,
78    /// Everything after the name on the line, trimmed. Empty if the annotation had no value.
79    pub value: String,
80    /// 1-based line number within the query SQL, for diagnostics.
81    pub line: usize,
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
85#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
86pub struct Annotations {
87    pub name: String,
88    pub command: QueryCommand,
89    pub param_docs: Vec<ParamDoc>,
90    pub nullable_overrides: Vec<String>,
91    pub nonnull_overrides: Vec<String>,
92    pub json_mappings: Vec<JsonMapping>,
93    pub deprecated: Option<String>,
94    pub optional_params: Vec<String>,
95    pub group_by: Option<String>,
96    /// Annotations scythe does not natively recognise, preserved in source order
97    /// for crate consumers to interpret.
98    pub custom: Vec<CustomAnnotation>,
99}
100
101#[derive(Debug)]
102pub struct Query {
103    pub name: String,
104    pub command: QueryCommand,
105    pub sql: String,
106    pub stmt: sqlparser::ast::Statement,
107    pub annotations: Annotations,
108}
109
110/// Parse a single annotated SQL query into a `Query` using the PostgreSQL dialect.
111pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
112    parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
113}
114
115/// Parse a single annotated SQL query into a `Query` using the specified dialect.
116pub fn parse_query_with_dialect(
117    query_sql: &str,
118    dialect: &SqlDialect,
119) -> Result<Query, ScytheError> {
120    let mut name: Option<String> = None;
121    let mut command: Option<QueryCommand> = None;
122    let mut param_docs = Vec::new();
123    let mut nullable_overrides = Vec::new();
124    let mut nonnull_overrides = Vec::new();
125    let mut json_mappings = Vec::new();
126    let mut deprecated: Option<String> = None;
127    let mut optional_params = Vec::new();
128    let mut group_by: Option<String> = None;
129    let mut custom: Vec<CustomAnnotation> = Vec::new();
130
131    let mut sql_lines = Vec::new();
132
133    for (line_idx, line) in query_sql.lines().enumerate() {
134        let line_no = line_idx + 1;
135        let trimmed = line.trim();
136
137        // Check for annotation: "-- @..." or "--@..."
138        let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
139            let rest = rest.trim_start();
140            rest.strip_prefix('@')
141        } else {
142            None
143        };
144
145        if let Some(body) = annotation_body {
146            // Parse the annotation keyword and value
147            let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
148                Some(pos) => (&body[..pos], body[pos..].trim()),
149                None => (body, ""),
150            };
151
152            match keyword.to_ascii_lowercase().as_str() {
153                "name" => {
154                    name = Some(value.to_string());
155                }
156                "returns" => {
157                    let cmd_str = value.strip_prefix(':').unwrap_or(value);
158                    command = Some(QueryCommand::from_str(cmd_str)?);
159                }
160                "param" => {
161                    // format: "<name>: <description>" or "<name>:<description>"
162                    if let Some(colon_pos) = value.find(':') {
163                        let param_name = value[..colon_pos].trim().to_string();
164                        let description = value[colon_pos + 1..].trim().to_string();
165                        param_docs.push(ParamDoc {
166                            name: param_name,
167                            description,
168                        });
169                    } else {
170                        param_docs.push(ParamDoc {
171                            name: value.to_string(),
172                            description: String::new(),
173                        });
174                    }
175                }
176                "nullable" => {
177                    for col in value.split(',') {
178                        let col = col.trim();
179                        if !col.is_empty() {
180                            nullable_overrides.push(col.to_string());
181                        }
182                    }
183                }
184                "nonnull" => {
185                    for col in value.split(',') {
186                        let col = col.trim();
187                        if !col.is_empty() {
188                            nonnull_overrides.push(col.to_string());
189                        }
190                    }
191                }
192                "json" => {
193                    // format: "<col> = <Type>"
194                    if let Some(eq_pos) = value.find('=') {
195                        let column = value[..eq_pos].trim().to_string();
196                        let rust_type = value[eq_pos + 1..].trim().to_string();
197                        json_mappings.push(JsonMapping { column, rust_type });
198                    }
199                }
200                "deprecated" => {
201                    deprecated = Some(value.to_string());
202                }
203                "group_by" => {
204                    group_by = Some(value.to_string());
205                }
206                "optional" => {
207                    for param in value.split(',') {
208                        let param = param.trim();
209                        if !param.is_empty() {
210                            optional_params.push(param.to_string());
211                        }
212                    }
213                }
214                other => {
215                    // Unknown annotation — capture verbatim for crate consumers.
216                    custom.push(CustomAnnotation {
217                        name: other.to_string(),
218                        value: value.to_string(),
219                        line: line_no,
220                    });
221                }
222            }
223        } else {
224            sql_lines.push(line);
225        }
226    }
227
228    let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
229    let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
230
231    if command == QueryCommand::Grouped && group_by.is_none() {
232        return Err(ScytheError::invalid_annotation(
233            "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
234        ));
235    }
236
237    let sql = sql_lines.join("\n").trim().to_string();
238
239    if sql.is_empty() {
240        return Err(ScytheError::syntax("empty SQL body"));
241    }
242
243    // Preprocess dialect-specific syntax before parsing:
244    //   * Oracle: strip `RETURNING ... INTO` output binds, convert `:N` → `?`.
245    //   * MSSQL: convert `OUTPUT INSERTED.*` → `RETURNING` for parsing,
246    //     convert `@pN` → `?` for parsing; keep original SQL for codegen.
247    //   * PostgreSQL: strip `WHERE …` between `ON CONFLICT (cols)` and `DO …`
248    //     for parsing (sqlparser-rs <= 0.61 doesn't recognise the
249    //     partial-index inference form); keep original SQL for codegen.
250    let (sql, parse_sql) = if *dialect == SqlDialect::Oracle {
251        let processed = preprocess_oracle_sql(&sql);
252        (processed.clone(), processed)
253    } else if *dialect == SqlDialect::MsSql {
254        // For codegen: only convert @pN → ? placeholders (keep OUTPUT syntax)
255        let codegen_sql = convert_mssql_placeholders(&sql);
256        // For parsing: also convert OUTPUT INSERTED → RETURNING
257        let parse_sql = preprocess_mssql_sql(&sql);
258        (codegen_sql, parse_sql)
259    } else if *dialect == SqlDialect::PostgreSQL {
260        let parse_sql = preprocess_postgres_sql(&sql);
261        (sql.clone(), parse_sql)
262    } else {
263        (sql.clone(), sql)
264    };
265
266    let parser_dialect = dialect.to_sqlparser_dialect();
267    let statements = Parser::parse_sql(parser_dialect.as_ref(), &parse_sql)
268        .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
269
270    if statements.len() != 1 {
271        // sqlparser may produce an extra empty statement from a trailing semicolon —
272        // filter those out by checking for exactly one non-empty statement.
273        let non_empty: Vec<_> = statements
274            .into_iter()
275            .filter(|s| {
276                !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
277            })
278            .collect();
279        if non_empty.len() != 1 {
280            return Err(ScytheError::syntax("expected exactly one SQL statement"));
281        }
282        let stmt = non_empty
283            .into_iter()
284            .next()
285            .expect("filtered to exactly one statement");
286        let annotations = Annotations {
287            name: name.clone(),
288            command: command.clone(),
289            param_docs,
290            nullable_overrides,
291            nonnull_overrides,
292            json_mappings,
293            deprecated,
294            optional_params,
295            group_by: group_by.clone(),
296            custom,
297        };
298        return Ok(Query {
299            name,
300            command,
301            sql,
302            stmt,
303            annotations,
304        });
305    }
306
307    let stmt = statements
308        .into_iter()
309        .next()
310        .expect("filtered to exactly one statement");
311
312    let annotations = Annotations {
313        name: name.clone(),
314        command: command.clone(),
315        param_docs,
316        nullable_overrides,
317        nonnull_overrides,
318        json_mappings,
319        deprecated,
320        optional_params,
321        group_by,
322        custom,
323    };
324
325    Ok(Query {
326        name,
327        command,
328        sql,
329        stmt,
330        annotations,
331    })
332}
333
334/// Strip the `WHERE …` predicate that PostgreSQL allows between
335/// `ON CONFLICT (cols)` and `DO …` (the index-inference form for partial
336/// unique indexes). sqlparser-rs through 0.61 does not parse this construct;
337/// we lift it out for the parser and let the caller keep the original SQL
338/// for codegen + runtime, where Postgres validates it.
339fn preprocess_postgres_sql(sql: &str) -> String {
340    // Strip line comments + string literals first so we only scan structural SQL.
341    // (We still emit the original `sql` slice byte-for-byte; the upper-mask is
342    //  only used to decide *where* to cut.)
343    let mask = mask_postgres_for_scan(sql);
344    let mask_bytes = mask.as_bytes();
345    let bytes = sql.as_bytes();
346    let mut search_from = 0;
347    let mut result = String::with_capacity(sql.len());
348    let mut last = 0;
349    while let Some(rel) = find_keyword(&mask[search_from..], "ON CONFLICT") {
350        let on_conflict_pos = search_from + rel;
351        let after_on_conflict = on_conflict_pos + "ON CONFLICT".len();
352        let mut idx = after_on_conflict;
353        while idx < mask_bytes.len() && mask_bytes[idx].is_ascii_whitespace() {
354            idx += 1;
355        }
356        if idx >= mask_bytes.len() || mask_bytes[idx] != b'(' {
357            search_from = after_on_conflict;
358            continue;
359        }
360        let mut depth = 0i32;
361        let mut close = idx;
362        while close < mask_bytes.len() {
363            match mask_bytes[close] {
364                b'(' => depth += 1,
365                b')' => {
366                    depth -= 1;
367                    if depth == 0 {
368                        break;
369                    }
370                }
371                _ => {}
372            }
373            close += 1;
374        }
375        if depth != 0 {
376            return sql.to_string();
377        }
378        let mut after_cols = close + 1;
379        while after_cols < mask_bytes.len() && mask_bytes[after_cols].is_ascii_whitespace() {
380            after_cols += 1;
381        }
382        if mask[after_cols..].starts_with("WHERE")
383            && let Some(do_rel) = find_keyword(&mask[after_cols + "WHERE".len()..], "DO")
384        {
385            let do_abs = after_cols + "WHERE".len() + do_rel;
386            // Slice from the original SQL (preserves casing + UTF-8) up to
387            // the byte before WHERE; skip ahead to DO.
388            result.push_str(std::str::from_utf8(&bytes[last..after_cols]).unwrap_or(""));
389            last = do_abs;
390            search_from = do_abs;
391            continue;
392        }
393        search_from = close + 1;
394    }
395    result.push_str(std::str::from_utf8(&bytes[last..]).unwrap_or(""));
396    result
397}
398
399/// Build an ASCII-uppercase, fixed-byte-offset mask of `sql` where `--` line
400/// comments, `/* … */` block comments, and `'…'` / `$$…$$` string literals are
401/// replaced with spaces. Multi-byte UTF-8 is collapsed to ASCII spaces of the
402/// same byte length so positions in the mask line up with the original `sql`.
403fn mask_postgres_for_scan(sql: &str) -> String {
404    let bytes = sql.as_bytes();
405    let mut out = vec![b' '; bytes.len()];
406    let mut i = 0;
407    while i < bytes.len() {
408        let b = bytes[i];
409        if b == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
410            // Line comment — replace through end-of-line with spaces.
411            while i < bytes.len() && bytes[i] != b'\n' {
412                out[i] = b' ';
413                i += 1;
414            }
415            continue;
416        }
417        if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
418            // Block comment — replace through `*/`.
419            out[i] = b' ';
420            out[i + 1] = b' ';
421            i += 2;
422            while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
423                out[i] = b' ';
424                i += 1;
425            }
426            if i + 1 < bytes.len() {
427                out[i] = b' ';
428                out[i + 1] = b' ';
429                i += 2;
430            }
431            continue;
432        }
433        if b == b'\'' {
434            out[i] = b' ';
435            i += 1;
436            while i < bytes.len() {
437                if bytes[i] == b'\'' {
438                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
439                        out[i] = b' ';
440                        out[i + 1] = b' ';
441                        i += 2;
442                        continue;
443                    }
444                    out[i] = b' ';
445                    i += 1;
446                    break;
447                }
448                out[i] = b' ';
449                i += 1;
450            }
451            continue;
452        }
453        // ASCII goes through as-uppercase; non-ASCII bytes become spaces so the
454        // mask stays single-byte-per-position and positions line up.
455        if b.is_ascii() {
456            out[i] = b.to_ascii_uppercase();
457        } else {
458            out[i] = b' ';
459        }
460        i += 1;
461    }
462    String::from_utf8(out).expect("mask is ASCII by construction")
463}
464
465/// Locate a whitespace-separated keyword in an uppercase haystack. Returns the
466/// byte offset of the keyword's start, or None if not found.
467fn find_keyword(haystack: &str, keyword: &str) -> Option<usize> {
468    let bytes = haystack.as_bytes();
469    let key = keyword.as_bytes();
470    let mut i = 0;
471    while i + key.len() <= bytes.len() {
472        if &bytes[i..i + key.len()] == key {
473            let prev_ok = i == 0 || !bytes[i - 1].is_ascii_alphanumeric();
474            let next = i + key.len();
475            let next_ok = next >= bytes.len() || !bytes[next].is_ascii_alphanumeric();
476            if prev_ok && next_ok {
477                return Some(i);
478            }
479        }
480        i += 1;
481    }
482    None
483}
484
485/// Preprocess Oracle SQL before parsing:
486/// 1. Strip `INTO :N, :N, ...` suffix from `RETURNING ... INTO` clauses.
487/// 2. Convert `:N` positional placeholders to `?` (universally supported).
488fn preprocess_oracle_sql(sql: &str) -> String {
489    // Strip Oracle RETURNING ... INTO clause (output bind variables)
490    // e.g. "INSERT ... RETURNING id, name INTO :4, :5" → "INSERT ... RETURNING id, name"
491    let sql = strip_returning_into(sql);
492
493    // Convert :N → ? (outside string literals)
494    let mut result = String::with_capacity(sql.len());
495    let mut chars = sql.chars().peekable();
496    while let Some(ch) = chars.next() {
497        if ch == '\'' {
498            // Skip string literals
499            result.push(ch);
500            while let Some(inner) = chars.next() {
501                result.push(inner);
502                if inner == '\'' {
503                    if chars.peek() == Some(&'\'') {
504                        result.push(chars.next().unwrap());
505                    } else {
506                        break;
507                    }
508                }
509            }
510        } else if ch == ':' && chars.peek().is_some_and(|c| c.is_ascii_digit()) {
511            // Convert :N → ?
512            result.push('?');
513            while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
514                chars.next();
515            }
516        } else {
517            result.push(ch);
518        }
519    }
520    result
521}
522
523/// Convert MSSQL `@pN` positional placeholders to `?` (outside string literals).
524/// MsSqlDialect treats `@` as an identifier start, so `@p1` becomes an identifier
525/// rather than a `Placeholder` token — preprocessing normalises it to `?`.
526fn convert_mssql_placeholders(sql: &str) -> String {
527    let mut result = String::with_capacity(sql.len());
528    let mut chars = sql.chars().peekable();
529    while let Some(ch) = chars.next() {
530        if ch == '\'' {
531            // Skip string literals verbatim
532            result.push(ch);
533            while let Some(inner) = chars.next() {
534                result.push(inner);
535                if inner == '\'' {
536                    if chars.peek() == Some(&'\'') {
537                        // Escaped quote inside string literal
538                        result.push(chars.next().unwrap());
539                    } else {
540                        break;
541                    }
542                }
543            }
544        } else if ch == '@' && chars.peek().is_some_and(|c| *c == 'p' || *c == 'P') {
545            // Peek ahead: must be `@p` followed by at least one digit
546            let mut lookahead = chars.clone();
547            lookahead.next(); // consume the 'p'/'P'
548            if lookahead.peek().is_some_and(|c| c.is_ascii_digit()) {
549                // It is an `@pN` placeholder — consume `p` and all digits
550                chars.next(); // consume 'p'/'P'
551                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
552                    chars.next();
553                }
554                result.push('?');
555            } else {
556                result.push(ch);
557            }
558        } else {
559            result.push(ch);
560        }
561    }
562    result
563}
564
565/// Preprocess MSSQL SQL before parsing:
566/// 1. Strip `OUTPUT INSERTED.col, ...` clauses and convert to RETURNING
567/// 2. Convert `@pN` positional placeholders to `?`
568fn preprocess_mssql_sql(sql: &str) -> String {
569    // First pass: convert OUTPUT INSERTED.col to RETURNING col
570    let sql = strip_and_convert_mssql_output(sql);
571    // Second pass: convert @pN to ?
572    convert_mssql_placeholders(&sql)
573}
574
575/// Strip MSSQL `OUTPUT INSERTED.col1, INSERTED.col2, ...` from INSERT statements
576/// and convert it to a `RETURNING col1, col2, ...` clause.
577/// The OUTPUT clause appears between the column list and VALUES clause:
578///   INSERT INTO table (cols) OUTPUT INSERTED.col1, INSERTED.col2, ... VALUES (...)
579/// becomes:
580///   INSERT INTO table (cols) VALUES (...) RETURNING col1, col2, ...
581fn strip_and_convert_mssql_output(sql: &str) -> String {
582    // Case-insensitive search for OUTPUT keyword in INSERT statements
583    let upper = sql.to_uppercase();
584
585    // Only process INSERT statements with OUTPUT
586    if !upper.contains("INSERT") || !upper.contains("OUTPUT") {
587        return sql.to_string();
588    }
589
590    // Find the OUTPUT keyword
591    if let Some(output_pos) = find_word_position(&upper, "OUTPUT") {
592        // Check if this is actually part of an INSERT statement by finding INSERT before it
593        let before_output = &upper[..output_pos];
594        if !before_output.contains("INSERT") {
595            return sql.to_string();
596        }
597
598        // Look for the VALUES keyword after OUTPUT
599        let after_output = &upper[output_pos + "OUTPUT".len()..];
600        if let Some(values_offset) = find_word_position(after_output, "VALUES") {
601            let values_pos = output_pos + "OUTPUT".len() + values_offset;
602
603            // Extract the OUTPUT column list (between OUTPUT and VALUES)
604            let output_cols_str = &sql[output_pos + "OUTPUT".len()..values_pos];
605
606            // Parse column names: strip "INSERTED." prefix from each column name
607            let cols = parse_inserted_columns(output_cols_str);
608
609            if !cols.is_empty() {
610                // Build result: keep everything before OUTPUT, then VALUES clause,
611                // then RETURNING clause (before any trailing semicolon)
612                let before_output_sql = sql[..output_pos].trim_end();
613                let after_values = sql[values_pos..].trim_end();
614                let (values_body, trailing) = if let Some(stripped) = after_values.strip_suffix(';')
615                {
616                    (stripped, ";")
617                } else {
618                    (after_values, "")
619                };
620
621                return format!(
622                    "{}\n{} RETURNING {}{}",
623                    before_output_sql, values_body, cols, trailing
624                );
625            }
626        }
627    }
628
629    sql.to_string()
630}
631
632/// Find the position of a word (case-insensitive) in the text.
633/// The word must be a separate word, not part of another identifier.
634fn find_word_position(text: &str, word: &str) -> Option<usize> {
635    let mut pos = 0;
636    let word_len = word.len();
637    while let Some(idx) = text[pos..].find(word) {
638        let abs_idx = pos + idx;
639
640        // Check character before
641        let before_ok = abs_idx == 0
642            || !text
643                .as_bytes()
644                .get(abs_idx - 1)
645                .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
646
647        // Check character after
648        let after_idx = abs_idx + word_len;
649        let after_ok = after_idx >= text.len()
650            || !text
651                .as_bytes()
652                .get(after_idx)
653                .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
654
655        if before_ok && after_ok {
656            return Some(abs_idx);
657        }
658        pos = abs_idx + 1;
659    }
660    None
661}
662
663/// Parse INSERTED.col1, INSERTED.col2, ... and extract column names as "col1, col2, ..."
664fn parse_inserted_columns(output_str: &str) -> String {
665    let mut cols = Vec::new();
666
667    for part in output_str.split(',') {
668        let trimmed = part.trim();
669
670        // Try to extract column name after INSERTED.
671        if let Some(after_inserted) = trimmed
672            .strip_prefix("INSERTED.")
673            .or_else(|| trimmed.strip_prefix("inserted."))
674            .or_else(|| trimmed.strip_prefix("INSERTED"))
675            .or_else(|| trimmed.strip_prefix("inserted"))
676        {
677            let col_name = after_inserted.trim().to_string();
678            if !col_name.is_empty() {
679                cols.push(col_name);
680            }
681        }
682    }
683
684    cols.join(", ")
685}
686
687/// Strip the `INTO :N, :N, ...` suffix from an Oracle `RETURNING ... INTO` clause.
688fn strip_returning_into(sql: &str) -> String {
689    // Case-insensitive search for "INTO" after "RETURNING" at the end of the statement
690    let upper = sql.to_uppercase();
691    if let Some(ret_pos) = upper.rfind("RETURNING") {
692        let after_returning = &upper[ret_pos + "RETURNING".len()..];
693        if let Some(into_offset) = after_returning.find("INTO") {
694            let into_pos = ret_pos + "RETURNING".len() + into_offset;
695            // Keep everything before INTO, trim trailing whitespace/semicolons
696            let trimmed = sql[..into_pos].trim_end();
697            return trimmed.to_string();
698        }
699    }
700    sql.to_string()
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use crate::errors::ErrorCode;
707
708    fn parse(sql: &str) -> Result<Query, ScytheError> {
709        parse_query(sql)
710    }
711
712    #[test]
713    fn test_basic_parse() {
714        let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
715        let q = parse(input).unwrap();
716        assert_eq!(q.name, "GetUsers");
717        assert_eq!(q.command, QueryCommand::Many);
718        assert!(q.sql.contains("SELECT"));
719    }
720
721    #[test]
722    fn test_all_command_types() {
723        let cases = vec![
724            (":one", QueryCommand::One),
725            (":many", QueryCommand::Many),
726            (":exec", QueryCommand::Exec),
727            (":exec_result", QueryCommand::ExecResult),
728            (":exec_rows", QueryCommand::ExecRows),
729        ];
730        for (tag, expected) in cases {
731            let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
732            let q = parse(&input).unwrap();
733            assert_eq!(q.command, expected, "failed for {}", tag);
734        }
735    }
736
737    #[test]
738    fn test_case_insensitive_keywords() {
739        let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
740        let q = parse(input).unwrap();
741        assert_eq!(q.name, "GetUsers");
742        assert_eq!(q.command, QueryCommand::Many);
743    }
744
745    #[test]
746    fn test_missing_name_errors() {
747        let input = "-- @returns :many\nSELECT 1";
748        let err = parse(input).unwrap_err();
749        assert_eq!(err.code, ErrorCode::MissingAnnotation);
750        assert!(err.message.contains("name"));
751    }
752
753    #[test]
754    fn test_missing_returns_errors() {
755        let input = "-- @name Foo\nSELECT 1";
756        let err = parse(input).unwrap_err();
757        assert_eq!(err.code, ErrorCode::MissingAnnotation);
758        assert!(err.message.contains("returns"));
759    }
760
761    #[test]
762    fn test_invalid_returns_value() {
763        let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
764        let err = parse(input).unwrap_err();
765        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
766    }
767
768    #[test]
769    fn test_empty_name_value() {
770        // An empty name is accepted by the parser (it stores "")
771        let input = "-- @name\n-- @returns :one\nSELECT 1";
772        let q = parse(input).unwrap();
773        assert_eq!(q.name, "");
774    }
775
776    #[test]
777    fn test_param_annotation() {
778        let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
779        let q = parse(input).unwrap();
780        assert_eq!(q.annotations.param_docs.len(), 1);
781        assert_eq!(q.annotations.param_docs[0].name, "id");
782        assert_eq!(q.annotations.param_docs[0].description, "the user ID");
783    }
784
785    #[test]
786    fn test_param_no_description() {
787        let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
788        let q = parse(input).unwrap();
789        assert_eq!(q.annotations.param_docs.len(), 1);
790        assert_eq!(q.annotations.param_docs[0].name, "id");
791        assert_eq!(q.annotations.param_docs[0].description, "");
792    }
793
794    #[test]
795    fn test_nullable_annotation() {
796        let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
797        let q = parse(input).unwrap();
798        assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
799    }
800
801    #[test]
802    fn test_nonnull_annotation() {
803        let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
804        let q = parse(input).unwrap();
805        assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
806    }
807
808    #[test]
809    fn test_json_annotation() {
810        let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
811        let q = parse(input).unwrap();
812        assert_eq!(q.annotations.json_mappings.len(), 1);
813        assert_eq!(q.annotations.json_mappings[0].column, "data");
814        assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
815    }
816
817    #[test]
818    fn test_custom_annotations_captured() {
819        // Unknown @xxx lines are captured verbatim as CustomAnnotation triples;
820        // native annotations remain in their typed fields.
821        let input = "-- @name GetUser
822-- @returns :one
823-- @http GET /users/{id}
824-- @http_auth bearer:jwt
825-- @http_status 200,404
826SELECT id FROM users WHERE id = $1";
827        let q = parse(input).unwrap();
828        assert_eq!(q.annotations.custom.len(), 3);
829        assert_eq!(q.annotations.custom[0].name, "http");
830        assert_eq!(q.annotations.custom[0].value, "GET /users/{id}");
831        assert_eq!(q.annotations.custom[0].line, 3);
832        assert_eq!(q.annotations.custom[1].name, "http_auth");
833        assert_eq!(q.annotations.custom[1].value, "bearer:jwt");
834        assert_eq!(q.annotations.custom[1].line, 4);
835        assert_eq!(q.annotations.custom[2].name, "http_status");
836        assert_eq!(q.annotations.custom[2].value, "200,404");
837        assert_eq!(q.annotations.custom[2].line, 5);
838    }
839
840    #[test]
841    fn test_custom_annotation_without_value() {
842        let input = "-- @name GetUser
843-- @returns :one
844-- @http_internal
845SELECT 1";
846        let q = parse(input).unwrap();
847        assert_eq!(q.annotations.custom.len(), 1);
848        assert_eq!(q.annotations.custom[0].name, "http_internal");
849        assert_eq!(q.annotations.custom[0].value, "");
850    }
851
852    #[cfg(feature = "serde")]
853    #[test]
854    fn test_custom_annotation_serde_round_trip() {
855        let original = CustomAnnotation {
856            name: "http".to_string(),
857            value: "GET /users/{id}".to_string(),
858            line: 7,
859        };
860        let json = serde_json::to_string(&original).unwrap();
861        let back: CustomAnnotation = serde_json::from_str(&json).unwrap();
862        assert_eq!(back, original);
863    }
864
865    #[test]
866    fn test_custom_annotation_name_lowercased() {
867        let input = "-- @name GetUser
868-- @returns :one
869-- @HTTP_Auth Bearer
870SELECT 1";
871        let q = parse(input).unwrap();
872        assert_eq!(q.annotations.custom.len(), 1);
873        assert_eq!(q.annotations.custom[0].name, "http_auth");
874        assert_eq!(q.annotations.custom[0].value, "Bearer");
875    }
876
877    #[test]
878    fn test_deprecated_annotation() {
879        let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
880        let q = parse(input).unwrap();
881        assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
882    }
883
884    #[test]
885    fn test_sql_syntax_error() {
886        let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
887        let err = parse(input).unwrap_err();
888        assert_eq!(err.code, ErrorCode::SyntaxError);
889    }
890
891    #[test]
892    fn test_trailing_semicolon() {
893        let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
894        let q = parse(input).unwrap();
895        assert_eq!(q.name, "Foo");
896    }
897
898    #[test]
899    fn test_multiple_statements_error() {
900        let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
901        let err = parse(input).unwrap_err();
902        assert_eq!(err.code, ErrorCode::SyntaxError);
903    }
904
905    #[test]
906    fn test_sql_preserved_without_annotations() {
907        let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
908        let q = parse(input).unwrap();
909        assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
910    }
911
912    #[test]
913    fn test_returns_without_colon_prefix() {
914        let input = "-- @name Foo\n-- @returns many\nSELECT 1";
915        let q = parse(input).unwrap();
916        assert_eq!(q.command, QueryCommand::Many);
917    }
918
919    #[test]
920    fn test_batch_command() {
921        let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
922        let q = parse(input).unwrap();
923        assert_eq!(q.command, QueryCommand::Batch);
924    }
925
926    #[test]
927    fn test_grouped_command_with_group_by() {
928        let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
929        let q = parse(input).unwrap();
930        assert_eq!(q.command, QueryCommand::Grouped);
931        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
932    }
933
934    #[test]
935    fn test_grouped_command_without_group_by_errors() {
936        let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
937        let err = parse(input).unwrap_err();
938        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
939        assert!(err.message.contains("@group_by"));
940    }
941
942    #[test]
943    fn test_group_by_without_grouped_is_ignored() {
944        let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
945        let q = parse(input).unwrap();
946        assert_eq!(q.command, QueryCommand::Many);
947        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
948    }
949
950    #[test]
951    fn test_preprocess_postgres_strips_partial_index_where() {
952        let sql = "INSERT INTO billing_events (project_id, stripe_event_id) \
953                   VALUES ($1, $2) \
954                   ON CONFLICT (stripe_event_id) WHERE stripe_event_id IS NOT NULL DO NOTHING";
955        let cleaned = preprocess_postgres_sql(sql);
956        assert!(
957            !cleaned
958                .to_uppercase()
959                .contains("WHERE STRIPE_EVENT_ID IS NOT NULL"),
960            "WHERE clause must be stripped between ON CONFLICT cols and DO; got: {cleaned}"
961        );
962        assert!(
963            cleaned
964                .to_uppercase()
965                .contains("ON CONFLICT (STRIPE_EVENT_ID) DO NOTHING")
966        );
967        // sqlparser must accept the cleaned form.
968        sqlparser::parser::Parser::parse_sql(&sqlparser::dialect::PostgreSqlDialect {}, &cleaned)
969            .expect("cleaned SQL should parse");
970    }
971
972    #[test]
973    fn test_preprocess_postgres_no_op_when_no_partial_clause() {
974        let sql = "INSERT INTO t (a) VALUES ($1) ON CONFLICT (a) DO UPDATE SET a = EXCLUDED.a";
975        assert_eq!(preprocess_postgres_sql(sql), sql);
976    }
977
978    #[test]
979    fn test_preprocess_postgres_leaves_on_conflict_on_constraint_alone() {
980        let sql = "INSERT INTO t (a) VALUES ($1) ON CONFLICT ON CONSTRAINT t_a_uidx DO NOTHING";
981        assert_eq!(preprocess_postgres_sql(sql), sql);
982    }
983
984    #[test]
985    fn test_preprocess_postgres_handles_compound_index_cols() {
986        let sql = "INSERT INTO t (a, b) VALUES ($1, $2) \
987                   ON CONFLICT (a, b) WHERE a IS NOT NULL AND b > 0 DO UPDATE SET b = EXCLUDED.b";
988        let cleaned = preprocess_postgres_sql(sql);
989        assert!(
990            cleaned
991                .to_uppercase()
992                .contains("ON CONFLICT (A, B) DO UPDATE")
993        );
994        assert!(!cleaned.to_uppercase().contains("WHERE A IS NOT NULL"));
995    }
996
997    #[test]
998    fn test_preprocess_postgres_preserves_unrelated_where() {
999        // The DELETE's WHERE is its own clause, not an ON-CONFLICT predicate;
1000        // it must survive untouched.
1001        let sql = "DELETE FROM t WHERE id = $1";
1002        assert_eq!(preprocess_postgres_sql(sql), sql);
1003    }
1004
1005    #[test]
1006    fn test_preprocess_postgres_ignores_text_inside_line_comments() {
1007        // Earlier scans treated this as a real `ON CONFLICT (col) WHERE … DO`
1008        // and excised the entire comment + INSERT body up to the next `DO`.
1009        // Comments must be opaque to the predicate-stripping pass.
1010        let sql = "-- inline doc: `ON CONFLICT (col) WHERE …` is the partial form\n\
1011                   INSERT INTO t (a) VALUES ($1) \
1012                   ON CONFLICT (a) WHERE a IS NOT NULL DO NOTHING";
1013        let cleaned = preprocess_postgres_sql(sql);
1014        assert!(
1015            cleaned.contains("-- inline doc"),
1016            "comment must survive the pass; got: {cleaned}"
1017        );
1018        assert!(cleaned.contains("ON CONFLICT (a) DO NOTHING"));
1019    }
1020
1021    #[test]
1022    fn test_preprocess_postgres_ignores_text_inside_string_literals() {
1023        let sql = "SELECT 'ON CONFLICT (a) WHERE a IS NOT NULL DO NOTHING' AS s";
1024        assert_eq!(preprocess_postgres_sql(sql), sql);
1025    }
1026
1027    #[test]
1028    fn test_preprocess_oracle_colon_placeholders() {
1029        assert_eq!(
1030            preprocess_oracle_sql("SELECT * FROM users WHERE id = :1"),
1031            "SELECT * FROM users WHERE id = ?"
1032        );
1033        assert_eq!(
1034            preprocess_oracle_sql("INSERT INTO users (name, email) VALUES (:1, :2)"),
1035            "INSERT INTO users (name, email) VALUES (?, ?)"
1036        );
1037    }
1038
1039    #[test]
1040    fn test_preprocess_oracle_preserves_string_literals() {
1041        assert_eq!(
1042            preprocess_oracle_sql("SELECT * FROM users WHERE name = ':1' AND id = :1"),
1043            "SELECT * FROM users WHERE name = ':1' AND id = ?"
1044        );
1045    }
1046
1047    #[test]
1048    fn test_preprocess_oracle_strips_returning_into() {
1049        assert_eq!(
1050            preprocess_oracle_sql(
1051                "INSERT INTO users (name) VALUES (:1) RETURNING id, name INTO :2, :3"
1052            ),
1053            "INSERT INTO users (name) VALUES (?) RETURNING id, name"
1054        );
1055    }
1056
1057    #[test]
1058    fn test_preprocess_oracle_full_insert_returning_into() {
1059        let sql = "INSERT INTO users (name, email, active) VALUES (:1, :2, :3) RETURNING id, name, email, active, created_at INTO :4, :5, :6, :7, :8";
1060        let result = preprocess_oracle_sql(sql);
1061        assert_eq!(
1062            result,
1063            "INSERT INTO users (name, email, active) VALUES (?, ?, ?) RETURNING id, name, email, active, created_at"
1064        );
1065    }
1066
1067    #[test]
1068    fn test_preprocess_oracle_no_returning_into_unchanged() {
1069        assert_eq!(
1070            preprocess_oracle_sql("DELETE FROM users WHERE id = :1"),
1071            "DELETE FROM users WHERE id = ?"
1072        );
1073    }
1074
1075    #[test]
1076    fn test_preprocess_mssql_single_placeholder() {
1077        assert_eq!(
1078            preprocess_mssql_sql("SELECT * FROM users WHERE id = @p1"),
1079            "SELECT * FROM users WHERE id = ?"
1080        );
1081    }
1082
1083    #[test]
1084    fn test_preprocess_mssql_multiple_placeholders() {
1085        assert_eq!(
1086            preprocess_mssql_sql("INSERT INTO users (name, email) VALUES (@p1, @p2)"),
1087            "INSERT INTO users (name, email) VALUES (?, ?)"
1088        );
1089    }
1090
1091    #[test]
1092    fn test_preprocess_mssql_preserves_string_literals() {
1093        assert_eq!(
1094            preprocess_mssql_sql("SELECT * FROM users WHERE name = '@p1' AND id = @p1"),
1095            "SELECT * FROM users WHERE name = '@p1' AND id = ?"
1096        );
1097    }
1098
1099    #[test]
1100    fn test_preprocess_mssql_case_insensitive_p() {
1101        assert_eq!(
1102            preprocess_mssql_sql("SELECT * FROM users WHERE id = @P1"),
1103            "SELECT * FROM users WHERE id = ?"
1104        );
1105    }
1106
1107    #[test]
1108    fn test_preprocess_mssql_non_placeholder_at_variable_unchanged() {
1109        // @variable (not @pN pattern) must not be touched
1110        assert_eq!(preprocess_mssql_sql("SELECT @myvar"), "SELECT @myvar");
1111    }
1112
1113    #[test]
1114    fn test_preprocess_mssql_multi_digit_placeholder() {
1115        assert_eq!(preprocess_mssql_sql("SELECT @p10, @p2"), "SELECT ?, ?");
1116    }
1117
1118    #[test]
1119    fn test_preprocess_mssql_output_inserted_simple() {
1120        let sql =
1121            "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, @p2)";
1122        let result = preprocess_mssql_sql(sql);
1123        // Should convert OUTPUT INSERTED.col to RETURNING col and @pN to ?
1124        assert!(result.contains("RETURNING id, name"), "got: {}", result);
1125        assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
1126        assert!(!result.contains("OUTPUT"), "got: {}", result);
1127    }
1128
1129    #[test]
1130    fn test_preprocess_mssql_output_inserted_full_example() {
1131        let sql = "INSERT INTO users (id, name, email, active) OUTPUT INSERTED.id, INSERTED.name, INSERTED.email, INSERTED.active, INSERTED.created_at VALUES (@p1, @p2, @p3, @p4)";
1132        let result = preprocess_mssql_sql(sql);
1133        assert!(
1134            result.contains("RETURNING id, name, email, active, created_at"),
1135            "got: {}",
1136            result
1137        );
1138        assert!(result.contains("VALUES (?, ?, ?, ?)"), "got: {}", result);
1139    }
1140
1141    #[test]
1142    fn test_preprocess_mssql_output_case_insensitive() {
1143        let sql = "INSERT INTO users (id) output inserted.id values (@p1)";
1144        let result = preprocess_mssql_sql(sql);
1145        assert!(result.contains("RETURNING id"), "got: {}", result);
1146        // The original lowercase "values" is preserved, then @p1 becomes ?
1147        assert!(
1148            result.contains("values (?)") || result.contains("VALUES (?)"),
1149            "got: {}",
1150            result
1151        );
1152    }
1153
1154    #[test]
1155    fn test_preprocess_mssql_no_output_unchanged() {
1156        let sql = "INSERT INTO users (id, name) VALUES (@p1, @p2)";
1157        let result = preprocess_mssql_sql(sql);
1158        assert_eq!(result, "INSERT INTO users (id, name) VALUES (?, ?)");
1159    }
1160
1161    #[test]
1162    fn test_preprocess_mssql_output_with_string_literal() {
1163        // @p1 inside a string should be preserved by placeholder conversion
1164        let sql =
1165            "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, '@p2')";
1166        let result = preprocess_mssql_sql(sql);
1167        assert!(result.contains("RETURNING id, name"), "got: {}", result);
1168        assert!(result.contains("(?, '@p2')"), "got: {}", result);
1169    }
1170
1171    #[test]
1172    fn test_preprocess_mssql_output_with_whitespace() {
1173        let sql =
1174            "INSERT INTO users (id, name)\nOUTPUT INSERTED.id,\n  INSERTED.name\nVALUES (@p1, @p2)";
1175        let result = preprocess_mssql_sql(sql);
1176        assert!(result.contains("RETURNING id, name"), "got: {}", result);
1177        assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
1178    }
1179}