1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, Default, PartialEq, Eq)]
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
8pub enum QueryCommand {
9 One,
10 Opt,
11 Many,
12 #[default]
13 Exec,
14 ExecResult,
15 ExecRows,
16 Batch,
17 Grouped,
18}
19
20impl std::fmt::Display for QueryCommand {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 QueryCommand::One => write!(f, "one"),
24 QueryCommand::Opt => write!(f, "opt"),
25 QueryCommand::Many => write!(f, "many"),
26 QueryCommand::Exec => write!(f, "exec"),
27 QueryCommand::ExecResult => write!(f, "exec_result"),
28 QueryCommand::ExecRows => write!(f, "exec_rows"),
29 QueryCommand::Batch => write!(f, "batch"),
30 QueryCommand::Grouped => write!(f, "grouped"),
31 }
32 }
33}
34
35impl QueryCommand {
36 fn from_str(s: &str) -> Result<Self, ScytheError> {
37 match s {
38 "one" => Ok(QueryCommand::One),
39 "opt" => Ok(QueryCommand::Opt),
40 "many" => Ok(QueryCommand::Many),
41 "exec" => Ok(QueryCommand::Exec),
42 "exec_result" => Ok(QueryCommand::ExecResult),
43 "exec_rows" => Ok(QueryCommand::ExecRows),
44 "batch" => Ok(QueryCommand::Batch),
45 "grouped" => Ok(QueryCommand::Grouped),
46 other => Err(ScytheError::invalid_annotation(format!(
47 "invalid @returns value: {other}"
48 ))),
49 }
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55pub struct ParamDoc {
56 pub name: String,
57 pub description: String,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62pub struct JsonMapping {
63 pub column: String,
64 pub rust_type: String,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
75#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
76pub struct CustomAnnotation {
77 pub name: String,
79 pub value: String,
81 pub line: usize,
83}
84
85#[derive(Debug, Clone, Default, PartialEq, Eq)]
86#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
87pub struct Annotations {
88 pub name: String,
89 pub command: QueryCommand,
90 pub param_docs: Vec<ParamDoc>,
91 pub nullable_overrides: Vec<String>,
92 pub nonnull_overrides: Vec<String>,
93 pub json_mappings: Vec<JsonMapping>,
94 pub deprecated: Option<String>,
95 pub optional_params: Vec<String>,
96 pub group_by: Option<String>,
97 pub custom: Vec<CustomAnnotation>,
100}
101
102#[derive(Debug)]
103pub struct Query {
104 pub name: String,
105 pub command: QueryCommand,
106 pub sql: String,
107 pub stmt: sqlparser::ast::Statement,
108 pub annotations: Annotations,
109}
110
111pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
113 parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
114}
115
116pub fn parse_query_with_dialect(
118 query_sql: &str,
119 dialect: &SqlDialect,
120) -> Result<Query, ScytheError> {
121 let mut name: Option<String> = None;
122 let mut command: Option<QueryCommand> = None;
123 let mut param_docs = Vec::new();
124 let mut nullable_overrides = Vec::new();
125 let mut nonnull_overrides = Vec::new();
126 let mut json_mappings = Vec::new();
127 let mut deprecated: Option<String> = None;
128 let mut optional_params = Vec::new();
129 let mut group_by: Option<String> = None;
130 let mut custom: Vec<CustomAnnotation> = Vec::new();
131
132 let mut sql_lines = Vec::new();
133
134 for (line_idx, line) in query_sql.lines().enumerate() {
135 let line_no = line_idx + 1;
136 let trimmed = line.trim();
137
138 let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
140 let rest = rest.trim_start();
141 rest.strip_prefix('@')
142 } else {
143 None
144 };
145
146 if let Some(body) = annotation_body {
147 let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
149 Some(pos) => (&body[..pos], body[pos..].trim()),
150 None => (body, ""),
151 };
152
153 match keyword.to_ascii_lowercase().as_str() {
154 "name" => {
155 name = Some(value.to_string());
156 }
157 "returns" => {
158 let cmd_str = value.strip_prefix(':').unwrap_or(value);
159 command = Some(QueryCommand::from_str(cmd_str)?);
160 }
161 "param" => {
162 if let Some(colon_pos) = value.find(':') {
164 let param_name = value[..colon_pos].trim().to_string();
165 let description = value[colon_pos + 1..].trim().to_string();
166 param_docs.push(ParamDoc {
167 name: param_name,
168 description,
169 });
170 } else {
171 param_docs.push(ParamDoc {
172 name: value.to_string(),
173 description: String::new(),
174 });
175 }
176 }
177 "nullable" => {
178 for col in value.split(',') {
179 let col = col.trim();
180 if !col.is_empty() {
181 nullable_overrides.push(col.to_string());
182 }
183 }
184 }
185 "nonnull" => {
186 for col in value.split(',') {
187 let col = col.trim();
188 if !col.is_empty() {
189 nonnull_overrides.push(col.to_string());
190 }
191 }
192 }
193 "json" => {
194 if let Some(eq_pos) = value.find('=') {
196 let column = value[..eq_pos].trim().to_string();
197 let rust_type = value[eq_pos + 1..].trim().to_string();
198 json_mappings.push(JsonMapping { column, rust_type });
199 }
200 }
201 "deprecated" => {
202 deprecated = Some(value.to_string());
203 }
204 "group_by" => {
205 group_by = Some(value.to_string());
206 }
207 "optional" => {
208 for param in value.split(',') {
209 let param = param.trim();
210 if !param.is_empty() {
211 optional_params.push(param.to_string());
212 }
213 }
214 }
215 other => {
216 custom.push(CustomAnnotation {
218 name: other.to_string(),
219 value: value.to_string(),
220 line: line_no,
221 });
222 }
223 }
224 } else {
225 sql_lines.push(line);
226 }
227 }
228
229 let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
230 let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
231
232 if command == QueryCommand::Grouped && group_by.is_none() {
233 return Err(ScytheError::invalid_annotation(
234 "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
235 ));
236 }
237
238 let sql = sql_lines.join("\n").trim().to_string();
239
240 if sql.is_empty() {
241 return Err(ScytheError::syntax("empty SQL body"));
242 }
243
244 let (sql, parse_sql) = if *dialect == SqlDialect::Oracle {
252 let processed = preprocess_oracle_sql(&sql);
253 (processed.clone(), processed)
254 } else if *dialect == SqlDialect::MsSql {
255 let codegen_sql = convert_mssql_placeholders(&sql);
257 let parse_sql = preprocess_mssql_sql(&sql);
259 (codegen_sql, parse_sql)
260 } else if *dialect == SqlDialect::PostgreSQL {
261 let parse_sql = preprocess_postgres_sql(&sql);
262 (sql.clone(), parse_sql)
263 } else {
264 (sql.clone(), sql)
265 };
266
267 let parser_dialect = dialect.to_sqlparser_dialect();
268 let statements = Parser::parse_sql(parser_dialect.as_ref(), &parse_sql)
269 .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
270
271 if statements.len() != 1 {
272 let non_empty: Vec<_> = statements
275 .into_iter()
276 .filter(|s| {
277 !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
278 })
279 .collect();
280 if non_empty.len() != 1 {
281 return Err(ScytheError::syntax("expected exactly one SQL statement"));
282 }
283 let stmt = non_empty
284 .into_iter()
285 .next()
286 .expect("filtered to exactly one statement");
287 let annotations = Annotations {
288 name: name.clone(),
289 command: command.clone(),
290 param_docs,
291 nullable_overrides,
292 nonnull_overrides,
293 json_mappings,
294 deprecated,
295 optional_params,
296 group_by: group_by.clone(),
297 custom,
298 };
299 return Ok(Query {
300 name,
301 command,
302 sql,
303 stmt,
304 annotations,
305 });
306 }
307
308 let stmt = statements
309 .into_iter()
310 .next()
311 .expect("filtered to exactly one statement");
312
313 let annotations = Annotations {
314 name: name.clone(),
315 command: command.clone(),
316 param_docs,
317 nullable_overrides,
318 nonnull_overrides,
319 json_mappings,
320 deprecated,
321 optional_params,
322 group_by,
323 custom,
324 };
325
326 Ok(Query {
327 name,
328 command,
329 sql,
330 stmt,
331 annotations,
332 })
333}
334
335fn preprocess_postgres_sql(sql: &str) -> String {
341 let mask = mask_postgres_for_scan(sql);
345 let mask_bytes = mask.as_bytes();
346 let bytes = sql.as_bytes();
347 let mut search_from = 0;
348 let mut result = String::with_capacity(sql.len());
349 let mut last = 0;
350 while let Some(rel) = find_keyword(&mask[search_from..], "ON CONFLICT") {
351 let on_conflict_pos = search_from + rel;
352 let after_on_conflict = on_conflict_pos + "ON CONFLICT".len();
353 let mut idx = after_on_conflict;
354 while idx < mask_bytes.len() && mask_bytes[idx].is_ascii_whitespace() {
355 idx += 1;
356 }
357 if idx >= mask_bytes.len() || mask_bytes[idx] != b'(' {
358 search_from = after_on_conflict;
359 continue;
360 }
361 let mut depth = 0i32;
362 let mut close = idx;
363 while close < mask_bytes.len() {
364 match mask_bytes[close] {
365 b'(' => depth += 1,
366 b')' => {
367 depth -= 1;
368 if depth == 0 {
369 break;
370 }
371 }
372 _ => {}
373 }
374 close += 1;
375 }
376 if depth != 0 {
377 return sql.to_string();
378 }
379 let mut after_cols = close + 1;
380 while after_cols < mask_bytes.len() && mask_bytes[after_cols].is_ascii_whitespace() {
381 after_cols += 1;
382 }
383 if mask[after_cols..].starts_with("WHERE")
384 && let Some(do_rel) = find_keyword(&mask[after_cols + "WHERE".len()..], "DO")
385 {
386 let do_abs = after_cols + "WHERE".len() + do_rel;
387 result.push_str(std::str::from_utf8(&bytes[last..after_cols]).unwrap_or(""));
390 last = do_abs;
391 search_from = do_abs;
392 continue;
393 }
394 search_from = close + 1;
395 }
396 result.push_str(std::str::from_utf8(&bytes[last..]).unwrap_or(""));
397 result
398}
399
400fn mask_postgres_for_scan(sql: &str) -> String {
405 let bytes = sql.as_bytes();
406 let mut out = vec![b' '; bytes.len()];
407 let mut i = 0;
408 while i < bytes.len() {
409 let b = bytes[i];
410 if b == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
411 while i < bytes.len() && bytes[i] != b'\n' {
413 out[i] = b' ';
414 i += 1;
415 }
416 continue;
417 }
418 if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
419 out[i] = b' ';
421 out[i + 1] = b' ';
422 i += 2;
423 while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
424 out[i] = b' ';
425 i += 1;
426 }
427 if i + 1 < bytes.len() {
428 out[i] = b' ';
429 out[i + 1] = b' ';
430 i += 2;
431 }
432 continue;
433 }
434 if b == b'\'' {
435 out[i] = b' ';
436 i += 1;
437 while i < bytes.len() {
438 if bytes[i] == b'\'' {
439 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
440 out[i] = b' ';
441 out[i + 1] = b' ';
442 i += 2;
443 continue;
444 }
445 out[i] = b' ';
446 i += 1;
447 break;
448 }
449 out[i] = b' ';
450 i += 1;
451 }
452 continue;
453 }
454 if b.is_ascii() {
457 out[i] = b.to_ascii_uppercase();
458 } else {
459 out[i] = b' ';
460 }
461 i += 1;
462 }
463 String::from_utf8(out).expect("mask is ASCII by construction")
464}
465
466fn find_keyword(haystack: &str, keyword: &str) -> Option<usize> {
469 let bytes = haystack.as_bytes();
470 let key = keyword.as_bytes();
471 let mut i = 0;
472 while i + key.len() <= bytes.len() {
473 if &bytes[i..i + key.len()] == key {
474 let prev_ok = i == 0 || !bytes[i - 1].is_ascii_alphanumeric();
475 let next = i + key.len();
476 let next_ok = next >= bytes.len() || !bytes[next].is_ascii_alphanumeric();
477 if prev_ok && next_ok {
478 return Some(i);
479 }
480 }
481 i += 1;
482 }
483 None
484}
485
486fn preprocess_oracle_sql(sql: &str) -> String {
490 let sql = strip_returning_into(sql);
493
494 let mut result = String::with_capacity(sql.len());
496 let mut chars = sql.chars().peekable();
497 while let Some(ch) = chars.next() {
498 if ch == '\'' {
499 result.push(ch);
501 while let Some(inner) = chars.next() {
502 result.push(inner);
503 if inner == '\'' {
504 if chars.peek() == Some(&'\'') {
505 result.push(chars.next().unwrap());
506 } else {
507 break;
508 }
509 }
510 }
511 } else if ch == ':' && chars.peek().is_some_and(|c| c.is_ascii_digit()) {
512 result.push('?');
514 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
515 chars.next();
516 }
517 } else {
518 result.push(ch);
519 }
520 }
521 result
522}
523
524fn convert_mssql_placeholders(sql: &str) -> String {
528 let mut result = String::with_capacity(sql.len());
529 let mut chars = sql.chars().peekable();
530 while let Some(ch) = chars.next() {
531 if ch == '\'' {
532 result.push(ch);
534 while let Some(inner) = chars.next() {
535 result.push(inner);
536 if inner == '\'' {
537 if chars.peek() == Some(&'\'') {
538 result.push(chars.next().unwrap());
540 } else {
541 break;
542 }
543 }
544 }
545 } else if ch == '@' && chars.peek().is_some_and(|c| *c == 'p' || *c == 'P') {
546 let mut lookahead = chars.clone();
548 lookahead.next(); if lookahead.peek().is_some_and(|c| c.is_ascii_digit()) {
550 chars.next(); while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
553 chars.next();
554 }
555 result.push('?');
556 } else {
557 result.push(ch);
558 }
559 } else {
560 result.push(ch);
561 }
562 }
563 result
564}
565
566fn preprocess_mssql_sql(sql: &str) -> String {
570 let sql = strip_and_convert_mssql_output(sql);
572 convert_mssql_placeholders(&sql)
574}
575
576fn strip_and_convert_mssql_output(sql: &str) -> String {
583 let upper = sql.to_uppercase();
585
586 if !upper.contains("INSERT") || !upper.contains("OUTPUT") {
588 return sql.to_string();
589 }
590
591 if let Some(output_pos) = find_word_position(&upper, "OUTPUT") {
593 let before_output = &upper[..output_pos];
595 if !before_output.contains("INSERT") {
596 return sql.to_string();
597 }
598
599 let after_output = &upper[output_pos + "OUTPUT".len()..];
601 if let Some(values_offset) = find_word_position(after_output, "VALUES") {
602 let values_pos = output_pos + "OUTPUT".len() + values_offset;
603
604 let output_cols_str = &sql[output_pos + "OUTPUT".len()..values_pos];
606
607 let cols = parse_inserted_columns(output_cols_str);
609
610 if !cols.is_empty() {
611 let before_output_sql = sql[..output_pos].trim_end();
614 let after_values = sql[values_pos..].trim_end();
615 let (values_body, trailing) = if let Some(stripped) = after_values.strip_suffix(';')
616 {
617 (stripped, ";")
618 } else {
619 (after_values, "")
620 };
621
622 return format!(
623 "{}\n{} RETURNING {}{}",
624 before_output_sql, values_body, cols, trailing
625 );
626 }
627 }
628 }
629
630 sql.to_string()
631}
632
633fn find_word_position(text: &str, word: &str) -> Option<usize> {
636 let mut pos = 0;
637 let word_len = word.len();
638 while let Some(idx) = text[pos..].find(word) {
639 let abs_idx = pos + idx;
640
641 let before_ok = abs_idx == 0
643 || !text
644 .as_bytes()
645 .get(abs_idx - 1)
646 .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
647
648 let after_idx = abs_idx + word_len;
650 let after_ok = after_idx >= text.len()
651 || !text
652 .as_bytes()
653 .get(after_idx)
654 .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
655
656 if before_ok && after_ok {
657 return Some(abs_idx);
658 }
659 pos = abs_idx + 1;
660 }
661 None
662}
663
664fn parse_inserted_columns(output_str: &str) -> String {
666 let mut cols = Vec::new();
667
668 for part in output_str.split(',') {
669 let trimmed = part.trim();
670
671 if let Some(after_inserted) = trimmed
673 .strip_prefix("INSERTED.")
674 .or_else(|| trimmed.strip_prefix("inserted."))
675 .or_else(|| trimmed.strip_prefix("INSERTED"))
676 .or_else(|| trimmed.strip_prefix("inserted"))
677 {
678 let col_name = after_inserted.trim().to_string();
679 if !col_name.is_empty() {
680 cols.push(col_name);
681 }
682 }
683 }
684
685 cols.join(", ")
686}
687
688fn strip_returning_into(sql: &str) -> String {
690 let upper = sql.to_uppercase();
692 if let Some(ret_pos) = upper.rfind("RETURNING") {
693 let after_returning = &upper[ret_pos + "RETURNING".len()..];
694 if let Some(into_offset) = after_returning.find("INTO") {
695 let into_pos = ret_pos + "RETURNING".len() + into_offset;
696 let trimmed = sql[..into_pos].trim_end();
698 return trimmed.to_string();
699 }
700 }
701 sql.to_string()
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707 use crate::errors::ErrorCode;
708
709 fn parse(sql: &str) -> Result<Query, ScytheError> {
710 parse_query(sql)
711 }
712
713 #[test]
714 fn test_basic_parse() {
715 let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
716 let q = parse(input).unwrap();
717 assert_eq!(q.name, "GetUsers");
718 assert_eq!(q.command, QueryCommand::Many);
719 assert!(q.sql.contains("SELECT"));
720 }
721
722 #[test]
723 fn test_all_command_types() {
724 let cases = vec![
725 (":one", QueryCommand::One),
726 (":many", QueryCommand::Many),
727 (":exec", QueryCommand::Exec),
728 (":exec_result", QueryCommand::ExecResult),
729 (":exec_rows", QueryCommand::ExecRows),
730 ];
731 for (tag, expected) in cases {
732 let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
733 let q = parse(&input).unwrap();
734 assert_eq!(q.command, expected, "failed for {}", tag);
735 }
736 }
737
738 #[test]
739 fn test_case_insensitive_keywords() {
740 let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
741 let q = parse(input).unwrap();
742 assert_eq!(q.name, "GetUsers");
743 assert_eq!(q.command, QueryCommand::Many);
744 }
745
746 #[test]
747 fn test_missing_name_errors() {
748 let input = "-- @returns :many\nSELECT 1";
749 let err = parse(input).unwrap_err();
750 assert_eq!(err.code, ErrorCode::MissingAnnotation);
751 assert!(err.message.contains("name"));
752 }
753
754 #[test]
755 fn test_missing_returns_errors() {
756 let input = "-- @name Foo\nSELECT 1";
757 let err = parse(input).unwrap_err();
758 assert_eq!(err.code, ErrorCode::MissingAnnotation);
759 assert!(err.message.contains("returns"));
760 }
761
762 #[test]
763 fn test_invalid_returns_value() {
764 let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
765 let err = parse(input).unwrap_err();
766 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
767 }
768
769 #[test]
770 fn test_empty_name_value() {
771 let input = "-- @name\n-- @returns :one\nSELECT 1";
773 let q = parse(input).unwrap();
774 assert_eq!(q.name, "");
775 }
776
777 #[test]
778 fn test_param_annotation() {
779 let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
780 let q = parse(input).unwrap();
781 assert_eq!(q.annotations.param_docs.len(), 1);
782 assert_eq!(q.annotations.param_docs[0].name, "id");
783 assert_eq!(q.annotations.param_docs[0].description, "the user ID");
784 }
785
786 #[test]
787 fn test_param_no_description() {
788 let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
789 let q = parse(input).unwrap();
790 assert_eq!(q.annotations.param_docs.len(), 1);
791 assert_eq!(q.annotations.param_docs[0].name, "id");
792 assert_eq!(q.annotations.param_docs[0].description, "");
793 }
794
795 #[test]
796 fn test_nullable_annotation() {
797 let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
798 let q = parse(input).unwrap();
799 assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
800 }
801
802 #[test]
803 fn test_nonnull_annotation() {
804 let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
805 let q = parse(input).unwrap();
806 assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
807 }
808
809 #[test]
810 fn test_json_annotation() {
811 let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
812 let q = parse(input).unwrap();
813 assert_eq!(q.annotations.json_mappings.len(), 1);
814 assert_eq!(q.annotations.json_mappings[0].column, "data");
815 assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
816 }
817
818 #[test]
819 fn test_custom_annotations_captured() {
820 let input = "-- @name GetUser
823-- @returns :one
824-- @http GET /users/{id}
825-- @http_auth bearer:jwt
826-- @http_status 200,404
827SELECT id FROM users WHERE id = $1";
828 let q = parse(input).unwrap();
829 assert_eq!(q.annotations.custom.len(), 3);
830 assert_eq!(q.annotations.custom[0].name, "http");
831 assert_eq!(q.annotations.custom[0].value, "GET /users/{id}");
832 assert_eq!(q.annotations.custom[0].line, 3);
833 assert_eq!(q.annotations.custom[1].name, "http_auth");
834 assert_eq!(q.annotations.custom[1].value, "bearer:jwt");
835 assert_eq!(q.annotations.custom[1].line, 4);
836 assert_eq!(q.annotations.custom[2].name, "http_status");
837 assert_eq!(q.annotations.custom[2].value, "200,404");
838 assert_eq!(q.annotations.custom[2].line, 5);
839 }
840
841 #[test]
842 fn test_custom_annotation_without_value() {
843 let input = "-- @name GetUser
844-- @returns :one
845-- @http_internal
846SELECT 1";
847 let q = parse(input).unwrap();
848 assert_eq!(q.annotations.custom.len(), 1);
849 assert_eq!(q.annotations.custom[0].name, "http_internal");
850 assert_eq!(q.annotations.custom[0].value, "");
851 }
852
853 #[cfg(feature = "serde")]
854 #[test]
855 fn test_custom_annotation_serde_round_trip() {
856 let original = CustomAnnotation {
857 name: "http".to_string(),
858 value: "GET /users/{id}".to_string(),
859 line: 7,
860 };
861 let json = serde_json::to_string(&original).unwrap();
862 let back: CustomAnnotation = serde_json::from_str(&json).unwrap();
863 assert_eq!(back, original);
864 }
865
866 #[test]
867 fn test_custom_annotation_name_lowercased() {
868 let input = "-- @name GetUser
869-- @returns :one
870-- @HTTP_Auth Bearer
871SELECT 1";
872 let q = parse(input).unwrap();
873 assert_eq!(q.annotations.custom.len(), 1);
874 assert_eq!(q.annotations.custom[0].name, "http_auth");
875 assert_eq!(q.annotations.custom[0].value, "Bearer");
876 }
877
878 #[test]
879 fn test_deprecated_annotation() {
880 let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
881 let q = parse(input).unwrap();
882 assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
883 }
884
885 #[test]
886 fn test_sql_syntax_error() {
887 let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
888 let err = parse(input).unwrap_err();
889 assert_eq!(err.code, ErrorCode::SyntaxError);
890 }
891
892 #[test]
893 fn test_trailing_semicolon() {
894 let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
895 let q = parse(input).unwrap();
896 assert_eq!(q.name, "Foo");
897 }
898
899 #[test]
900 fn test_multiple_statements_error() {
901 let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
902 let err = parse(input).unwrap_err();
903 assert_eq!(err.code, ErrorCode::SyntaxError);
904 }
905
906 #[test]
907 fn test_sql_preserved_without_annotations() {
908 let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
909 let q = parse(input).unwrap();
910 assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
911 }
912
913 #[test]
914 fn test_returns_without_colon_prefix() {
915 let input = "-- @name Foo\n-- @returns many\nSELECT 1";
916 let q = parse(input).unwrap();
917 assert_eq!(q.command, QueryCommand::Many);
918 }
919
920 #[test]
921 fn test_batch_command() {
922 let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
923 let q = parse(input).unwrap();
924 assert_eq!(q.command, QueryCommand::Batch);
925 }
926
927 #[test]
928 fn test_grouped_command_with_group_by() {
929 let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
930 let q = parse(input).unwrap();
931 assert_eq!(q.command, QueryCommand::Grouped);
932 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
933 }
934
935 #[test]
936 fn test_grouped_command_without_group_by_errors() {
937 let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
938 let err = parse(input).unwrap_err();
939 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
940 assert!(err.message.contains("@group_by"));
941 }
942
943 #[test]
944 fn test_group_by_without_grouped_is_ignored() {
945 let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
946 let q = parse(input).unwrap();
947 assert_eq!(q.command, QueryCommand::Many);
948 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
949 }
950
951 #[test]
952 fn test_preprocess_postgres_strips_partial_index_where() {
953 let sql = "INSERT INTO billing_events (project_id, stripe_event_id) \
954 VALUES ($1, $2) \
955 ON CONFLICT (stripe_event_id) WHERE stripe_event_id IS NOT NULL DO NOTHING";
956 let cleaned = preprocess_postgres_sql(sql);
957 assert!(
958 !cleaned
959 .to_uppercase()
960 .contains("WHERE STRIPE_EVENT_ID IS NOT NULL"),
961 "WHERE clause must be stripped between ON CONFLICT cols and DO; got: {cleaned}"
962 );
963 assert!(
964 cleaned
965 .to_uppercase()
966 .contains("ON CONFLICT (STRIPE_EVENT_ID) DO NOTHING")
967 );
968 sqlparser::parser::Parser::parse_sql(&sqlparser::dialect::PostgreSqlDialect {}, &cleaned)
970 .expect("cleaned SQL should parse");
971 }
972
973 #[test]
974 fn test_preprocess_postgres_no_op_when_no_partial_clause() {
975 let sql = "INSERT INTO t (a) VALUES ($1) ON CONFLICT (a) DO UPDATE SET a = EXCLUDED.a";
976 assert_eq!(preprocess_postgres_sql(sql), sql);
977 }
978
979 #[test]
980 fn test_preprocess_postgres_leaves_on_conflict_on_constraint_alone() {
981 let sql = "INSERT INTO t (a) VALUES ($1) ON CONFLICT ON CONSTRAINT t_a_uidx DO NOTHING";
982 assert_eq!(preprocess_postgres_sql(sql), sql);
983 }
984
985 #[test]
986 fn test_preprocess_postgres_handles_compound_index_cols() {
987 let sql = "INSERT INTO t (a, b) VALUES ($1, $2) \
988 ON CONFLICT (a, b) WHERE a IS NOT NULL AND b > 0 DO UPDATE SET b = EXCLUDED.b";
989 let cleaned = preprocess_postgres_sql(sql);
990 assert!(
991 cleaned
992 .to_uppercase()
993 .contains("ON CONFLICT (A, B) DO UPDATE")
994 );
995 assert!(!cleaned.to_uppercase().contains("WHERE A IS NOT NULL"));
996 }
997
998 #[test]
999 fn test_preprocess_postgres_preserves_unrelated_where() {
1000 let sql = "DELETE FROM t WHERE id = $1";
1003 assert_eq!(preprocess_postgres_sql(sql), sql);
1004 }
1005
1006 #[test]
1007 fn test_preprocess_postgres_ignores_text_inside_line_comments() {
1008 let sql = "-- inline doc: `ON CONFLICT (col) WHERE …` is the partial form\n\
1012 INSERT INTO t (a) VALUES ($1) \
1013 ON CONFLICT (a) WHERE a IS NOT NULL DO NOTHING";
1014 let cleaned = preprocess_postgres_sql(sql);
1015 assert!(
1016 cleaned.contains("-- inline doc"),
1017 "comment must survive the pass; got: {cleaned}"
1018 );
1019 assert!(cleaned.contains("ON CONFLICT (a) DO NOTHING"));
1020 }
1021
1022 #[test]
1023 fn test_preprocess_postgres_ignores_text_inside_string_literals() {
1024 let sql = "SELECT 'ON CONFLICT (a) WHERE a IS NOT NULL DO NOTHING' AS s";
1025 assert_eq!(preprocess_postgres_sql(sql), sql);
1026 }
1027
1028 #[test]
1029 fn test_preprocess_oracle_colon_placeholders() {
1030 assert_eq!(
1031 preprocess_oracle_sql("SELECT * FROM users WHERE id = :1"),
1032 "SELECT * FROM users WHERE id = ?"
1033 );
1034 assert_eq!(
1035 preprocess_oracle_sql("INSERT INTO users (name, email) VALUES (:1, :2)"),
1036 "INSERT INTO users (name, email) VALUES (?, ?)"
1037 );
1038 }
1039
1040 #[test]
1041 fn test_preprocess_oracle_preserves_string_literals() {
1042 assert_eq!(
1043 preprocess_oracle_sql("SELECT * FROM users WHERE name = ':1' AND id = :1"),
1044 "SELECT * FROM users WHERE name = ':1' AND id = ?"
1045 );
1046 }
1047
1048 #[test]
1049 fn test_preprocess_oracle_strips_returning_into() {
1050 assert_eq!(
1051 preprocess_oracle_sql(
1052 "INSERT INTO users (name) VALUES (:1) RETURNING id, name INTO :2, :3"
1053 ),
1054 "INSERT INTO users (name) VALUES (?) RETURNING id, name"
1055 );
1056 }
1057
1058 #[test]
1059 fn test_preprocess_oracle_full_insert_returning_into() {
1060 let sql = "INSERT INTO users (name, email, active) VALUES (:1, :2, :3) RETURNING id, name, email, active, created_at INTO :4, :5, :6, :7, :8";
1061 let result = preprocess_oracle_sql(sql);
1062 assert_eq!(
1063 result,
1064 "INSERT INTO users (name, email, active) VALUES (?, ?, ?) RETURNING id, name, email, active, created_at"
1065 );
1066 }
1067
1068 #[test]
1069 fn test_preprocess_oracle_no_returning_into_unchanged() {
1070 assert_eq!(
1071 preprocess_oracle_sql("DELETE FROM users WHERE id = :1"),
1072 "DELETE FROM users WHERE id = ?"
1073 );
1074 }
1075
1076 #[test]
1077 fn test_preprocess_mssql_single_placeholder() {
1078 assert_eq!(
1079 preprocess_mssql_sql("SELECT * FROM users WHERE id = @p1"),
1080 "SELECT * FROM users WHERE id = ?"
1081 );
1082 }
1083
1084 #[test]
1085 fn test_preprocess_mssql_multiple_placeholders() {
1086 assert_eq!(
1087 preprocess_mssql_sql("INSERT INTO users (name, email) VALUES (@p1, @p2)"),
1088 "INSERT INTO users (name, email) VALUES (?, ?)"
1089 );
1090 }
1091
1092 #[test]
1093 fn test_preprocess_mssql_preserves_string_literals() {
1094 assert_eq!(
1095 preprocess_mssql_sql("SELECT * FROM users WHERE name = '@p1' AND id = @p1"),
1096 "SELECT * FROM users WHERE name = '@p1' AND id = ?"
1097 );
1098 }
1099
1100 #[test]
1101 fn test_preprocess_mssql_case_insensitive_p() {
1102 assert_eq!(
1103 preprocess_mssql_sql("SELECT * FROM users WHERE id = @P1"),
1104 "SELECT * FROM users WHERE id = ?"
1105 );
1106 }
1107
1108 #[test]
1109 fn test_preprocess_mssql_non_placeholder_at_variable_unchanged() {
1110 assert_eq!(preprocess_mssql_sql("SELECT @myvar"), "SELECT @myvar");
1112 }
1113
1114 #[test]
1115 fn test_preprocess_mssql_multi_digit_placeholder() {
1116 assert_eq!(preprocess_mssql_sql("SELECT @p10, @p2"), "SELECT ?, ?");
1117 }
1118
1119 #[test]
1120 fn test_preprocess_mssql_output_inserted_simple() {
1121 let sql =
1122 "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, @p2)";
1123 let result = preprocess_mssql_sql(sql);
1124 assert!(result.contains("RETURNING id, name"), "got: {}", result);
1126 assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
1127 assert!(!result.contains("OUTPUT"), "got: {}", result);
1128 }
1129
1130 #[test]
1131 fn test_preprocess_mssql_output_inserted_full_example() {
1132 let sql = "INSERT INTO users (id, name, email, active) OUTPUT INSERTED.id, INSERTED.name, INSERTED.email, INSERTED.active, INSERTED.created_at VALUES (@p1, @p2, @p3, @p4)";
1133 let result = preprocess_mssql_sql(sql);
1134 assert!(
1135 result.contains("RETURNING id, name, email, active, created_at"),
1136 "got: {}",
1137 result
1138 );
1139 assert!(result.contains("VALUES (?, ?, ?, ?)"), "got: {}", result);
1140 }
1141
1142 #[test]
1143 fn test_preprocess_mssql_output_case_insensitive() {
1144 let sql = "INSERT INTO users (id) output inserted.id values (@p1)";
1145 let result = preprocess_mssql_sql(sql);
1146 assert!(result.contains("RETURNING id"), "got: {}", result);
1147 assert!(
1149 result.contains("values (?)") || result.contains("VALUES (?)"),
1150 "got: {}",
1151 result
1152 );
1153 }
1154
1155 #[test]
1156 fn test_preprocess_mssql_no_output_unchanged() {
1157 let sql = "INSERT INTO users (id, name) VALUES (@p1, @p2)";
1158 let result = preprocess_mssql_sql(sql);
1159 assert_eq!(result, "INSERT INTO users (id, name) VALUES (?, ?)");
1160 }
1161
1162 #[test]
1163 fn test_preprocess_mssql_output_with_string_literal() {
1164 let sql =
1166 "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, '@p2')";
1167 let result = preprocess_mssql_sql(sql);
1168 assert!(result.contains("RETURNING id, name"), "got: {}", result);
1169 assert!(result.contains("(?, '@p2')"), "got: {}", result);
1170 }
1171
1172 #[test]
1173 fn test_preprocess_mssql_output_with_whitespace() {
1174 let sql =
1175 "INSERT INTO users (id, name)\nOUTPUT INSERTED.id,\n INSERTED.name\nVALUES (@p1, @p2)";
1176 let result = preprocess_mssql_sql(sql);
1177 assert!(result.contains("RETURNING id, name"), "got: {}", result);
1178 assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
1179 }
1180}