Skip to main content

scythe_core/parser/
mod.rs

1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum QueryCommand {
8    One,
9    Many,
10    Exec,
11    ExecResult,
12    ExecRows,
13    Batch,
14    Grouped,
15}
16
17impl std::fmt::Display for QueryCommand {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        match self {
20            QueryCommand::One => write!(f, "one"),
21            QueryCommand::Many => write!(f, "many"),
22            QueryCommand::Exec => write!(f, "exec"),
23            QueryCommand::ExecResult => write!(f, "exec_result"),
24            QueryCommand::ExecRows => write!(f, "exec_rows"),
25            QueryCommand::Batch => write!(f, "batch"),
26            QueryCommand::Grouped => write!(f, "grouped"),
27        }
28    }
29}
30
31impl QueryCommand {
32    fn from_str(s: &str) -> Result<Self, ScytheError> {
33        match s {
34            "one" => Ok(QueryCommand::One),
35            "many" => Ok(QueryCommand::Many),
36            "exec" => Ok(QueryCommand::Exec),
37            "exec_result" => Ok(QueryCommand::ExecResult),
38            "exec_rows" => Ok(QueryCommand::ExecRows),
39            "batch" => Ok(QueryCommand::Batch),
40            "grouped" => Ok(QueryCommand::Grouped),
41            other => Err(ScytheError::invalid_annotation(format!(
42                "invalid @returns value: {other}"
43            ))),
44        }
45    }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct ParamDoc {
50    pub name: String,
51    pub description: String,
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct JsonMapping {
56    pub column: String,
57    pub rust_type: String,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct Annotations {
62    pub name: String,
63    pub command: QueryCommand,
64    pub param_docs: Vec<ParamDoc>,
65    pub nullable_overrides: Vec<String>,
66    pub nonnull_overrides: Vec<String>,
67    pub json_mappings: Vec<JsonMapping>,
68    pub deprecated: Option<String>,
69    pub optional_params: Vec<String>,
70    pub group_by: Option<String>,
71}
72
73#[derive(Debug)]
74pub struct Query {
75    pub name: String,
76    pub command: QueryCommand,
77    pub sql: String,
78    pub stmt: sqlparser::ast::Statement,
79    pub annotations: Annotations,
80}
81
82/// Parse a single annotated SQL query into a `Query` using the PostgreSQL dialect.
83pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
84    parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
85}
86
87/// Parse a single annotated SQL query into a `Query` using the specified dialect.
88pub fn parse_query_with_dialect(
89    query_sql: &str,
90    dialect: &SqlDialect,
91) -> Result<Query, ScytheError> {
92    let mut name: Option<String> = None;
93    let mut command: Option<QueryCommand> = None;
94    let mut param_docs = Vec::new();
95    let mut nullable_overrides = Vec::new();
96    let mut nonnull_overrides = Vec::new();
97    let mut json_mappings = Vec::new();
98    let mut deprecated: Option<String> = None;
99    let mut optional_params = Vec::new();
100    let mut group_by: Option<String> = None;
101
102    let mut sql_lines = Vec::new();
103
104    for line in query_sql.lines() {
105        let trimmed = line.trim();
106
107        // Check for annotation: "-- @..." or "--@..."
108        let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
109            let rest = rest.trim_start();
110            rest.strip_prefix('@')
111        } else {
112            None
113        };
114
115        if let Some(body) = annotation_body {
116            // Parse the annotation keyword and value
117            let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
118                Some(pos) => (&body[..pos], body[pos..].trim()),
119                None => (body, ""),
120            };
121
122            match keyword.to_ascii_lowercase().as_str() {
123                "name" => {
124                    name = Some(value.to_string());
125                }
126                "returns" => {
127                    let cmd_str = value.strip_prefix(':').unwrap_or(value);
128                    command = Some(QueryCommand::from_str(cmd_str)?);
129                }
130                "param" => {
131                    // format: "<name>: <description>" or "<name>:<description>"
132                    if let Some(colon_pos) = value.find(':') {
133                        let param_name = value[..colon_pos].trim().to_string();
134                        let description = value[colon_pos + 1..].trim().to_string();
135                        param_docs.push(ParamDoc {
136                            name: param_name,
137                            description,
138                        });
139                    } else {
140                        param_docs.push(ParamDoc {
141                            name: value.to_string(),
142                            description: String::new(),
143                        });
144                    }
145                }
146                "nullable" => {
147                    for col in value.split(',') {
148                        let col = col.trim();
149                        if !col.is_empty() {
150                            nullable_overrides.push(col.to_string());
151                        }
152                    }
153                }
154                "nonnull" => {
155                    for col in value.split(',') {
156                        let col = col.trim();
157                        if !col.is_empty() {
158                            nonnull_overrides.push(col.to_string());
159                        }
160                    }
161                }
162                "json" => {
163                    // format: "<col> = <Type>"
164                    if let Some(eq_pos) = value.find('=') {
165                        let column = value[..eq_pos].trim().to_string();
166                        let rust_type = value[eq_pos + 1..].trim().to_string();
167                        json_mappings.push(JsonMapping { column, rust_type });
168                    }
169                }
170                "deprecated" => {
171                    deprecated = Some(value.to_string());
172                }
173                "group_by" => {
174                    group_by = Some(value.to_string());
175                }
176                "optional" => {
177                    for param in value.split(',') {
178                        let param = param.trim();
179                        if !param.is_empty() {
180                            optional_params.push(param.to_string());
181                        }
182                    }
183                }
184                _ => {
185                    // Unknown annotation — ignore or could error
186                }
187            }
188        } else {
189            sql_lines.push(line);
190        }
191    }
192
193    let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
194    let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
195
196    if command == QueryCommand::Grouped && group_by.is_none() {
197        return Err(ScytheError::invalid_annotation(
198            "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
199        ));
200    }
201
202    let sql = sql_lines.join("\n").trim().to_string();
203
204    if sql.is_empty() {
205        return Err(ScytheError::syntax("empty SQL body"));
206    }
207
208    let parser_dialect = dialect.to_sqlparser_dialect();
209    let statements = Parser::parse_sql(parser_dialect.as_ref(), &sql)
210        .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
211
212    if statements.len() != 1 {
213        // sqlparser may produce an extra empty statement from a trailing semicolon —
214        // filter those out by checking for exactly one non-empty statement.
215        let non_empty: Vec<_> = statements
216            .into_iter()
217            .filter(|s| {
218                !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
219            })
220            .collect();
221        if non_empty.len() != 1 {
222            return Err(ScytheError::syntax("expected exactly one SQL statement"));
223        }
224        let stmt = non_empty
225            .into_iter()
226            .next()
227            .expect("filtered to exactly one statement");
228        let annotations = Annotations {
229            name: name.clone(),
230            command: command.clone(),
231            param_docs,
232            nullable_overrides,
233            nonnull_overrides,
234            json_mappings,
235            deprecated,
236            optional_params,
237            group_by: group_by.clone(),
238        };
239        return Ok(Query {
240            name,
241            command,
242            sql,
243            stmt,
244            annotations,
245        });
246    }
247
248    let stmt = statements
249        .into_iter()
250        .next()
251        .expect("filtered to exactly one statement");
252
253    let annotations = Annotations {
254        name: name.clone(),
255        command: command.clone(),
256        param_docs,
257        nullable_overrides,
258        nonnull_overrides,
259        json_mappings,
260        deprecated,
261        optional_params,
262        group_by,
263    };
264
265    Ok(Query {
266        name,
267        command,
268        sql,
269        stmt,
270        annotations,
271    })
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::errors::ErrorCode;
278
279    fn parse(sql: &str) -> Result<Query, ScytheError> {
280        parse_query(sql)
281    }
282
283    #[test]
284    fn test_basic_parse() {
285        let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
286        let q = parse(input).unwrap();
287        assert_eq!(q.name, "GetUsers");
288        assert_eq!(q.command, QueryCommand::Many);
289        assert!(q.sql.contains("SELECT"));
290    }
291
292    #[test]
293    fn test_all_command_types() {
294        let cases = vec![
295            (":one", QueryCommand::One),
296            (":many", QueryCommand::Many),
297            (":exec", QueryCommand::Exec),
298            (":exec_result", QueryCommand::ExecResult),
299            (":exec_rows", QueryCommand::ExecRows),
300        ];
301        for (tag, expected) in cases {
302            let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
303            let q = parse(&input).unwrap();
304            assert_eq!(q.command, expected, "failed for {}", tag);
305        }
306    }
307
308    #[test]
309    fn test_case_insensitive_keywords() {
310        let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
311        let q = parse(input).unwrap();
312        assert_eq!(q.name, "GetUsers");
313        assert_eq!(q.command, QueryCommand::Many);
314    }
315
316    #[test]
317    fn test_missing_name_errors() {
318        let input = "-- @returns :many\nSELECT 1";
319        let err = parse(input).unwrap_err();
320        assert_eq!(err.code, ErrorCode::MissingAnnotation);
321        assert!(err.message.contains("name"));
322    }
323
324    #[test]
325    fn test_missing_returns_errors() {
326        let input = "-- @name Foo\nSELECT 1";
327        let err = parse(input).unwrap_err();
328        assert_eq!(err.code, ErrorCode::MissingAnnotation);
329        assert!(err.message.contains("returns"));
330    }
331
332    #[test]
333    fn test_invalid_returns_value() {
334        let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
335        let err = parse(input).unwrap_err();
336        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
337    }
338
339    #[test]
340    fn test_empty_name_value() {
341        // An empty name is accepted by the parser (it stores "")
342        let input = "-- @name\n-- @returns :one\nSELECT 1";
343        let q = parse(input).unwrap();
344        assert_eq!(q.name, "");
345    }
346
347    #[test]
348    fn test_param_annotation() {
349        let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
350        let q = parse(input).unwrap();
351        assert_eq!(q.annotations.param_docs.len(), 1);
352        assert_eq!(q.annotations.param_docs[0].name, "id");
353        assert_eq!(q.annotations.param_docs[0].description, "the user ID");
354    }
355
356    #[test]
357    fn test_param_no_description() {
358        let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
359        let q = parse(input).unwrap();
360        assert_eq!(q.annotations.param_docs.len(), 1);
361        assert_eq!(q.annotations.param_docs[0].name, "id");
362        assert_eq!(q.annotations.param_docs[0].description, "");
363    }
364
365    #[test]
366    fn test_nullable_annotation() {
367        let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
368        let q = parse(input).unwrap();
369        assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
370    }
371
372    #[test]
373    fn test_nonnull_annotation() {
374        let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
375        let q = parse(input).unwrap();
376        assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
377    }
378
379    #[test]
380    fn test_json_annotation() {
381        let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
382        let q = parse(input).unwrap();
383        assert_eq!(q.annotations.json_mappings.len(), 1);
384        assert_eq!(q.annotations.json_mappings[0].column, "data");
385        assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
386    }
387
388    #[test]
389    fn test_deprecated_annotation() {
390        let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
391        let q = parse(input).unwrap();
392        assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
393    }
394
395    #[test]
396    fn test_sql_syntax_error() {
397        let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
398        let err = parse(input).unwrap_err();
399        assert_eq!(err.code, ErrorCode::SyntaxError);
400    }
401
402    #[test]
403    fn test_trailing_semicolon() {
404        let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
405        let q = parse(input).unwrap();
406        assert_eq!(q.name, "Foo");
407    }
408
409    #[test]
410    fn test_multiple_statements_error() {
411        let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
412        let err = parse(input).unwrap_err();
413        assert_eq!(err.code, ErrorCode::SyntaxError);
414    }
415
416    #[test]
417    fn test_sql_preserved_without_annotations() {
418        let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
419        let q = parse(input).unwrap();
420        assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
421    }
422
423    #[test]
424    fn test_returns_without_colon_prefix() {
425        let input = "-- @name Foo\n-- @returns many\nSELECT 1";
426        let q = parse(input).unwrap();
427        assert_eq!(q.command, QueryCommand::Many);
428    }
429
430    #[test]
431    fn test_batch_command() {
432        let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
433        let q = parse(input).unwrap();
434        assert_eq!(q.command, QueryCommand::Batch);
435    }
436
437    #[test]
438    fn test_grouped_command_with_group_by() {
439        let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
440        let q = parse(input).unwrap();
441        assert_eq!(q.command, QueryCommand::Grouped);
442        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
443    }
444
445    #[test]
446    fn test_grouped_command_without_group_by_errors() {
447        let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
448        let err = parse(input).unwrap_err();
449        assert_eq!(err.code, ErrorCode::InvalidAnnotation);
450        assert!(err.message.contains("@group_by"));
451    }
452
453    #[test]
454    fn test_group_by_without_grouped_is_ignored() {
455        let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
456        let q = parse(input).unwrap();
457        assert_eq!(q.command, QueryCommand::Many);
458        assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
459    }
460}