Skip to main content

scythe_core/parser/
mod.rs

1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum QueryCommand {
8    One,
9    Opt,
10    Many,
11    Exec,
12    ExecResult,
13    ExecRows,
14    Batch,
15    Grouped,
16}
17
18impl std::fmt::Display for QueryCommand {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            QueryCommand::One => write!(f, "one"),
22            QueryCommand::Opt => write!(f, "opt"),
23            QueryCommand::Many => write!(f, "many"),
24            QueryCommand::Exec => write!(f, "exec"),
25            QueryCommand::ExecResult => write!(f, "exec_result"),
26            QueryCommand::ExecRows => write!(f, "exec_rows"),
27            QueryCommand::Batch => write!(f, "batch"),
28            QueryCommand::Grouped => write!(f, "grouped"),
29        }
30    }
31}
32
33impl QueryCommand {
34    fn from_str(s: &str) -> Result<Self, ScytheError> {
35        match s {
36            "one" => Ok(QueryCommand::One),
37            "opt" => Ok(QueryCommand::Opt),
38            "many" => Ok(QueryCommand::Many),
39            "exec" => Ok(QueryCommand::Exec),
40            "exec_result" => Ok(QueryCommand::ExecResult),
41            "exec_rows" => Ok(QueryCommand::ExecRows),
42            "batch" => Ok(QueryCommand::Batch),
43            "grouped" => Ok(QueryCommand::Grouped),
44            other => Err(ScytheError::invalid_annotation(format!(
45                "invalid @returns value: {other}"
46            ))),
47        }
48    }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct ParamDoc {
53    pub name: String,
54    pub description: String,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct JsonMapping {
59    pub column: String,
60    pub rust_type: String,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct Annotations {
65    pub name: String,
66    pub command: QueryCommand,
67    pub param_docs: Vec<ParamDoc>,
68    pub nullable_overrides: Vec<String>,
69    pub nonnull_overrides: Vec<String>,
70    pub json_mappings: Vec<JsonMapping>,
71    pub deprecated: Option<String>,
72    pub optional_params: Vec<String>,
73    pub group_by: Option<String>,
74}
75
76#[derive(Debug)]
77pub struct Query {
78    pub name: String,
79    pub command: QueryCommand,
80    pub sql: String,
81    pub stmt: sqlparser::ast::Statement,
82    pub annotations: Annotations,
83}
84
85/// Parse a single annotated SQL query into a `Query` using the PostgreSQL dialect.
86pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
87    parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
88}
89
90/// Parse a single annotated SQL query into a `Query` using the specified dialect.
91pub fn parse_query_with_dialect(
92    query_sql: &str,
93    dialect: &SqlDialect,
94) -> Result<Query, ScytheError> {
95    let mut name: Option<String> = None;
96    let mut command: Option<QueryCommand> = None;
97    let mut param_docs = Vec::new();
98    let mut nullable_overrides = Vec::new();
99    let mut nonnull_overrides = Vec::new();
100    let mut json_mappings = Vec::new();
101    let mut deprecated: Option<String> = None;
102    let mut optional_params = Vec::new();
103    let mut group_by: Option<String> = None;
104
105    let mut sql_lines = Vec::new();
106
107    for line in query_sql.lines() {
108        let trimmed = line.trim();
109
110        // Check for annotation: "-- @..." or "--@..."
111        let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
112            let rest = rest.trim_start();
113            rest.strip_prefix('@')
114        } else {
115            None
116        };
117
118        if let Some(body) = annotation_body {
119            // Parse the annotation keyword and value
120            let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
121                Some(pos) => (&body[..pos], body[pos..].trim()),
122                None => (body, ""),
123            };
124
125            match keyword.to_ascii_lowercase().as_str() {
126                "name" => {
127                    name = Some(value.to_string());
128                }
129                "returns" => {
130                    let cmd_str = value.strip_prefix(':').unwrap_or(value);
131                    command = Some(QueryCommand::from_str(cmd_str)?);
132                }
133                "param" => {
134                    // format: "<name>: <description>" or "<name>:<description>"
135                    if let Some(colon_pos) = value.find(':') {
136                        let param_name = value[..colon_pos].trim().to_string();
137                        let description = value[colon_pos + 1..].trim().to_string();
138                        param_docs.push(ParamDoc {
139                            name: param_name,
140                            description,
141                        });
142                    } else {
143                        param_docs.push(ParamDoc {
144                            name: value.to_string(),
145                            description: String::new(),
146                        });
147                    }
148                }
149                "nullable" => {
150                    for col in value.split(',') {
151                        let col = col.trim();
152                        if !col.is_empty() {
153                            nullable_overrides.push(col.to_string());
154                        }
155                    }
156                }
157                "nonnull" => {
158                    for col in value.split(',') {
159                        let col = col.trim();
160                        if !col.is_empty() {
161                            nonnull_overrides.push(col.to_string());
162                        }
163                    }
164                }
165                "json" => {
166                    // format: "<col> = <Type>"
167                    if let Some(eq_pos) = value.find('=') {
168                        let column = value[..eq_pos].trim().to_string();
169                        let rust_type = value[eq_pos + 1..].trim().to_string();
170                        json_mappings.push(JsonMapping { column, rust_type });
171                    }
172                }
173                "deprecated" => {
174                    deprecated = Some(value.to_string());
175                }
176                "group_by" => {
177                    group_by = Some(value.to_string());
178                }
179                "optional" => {
180                    for param in value.split(',') {
181                        let param = param.trim();
182                        if !param.is_empty() {
183                            optional_params.push(param.to_string());
184                        }
185                    }
186                }
187                _ => {
188                    // Unknown annotation — ignore or could error
189                }
190            }
191        } else {
192            sql_lines.push(line);
193        }
194    }
195
196    let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
197    let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
198
199    if command == QueryCommand::Grouped && group_by.is_none() {
200        return Err(ScytheError::invalid_annotation(
201            "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
202        ));
203    }
204
205    let sql = sql_lines.join("\n").trim().to_string();
206
207    if sql.is_empty() {
208        return Err(ScytheError::syntax("empty SQL body"));
209    }
210
211    // Preprocess Oracle-specific syntax before parsing:
212    // 1. Strip `INTO :N, :N, ...` from `RETURNING ... INTO` clauses (Oracle output binds)
213    // 2. Convert `:N` positional placeholders to `?` (universal placeholder)
214    let sql = if *dialect == SqlDialect::Oracle {
215        preprocess_oracle_sql(&sql)
216    } else {
217        sql
218    };
219
220    let parser_dialect = dialect.to_sqlparser_dialect();
221    let statements = Parser::parse_sql(parser_dialect.as_ref(), &sql)
222        .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
223
224    if statements.len() != 1 {
225        // sqlparser may produce an extra empty statement from a trailing semicolon —
226        // filter those out by checking for exactly one non-empty statement.
227        let non_empty: Vec<_> = statements
228            .into_iter()
229            .filter(|s| {
230                !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
231            })
232            .collect();
233        if non_empty.len() != 1 {
234            return Err(ScytheError::syntax("expected exactly one SQL statement"));
235        }
236        let stmt = non_empty
237            .into_iter()
238            .next()
239            .expect("filtered to exactly one statement");
240        let annotations = Annotations {
241            name: name.clone(),
242            command: command.clone(),
243            param_docs,
244            nullable_overrides,
245            nonnull_overrides,
246            json_mappings,
247            deprecated,
248            optional_params,
249            group_by: group_by.clone(),
250        };
251        return Ok(Query {
252            name,
253            command,
254            sql,
255            stmt,
256            annotations,
257        });
258    }
259
260    let stmt = statements
261        .into_iter()
262        .next()
263        .expect("filtered to exactly one statement");
264
265    let annotations = Annotations {
266        name: name.clone(),
267        command: command.clone(),
268        param_docs,
269        nullable_overrides,
270        nonnull_overrides,
271        json_mappings,
272        deprecated,
273        optional_params,
274        group_by,
275    };
276
277    Ok(Query {
278        name,
279        command,
280        sql,
281        stmt,
282        annotations,
283    })
284}
285
286/// Preprocess Oracle SQL before parsing:
287/// 1. Strip `INTO :N, :N, ...` suffix from `RETURNING ... INTO` clauses
288/// 2. Convert `:N` positional placeholders to `?` (universally supported)
289fn preprocess_oracle_sql(sql: &str) -> String {
290    // Strip Oracle RETURNING ... INTO clause (output bind variables)
291    // e.g. "INSERT ... RETURNING id, name INTO :4, :5" → "INSERT ... RETURNING id, name"
292    let sql = strip_returning_into(sql);
293
294    // Convert :N → ? (outside string literals)
295    let mut result = String::with_capacity(sql.len());
296    let mut chars = sql.chars().peekable();
297    while let Some(ch) = chars.next() {
298        if ch == '\'' {
299            // Skip string literals
300            result.push(ch);
301            while let Some(inner) = chars.next() {
302                result.push(inner);
303                if inner == '\'' {
304                    if chars.peek() == Some(&'\'') {
305                        result.push(chars.next().unwrap());
306                    } else {
307                        break;
308                    }
309                }
310            }
311        } else if ch == ':' && chars.peek().is_some_and(|c| c.is_ascii_digit()) {
312            // Convert :N → ?
313            result.push('?');
314            while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
315                chars.next();
316            }
317        } else {
318            result.push(ch);
319        }
320    }
321    result
322}
323
324/// Strip the `INTO :N, :N, ...` suffix from an Oracle `RETURNING ... INTO` clause.
325fn strip_returning_into(sql: &str) -> String {
326    // Case-insensitive search for "INTO" after "RETURNING" at the end of the statement
327    let upper = sql.to_uppercase();
328    if let Some(ret_pos) = upper.rfind("RETURNING") {
329        let after_returning = &upper[ret_pos + "RETURNING".len()..];
330        if let Some(into_offset) = after_returning.find("INTO") {
331            let into_pos = ret_pos + "RETURNING".len() + into_offset;
332            // Keep everything before INTO, trim trailing whitespace/semicolons
333            let trimmed = sql[..into_pos].trim_end();
334            return trimmed.to_string();
335        }
336    }
337    sql.to_string()
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use crate::errors::ErrorCode;
344
345    fn parse(sql: &str) -> Result<Query, ScytheError> {
346        parse_query(sql)
347    }
348
349    #[test]
350    fn test_basic_parse() {
351        let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
352        let q = parse(input).unwrap();
353        assert_eq!(q.name, "GetUsers");
354        assert_eq!(q.command, QueryCommand::Many);
355        assert!(q.sql.contains("SELECT"));
356    }
357
358    #[test]
359    fn test_all_command_types() {
360        let cases = vec![
361            (":one", QueryCommand::One),
362            (":many", QueryCommand::Many),
363            (":exec", QueryCommand::Exec),
364            (":exec_result", QueryCommand::ExecResult),
365            (":exec_rows", QueryCommand::ExecRows),
366        ];
367        for (tag, expected) in cases {
368            let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
369            let q = parse(&input).unwrap();
370            assert_eq!(q.command, expected, "failed for {}", tag);
371        }
372    }
373
374    #[test]
375    fn test_case_insensitive_keywords() {
376        let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
377        let q = parse(input).unwrap();
378        assert_eq!(q.name, "GetUsers");
379        assert_eq!(q.command, QueryCommand::Many);
380    }
381
382    #[test]
383    fn test_missing_name_errors() {
384        let input = "-- @returns :many\nSELECT 1";
385        let err = parse(input).unwrap_err();
386        assert_eq!(err.code, ErrorCode::MissingAnnotation);
387        assert!(err.message.contains("name"));
388    }
389
390    #[test]
391    fn test_missing_returns_errors() {
392        let input = "-- @name Foo\nSELECT 1";
393        let err = parse(input).unwrap_err();
394        assert_eq!(err.code, ErrorCode::MissingAnnotation);
395        assert!(err.message.contains("returns"));
396    }
397
398    #[test]
399    fn test_invalid_returns_value() {
400        let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
401        let err = parse(input).unwrap_err();
402        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
403    }
404
405    #[test]
406    fn test_empty_name_value() {
407        // An empty name is accepted by the parser (it stores "")
408        let input = "-- @name\n-- @returns :one\nSELECT 1";
409        let q = parse(input).unwrap();
410        assert_eq!(q.name, "");
411    }
412
413    #[test]
414    fn test_param_annotation() {
415        let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
416        let q = parse(input).unwrap();
417        assert_eq!(q.annotations.param_docs.len(), 1);
418        assert_eq!(q.annotations.param_docs[0].name, "id");
419        assert_eq!(q.annotations.param_docs[0].description, "the user ID");
420    }
421
422    #[test]
423    fn test_param_no_description() {
424        let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
425        let q = parse(input).unwrap();
426        assert_eq!(q.annotations.param_docs.len(), 1);
427        assert_eq!(q.annotations.param_docs[0].name, "id");
428        assert_eq!(q.annotations.param_docs[0].description, "");
429    }
430
431    #[test]
432    fn test_nullable_annotation() {
433        let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
434        let q = parse(input).unwrap();
435        assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
436    }
437
438    #[test]
439    fn test_nonnull_annotation() {
440        let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
441        let q = parse(input).unwrap();
442        assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
443    }
444
445    #[test]
446    fn test_json_annotation() {
447        let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
448        let q = parse(input).unwrap();
449        assert_eq!(q.annotations.json_mappings.len(), 1);
450        assert_eq!(q.annotations.json_mappings[0].column, "data");
451        assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
452    }
453
454    #[test]
455    fn test_deprecated_annotation() {
456        let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
457        let q = parse(input).unwrap();
458        assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
459    }
460
461    #[test]
462    fn test_sql_syntax_error() {
463        let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
464        let err = parse(input).unwrap_err();
465        assert_eq!(err.code, ErrorCode::SyntaxError);
466    }
467
468    #[test]
469    fn test_trailing_semicolon() {
470        let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
471        let q = parse(input).unwrap();
472        assert_eq!(q.name, "Foo");
473    }
474
475    #[test]
476    fn test_multiple_statements_error() {
477        let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
478        let err = parse(input).unwrap_err();
479        assert_eq!(err.code, ErrorCode::SyntaxError);
480    }
481
482    #[test]
483    fn test_sql_preserved_without_annotations() {
484        let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
485        let q = parse(input).unwrap();
486        assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
487    }
488
489    #[test]
490    fn test_returns_without_colon_prefix() {
491        let input = "-- @name Foo\n-- @returns many\nSELECT 1";
492        let q = parse(input).unwrap();
493        assert_eq!(q.command, QueryCommand::Many);
494    }
495
496    #[test]
497    fn test_batch_command() {
498        let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
499        let q = parse(input).unwrap();
500        assert_eq!(q.command, QueryCommand::Batch);
501    }
502
503    #[test]
504    fn test_grouped_command_with_group_by() {
505        let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
506        let q = parse(input).unwrap();
507        assert_eq!(q.command, QueryCommand::Grouped);
508        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
509    }
510
511    #[test]
512    fn test_grouped_command_without_group_by_errors() {
513        let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
514        let err = parse(input).unwrap_err();
515        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
516        assert!(err.message.contains("@group_by"));
517    }
518
519    #[test]
520    fn test_group_by_without_grouped_is_ignored() {
521        let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
522        let q = parse(input).unwrap();
523        assert_eq!(q.command, QueryCommand::Many);
524        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
525    }
526
527    #[test]
528    fn test_preprocess_oracle_colon_placeholders() {
529        assert_eq!(
530            preprocess_oracle_sql("SELECT * FROM users WHERE id = :1"),
531            "SELECT * FROM users WHERE id = ?"
532        );
533        assert_eq!(
534            preprocess_oracle_sql("INSERT INTO users (name, email) VALUES (:1, :2)"),
535            "INSERT INTO users (name, email) VALUES (?, ?)"
536        );
537    }
538
539    #[test]
540    fn test_preprocess_oracle_preserves_string_literals() {
541        assert_eq!(
542            preprocess_oracle_sql("SELECT * FROM users WHERE name = ':1' AND id = :1"),
543            "SELECT * FROM users WHERE name = ':1' AND id = ?"
544        );
545    }
546
547    #[test]
548    fn test_preprocess_oracle_strips_returning_into() {
549        assert_eq!(
550            preprocess_oracle_sql(
551                "INSERT INTO users (name) VALUES (:1) RETURNING id, name INTO :2, :3"
552            ),
553            "INSERT INTO users (name) VALUES (?) RETURNING id, name"
554        );
555    }
556
557    #[test]
558    fn test_preprocess_oracle_full_insert_returning_into() {
559        let sql = "INSERT INTO users (name, email, active) VALUES (:1, :2, :3) RETURNING id, name, email, active, created_at INTO :4, :5, :6, :7, :8";
560        let result = preprocess_oracle_sql(sql);
561        assert_eq!(
562            result,
563            "INSERT INTO users (name, email, active) VALUES (?, ?, ?) RETURNING id, name, email, active, created_at"
564        );
565    }
566
567    #[test]
568    fn test_preprocess_oracle_no_returning_into_unchanged() {
569        assert_eq!(
570            preprocess_oracle_sql("DELETE FROM users WHERE id = :1"),
571            "DELETE FROM users WHERE id = ?"
572        );
573    }
574}