Skip to main content

scythe_core/parser/
mod.rs

1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, Default, PartialEq, Eq)]
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
8pub enum QueryCommand {
9    One,
10    Opt,
11    Many,
12    #[default]
13    Exec,
14    ExecResult,
15    ExecRows,
16    Batch,
17    Grouped,
18}
19
20impl std::fmt::Display for QueryCommand {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            QueryCommand::One => write!(f, "one"),
24            QueryCommand::Opt => write!(f, "opt"),
25            QueryCommand::Many => write!(f, "many"),
26            QueryCommand::Exec => write!(f, "exec"),
27            QueryCommand::ExecResult => write!(f, "exec_result"),
28            QueryCommand::ExecRows => write!(f, "exec_rows"),
29            QueryCommand::Batch => write!(f, "batch"),
30            QueryCommand::Grouped => write!(f, "grouped"),
31        }
32    }
33}
34
35impl QueryCommand {
36    fn from_str(s: &str) -> Result<Self, ScytheError> {
37        match s {
38            "one" => Ok(QueryCommand::One),
39            "opt" => Ok(QueryCommand::Opt),
40            "many" => Ok(QueryCommand::Many),
41            "exec" => Ok(QueryCommand::Exec),
42            "exec_result" => Ok(QueryCommand::ExecResult),
43            "exec_rows" => Ok(QueryCommand::ExecRows),
44            "batch" => Ok(QueryCommand::Batch),
45            "grouped" => Ok(QueryCommand::Grouped),
46            other => Err(ScytheError::invalid_annotation(format!(
47                "invalid @returns value: {other}"
48            ))),
49        }
50    }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55pub struct ParamDoc {
56    pub name: String,
57    pub description: String,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62pub struct JsonMapping {
63    pub column: String,
64    pub rust_type: String,
65}
66
67/// A custom (non-native) annotation captured verbatim from the SQL source.
68///
69/// Scythe parses its known annotations (`@name`, `@returns`, `@param`, `@nullable`,
70/// `@nonnull`, `@json`, `@optional`, `@group_by`, `@deprecated`) into typed fields.
71/// Any other `-- @<name> <value>` line is captured here as an opaque triple and
72/// exposed to crate consumers, who can layer their own annotation vocabulary on
73/// top of scythe without coupling scythe to their domain.
74#[derive(Debug, Clone, PartialEq, Eq)]
75#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
76pub struct CustomAnnotation {
77    /// Annotation name, lowercased, without the leading `@` (e.g. `http`, `http_param`).
78    pub name: String,
79    /// Everything after the name on the line, trimmed. Empty if the annotation had no value.
80    pub value: String,
81    /// 1-based line number within the query SQL, for diagnostics.
82    pub line: usize,
83}
84
85#[derive(Debug, Clone, Default, PartialEq, Eq)]
86#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
87pub struct Annotations {
88    pub name: String,
89    pub command: QueryCommand,
90    pub param_docs: Vec<ParamDoc>,
91    pub nullable_overrides: Vec<String>,
92    pub nonnull_overrides: Vec<String>,
93    pub json_mappings: Vec<JsonMapping>,
94    pub deprecated: Option<String>,
95    pub optional_params: Vec<String>,
96    pub group_by: Option<String>,
97    /// Annotations scythe does not natively recognise, preserved in source order
98    /// for crate consumers to interpret.
99    pub custom: Vec<CustomAnnotation>,
100}
101
102#[derive(Debug)]
103pub struct Query {
104    pub name: String,
105    pub command: QueryCommand,
106    pub sql: String,
107    pub stmt: sqlparser::ast::Statement,
108    pub annotations: Annotations,
109}
110
111/// Parse a single annotated SQL query into a `Query` using the PostgreSQL dialect.
112pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
113    parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
114}
115
116/// Parse a single annotated SQL query into a `Query` using the specified dialect.
117pub fn parse_query_with_dialect(
118    query_sql: &str,
119    dialect: &SqlDialect,
120) -> Result<Query, ScytheError> {
121    let mut name: Option<String> = None;
122    let mut command: Option<QueryCommand> = None;
123    let mut param_docs = Vec::new();
124    let mut nullable_overrides = Vec::new();
125    let mut nonnull_overrides = Vec::new();
126    let mut json_mappings = Vec::new();
127    let mut deprecated: Option<String> = None;
128    let mut optional_params = Vec::new();
129    let mut group_by: Option<String> = None;
130    let mut custom: Vec<CustomAnnotation> = Vec::new();
131
132    let mut sql_lines = Vec::new();
133
134    for (line_idx, line) in query_sql.lines().enumerate() {
135        let line_no = line_idx + 1;
136        let trimmed = line.trim();
137
138        // Check for annotation: "-- @..." or "--@..."
139        let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
140            let rest = rest.trim_start();
141            rest.strip_prefix('@')
142        } else {
143            None
144        };
145
146        if let Some(body) = annotation_body {
147            // Parse the annotation keyword and value
148            let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
149                Some(pos) => (&body[..pos], body[pos..].trim()),
150                None => (body, ""),
151            };
152
153            match keyword.to_ascii_lowercase().as_str() {
154                "name" => {
155                    name = Some(value.to_string());
156                }
157                "returns" => {
158                    let cmd_str = value.strip_prefix(':').unwrap_or(value);
159                    command = Some(QueryCommand::from_str(cmd_str)?);
160                }
161                "param" => {
162                    // format: "<name>: <description>" or "<name>:<description>"
163                    if let Some(colon_pos) = value.find(':') {
164                        let param_name = value[..colon_pos].trim().to_string();
165                        let description = value[colon_pos + 1..].trim().to_string();
166                        param_docs.push(ParamDoc {
167                            name: param_name,
168                            description,
169                        });
170                    } else {
171                        param_docs.push(ParamDoc {
172                            name: value.to_string(),
173                            description: String::new(),
174                        });
175                    }
176                }
177                "nullable" => {
178                    for col in value.split(',') {
179                        let col = col.trim();
180                        if !col.is_empty() {
181                            nullable_overrides.push(col.to_string());
182                        }
183                    }
184                }
185                "nonnull" => {
186                    for col in value.split(',') {
187                        let col = col.trim();
188                        if !col.is_empty() {
189                            nonnull_overrides.push(col.to_string());
190                        }
191                    }
192                }
193                "json" => {
194                    // format: "<col> = <Type>"
195                    if let Some(eq_pos) = value.find('=') {
196                        let column = value[..eq_pos].trim().to_string();
197                        let rust_type = value[eq_pos + 1..].trim().to_string();
198                        json_mappings.push(JsonMapping { column, rust_type });
199                    }
200                }
201                "deprecated" => {
202                    deprecated = Some(value.to_string());
203                }
204                "group_by" => {
205                    group_by = Some(value.to_string());
206                }
207                "optional" => {
208                    for param in value.split(',') {
209                        let param = param.trim();
210                        if !param.is_empty() {
211                            optional_params.push(param.to_string());
212                        }
213                    }
214                }
215                other => {
216                    // Unknown annotation — capture verbatim for crate consumers.
217                    custom.push(CustomAnnotation {
218                        name: other.to_string(),
219                        value: value.to_string(),
220                        line: line_no,
221                    });
222                }
223            }
224        } else {
225            sql_lines.push(line);
226        }
227    }
228
229    let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
230    let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
231
232    if command == QueryCommand::Grouped && group_by.is_none() {
233        return Err(ScytheError::invalid_annotation(
234            "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
235        ));
236    }
237
238    let sql = sql_lines.join("\n").trim().to_string();
239
240    if sql.is_empty() {
241        return Err(ScytheError::syntax("empty SQL body"));
242    }
243
244    // Preprocess dialect-specific syntax before parsing:
245    //   * Oracle: strip `RETURNING ... INTO` output binds, convert `:N` → `?`.
246    //   * MSSQL: convert `OUTPUT INSERTED.*` → `RETURNING` for parsing,
247    //     convert `@pN` → `?` for parsing; keep original SQL for codegen.
248    //   * PostgreSQL: strip `WHERE …` between `ON CONFLICT (cols)` and `DO …`
249    //     for parsing (sqlparser-rs <= 0.61 doesn't recognise the
250    //     partial-index inference form); keep original SQL for codegen.
251    let (sql, parse_sql) = if *dialect == SqlDialect::Oracle {
252        let processed = preprocess_oracle_sql(&sql);
253        (processed.clone(), processed)
254    } else if *dialect == SqlDialect::MsSql {
255        // For codegen: only convert @pN → ? placeholders (keep OUTPUT syntax)
256        let codegen_sql = convert_mssql_placeholders(&sql);
257        // For parsing: also convert OUTPUT INSERTED → RETURNING
258        let parse_sql = preprocess_mssql_sql(&sql);
259        (codegen_sql, parse_sql)
260    } else if *dialect == SqlDialect::PostgreSQL {
261        let parse_sql = preprocess_postgres_sql(&sql);
262        (sql.clone(), parse_sql)
263    } else {
264        (sql.clone(), sql)
265    };
266
267    let parser_dialect = dialect.to_sqlparser_dialect();
268    let statements = Parser::parse_sql(parser_dialect.as_ref(), &parse_sql)
269        .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
270
271    if statements.len() != 1 {
272        // sqlparser may produce an extra empty statement from a trailing semicolon —
273        // filter those out by checking for exactly one non-empty statement.
274        let non_empty: Vec<_> = statements
275            .into_iter()
276            .filter(|s| {
277                !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
278            })
279            .collect();
280        if non_empty.len() != 1 {
281            return Err(ScytheError::syntax("expected exactly one SQL statement"));
282        }
283        let stmt = non_empty
284            .into_iter()
285            .next()
286            .expect("filtered to exactly one statement");
287        let annotations = Annotations {
288            name: name.clone(),
289            command: command.clone(),
290            param_docs,
291            nullable_overrides,
292            nonnull_overrides,
293            json_mappings,
294            deprecated,
295            optional_params,
296            group_by: group_by.clone(),
297            custom,
298        };
299        return Ok(Query {
300            name,
301            command,
302            sql,
303            stmt,
304            annotations,
305        });
306    }
307
308    let stmt = statements
309        .into_iter()
310        .next()
311        .expect("filtered to exactly one statement");
312
313    let annotations = Annotations {
314        name: name.clone(),
315        command: command.clone(),
316        param_docs,
317        nullable_overrides,
318        nonnull_overrides,
319        json_mappings,
320        deprecated,
321        optional_params,
322        group_by,
323        custom,
324    };
325
326    Ok(Query {
327        name,
328        command,
329        sql,
330        stmt,
331        annotations,
332    })
333}
334
335/// Strip the `WHERE …` predicate that PostgreSQL allows between
336/// `ON CONFLICT (cols)` and `DO …` (the index-inference form for partial
337/// unique indexes). sqlparser-rs through 0.61 does not parse this construct;
338/// we lift it out for the parser and let the caller keep the original SQL
339/// for codegen + runtime, where Postgres validates it.
340fn preprocess_postgres_sql(sql: &str) -> String {
341    // Strip line comments + string literals first so we only scan structural SQL.
342    // (We still emit the original `sql` slice byte-for-byte; the upper-mask is
343    //  only used to decide *where* to cut.)
344    let mask = mask_postgres_for_scan(sql);
345    let mask_bytes = mask.as_bytes();
346    let bytes = sql.as_bytes();
347    let mut search_from = 0;
348    let mut result = String::with_capacity(sql.len());
349    let mut last = 0;
350    while let Some(rel) = find_keyword(&mask[search_from..], "ON CONFLICT") {
351        let on_conflict_pos = search_from + rel;
352        let after_on_conflict = on_conflict_pos + "ON CONFLICT".len();
353        let mut idx = after_on_conflict;
354        while idx < mask_bytes.len() && mask_bytes[idx].is_ascii_whitespace() {
355            idx += 1;
356        }
357        if idx >= mask_bytes.len() || mask_bytes[idx] != b'(' {
358            search_from = after_on_conflict;
359            continue;
360        }
361        let mut depth = 0i32;
362        let mut close = idx;
363        while close < mask_bytes.len() {
364            match mask_bytes[close] {
365                b'(' => depth += 1,
366                b')' => {
367                    depth -= 1;
368                    if depth == 0 {
369                        break;
370                    }
371                }
372                _ => {}
373            }
374            close += 1;
375        }
376        if depth != 0 {
377            return sql.to_string();
378        }
379        let mut after_cols = close + 1;
380        while after_cols < mask_bytes.len() && mask_bytes[after_cols].is_ascii_whitespace() {
381            after_cols += 1;
382        }
383        if mask[after_cols..].starts_with("WHERE")
384            && let Some(do_rel) = find_keyword(&mask[after_cols + "WHERE".len()..], "DO")
385        {
386            let do_abs = after_cols + "WHERE".len() + do_rel;
387            // Slice from the original SQL (preserves casing + UTF-8) up to
388            // the byte before WHERE; skip ahead to DO.
389            result.push_str(std::str::from_utf8(&bytes[last..after_cols]).unwrap_or(""));
390            last = do_abs;
391            search_from = do_abs;
392            continue;
393        }
394        search_from = close + 1;
395    }
396    result.push_str(std::str::from_utf8(&bytes[last..]).unwrap_or(""));
397    result
398}
399
400/// Build an ASCII-uppercase, fixed-byte-offset mask of `sql` where `--` line
401/// comments, `/* … */` block comments, and `'…'` / `$$…$$` string literals are
402/// replaced with spaces. Multi-byte UTF-8 is collapsed to ASCII spaces of the
403/// same byte length so positions in the mask line up with the original `sql`.
404fn mask_postgres_for_scan(sql: &str) -> String {
405    let bytes = sql.as_bytes();
406    let mut out = vec![b' '; bytes.len()];
407    let mut i = 0;
408    while i < bytes.len() {
409        let b = bytes[i];
410        if b == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
411            // Line comment — replace through end-of-line with spaces.
412            while i < bytes.len() && bytes[i] != b'\n' {
413                out[i] = b' ';
414                i += 1;
415            }
416            continue;
417        }
418        if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
419            // Block comment — replace through `*/`.
420            out[i] = b' ';
421            out[i + 1] = b' ';
422            i += 2;
423            while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
424                out[i] = b' ';
425                i += 1;
426            }
427            if i + 1 < bytes.len() {
428                out[i] = b' ';
429                out[i + 1] = b' ';
430                i += 2;
431            }
432            continue;
433        }
434        if b == b'\'' {
435            out[i] = b' ';
436            i += 1;
437            while i < bytes.len() {
438                if bytes[i] == b'\'' {
439                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
440                        out[i] = b' ';
441                        out[i + 1] = b' ';
442                        i += 2;
443                        continue;
444                    }
445                    out[i] = b' ';
446                    i += 1;
447                    break;
448                }
449                out[i] = b' ';
450                i += 1;
451            }
452            continue;
453        }
454        // ASCII goes through as-uppercase; non-ASCII bytes become spaces so the
455        // mask stays single-byte-per-position and positions line up.
456        if b.is_ascii() {
457            out[i] = b.to_ascii_uppercase();
458        } else {
459            out[i] = b' ';
460        }
461        i += 1;
462    }
463    String::from_utf8(out).expect("mask is ASCII by construction")
464}
465
466/// Locate a whitespace-separated keyword in an uppercase haystack. Returns the
467/// byte offset of the keyword's start, or None if not found.
468fn find_keyword(haystack: &str, keyword: &str) -> Option<usize> {
469    let bytes = haystack.as_bytes();
470    let key = keyword.as_bytes();
471    let mut i = 0;
472    while i + key.len() <= bytes.len() {
473        if &bytes[i..i + key.len()] == key {
474            let prev_ok = i == 0 || !bytes[i - 1].is_ascii_alphanumeric();
475            let next = i + key.len();
476            let next_ok = next >= bytes.len() || !bytes[next].is_ascii_alphanumeric();
477            if prev_ok && next_ok {
478                return Some(i);
479            }
480        }
481        i += 1;
482    }
483    None
484}
485
486/// Preprocess Oracle SQL before parsing:
487/// 1. Strip `INTO :N, :N, ...` suffix from `RETURNING ... INTO` clauses.
488/// 2. Convert `:N` positional placeholders to `?` (universally supported).
489fn preprocess_oracle_sql(sql: &str) -> String {
490    // Strip Oracle RETURNING ... INTO clause (output bind variables)
491    // e.g. "INSERT ... RETURNING id, name INTO :4, :5" → "INSERT ... RETURNING id, name"
492    let sql = strip_returning_into(sql);
493
494    // Convert :N → ? (outside string literals)
495    let mut result = String::with_capacity(sql.len());
496    let mut chars = sql.chars().peekable();
497    while let Some(ch) = chars.next() {
498        if ch == '\'' {
499            // Skip string literals
500            result.push(ch);
501            while let Some(inner) = chars.next() {
502                result.push(inner);
503                if inner == '\'' {
504                    if chars.peek() == Some(&'\'') {
505                        result.push(chars.next().unwrap());
506                    } else {
507                        break;
508                    }
509                }
510            }
511        } else if ch == ':' && chars.peek().is_some_and(|c| c.is_ascii_digit()) {
512            // Convert :N → ?
513            result.push('?');
514            while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
515                chars.next();
516            }
517        } else {
518            result.push(ch);
519        }
520    }
521    result
522}
523
524/// Convert MSSQL `@pN` positional placeholders to `?` (outside string literals).
525/// MsSqlDialect treats `@` as an identifier start, so `@p1` becomes an identifier
526/// rather than a `Placeholder` token — preprocessing normalises it to `?`.
527fn convert_mssql_placeholders(sql: &str) -> String {
528    let mut result = String::with_capacity(sql.len());
529    let mut chars = sql.chars().peekable();
530    while let Some(ch) = chars.next() {
531        if ch == '\'' {
532            // Skip string literals verbatim
533            result.push(ch);
534            while let Some(inner) = chars.next() {
535                result.push(inner);
536                if inner == '\'' {
537                    if chars.peek() == Some(&'\'') {
538                        // Escaped quote inside string literal
539                        result.push(chars.next().unwrap());
540                    } else {
541                        break;
542                    }
543                }
544            }
545        } else if ch == '@' && chars.peek().is_some_and(|c| *c == 'p' || *c == 'P') {
546            // Peek ahead: must be `@p` followed by at least one digit
547            let mut lookahead = chars.clone();
548            lookahead.next(); // consume the 'p'/'P'
549            if lookahead.peek().is_some_and(|c| c.is_ascii_digit()) {
550                // It is an `@pN` placeholder — consume `p` and all digits
551                chars.next(); // consume 'p'/'P'
552                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
553                    chars.next();
554                }
555                result.push('?');
556            } else {
557                result.push(ch);
558            }
559        } else {
560            result.push(ch);
561        }
562    }
563    result
564}
565
566/// Preprocess MSSQL SQL before parsing:
567/// 1. Strip `OUTPUT INSERTED.col, ...` clauses and convert to RETURNING
568/// 2. Convert `@pN` positional placeholders to `?`
569fn preprocess_mssql_sql(sql: &str) -> String {
570    // First pass: convert OUTPUT INSERTED.col to RETURNING col
571    let sql = strip_and_convert_mssql_output(sql);
572    // Second pass: convert @pN to ?
573    convert_mssql_placeholders(&sql)
574}
575
576/// Strip MSSQL `OUTPUT INSERTED.col1, INSERTED.col2, ...` from INSERT statements
577/// and convert it to a `RETURNING col1, col2, ...` clause.
578/// The OUTPUT clause appears between the column list and VALUES clause:
579///   INSERT INTO table (cols) OUTPUT INSERTED.col1, INSERTED.col2, ... VALUES (...)
580/// becomes:
581///   INSERT INTO table (cols) VALUES (...) RETURNING col1, col2, ...
582fn strip_and_convert_mssql_output(sql: &str) -> String {
583    // Case-insensitive search for OUTPUT keyword in INSERT statements
584    let upper = sql.to_uppercase();
585
586    // Only process INSERT statements with OUTPUT
587    if !upper.contains("INSERT") || !upper.contains("OUTPUT") {
588        return sql.to_string();
589    }
590
591    // Find the OUTPUT keyword
592    if let Some(output_pos) = find_word_position(&upper, "OUTPUT") {
593        // Check if this is actually part of an INSERT statement by finding INSERT before it
594        let before_output = &upper[..output_pos];
595        if !before_output.contains("INSERT") {
596            return sql.to_string();
597        }
598
599        // Look for the VALUES keyword after OUTPUT
600        let after_output = &upper[output_pos + "OUTPUT".len()..];
601        if let Some(values_offset) = find_word_position(after_output, "VALUES") {
602            let values_pos = output_pos + "OUTPUT".len() + values_offset;
603
604            // Extract the OUTPUT column list (between OUTPUT and VALUES)
605            let output_cols_str = &sql[output_pos + "OUTPUT".len()..values_pos];
606
607            // Parse column names: strip "INSERTED." prefix from each column name
608            let cols = parse_inserted_columns(output_cols_str);
609
610            if !cols.is_empty() {
611                // Build result: keep everything before OUTPUT, then VALUES clause,
612                // then RETURNING clause (before any trailing semicolon)
613                let before_output_sql = sql[..output_pos].trim_end();
614                let after_values = sql[values_pos..].trim_end();
615                let (values_body, trailing) = if let Some(stripped) = after_values.strip_suffix(';')
616                {
617                    (stripped, ";")
618                } else {
619                    (after_values, "")
620                };
621
622                return format!(
623                    "{}\n{} RETURNING {}{}",
624                    before_output_sql, values_body, cols, trailing
625                );
626            }
627        }
628    }
629
630    sql.to_string()
631}
632
633/// Find the position of a word (case-insensitive) in the text.
634/// The word must be a separate word, not part of another identifier.
635fn find_word_position(text: &str, word: &str) -> Option<usize> {
636    let mut pos = 0;
637    let word_len = word.len();
638    while let Some(idx) = text[pos..].find(word) {
639        let abs_idx = pos + idx;
640
641        // Check character before
642        let before_ok = abs_idx == 0
643            || !text
644                .as_bytes()
645                .get(abs_idx - 1)
646                .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
647
648        // Check character after
649        let after_idx = abs_idx + word_len;
650        let after_ok = after_idx >= text.len()
651            || !text
652                .as_bytes()
653                .get(after_idx)
654                .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
655
656        if before_ok && after_ok {
657            return Some(abs_idx);
658        }
659        pos = abs_idx + 1;
660    }
661    None
662}
663
664/// Parse INSERTED.col1, INSERTED.col2, ... and extract column names as "col1, col2, ..."
665fn parse_inserted_columns(output_str: &str) -> String {
666    let mut cols = Vec::new();
667
668    for part in output_str.split(',') {
669        let trimmed = part.trim();
670
671        // Try to extract column name after INSERTED.
672        if let Some(after_inserted) = trimmed
673            .strip_prefix("INSERTED.")
674            .or_else(|| trimmed.strip_prefix("inserted."))
675            .or_else(|| trimmed.strip_prefix("INSERTED"))
676            .or_else(|| trimmed.strip_prefix("inserted"))
677        {
678            let col_name = after_inserted.trim().to_string();
679            if !col_name.is_empty() {
680                cols.push(col_name);
681            }
682        }
683    }
684
685    cols.join(", ")
686}
687
688/// Strip the `INTO :N, :N, ...` suffix from an Oracle `RETURNING ... INTO` clause.
689fn strip_returning_into(sql: &str) -> String {
690    // Case-insensitive search for "INTO" after "RETURNING" at the end of the statement
691    let upper = sql.to_uppercase();
692    if let Some(ret_pos) = upper.rfind("RETURNING") {
693        let after_returning = &upper[ret_pos + "RETURNING".len()..];
694        if let Some(into_offset) = after_returning.find("INTO") {
695            let into_pos = ret_pos + "RETURNING".len() + into_offset;
696            // Keep everything before INTO, trim trailing whitespace/semicolons
697            let trimmed = sql[..into_pos].trim_end();
698            return trimmed.to_string();
699        }
700    }
701    sql.to_string()
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707    use crate::errors::ErrorCode;
708
709    fn parse(sql: &str) -> Result<Query, ScytheError> {
710        parse_query(sql)
711    }
712
713    #[test]
714    fn test_basic_parse() {
715        let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
716        let q = parse(input).unwrap();
717        assert_eq!(q.name, "GetUsers");
718        assert_eq!(q.command, QueryCommand::Many);
719        assert!(q.sql.contains("SELECT"));
720    }
721
722    #[test]
723    fn test_all_command_types() {
724        let cases = vec![
725            (":one", QueryCommand::One),
726            (":many", QueryCommand::Many),
727            (":exec", QueryCommand::Exec),
728            (":exec_result", QueryCommand::ExecResult),
729            (":exec_rows", QueryCommand::ExecRows),
730        ];
731        for (tag, expected) in cases {
732            let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
733            let q = parse(&input).unwrap();
734            assert_eq!(q.command, expected, "failed for {}", tag);
735        }
736    }
737
738    #[test]
739    fn test_case_insensitive_keywords() {
740        let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
741        let q = parse(input).unwrap();
742        assert_eq!(q.name, "GetUsers");
743        assert_eq!(q.command, QueryCommand::Many);
744    }
745
746    #[test]
747    fn test_missing_name_errors() {
748        let input = "-- @returns :many\nSELECT 1";
749        let err = parse(input).unwrap_err();
750        assert_eq!(err.code, ErrorCode::MissingAnnotation);
751        assert!(err.message.contains("name"));
752    }
753
754    #[test]
755    fn test_missing_returns_errors() {
756        let input = "-- @name Foo\nSELECT 1";
757        let err = parse(input).unwrap_err();
758        assert_eq!(err.code, ErrorCode::MissingAnnotation);
759        assert!(err.message.contains("returns"));
760    }
761
762    #[test]
763    fn test_invalid_returns_value() {
764        let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
765        let err = parse(input).unwrap_err();
766        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
767    }
768
769    #[test]
770    fn test_empty_name_value() {
771        // An empty name is accepted by the parser (it stores "")
772        let input = "-- @name\n-- @returns :one\nSELECT 1";
773        let q = parse(input).unwrap();
774        assert_eq!(q.name, "");
775    }
776
777    #[test]
778    fn test_param_annotation() {
779        let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
780        let q = parse(input).unwrap();
781        assert_eq!(q.annotations.param_docs.len(), 1);
782        assert_eq!(q.annotations.param_docs[0].name, "id");
783        assert_eq!(q.annotations.param_docs[0].description, "the user ID");
784    }
785
786    #[test]
787    fn test_param_no_description() {
788        let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
789        let q = parse(input).unwrap();
790        assert_eq!(q.annotations.param_docs.len(), 1);
791        assert_eq!(q.annotations.param_docs[0].name, "id");
792        assert_eq!(q.annotations.param_docs[0].description, "");
793    }
794
795    #[test]
796    fn test_nullable_annotation() {
797        let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
798        let q = parse(input).unwrap();
799        assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
800    }
801
802    #[test]
803    fn test_nonnull_annotation() {
804        let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
805        let q = parse(input).unwrap();
806        assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
807    }
808
809    #[test]
810    fn test_json_annotation() {
811        let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
812        let q = parse(input).unwrap();
813        assert_eq!(q.annotations.json_mappings.len(), 1);
814        assert_eq!(q.annotations.json_mappings[0].column, "data");
815        assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
816    }
817
818    #[test]
819    fn test_custom_annotations_captured() {
820        // Unknown @xxx lines are captured verbatim as CustomAnnotation triples;
821        // native annotations remain in their typed fields.
822        let input = "-- @name GetUser
823-- @returns :one
824-- @http GET /users/{id}
825-- @http_auth bearer:jwt
826-- @http_status 200,404
827SELECT id FROM users WHERE id = $1";
828        let q = parse(input).unwrap();
829        assert_eq!(q.annotations.custom.len(), 3);
830        assert_eq!(q.annotations.custom[0].name, "http");
831        assert_eq!(q.annotations.custom[0].value, "GET /users/{id}");
832        assert_eq!(q.annotations.custom[0].line, 3);
833        assert_eq!(q.annotations.custom[1].name, "http_auth");
834        assert_eq!(q.annotations.custom[1].value, "bearer:jwt");
835        assert_eq!(q.annotations.custom[1].line, 4);
836        assert_eq!(q.annotations.custom[2].name, "http_status");
837        assert_eq!(q.annotations.custom[2].value, "200,404");
838        assert_eq!(q.annotations.custom[2].line, 5);
839    }
840
841    #[test]
842    fn test_custom_annotation_without_value() {
843        let input = "-- @name GetUser
844-- @returns :one
845-- @http_internal
846SELECT 1";
847        let q = parse(input).unwrap();
848        assert_eq!(q.annotations.custom.len(), 1);
849        assert_eq!(q.annotations.custom[0].name, "http_internal");
850        assert_eq!(q.annotations.custom[0].value, "");
851    }
852
853    #[cfg(feature = "serde")]
854    #[test]
855    fn test_custom_annotation_serde_round_trip() {
856        let original = CustomAnnotation {
857            name: "http".to_string(),
858            value: "GET /users/{id}".to_string(),
859            line: 7,
860        };
861        let json = serde_json::to_string(&original).unwrap();
862        let back: CustomAnnotation = serde_json::from_str(&json).unwrap();
863        assert_eq!(back, original);
864    }
865
866    #[test]
867    fn test_custom_annotation_name_lowercased() {
868        let input = "-- @name GetUser
869-- @returns :one
870-- @HTTP_Auth Bearer
871SELECT 1";
872        let q = parse(input).unwrap();
873        assert_eq!(q.annotations.custom.len(), 1);
874        assert_eq!(q.annotations.custom[0].name, "http_auth");
875        assert_eq!(q.annotations.custom[0].value, "Bearer");
876    }
877
878    #[test]
879    fn test_deprecated_annotation() {
880        let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
881        let q = parse(input).unwrap();
882        assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
883    }
884
885    #[test]
886    fn test_sql_syntax_error() {
887        let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
888        let err = parse(input).unwrap_err();
889        assert_eq!(err.code, ErrorCode::SyntaxError);
890    }
891
892    #[test]
893    fn test_trailing_semicolon() {
894        let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
895        let q = parse(input).unwrap();
896        assert_eq!(q.name, "Foo");
897    }
898
899    #[test]
900    fn test_multiple_statements_error() {
901        let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
902        let err = parse(input).unwrap_err();
903        assert_eq!(err.code, ErrorCode::SyntaxError);
904    }
905
906    #[test]
907    fn test_sql_preserved_without_annotations() {
908        let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
909        let q = parse(input).unwrap();
910        assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
911    }
912
913    #[test]
914    fn test_returns_without_colon_prefix() {
915        let input = "-- @name Foo\n-- @returns many\nSELECT 1";
916        let q = parse(input).unwrap();
917        assert_eq!(q.command, QueryCommand::Many);
918    }
919
920    #[test]
921    fn test_batch_command() {
922        let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
923        let q = parse(input).unwrap();
924        assert_eq!(q.command, QueryCommand::Batch);
925    }
926
927    #[test]
928    fn test_grouped_command_with_group_by() {
929        let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
930        let q = parse(input).unwrap();
931        assert_eq!(q.command, QueryCommand::Grouped);
932        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
933    }
934
935    #[test]
936    fn test_grouped_command_without_group_by_errors() {
937        let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
938        let err = parse(input).unwrap_err();
939        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
940        assert!(err.message.contains("@group_by"));
941    }
942
943    #[test]
944    fn test_group_by_without_grouped_is_ignored() {
945        let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
946        let q = parse(input).unwrap();
947        assert_eq!(q.command, QueryCommand::Many);
948        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
949    }
950
951    #[test]
952    fn test_preprocess_postgres_strips_partial_index_where() {
953        let sql = "INSERT INTO billing_events (project_id, stripe_event_id) \
954                   VALUES ($1, $2) \
955                   ON CONFLICT (stripe_event_id) WHERE stripe_event_id IS NOT NULL DO NOTHING";
956        let cleaned = preprocess_postgres_sql(sql);
957        assert!(
958            !cleaned
959                .to_uppercase()
960                .contains("WHERE STRIPE_EVENT_ID IS NOT NULL"),
961            "WHERE clause must be stripped between ON CONFLICT cols and DO; got: {cleaned}"
962        );
963        assert!(
964            cleaned
965                .to_uppercase()
966                .contains("ON CONFLICT (STRIPE_EVENT_ID) DO NOTHING")
967        );
968        // sqlparser must accept the cleaned form.
969        sqlparser::parser::Parser::parse_sql(&sqlparser::dialect::PostgreSqlDialect {}, &cleaned)
970            .expect("cleaned SQL should parse");
971    }
972
973    #[test]
974    fn test_preprocess_postgres_no_op_when_no_partial_clause() {
975        let sql = "INSERT INTO t (a) VALUES ($1) ON CONFLICT (a) DO UPDATE SET a = EXCLUDED.a";
976        assert_eq!(preprocess_postgres_sql(sql), sql);
977    }
978
979    #[test]
980    fn test_preprocess_postgres_leaves_on_conflict_on_constraint_alone() {
981        let sql = "INSERT INTO t (a) VALUES ($1) ON CONFLICT ON CONSTRAINT t_a_uidx DO NOTHING";
982        assert_eq!(preprocess_postgres_sql(sql), sql);
983    }
984
985    #[test]
986    fn test_preprocess_postgres_handles_compound_index_cols() {
987        let sql = "INSERT INTO t (a, b) VALUES ($1, $2) \
988                   ON CONFLICT (a, b) WHERE a IS NOT NULL AND b > 0 DO UPDATE SET b = EXCLUDED.b";
989        let cleaned = preprocess_postgres_sql(sql);
990        assert!(
991            cleaned
992                .to_uppercase()
993                .contains("ON CONFLICT (A, B) DO UPDATE")
994        );
995        assert!(!cleaned.to_uppercase().contains("WHERE A IS NOT NULL"));
996    }
997
998    #[test]
999    fn test_preprocess_postgres_preserves_unrelated_where() {
1000        // The DELETE's WHERE is its own clause, not an ON-CONFLICT predicate;
1001        // it must survive untouched.
1002        let sql = "DELETE FROM t WHERE id = $1";
1003        assert_eq!(preprocess_postgres_sql(sql), sql);
1004    }
1005
1006    #[test]
1007    fn test_preprocess_postgres_ignores_text_inside_line_comments() {
1008        // Earlier scans treated this as a real `ON CONFLICT (col) WHERE … DO`
1009        // and excised the entire comment + INSERT body up to the next `DO`.
1010        // Comments must be opaque to the predicate-stripping pass.
1011        let sql = "-- inline doc: `ON CONFLICT (col) WHERE …` is the partial form\n\
1012                   INSERT INTO t (a) VALUES ($1) \
1013                   ON CONFLICT (a) WHERE a IS NOT NULL DO NOTHING";
1014        let cleaned = preprocess_postgres_sql(sql);
1015        assert!(
1016            cleaned.contains("-- inline doc"),
1017            "comment must survive the pass; got: {cleaned}"
1018        );
1019        assert!(cleaned.contains("ON CONFLICT (a) DO NOTHING"));
1020    }
1021
1022    #[test]
1023    fn test_preprocess_postgres_ignores_text_inside_string_literals() {
1024        let sql = "SELECT 'ON CONFLICT (a) WHERE a IS NOT NULL DO NOTHING' AS s";
1025        assert_eq!(preprocess_postgres_sql(sql), sql);
1026    }
1027
1028    #[test]
1029    fn test_preprocess_oracle_colon_placeholders() {
1030        assert_eq!(
1031            preprocess_oracle_sql("SELECT * FROM users WHERE id = :1"),
1032            "SELECT * FROM users WHERE id = ?"
1033        );
1034        assert_eq!(
1035            preprocess_oracle_sql("INSERT INTO users (name, email) VALUES (:1, :2)"),
1036            "INSERT INTO users (name, email) VALUES (?, ?)"
1037        );
1038    }
1039
1040    #[test]
1041    fn test_preprocess_oracle_preserves_string_literals() {
1042        assert_eq!(
1043            preprocess_oracle_sql("SELECT * FROM users WHERE name = ':1' AND id = :1"),
1044            "SELECT * FROM users WHERE name = ':1' AND id = ?"
1045        );
1046    }
1047
1048    #[test]
1049    fn test_preprocess_oracle_strips_returning_into() {
1050        assert_eq!(
1051            preprocess_oracle_sql(
1052                "INSERT INTO users (name) VALUES (:1) RETURNING id, name INTO :2, :3"
1053            ),
1054            "INSERT INTO users (name) VALUES (?) RETURNING id, name"
1055        );
1056    }
1057
1058    #[test]
1059    fn test_preprocess_oracle_full_insert_returning_into() {
1060        let sql = "INSERT INTO users (name, email, active) VALUES (:1, :2, :3) RETURNING id, name, email, active, created_at INTO :4, :5, :6, :7, :8";
1061        let result = preprocess_oracle_sql(sql);
1062        assert_eq!(
1063            result,
1064            "INSERT INTO users (name, email, active) VALUES (?, ?, ?) RETURNING id, name, email, active, created_at"
1065        );
1066    }
1067
1068    #[test]
1069    fn test_preprocess_oracle_no_returning_into_unchanged() {
1070        assert_eq!(
1071            preprocess_oracle_sql("DELETE FROM users WHERE id = :1"),
1072            "DELETE FROM users WHERE id = ?"
1073        );
1074    }
1075
1076    #[test]
1077    fn test_preprocess_mssql_single_placeholder() {
1078        assert_eq!(
1079            preprocess_mssql_sql("SELECT * FROM users WHERE id = @p1"),
1080            "SELECT * FROM users WHERE id = ?"
1081        );
1082    }
1083
1084    #[test]
1085    fn test_preprocess_mssql_multiple_placeholders() {
1086        assert_eq!(
1087            preprocess_mssql_sql("INSERT INTO users (name, email) VALUES (@p1, @p2)"),
1088            "INSERT INTO users (name, email) VALUES (?, ?)"
1089        );
1090    }
1091
1092    #[test]
1093    fn test_preprocess_mssql_preserves_string_literals() {
1094        assert_eq!(
1095            preprocess_mssql_sql("SELECT * FROM users WHERE name = '@p1' AND id = @p1"),
1096            "SELECT * FROM users WHERE name = '@p1' AND id = ?"
1097        );
1098    }
1099
1100    #[test]
1101    fn test_preprocess_mssql_case_insensitive_p() {
1102        assert_eq!(
1103            preprocess_mssql_sql("SELECT * FROM users WHERE id = @P1"),
1104            "SELECT * FROM users WHERE id = ?"
1105        );
1106    }
1107
1108    #[test]
1109    fn test_preprocess_mssql_non_placeholder_at_variable_unchanged() {
1110        // @variable (not @pN pattern) must not be touched
1111        assert_eq!(preprocess_mssql_sql("SELECT @myvar"), "SELECT @myvar");
1112    }
1113
1114    #[test]
1115    fn test_preprocess_mssql_multi_digit_placeholder() {
1116        assert_eq!(preprocess_mssql_sql("SELECT @p10, @p2"), "SELECT ?, ?");
1117    }
1118
1119    #[test]
1120    fn test_preprocess_mssql_output_inserted_simple() {
1121        let sql =
1122            "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, @p2)";
1123        let result = preprocess_mssql_sql(sql);
1124        // Should convert OUTPUT INSERTED.col to RETURNING col and @pN to ?
1125        assert!(result.contains("RETURNING id, name"), "got: {}", result);
1126        assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
1127        assert!(!result.contains("OUTPUT"), "got: {}", result);
1128    }
1129
1130    #[test]
1131    fn test_preprocess_mssql_output_inserted_full_example() {
1132        let sql = "INSERT INTO users (id, name, email, active) OUTPUT INSERTED.id, INSERTED.name, INSERTED.email, INSERTED.active, INSERTED.created_at VALUES (@p1, @p2, @p3, @p4)";
1133        let result = preprocess_mssql_sql(sql);
1134        assert!(
1135            result.contains("RETURNING id, name, email, active, created_at"),
1136            "got: {}",
1137            result
1138        );
1139        assert!(result.contains("VALUES (?, ?, ?, ?)"), "got: {}", result);
1140    }
1141
1142    #[test]
1143    fn test_preprocess_mssql_output_case_insensitive() {
1144        let sql = "INSERT INTO users (id) output inserted.id values (@p1)";
1145        let result = preprocess_mssql_sql(sql);
1146        assert!(result.contains("RETURNING id"), "got: {}", result);
1147        // The original lowercase "values" is preserved, then @p1 becomes ?
1148        assert!(
1149            result.contains("values (?)") || result.contains("VALUES (?)"),
1150            "got: {}",
1151            result
1152        );
1153    }
1154
1155    #[test]
1156    fn test_preprocess_mssql_no_output_unchanged() {
1157        let sql = "INSERT INTO users (id, name) VALUES (@p1, @p2)";
1158        let result = preprocess_mssql_sql(sql);
1159        assert_eq!(result, "INSERT INTO users (id, name) VALUES (?, ?)");
1160    }
1161
1162    #[test]
1163    fn test_preprocess_mssql_output_with_string_literal() {
1164        // @p1 inside a string should be preserved by placeholder conversion
1165        let sql =
1166            "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, '@p2')";
1167        let result = preprocess_mssql_sql(sql);
1168        assert!(result.contains("RETURNING id, name"), "got: {}", result);
1169        assert!(result.contains("(?, '@p2')"), "got: {}", result);
1170    }
1171
1172    #[test]
1173    fn test_preprocess_mssql_output_with_whitespace() {
1174        let sql =
1175            "INSERT INTO users (id, name)\nOUTPUT INSERTED.id,\n  INSERTED.name\nVALUES (@p1, @p2)";
1176        let result = preprocess_mssql_sql(sql);
1177        assert!(result.contains("RETURNING id, name"), "got: {}", result);
1178        assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
1179    }
1180}