1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
8pub enum QueryCommand {
9 One,
10 Opt,
11 Many,
12 Exec,
13 ExecResult,
14 ExecRows,
15 Batch,
16 Grouped,
17}
18
19impl std::fmt::Display for QueryCommand {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 match self {
22 QueryCommand::One => write!(f, "one"),
23 QueryCommand::Opt => write!(f, "opt"),
24 QueryCommand::Many => write!(f, "many"),
25 QueryCommand::Exec => write!(f, "exec"),
26 QueryCommand::ExecResult => write!(f, "exec_result"),
27 QueryCommand::ExecRows => write!(f, "exec_rows"),
28 QueryCommand::Batch => write!(f, "batch"),
29 QueryCommand::Grouped => write!(f, "grouped"),
30 }
31 }
32}
33
34impl QueryCommand {
35 fn from_str(s: &str) -> Result<Self, ScytheError> {
36 match s {
37 "one" => Ok(QueryCommand::One),
38 "opt" => Ok(QueryCommand::Opt),
39 "many" => Ok(QueryCommand::Many),
40 "exec" => Ok(QueryCommand::Exec),
41 "exec_result" => Ok(QueryCommand::ExecResult),
42 "exec_rows" => Ok(QueryCommand::ExecRows),
43 "batch" => Ok(QueryCommand::Batch),
44 "grouped" => Ok(QueryCommand::Grouped),
45 other => Err(ScytheError::invalid_annotation(format!(
46 "invalid @returns value: {other}"
47 ))),
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54pub struct ParamDoc {
55 pub name: String,
56 pub description: String,
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61pub struct JsonMapping {
62 pub column: String,
63 pub rust_type: String,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
74#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
75pub struct CustomAnnotation {
76 pub name: String,
78 pub value: String,
80 pub line: usize,
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
85#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
86pub struct Annotations {
87 pub name: String,
88 pub command: QueryCommand,
89 pub param_docs: Vec<ParamDoc>,
90 pub nullable_overrides: Vec<String>,
91 pub nonnull_overrides: Vec<String>,
92 pub json_mappings: Vec<JsonMapping>,
93 pub deprecated: Option<String>,
94 pub optional_params: Vec<String>,
95 pub group_by: Option<String>,
96 pub custom: Vec<CustomAnnotation>,
99}
100
101#[derive(Debug)]
102pub struct Query {
103 pub name: String,
104 pub command: QueryCommand,
105 pub sql: String,
106 pub stmt: sqlparser::ast::Statement,
107 pub annotations: Annotations,
108}
109
110pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
112 parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
113}
114
115pub fn parse_query_with_dialect(
117 query_sql: &str,
118 dialect: &SqlDialect,
119) -> Result<Query, ScytheError> {
120 let mut name: Option<String> = None;
121 let mut command: Option<QueryCommand> = None;
122 let mut param_docs = Vec::new();
123 let mut nullable_overrides = Vec::new();
124 let mut nonnull_overrides = Vec::new();
125 let mut json_mappings = Vec::new();
126 let mut deprecated: Option<String> = None;
127 let mut optional_params = Vec::new();
128 let mut group_by: Option<String> = None;
129 let mut custom: Vec<CustomAnnotation> = Vec::new();
130
131 let mut sql_lines = Vec::new();
132
133 for (line_idx, line) in query_sql.lines().enumerate() {
134 let line_no = line_idx + 1;
135 let trimmed = line.trim();
136
137 let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
139 let rest = rest.trim_start();
140 rest.strip_prefix('@')
141 } else {
142 None
143 };
144
145 if let Some(body) = annotation_body {
146 let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
148 Some(pos) => (&body[..pos], body[pos..].trim()),
149 None => (body, ""),
150 };
151
152 match keyword.to_ascii_lowercase().as_str() {
153 "name" => {
154 name = Some(value.to_string());
155 }
156 "returns" => {
157 let cmd_str = value.strip_prefix(':').unwrap_or(value);
158 command = Some(QueryCommand::from_str(cmd_str)?);
159 }
160 "param" => {
161 if let Some(colon_pos) = value.find(':') {
163 let param_name = value[..colon_pos].trim().to_string();
164 let description = value[colon_pos + 1..].trim().to_string();
165 param_docs.push(ParamDoc {
166 name: param_name,
167 description,
168 });
169 } else {
170 param_docs.push(ParamDoc {
171 name: value.to_string(),
172 description: String::new(),
173 });
174 }
175 }
176 "nullable" => {
177 for col in value.split(',') {
178 let col = col.trim();
179 if !col.is_empty() {
180 nullable_overrides.push(col.to_string());
181 }
182 }
183 }
184 "nonnull" => {
185 for col in value.split(',') {
186 let col = col.trim();
187 if !col.is_empty() {
188 nonnull_overrides.push(col.to_string());
189 }
190 }
191 }
192 "json" => {
193 if let Some(eq_pos) = value.find('=') {
195 let column = value[..eq_pos].trim().to_string();
196 let rust_type = value[eq_pos + 1..].trim().to_string();
197 json_mappings.push(JsonMapping { column, rust_type });
198 }
199 }
200 "deprecated" => {
201 deprecated = Some(value.to_string());
202 }
203 "group_by" => {
204 group_by = Some(value.to_string());
205 }
206 "optional" => {
207 for param in value.split(',') {
208 let param = param.trim();
209 if !param.is_empty() {
210 optional_params.push(param.to_string());
211 }
212 }
213 }
214 other => {
215 custom.push(CustomAnnotation {
217 name: other.to_string(),
218 value: value.to_string(),
219 line: line_no,
220 });
221 }
222 }
223 } else {
224 sql_lines.push(line);
225 }
226 }
227
228 let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
229 let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
230
231 if command == QueryCommand::Grouped && group_by.is_none() {
232 return Err(ScytheError::invalid_annotation(
233 "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
234 ));
235 }
236
237 let sql = sql_lines.join("\n").trim().to_string();
238
239 if sql.is_empty() {
240 return Err(ScytheError::syntax("empty SQL body"));
241 }
242
243 let (sql, parse_sql) = if *dialect == SqlDialect::Oracle {
251 let processed = preprocess_oracle_sql(&sql);
252 (processed.clone(), processed)
253 } else if *dialect == SqlDialect::MsSql {
254 let codegen_sql = convert_mssql_placeholders(&sql);
256 let parse_sql = preprocess_mssql_sql(&sql);
258 (codegen_sql, parse_sql)
259 } else if *dialect == SqlDialect::PostgreSQL {
260 let parse_sql = preprocess_postgres_sql(&sql);
261 (sql.clone(), parse_sql)
262 } else {
263 (sql.clone(), sql)
264 };
265
266 let parser_dialect = dialect.to_sqlparser_dialect();
267 let statements = Parser::parse_sql(parser_dialect.as_ref(), &parse_sql)
268 .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
269
270 if statements.len() != 1 {
271 let non_empty: Vec<_> = statements
274 .into_iter()
275 .filter(|s| {
276 !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
277 })
278 .collect();
279 if non_empty.len() != 1 {
280 return Err(ScytheError::syntax("expected exactly one SQL statement"));
281 }
282 let stmt = non_empty
283 .into_iter()
284 .next()
285 .expect("filtered to exactly one statement");
286 let annotations = Annotations {
287 name: name.clone(),
288 command: command.clone(),
289 param_docs,
290 nullable_overrides,
291 nonnull_overrides,
292 json_mappings,
293 deprecated,
294 optional_params,
295 group_by: group_by.clone(),
296 custom,
297 };
298 return Ok(Query {
299 name,
300 command,
301 sql,
302 stmt,
303 annotations,
304 });
305 }
306
307 let stmt = statements
308 .into_iter()
309 .next()
310 .expect("filtered to exactly one statement");
311
312 let annotations = Annotations {
313 name: name.clone(),
314 command: command.clone(),
315 param_docs,
316 nullable_overrides,
317 nonnull_overrides,
318 json_mappings,
319 deprecated,
320 optional_params,
321 group_by,
322 custom,
323 };
324
325 Ok(Query {
326 name,
327 command,
328 sql,
329 stmt,
330 annotations,
331 })
332}
333
334fn preprocess_postgres_sql(sql: &str) -> String {
340 let mask = mask_postgres_for_scan(sql);
344 let mask_bytes = mask.as_bytes();
345 let bytes = sql.as_bytes();
346 let mut search_from = 0;
347 let mut result = String::with_capacity(sql.len());
348 let mut last = 0;
349 while let Some(rel) = find_keyword(&mask[search_from..], "ON CONFLICT") {
350 let on_conflict_pos = search_from + rel;
351 let after_on_conflict = on_conflict_pos + "ON CONFLICT".len();
352 let mut idx = after_on_conflict;
353 while idx < mask_bytes.len() && mask_bytes[idx].is_ascii_whitespace() {
354 idx += 1;
355 }
356 if idx >= mask_bytes.len() || mask_bytes[idx] != b'(' {
357 search_from = after_on_conflict;
358 continue;
359 }
360 let mut depth = 0i32;
361 let mut close = idx;
362 while close < mask_bytes.len() {
363 match mask_bytes[close] {
364 b'(' => depth += 1,
365 b')' => {
366 depth -= 1;
367 if depth == 0 {
368 break;
369 }
370 }
371 _ => {}
372 }
373 close += 1;
374 }
375 if depth != 0 {
376 return sql.to_string();
377 }
378 let mut after_cols = close + 1;
379 while after_cols < mask_bytes.len() && mask_bytes[after_cols].is_ascii_whitespace() {
380 after_cols += 1;
381 }
382 if mask[after_cols..].starts_with("WHERE")
383 && let Some(do_rel) = find_keyword(&mask[after_cols + "WHERE".len()..], "DO")
384 {
385 let do_abs = after_cols + "WHERE".len() + do_rel;
386 result.push_str(std::str::from_utf8(&bytes[last..after_cols]).unwrap_or(""));
389 last = do_abs;
390 search_from = do_abs;
391 continue;
392 }
393 search_from = close + 1;
394 }
395 result.push_str(std::str::from_utf8(&bytes[last..]).unwrap_or(""));
396 result
397}
398
399fn mask_postgres_for_scan(sql: &str) -> String {
404 let bytes = sql.as_bytes();
405 let mut out = vec![b' '; bytes.len()];
406 let mut i = 0;
407 while i < bytes.len() {
408 let b = bytes[i];
409 if b == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
410 while i < bytes.len() && bytes[i] != b'\n' {
412 out[i] = b' ';
413 i += 1;
414 }
415 continue;
416 }
417 if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
418 out[i] = b' ';
420 out[i + 1] = b' ';
421 i += 2;
422 while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
423 out[i] = b' ';
424 i += 1;
425 }
426 if i + 1 < bytes.len() {
427 out[i] = b' ';
428 out[i + 1] = b' ';
429 i += 2;
430 }
431 continue;
432 }
433 if b == b'\'' {
434 out[i] = b' ';
435 i += 1;
436 while i < bytes.len() {
437 if bytes[i] == b'\'' {
438 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
439 out[i] = b' ';
440 out[i + 1] = b' ';
441 i += 2;
442 continue;
443 }
444 out[i] = b' ';
445 i += 1;
446 break;
447 }
448 out[i] = b' ';
449 i += 1;
450 }
451 continue;
452 }
453 if b.is_ascii() {
456 out[i] = b.to_ascii_uppercase();
457 } else {
458 out[i] = b' ';
459 }
460 i += 1;
461 }
462 String::from_utf8(out).expect("mask is ASCII by construction")
463}
464
465fn find_keyword(haystack: &str, keyword: &str) -> Option<usize> {
468 let bytes = haystack.as_bytes();
469 let key = keyword.as_bytes();
470 let mut i = 0;
471 while i + key.len() <= bytes.len() {
472 if &bytes[i..i + key.len()] == key {
473 let prev_ok = i == 0 || !bytes[i - 1].is_ascii_alphanumeric();
474 let next = i + key.len();
475 let next_ok = next >= bytes.len() || !bytes[next].is_ascii_alphanumeric();
476 if prev_ok && next_ok {
477 return Some(i);
478 }
479 }
480 i += 1;
481 }
482 None
483}
484
485fn preprocess_oracle_sql(sql: &str) -> String {
489 let sql = strip_returning_into(sql);
492
493 let mut result = String::with_capacity(sql.len());
495 let mut chars = sql.chars().peekable();
496 while let Some(ch) = chars.next() {
497 if ch == '\'' {
498 result.push(ch);
500 while let Some(inner) = chars.next() {
501 result.push(inner);
502 if inner == '\'' {
503 if chars.peek() == Some(&'\'') {
504 result.push(chars.next().unwrap());
505 } else {
506 break;
507 }
508 }
509 }
510 } else if ch == ':' && chars.peek().is_some_and(|c| c.is_ascii_digit()) {
511 result.push('?');
513 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
514 chars.next();
515 }
516 } else {
517 result.push(ch);
518 }
519 }
520 result
521}
522
523fn convert_mssql_placeholders(sql: &str) -> String {
527 let mut result = String::with_capacity(sql.len());
528 let mut chars = sql.chars().peekable();
529 while let Some(ch) = chars.next() {
530 if ch == '\'' {
531 result.push(ch);
533 while let Some(inner) = chars.next() {
534 result.push(inner);
535 if inner == '\'' {
536 if chars.peek() == Some(&'\'') {
537 result.push(chars.next().unwrap());
539 } else {
540 break;
541 }
542 }
543 }
544 } else if ch == '@' && chars.peek().is_some_and(|c| *c == 'p' || *c == 'P') {
545 let mut lookahead = chars.clone();
547 lookahead.next(); if lookahead.peek().is_some_and(|c| c.is_ascii_digit()) {
549 chars.next(); while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
552 chars.next();
553 }
554 result.push('?');
555 } else {
556 result.push(ch);
557 }
558 } else {
559 result.push(ch);
560 }
561 }
562 result
563}
564
565fn preprocess_mssql_sql(sql: &str) -> String {
569 let sql = strip_and_convert_mssql_output(sql);
571 convert_mssql_placeholders(&sql)
573}
574
575fn strip_and_convert_mssql_output(sql: &str) -> String {
582 let upper = sql.to_uppercase();
584
585 if !upper.contains("INSERT") || !upper.contains("OUTPUT") {
587 return sql.to_string();
588 }
589
590 if let Some(output_pos) = find_word_position(&upper, "OUTPUT") {
592 let before_output = &upper[..output_pos];
594 if !before_output.contains("INSERT") {
595 return sql.to_string();
596 }
597
598 let after_output = &upper[output_pos + "OUTPUT".len()..];
600 if let Some(values_offset) = find_word_position(after_output, "VALUES") {
601 let values_pos = output_pos + "OUTPUT".len() + values_offset;
602
603 let output_cols_str = &sql[output_pos + "OUTPUT".len()..values_pos];
605
606 let cols = parse_inserted_columns(output_cols_str);
608
609 if !cols.is_empty() {
610 let before_output_sql = sql[..output_pos].trim_end();
613 let after_values = sql[values_pos..].trim_end();
614 let (values_body, trailing) = if let Some(stripped) = after_values.strip_suffix(';')
615 {
616 (stripped, ";")
617 } else {
618 (after_values, "")
619 };
620
621 return format!(
622 "{}\n{} RETURNING {}{}",
623 before_output_sql, values_body, cols, trailing
624 );
625 }
626 }
627 }
628
629 sql.to_string()
630}
631
632fn find_word_position(text: &str, word: &str) -> Option<usize> {
635 let mut pos = 0;
636 let word_len = word.len();
637 while let Some(idx) = text[pos..].find(word) {
638 let abs_idx = pos + idx;
639
640 let before_ok = abs_idx == 0
642 || !text
643 .as_bytes()
644 .get(abs_idx - 1)
645 .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
646
647 let after_idx = abs_idx + word_len;
649 let after_ok = after_idx >= text.len()
650 || !text
651 .as_bytes()
652 .get(after_idx)
653 .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
654
655 if before_ok && after_ok {
656 return Some(abs_idx);
657 }
658 pos = abs_idx + 1;
659 }
660 None
661}
662
663fn parse_inserted_columns(output_str: &str) -> String {
665 let mut cols = Vec::new();
666
667 for part in output_str.split(',') {
668 let trimmed = part.trim();
669
670 if let Some(after_inserted) = trimmed
672 .strip_prefix("INSERTED.")
673 .or_else(|| trimmed.strip_prefix("inserted."))
674 .or_else(|| trimmed.strip_prefix("INSERTED"))
675 .or_else(|| trimmed.strip_prefix("inserted"))
676 {
677 let col_name = after_inserted.trim().to_string();
678 if !col_name.is_empty() {
679 cols.push(col_name);
680 }
681 }
682 }
683
684 cols.join(", ")
685}
686
687fn strip_returning_into(sql: &str) -> String {
689 let upper = sql.to_uppercase();
691 if let Some(ret_pos) = upper.rfind("RETURNING") {
692 let after_returning = &upper[ret_pos + "RETURNING".len()..];
693 if let Some(into_offset) = after_returning.find("INTO") {
694 let into_pos = ret_pos + "RETURNING".len() + into_offset;
695 let trimmed = sql[..into_pos].trim_end();
697 return trimmed.to_string();
698 }
699 }
700 sql.to_string()
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706 use crate::errors::ErrorCode;
707
708 fn parse(sql: &str) -> Result<Query, ScytheError> {
709 parse_query(sql)
710 }
711
712 #[test]
713 fn test_basic_parse() {
714 let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
715 let q = parse(input).unwrap();
716 assert_eq!(q.name, "GetUsers");
717 assert_eq!(q.command, QueryCommand::Many);
718 assert!(q.sql.contains("SELECT"));
719 }
720
721 #[test]
722 fn test_all_command_types() {
723 let cases = vec![
724 (":one", QueryCommand::One),
725 (":many", QueryCommand::Many),
726 (":exec", QueryCommand::Exec),
727 (":exec_result", QueryCommand::ExecResult),
728 (":exec_rows", QueryCommand::ExecRows),
729 ];
730 for (tag, expected) in cases {
731 let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
732 let q = parse(&input).unwrap();
733 assert_eq!(q.command, expected, "failed for {}", tag);
734 }
735 }
736
737 #[test]
738 fn test_case_insensitive_keywords() {
739 let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
740 let q = parse(input).unwrap();
741 assert_eq!(q.name, "GetUsers");
742 assert_eq!(q.command, QueryCommand::Many);
743 }
744
745 #[test]
746 fn test_missing_name_errors() {
747 let input = "-- @returns :many\nSELECT 1";
748 let err = parse(input).unwrap_err();
749 assert_eq!(err.code, ErrorCode::MissingAnnotation);
750 assert!(err.message.contains("name"));
751 }
752
753 #[test]
754 fn test_missing_returns_errors() {
755 let input = "-- @name Foo\nSELECT 1";
756 let err = parse(input).unwrap_err();
757 assert_eq!(err.code, ErrorCode::MissingAnnotation);
758 assert!(err.message.contains("returns"));
759 }
760
761 #[test]
762 fn test_invalid_returns_value() {
763 let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
764 let err = parse(input).unwrap_err();
765 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
766 }
767
768 #[test]
769 fn test_empty_name_value() {
770 let input = "-- @name\n-- @returns :one\nSELECT 1";
772 let q = parse(input).unwrap();
773 assert_eq!(q.name, "");
774 }
775
776 #[test]
777 fn test_param_annotation() {
778 let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
779 let q = parse(input).unwrap();
780 assert_eq!(q.annotations.param_docs.len(), 1);
781 assert_eq!(q.annotations.param_docs[0].name, "id");
782 assert_eq!(q.annotations.param_docs[0].description, "the user ID");
783 }
784
785 #[test]
786 fn test_param_no_description() {
787 let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
788 let q = parse(input).unwrap();
789 assert_eq!(q.annotations.param_docs.len(), 1);
790 assert_eq!(q.annotations.param_docs[0].name, "id");
791 assert_eq!(q.annotations.param_docs[0].description, "");
792 }
793
794 #[test]
795 fn test_nullable_annotation() {
796 let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
797 let q = parse(input).unwrap();
798 assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
799 }
800
801 #[test]
802 fn test_nonnull_annotation() {
803 let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
804 let q = parse(input).unwrap();
805 assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
806 }
807
808 #[test]
809 fn test_json_annotation() {
810 let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
811 let q = parse(input).unwrap();
812 assert_eq!(q.annotations.json_mappings.len(), 1);
813 assert_eq!(q.annotations.json_mappings[0].column, "data");
814 assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
815 }
816
817 #[test]
818 fn test_custom_annotations_captured() {
819 let input = "-- @name GetUser
822-- @returns :one
823-- @http GET /users/{id}
824-- @http_auth bearer:jwt
825-- @http_status 200,404
826SELECT id FROM users WHERE id = $1";
827 let q = parse(input).unwrap();
828 assert_eq!(q.annotations.custom.len(), 3);
829 assert_eq!(q.annotations.custom[0].name, "http");
830 assert_eq!(q.annotations.custom[0].value, "GET /users/{id}");
831 assert_eq!(q.annotations.custom[0].line, 3);
832 assert_eq!(q.annotations.custom[1].name, "http_auth");
833 assert_eq!(q.annotations.custom[1].value, "bearer:jwt");
834 assert_eq!(q.annotations.custom[1].line, 4);
835 assert_eq!(q.annotations.custom[2].name, "http_status");
836 assert_eq!(q.annotations.custom[2].value, "200,404");
837 assert_eq!(q.annotations.custom[2].line, 5);
838 }
839
840 #[test]
841 fn test_custom_annotation_without_value() {
842 let input = "-- @name GetUser
843-- @returns :one
844-- @http_internal
845SELECT 1";
846 let q = parse(input).unwrap();
847 assert_eq!(q.annotations.custom.len(), 1);
848 assert_eq!(q.annotations.custom[0].name, "http_internal");
849 assert_eq!(q.annotations.custom[0].value, "");
850 }
851
852 #[cfg(feature = "serde")]
853 #[test]
854 fn test_custom_annotation_serde_round_trip() {
855 let original = CustomAnnotation {
856 name: "http".to_string(),
857 value: "GET /users/{id}".to_string(),
858 line: 7,
859 };
860 let json = serde_json::to_string(&original).unwrap();
861 let back: CustomAnnotation = serde_json::from_str(&json).unwrap();
862 assert_eq!(back, original);
863 }
864
865 #[test]
866 fn test_custom_annotation_name_lowercased() {
867 let input = "-- @name GetUser
868-- @returns :one
869-- @HTTP_Auth Bearer
870SELECT 1";
871 let q = parse(input).unwrap();
872 assert_eq!(q.annotations.custom.len(), 1);
873 assert_eq!(q.annotations.custom[0].name, "http_auth");
874 assert_eq!(q.annotations.custom[0].value, "Bearer");
875 }
876
877 #[test]
878 fn test_deprecated_annotation() {
879 let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
880 let q = parse(input).unwrap();
881 assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
882 }
883
884 #[test]
885 fn test_sql_syntax_error() {
886 let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
887 let err = parse(input).unwrap_err();
888 assert_eq!(err.code, ErrorCode::SyntaxError);
889 }
890
891 #[test]
892 fn test_trailing_semicolon() {
893 let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
894 let q = parse(input).unwrap();
895 assert_eq!(q.name, "Foo");
896 }
897
898 #[test]
899 fn test_multiple_statements_error() {
900 let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
901 let err = parse(input).unwrap_err();
902 assert_eq!(err.code, ErrorCode::SyntaxError);
903 }
904
905 #[test]
906 fn test_sql_preserved_without_annotations() {
907 let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
908 let q = parse(input).unwrap();
909 assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
910 }
911
912 #[test]
913 fn test_returns_without_colon_prefix() {
914 let input = "-- @name Foo\n-- @returns many\nSELECT 1";
915 let q = parse(input).unwrap();
916 assert_eq!(q.command, QueryCommand::Many);
917 }
918
919 #[test]
920 fn test_batch_command() {
921 let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
922 let q = parse(input).unwrap();
923 assert_eq!(q.command, QueryCommand::Batch);
924 }
925
926 #[test]
927 fn test_grouped_command_with_group_by() {
928 let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
929 let q = parse(input).unwrap();
930 assert_eq!(q.command, QueryCommand::Grouped);
931 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
932 }
933
934 #[test]
935 fn test_grouped_command_without_group_by_errors() {
936 let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
937 let err = parse(input).unwrap_err();
938 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
939 assert!(err.message.contains("@group_by"));
940 }
941
942 #[test]
943 fn test_group_by_without_grouped_is_ignored() {
944 let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
945 let q = parse(input).unwrap();
946 assert_eq!(q.command, QueryCommand::Many);
947 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
948 }
949
950 #[test]
951 fn test_preprocess_postgres_strips_partial_index_where() {
952 let sql = "INSERT INTO billing_events (project_id, stripe_event_id) \
953 VALUES ($1, $2) \
954 ON CONFLICT (stripe_event_id) WHERE stripe_event_id IS NOT NULL DO NOTHING";
955 let cleaned = preprocess_postgres_sql(sql);
956 assert!(
957 !cleaned
958 .to_uppercase()
959 .contains("WHERE STRIPE_EVENT_ID IS NOT NULL"),
960 "WHERE clause must be stripped between ON CONFLICT cols and DO; got: {cleaned}"
961 );
962 assert!(
963 cleaned
964 .to_uppercase()
965 .contains("ON CONFLICT (STRIPE_EVENT_ID) DO NOTHING")
966 );
967 sqlparser::parser::Parser::parse_sql(&sqlparser::dialect::PostgreSqlDialect {}, &cleaned)
969 .expect("cleaned SQL should parse");
970 }
971
972 #[test]
973 fn test_preprocess_postgres_no_op_when_no_partial_clause() {
974 let sql = "INSERT INTO t (a) VALUES ($1) ON CONFLICT (a) DO UPDATE SET a = EXCLUDED.a";
975 assert_eq!(preprocess_postgres_sql(sql), sql);
976 }
977
978 #[test]
979 fn test_preprocess_postgres_leaves_on_conflict_on_constraint_alone() {
980 let sql = "INSERT INTO t (a) VALUES ($1) ON CONFLICT ON CONSTRAINT t_a_uidx DO NOTHING";
981 assert_eq!(preprocess_postgres_sql(sql), sql);
982 }
983
984 #[test]
985 fn test_preprocess_postgres_handles_compound_index_cols() {
986 let sql = "INSERT INTO t (a, b) VALUES ($1, $2) \
987 ON CONFLICT (a, b) WHERE a IS NOT NULL AND b > 0 DO UPDATE SET b = EXCLUDED.b";
988 let cleaned = preprocess_postgres_sql(sql);
989 assert!(
990 cleaned
991 .to_uppercase()
992 .contains("ON CONFLICT (A, B) DO UPDATE")
993 );
994 assert!(!cleaned.to_uppercase().contains("WHERE A IS NOT NULL"));
995 }
996
997 #[test]
998 fn test_preprocess_postgres_preserves_unrelated_where() {
999 let sql = "DELETE FROM t WHERE id = $1";
1002 assert_eq!(preprocess_postgres_sql(sql), sql);
1003 }
1004
1005 #[test]
1006 fn test_preprocess_postgres_ignores_text_inside_line_comments() {
1007 let sql = "-- inline doc: `ON CONFLICT (col) WHERE …` is the partial form\n\
1011 INSERT INTO t (a) VALUES ($1) \
1012 ON CONFLICT (a) WHERE a IS NOT NULL DO NOTHING";
1013 let cleaned = preprocess_postgres_sql(sql);
1014 assert!(
1015 cleaned.contains("-- inline doc"),
1016 "comment must survive the pass; got: {cleaned}"
1017 );
1018 assert!(cleaned.contains("ON CONFLICT (a) DO NOTHING"));
1019 }
1020
1021 #[test]
1022 fn test_preprocess_postgres_ignores_text_inside_string_literals() {
1023 let sql = "SELECT 'ON CONFLICT (a) WHERE a IS NOT NULL DO NOTHING' AS s";
1024 assert_eq!(preprocess_postgres_sql(sql), sql);
1025 }
1026
1027 #[test]
1028 fn test_preprocess_oracle_colon_placeholders() {
1029 assert_eq!(
1030 preprocess_oracle_sql("SELECT * FROM users WHERE id = :1"),
1031 "SELECT * FROM users WHERE id = ?"
1032 );
1033 assert_eq!(
1034 preprocess_oracle_sql("INSERT INTO users (name, email) VALUES (:1, :2)"),
1035 "INSERT INTO users (name, email) VALUES (?, ?)"
1036 );
1037 }
1038
1039 #[test]
1040 fn test_preprocess_oracle_preserves_string_literals() {
1041 assert_eq!(
1042 preprocess_oracle_sql("SELECT * FROM users WHERE name = ':1' AND id = :1"),
1043 "SELECT * FROM users WHERE name = ':1' AND id = ?"
1044 );
1045 }
1046
1047 #[test]
1048 fn test_preprocess_oracle_strips_returning_into() {
1049 assert_eq!(
1050 preprocess_oracle_sql(
1051 "INSERT INTO users (name) VALUES (:1) RETURNING id, name INTO :2, :3"
1052 ),
1053 "INSERT INTO users (name) VALUES (?) RETURNING id, name"
1054 );
1055 }
1056
1057 #[test]
1058 fn test_preprocess_oracle_full_insert_returning_into() {
1059 let sql = "INSERT INTO users (name, email, active) VALUES (:1, :2, :3) RETURNING id, name, email, active, created_at INTO :4, :5, :6, :7, :8";
1060 let result = preprocess_oracle_sql(sql);
1061 assert_eq!(
1062 result,
1063 "INSERT INTO users (name, email, active) VALUES (?, ?, ?) RETURNING id, name, email, active, created_at"
1064 );
1065 }
1066
1067 #[test]
1068 fn test_preprocess_oracle_no_returning_into_unchanged() {
1069 assert_eq!(
1070 preprocess_oracle_sql("DELETE FROM users WHERE id = :1"),
1071 "DELETE FROM users WHERE id = ?"
1072 );
1073 }
1074
1075 #[test]
1076 fn test_preprocess_mssql_single_placeholder() {
1077 assert_eq!(
1078 preprocess_mssql_sql("SELECT * FROM users WHERE id = @p1"),
1079 "SELECT * FROM users WHERE id = ?"
1080 );
1081 }
1082
1083 #[test]
1084 fn test_preprocess_mssql_multiple_placeholders() {
1085 assert_eq!(
1086 preprocess_mssql_sql("INSERT INTO users (name, email) VALUES (@p1, @p2)"),
1087 "INSERT INTO users (name, email) VALUES (?, ?)"
1088 );
1089 }
1090
1091 #[test]
1092 fn test_preprocess_mssql_preserves_string_literals() {
1093 assert_eq!(
1094 preprocess_mssql_sql("SELECT * FROM users WHERE name = '@p1' AND id = @p1"),
1095 "SELECT * FROM users WHERE name = '@p1' AND id = ?"
1096 );
1097 }
1098
1099 #[test]
1100 fn test_preprocess_mssql_case_insensitive_p() {
1101 assert_eq!(
1102 preprocess_mssql_sql("SELECT * FROM users WHERE id = @P1"),
1103 "SELECT * FROM users WHERE id = ?"
1104 );
1105 }
1106
1107 #[test]
1108 fn test_preprocess_mssql_non_placeholder_at_variable_unchanged() {
1109 assert_eq!(preprocess_mssql_sql("SELECT @myvar"), "SELECT @myvar");
1111 }
1112
1113 #[test]
1114 fn test_preprocess_mssql_multi_digit_placeholder() {
1115 assert_eq!(preprocess_mssql_sql("SELECT @p10, @p2"), "SELECT ?, ?");
1116 }
1117
1118 #[test]
1119 fn test_preprocess_mssql_output_inserted_simple() {
1120 let sql =
1121 "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, @p2)";
1122 let result = preprocess_mssql_sql(sql);
1123 assert!(result.contains("RETURNING id, name"), "got: {}", result);
1125 assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
1126 assert!(!result.contains("OUTPUT"), "got: {}", result);
1127 }
1128
1129 #[test]
1130 fn test_preprocess_mssql_output_inserted_full_example() {
1131 let sql = "INSERT INTO users (id, name, email, active) OUTPUT INSERTED.id, INSERTED.name, INSERTED.email, INSERTED.active, INSERTED.created_at VALUES (@p1, @p2, @p3, @p4)";
1132 let result = preprocess_mssql_sql(sql);
1133 assert!(
1134 result.contains("RETURNING id, name, email, active, created_at"),
1135 "got: {}",
1136 result
1137 );
1138 assert!(result.contains("VALUES (?, ?, ?, ?)"), "got: {}", result);
1139 }
1140
1141 #[test]
1142 fn test_preprocess_mssql_output_case_insensitive() {
1143 let sql = "INSERT INTO users (id) output inserted.id values (@p1)";
1144 let result = preprocess_mssql_sql(sql);
1145 assert!(result.contains("RETURNING id"), "got: {}", result);
1146 assert!(
1148 result.contains("values (?)") || result.contains("VALUES (?)"),
1149 "got: {}",
1150 result
1151 );
1152 }
1153
1154 #[test]
1155 fn test_preprocess_mssql_no_output_unchanged() {
1156 let sql = "INSERT INTO users (id, name) VALUES (@p1, @p2)";
1157 let result = preprocess_mssql_sql(sql);
1158 assert_eq!(result, "INSERT INTO users (id, name) VALUES (?, ?)");
1159 }
1160
1161 #[test]
1162 fn test_preprocess_mssql_output_with_string_literal() {
1163 let sql =
1165 "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, '@p2')";
1166 let result = preprocess_mssql_sql(sql);
1167 assert!(result.contains("RETURNING id, name"), "got: {}", result);
1168 assert!(result.contains("(?, '@p2')"), "got: {}", result);
1169 }
1170
1171 #[test]
1172 fn test_preprocess_mssql_output_with_whitespace() {
1173 let sql =
1174 "INSERT INTO users (id, name)\nOUTPUT INSERTED.id,\n INSERTED.name\nVALUES (@p1, @p2)";
1175 let result = preprocess_mssql_sql(sql);
1176 assert!(result.contains("RETURNING id, name"), "got: {}", result);
1177 assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
1178 }
1179}