1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum QueryCommand {
8 One,
9 Opt,
10 Many,
11 Exec,
12 ExecResult,
13 ExecRows,
14 Batch,
15 Grouped,
16}
17
18impl std::fmt::Display for QueryCommand {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 QueryCommand::One => write!(f, "one"),
22 QueryCommand::Opt => write!(f, "opt"),
23 QueryCommand::Many => write!(f, "many"),
24 QueryCommand::Exec => write!(f, "exec"),
25 QueryCommand::ExecResult => write!(f, "exec_result"),
26 QueryCommand::ExecRows => write!(f, "exec_rows"),
27 QueryCommand::Batch => write!(f, "batch"),
28 QueryCommand::Grouped => write!(f, "grouped"),
29 }
30 }
31}
32
33impl QueryCommand {
34 fn from_str(s: &str) -> Result<Self, ScytheError> {
35 match s {
36 "one" => Ok(QueryCommand::One),
37 "opt" => Ok(QueryCommand::Opt),
38 "many" => Ok(QueryCommand::Many),
39 "exec" => Ok(QueryCommand::Exec),
40 "exec_result" => Ok(QueryCommand::ExecResult),
41 "exec_rows" => Ok(QueryCommand::ExecRows),
42 "batch" => Ok(QueryCommand::Batch),
43 "grouped" => Ok(QueryCommand::Grouped),
44 other => Err(ScytheError::invalid_annotation(format!(
45 "invalid @returns value: {other}"
46 ))),
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct ParamDoc {
53 pub name: String,
54 pub description: String,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct JsonMapping {
59 pub column: String,
60 pub rust_type: String,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct Annotations {
65 pub name: String,
66 pub command: QueryCommand,
67 pub param_docs: Vec<ParamDoc>,
68 pub nullable_overrides: Vec<String>,
69 pub nonnull_overrides: Vec<String>,
70 pub json_mappings: Vec<JsonMapping>,
71 pub deprecated: Option<String>,
72 pub optional_params: Vec<String>,
73 pub group_by: Option<String>,
74}
75
76#[derive(Debug)]
77pub struct Query {
78 pub name: String,
79 pub command: QueryCommand,
80 pub sql: String,
81 pub stmt: sqlparser::ast::Statement,
82 pub annotations: Annotations,
83}
84
85pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
87 parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
88}
89
90pub fn parse_query_with_dialect(
92 query_sql: &str,
93 dialect: &SqlDialect,
94) -> Result<Query, ScytheError> {
95 let mut name: Option<String> = None;
96 let mut command: Option<QueryCommand> = None;
97 let mut param_docs = Vec::new();
98 let mut nullable_overrides = Vec::new();
99 let mut nonnull_overrides = Vec::new();
100 let mut json_mappings = Vec::new();
101 let mut deprecated: Option<String> = None;
102 let mut optional_params = Vec::new();
103 let mut group_by: Option<String> = None;
104
105 let mut sql_lines = Vec::new();
106
107 for line in query_sql.lines() {
108 let trimmed = line.trim();
109
110 let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
112 let rest = rest.trim_start();
113 rest.strip_prefix('@')
114 } else {
115 None
116 };
117
118 if let Some(body) = annotation_body {
119 let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
121 Some(pos) => (&body[..pos], body[pos..].trim()),
122 None => (body, ""),
123 };
124
125 match keyword.to_ascii_lowercase().as_str() {
126 "name" => {
127 name = Some(value.to_string());
128 }
129 "returns" => {
130 let cmd_str = value.strip_prefix(':').unwrap_or(value);
131 command = Some(QueryCommand::from_str(cmd_str)?);
132 }
133 "param" => {
134 if let Some(colon_pos) = value.find(':') {
136 let param_name = value[..colon_pos].trim().to_string();
137 let description = value[colon_pos + 1..].trim().to_string();
138 param_docs.push(ParamDoc {
139 name: param_name,
140 description,
141 });
142 } else {
143 param_docs.push(ParamDoc {
144 name: value.to_string(),
145 description: String::new(),
146 });
147 }
148 }
149 "nullable" => {
150 for col in value.split(',') {
151 let col = col.trim();
152 if !col.is_empty() {
153 nullable_overrides.push(col.to_string());
154 }
155 }
156 }
157 "nonnull" => {
158 for col in value.split(',') {
159 let col = col.trim();
160 if !col.is_empty() {
161 nonnull_overrides.push(col.to_string());
162 }
163 }
164 }
165 "json" => {
166 if let Some(eq_pos) = value.find('=') {
168 let column = value[..eq_pos].trim().to_string();
169 let rust_type = value[eq_pos + 1..].trim().to_string();
170 json_mappings.push(JsonMapping { column, rust_type });
171 }
172 }
173 "deprecated" => {
174 deprecated = Some(value.to_string());
175 }
176 "group_by" => {
177 group_by = Some(value.to_string());
178 }
179 "optional" => {
180 for param in value.split(',') {
181 let param = param.trim();
182 if !param.is_empty() {
183 optional_params.push(param.to_string());
184 }
185 }
186 }
187 _ => {
188 }
190 }
191 } else {
192 sql_lines.push(line);
193 }
194 }
195
196 let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
197 let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
198
199 if command == QueryCommand::Grouped && group_by.is_none() {
200 return Err(ScytheError::invalid_annotation(
201 "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
202 ));
203 }
204
205 let sql = sql_lines.join("\n").trim().to_string();
206
207 if sql.is_empty() {
208 return Err(ScytheError::syntax("empty SQL body"));
209 }
210
211 let sql = if *dialect == SqlDialect::Oracle {
215 preprocess_oracle_sql(&sql)
216 } else {
217 sql
218 };
219
220 let parser_dialect = dialect.to_sqlparser_dialect();
221 let statements = Parser::parse_sql(parser_dialect.as_ref(), &sql)
222 .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
223
224 if statements.len() != 1 {
225 let non_empty: Vec<_> = statements
228 .into_iter()
229 .filter(|s| {
230 !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
231 })
232 .collect();
233 if non_empty.len() != 1 {
234 return Err(ScytheError::syntax("expected exactly one SQL statement"));
235 }
236 let stmt = non_empty
237 .into_iter()
238 .next()
239 .expect("filtered to exactly one statement");
240 let annotations = Annotations {
241 name: name.clone(),
242 command: command.clone(),
243 param_docs,
244 nullable_overrides,
245 nonnull_overrides,
246 json_mappings,
247 deprecated,
248 optional_params,
249 group_by: group_by.clone(),
250 };
251 return Ok(Query {
252 name,
253 command,
254 sql,
255 stmt,
256 annotations,
257 });
258 }
259
260 let stmt = statements
261 .into_iter()
262 .next()
263 .expect("filtered to exactly one statement");
264
265 let annotations = Annotations {
266 name: name.clone(),
267 command: command.clone(),
268 param_docs,
269 nullable_overrides,
270 nonnull_overrides,
271 json_mappings,
272 deprecated,
273 optional_params,
274 group_by,
275 };
276
277 Ok(Query {
278 name,
279 command,
280 sql,
281 stmt,
282 annotations,
283 })
284}
285
286fn preprocess_oracle_sql(sql: &str) -> String {
290 let sql = strip_returning_into(sql);
293
294 let mut result = String::with_capacity(sql.len());
296 let mut chars = sql.chars().peekable();
297 while let Some(ch) = chars.next() {
298 if ch == '\'' {
299 result.push(ch);
301 while let Some(inner) = chars.next() {
302 result.push(inner);
303 if inner == '\'' {
304 if chars.peek() == Some(&'\'') {
305 result.push(chars.next().unwrap());
306 } else {
307 break;
308 }
309 }
310 }
311 } else if ch == ':' && chars.peek().is_some_and(|c| c.is_ascii_digit()) {
312 result.push('?');
314 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
315 chars.next();
316 }
317 } else {
318 result.push(ch);
319 }
320 }
321 result
322}
323
324fn strip_returning_into(sql: &str) -> String {
326 let upper = sql.to_uppercase();
328 if let Some(ret_pos) = upper.rfind("RETURNING") {
329 let after_returning = &upper[ret_pos + "RETURNING".len()..];
330 if let Some(into_offset) = after_returning.find("INTO") {
331 let into_pos = ret_pos + "RETURNING".len() + into_offset;
332 let trimmed = sql[..into_pos].trim_end();
334 return trimmed.to_string();
335 }
336 }
337 sql.to_string()
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::errors::ErrorCode;
344
345 fn parse(sql: &str) -> Result<Query, ScytheError> {
346 parse_query(sql)
347 }
348
349 #[test]
350 fn test_basic_parse() {
351 let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
352 let q = parse(input).unwrap();
353 assert_eq!(q.name, "GetUsers");
354 assert_eq!(q.command, QueryCommand::Many);
355 assert!(q.sql.contains("SELECT"));
356 }
357
358 #[test]
359 fn test_all_command_types() {
360 let cases = vec![
361 (":one", QueryCommand::One),
362 (":many", QueryCommand::Many),
363 (":exec", QueryCommand::Exec),
364 (":exec_result", QueryCommand::ExecResult),
365 (":exec_rows", QueryCommand::ExecRows),
366 ];
367 for (tag, expected) in cases {
368 let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
369 let q = parse(&input).unwrap();
370 assert_eq!(q.command, expected, "failed for {}", tag);
371 }
372 }
373
374 #[test]
375 fn test_case_insensitive_keywords() {
376 let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
377 let q = parse(input).unwrap();
378 assert_eq!(q.name, "GetUsers");
379 assert_eq!(q.command, QueryCommand::Many);
380 }
381
382 #[test]
383 fn test_missing_name_errors() {
384 let input = "-- @returns :many\nSELECT 1";
385 let err = parse(input).unwrap_err();
386 assert_eq!(err.code, ErrorCode::MissingAnnotation);
387 assert!(err.message.contains("name"));
388 }
389
390 #[test]
391 fn test_missing_returns_errors() {
392 let input = "-- @name Foo\nSELECT 1";
393 let err = parse(input).unwrap_err();
394 assert_eq!(err.code, ErrorCode::MissingAnnotation);
395 assert!(err.message.contains("returns"));
396 }
397
398 #[test]
399 fn test_invalid_returns_value() {
400 let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
401 let err = parse(input).unwrap_err();
402 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
403 }
404
405 #[test]
406 fn test_empty_name_value() {
407 let input = "-- @name\n-- @returns :one\nSELECT 1";
409 let q = parse(input).unwrap();
410 assert_eq!(q.name, "");
411 }
412
413 #[test]
414 fn test_param_annotation() {
415 let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
416 let q = parse(input).unwrap();
417 assert_eq!(q.annotations.param_docs.len(), 1);
418 assert_eq!(q.annotations.param_docs[0].name, "id");
419 assert_eq!(q.annotations.param_docs[0].description, "the user ID");
420 }
421
422 #[test]
423 fn test_param_no_description() {
424 let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
425 let q = parse(input).unwrap();
426 assert_eq!(q.annotations.param_docs.len(), 1);
427 assert_eq!(q.annotations.param_docs[0].name, "id");
428 assert_eq!(q.annotations.param_docs[0].description, "");
429 }
430
431 #[test]
432 fn test_nullable_annotation() {
433 let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
434 let q = parse(input).unwrap();
435 assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
436 }
437
438 #[test]
439 fn test_nonnull_annotation() {
440 let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
441 let q = parse(input).unwrap();
442 assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
443 }
444
445 #[test]
446 fn test_json_annotation() {
447 let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
448 let q = parse(input).unwrap();
449 assert_eq!(q.annotations.json_mappings.len(), 1);
450 assert_eq!(q.annotations.json_mappings[0].column, "data");
451 assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
452 }
453
454 #[test]
455 fn test_deprecated_annotation() {
456 let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
457 let q = parse(input).unwrap();
458 assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
459 }
460
461 #[test]
462 fn test_sql_syntax_error() {
463 let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
464 let err = parse(input).unwrap_err();
465 assert_eq!(err.code, ErrorCode::SyntaxError);
466 }
467
468 #[test]
469 fn test_trailing_semicolon() {
470 let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
471 let q = parse(input).unwrap();
472 assert_eq!(q.name, "Foo");
473 }
474
475 #[test]
476 fn test_multiple_statements_error() {
477 let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
478 let err = parse(input).unwrap_err();
479 assert_eq!(err.code, ErrorCode::SyntaxError);
480 }
481
482 #[test]
483 fn test_sql_preserved_without_annotations() {
484 let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
485 let q = parse(input).unwrap();
486 assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
487 }
488
489 #[test]
490 fn test_returns_without_colon_prefix() {
491 let input = "-- @name Foo\n-- @returns many\nSELECT 1";
492 let q = parse(input).unwrap();
493 assert_eq!(q.command, QueryCommand::Many);
494 }
495
496 #[test]
497 fn test_batch_command() {
498 let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
499 let q = parse(input).unwrap();
500 assert_eq!(q.command, QueryCommand::Batch);
501 }
502
503 #[test]
504 fn test_grouped_command_with_group_by() {
505 let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
506 let q = parse(input).unwrap();
507 assert_eq!(q.command, QueryCommand::Grouped);
508 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
509 }
510
511 #[test]
512 fn test_grouped_command_without_group_by_errors() {
513 let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
514 let err = parse(input).unwrap_err();
515 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
516 assert!(err.message.contains("@group_by"));
517 }
518
519 #[test]
520 fn test_group_by_without_grouped_is_ignored() {
521 let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
522 let q = parse(input).unwrap();
523 assert_eq!(q.command, QueryCommand::Many);
524 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
525 }
526
527 #[test]
528 fn test_preprocess_oracle_colon_placeholders() {
529 assert_eq!(
530 preprocess_oracle_sql("SELECT * FROM users WHERE id = :1"),
531 "SELECT * FROM users WHERE id = ?"
532 );
533 assert_eq!(
534 preprocess_oracle_sql("INSERT INTO users (name, email) VALUES (:1, :2)"),
535 "INSERT INTO users (name, email) VALUES (?, ?)"
536 );
537 }
538
539 #[test]
540 fn test_preprocess_oracle_preserves_string_literals() {
541 assert_eq!(
542 preprocess_oracle_sql("SELECT * FROM users WHERE name = ':1' AND id = :1"),
543 "SELECT * FROM users WHERE name = ':1' AND id = ?"
544 );
545 }
546
547 #[test]
548 fn test_preprocess_oracle_strips_returning_into() {
549 assert_eq!(
550 preprocess_oracle_sql(
551 "INSERT INTO users (name) VALUES (:1) RETURNING id, name INTO :2, :3"
552 ),
553 "INSERT INTO users (name) VALUES (?) RETURNING id, name"
554 );
555 }
556
557 #[test]
558 fn test_preprocess_oracle_full_insert_returning_into() {
559 let sql = "INSERT INTO users (name, email, active) VALUES (:1, :2, :3) RETURNING id, name, email, active, created_at INTO :4, :5, :6, :7, :8";
560 let result = preprocess_oracle_sql(sql);
561 assert_eq!(
562 result,
563 "INSERT INTO users (name, email, active) VALUES (?, ?, ?) RETURNING id, name, email, active, created_at"
564 );
565 }
566
567 #[test]
568 fn test_preprocess_oracle_no_returning_into_unchanged() {
569 assert_eq!(
570 preprocess_oracle_sql("DELETE FROM users WHERE id = :1"),
571 "DELETE FROM users WHERE id = ?"
572 );
573 }
574}