1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum QueryCommand {
8 One,
9 Opt,
10 Many,
11 Exec,
12 ExecResult,
13 ExecRows,
14 Batch,
15 Grouped,
16}
17
18impl std::fmt::Display for QueryCommand {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 QueryCommand::One => write!(f, "one"),
22 QueryCommand::Opt => write!(f, "opt"),
23 QueryCommand::Many => write!(f, "many"),
24 QueryCommand::Exec => write!(f, "exec"),
25 QueryCommand::ExecResult => write!(f, "exec_result"),
26 QueryCommand::ExecRows => write!(f, "exec_rows"),
27 QueryCommand::Batch => write!(f, "batch"),
28 QueryCommand::Grouped => write!(f, "grouped"),
29 }
30 }
31}
32
33impl QueryCommand {
34 fn from_str(s: &str) -> Result<Self, ScytheError> {
35 match s {
36 "one" => Ok(QueryCommand::One),
37 "opt" => Ok(QueryCommand::Opt),
38 "many" => Ok(QueryCommand::Many),
39 "exec" => Ok(QueryCommand::Exec),
40 "exec_result" => Ok(QueryCommand::ExecResult),
41 "exec_rows" => Ok(QueryCommand::ExecRows),
42 "batch" => Ok(QueryCommand::Batch),
43 "grouped" => Ok(QueryCommand::Grouped),
44 other => Err(ScytheError::invalid_annotation(format!(
45 "invalid @returns value: {other}"
46 ))),
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct ParamDoc {
53 pub name: String,
54 pub description: String,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct JsonMapping {
59 pub column: String,
60 pub rust_type: String,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct Annotations {
65 pub name: String,
66 pub command: QueryCommand,
67 pub param_docs: Vec<ParamDoc>,
68 pub nullable_overrides: Vec<String>,
69 pub nonnull_overrides: Vec<String>,
70 pub json_mappings: Vec<JsonMapping>,
71 pub deprecated: Option<String>,
72 pub optional_params: Vec<String>,
73 pub group_by: Option<String>,
74}
75
76#[derive(Debug)]
77pub struct Query {
78 pub name: String,
79 pub command: QueryCommand,
80 pub sql: String,
81 pub stmt: sqlparser::ast::Statement,
82 pub annotations: Annotations,
83}
84
85pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
87 parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
88}
89
90pub fn parse_query_with_dialect(
92 query_sql: &str,
93 dialect: &SqlDialect,
94) -> Result<Query, ScytheError> {
95 let mut name: Option<String> = None;
96 let mut command: Option<QueryCommand> = None;
97 let mut param_docs = Vec::new();
98 let mut nullable_overrides = Vec::new();
99 let mut nonnull_overrides = Vec::new();
100 let mut json_mappings = Vec::new();
101 let mut deprecated: Option<String> = None;
102 let mut optional_params = Vec::new();
103 let mut group_by: Option<String> = None;
104
105 let mut sql_lines = Vec::new();
106
107 for line in query_sql.lines() {
108 let trimmed = line.trim();
109
110 let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
112 let rest = rest.trim_start();
113 rest.strip_prefix('@')
114 } else {
115 None
116 };
117
118 if let Some(body) = annotation_body {
119 let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
121 Some(pos) => (&body[..pos], body[pos..].trim()),
122 None => (body, ""),
123 };
124
125 match keyword.to_ascii_lowercase().as_str() {
126 "name" => {
127 name = Some(value.to_string());
128 }
129 "returns" => {
130 let cmd_str = value.strip_prefix(':').unwrap_or(value);
131 command = Some(QueryCommand::from_str(cmd_str)?);
132 }
133 "param" => {
134 if let Some(colon_pos) = value.find(':') {
136 let param_name = value[..colon_pos].trim().to_string();
137 let description = value[colon_pos + 1..].trim().to_string();
138 param_docs.push(ParamDoc {
139 name: param_name,
140 description,
141 });
142 } else {
143 param_docs.push(ParamDoc {
144 name: value.to_string(),
145 description: String::new(),
146 });
147 }
148 }
149 "nullable" => {
150 for col in value.split(',') {
151 let col = col.trim();
152 if !col.is_empty() {
153 nullable_overrides.push(col.to_string());
154 }
155 }
156 }
157 "nonnull" => {
158 for col in value.split(',') {
159 let col = col.trim();
160 if !col.is_empty() {
161 nonnull_overrides.push(col.to_string());
162 }
163 }
164 }
165 "json" => {
166 if let Some(eq_pos) = value.find('=') {
168 let column = value[..eq_pos].trim().to_string();
169 let rust_type = value[eq_pos + 1..].trim().to_string();
170 json_mappings.push(JsonMapping { column, rust_type });
171 }
172 }
173 "deprecated" => {
174 deprecated = Some(value.to_string());
175 }
176 "group_by" => {
177 group_by = Some(value.to_string());
178 }
179 "optional" => {
180 for param in value.split(',') {
181 let param = param.trim();
182 if !param.is_empty() {
183 optional_params.push(param.to_string());
184 }
185 }
186 }
187 _ => {
188 }
190 }
191 } else {
192 sql_lines.push(line);
193 }
194 }
195
196 let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
197 let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
198
199 if command == QueryCommand::Grouped && group_by.is_none() {
200 return Err(ScytheError::invalid_annotation(
201 "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
202 ));
203 }
204
205 let sql = sql_lines.join("\n").trim().to_string();
206
207 if sql.is_empty() {
208 return Err(ScytheError::syntax("empty SQL body"));
209 }
210
211 let (sql, parse_sql) = if *dialect == SqlDialect::Oracle {
216 let processed = preprocess_oracle_sql(&sql);
217 (processed.clone(), processed)
218 } else if *dialect == SqlDialect::MsSql {
219 let codegen_sql = convert_mssql_placeholders(&sql);
221 let parse_sql = preprocess_mssql_sql(&sql);
223 (codegen_sql, parse_sql)
224 } else {
225 (sql.clone(), sql)
226 };
227
228 let parser_dialect = dialect.to_sqlparser_dialect();
229 let statements = Parser::parse_sql(parser_dialect.as_ref(), &parse_sql)
230 .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
231
232 if statements.len() != 1 {
233 let non_empty: Vec<_> = statements
236 .into_iter()
237 .filter(|s| {
238 !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
239 })
240 .collect();
241 if non_empty.len() != 1 {
242 return Err(ScytheError::syntax("expected exactly one SQL statement"));
243 }
244 let stmt = non_empty
245 .into_iter()
246 .next()
247 .expect("filtered to exactly one statement");
248 let annotations = Annotations {
249 name: name.clone(),
250 command: command.clone(),
251 param_docs,
252 nullable_overrides,
253 nonnull_overrides,
254 json_mappings,
255 deprecated,
256 optional_params,
257 group_by: group_by.clone(),
258 };
259 return Ok(Query {
260 name,
261 command,
262 sql,
263 stmt,
264 annotations,
265 });
266 }
267
268 let stmt = statements
269 .into_iter()
270 .next()
271 .expect("filtered to exactly one statement");
272
273 let annotations = Annotations {
274 name: name.clone(),
275 command: command.clone(),
276 param_docs,
277 nullable_overrides,
278 nonnull_overrides,
279 json_mappings,
280 deprecated,
281 optional_params,
282 group_by,
283 };
284
285 Ok(Query {
286 name,
287 command,
288 sql,
289 stmt,
290 annotations,
291 })
292}
293
294fn preprocess_oracle_sql(sql: &str) -> String {
298 let sql = strip_returning_into(sql);
301
302 let mut result = String::with_capacity(sql.len());
304 let mut chars = sql.chars().peekable();
305 while let Some(ch) = chars.next() {
306 if ch == '\'' {
307 result.push(ch);
309 while let Some(inner) = chars.next() {
310 result.push(inner);
311 if inner == '\'' {
312 if chars.peek() == Some(&'\'') {
313 result.push(chars.next().unwrap());
314 } else {
315 break;
316 }
317 }
318 }
319 } else if ch == ':' && chars.peek().is_some_and(|c| c.is_ascii_digit()) {
320 result.push('?');
322 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
323 chars.next();
324 }
325 } else {
326 result.push(ch);
327 }
328 }
329 result
330}
331
332fn convert_mssql_placeholders(sql: &str) -> String {
336 let mut result = String::with_capacity(sql.len());
337 let mut chars = sql.chars().peekable();
338 while let Some(ch) = chars.next() {
339 if ch == '\'' {
340 result.push(ch);
342 while let Some(inner) = chars.next() {
343 result.push(inner);
344 if inner == '\'' {
345 if chars.peek() == Some(&'\'') {
346 result.push(chars.next().unwrap());
348 } else {
349 break;
350 }
351 }
352 }
353 } else if ch == '@' && chars.peek().is_some_and(|c| *c == 'p' || *c == 'P') {
354 let mut lookahead = chars.clone();
356 lookahead.next(); if lookahead.peek().is_some_and(|c| c.is_ascii_digit()) {
358 chars.next(); while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
361 chars.next();
362 }
363 result.push('?');
364 } else {
365 result.push(ch);
366 }
367 } else {
368 result.push(ch);
369 }
370 }
371 result
372}
373
374fn preprocess_mssql_sql(sql: &str) -> String {
378 let sql = strip_and_convert_mssql_output(sql);
380 convert_mssql_placeholders(&sql)
382}
383
384fn strip_and_convert_mssql_output(sql: &str) -> String {
391 let upper = sql.to_uppercase();
393
394 if !upper.contains("INSERT") || !upper.contains("OUTPUT") {
396 return sql.to_string();
397 }
398
399 if let Some(output_pos) = find_word_position(&upper, "OUTPUT") {
401 let before_output = &upper[..output_pos];
403 if !before_output.contains("INSERT") {
404 return sql.to_string();
405 }
406
407 let after_output = &upper[output_pos + "OUTPUT".len()..];
409 if let Some(values_offset) = find_word_position(after_output, "VALUES") {
410 let values_pos = output_pos + "OUTPUT".len() + values_offset;
411
412 let output_cols_str = &sql[output_pos + "OUTPUT".len()..values_pos];
414
415 let cols = parse_inserted_columns(output_cols_str);
417
418 if !cols.is_empty() {
419 let before_output_sql = sql[..output_pos].trim_end();
422 let after_values = sql[values_pos..].trim_end();
423 let (values_body, trailing) = if let Some(stripped) = after_values.strip_suffix(';')
424 {
425 (stripped, ";")
426 } else {
427 (after_values, "")
428 };
429
430 return format!(
431 "{}\n{} RETURNING {}{}",
432 before_output_sql, values_body, cols, trailing
433 );
434 }
435 }
436 }
437
438 sql.to_string()
439}
440
441fn find_word_position(text: &str, word: &str) -> Option<usize> {
444 let mut pos = 0;
445 let word_len = word.len();
446 while let Some(idx) = text[pos..].find(word) {
447 let abs_idx = pos + idx;
448
449 let before_ok = abs_idx == 0
451 || !text
452 .as_bytes()
453 .get(abs_idx - 1)
454 .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
455
456 let after_idx = abs_idx + word_len;
458 let after_ok = after_idx >= text.len()
459 || !text
460 .as_bytes()
461 .get(after_idx)
462 .is_some_and(|&b| b.is_ascii_alphanumeric() || b == b'_');
463
464 if before_ok && after_ok {
465 return Some(abs_idx);
466 }
467 pos = abs_idx + 1;
468 }
469 None
470}
471
472fn parse_inserted_columns(output_str: &str) -> String {
474 let mut cols = Vec::new();
475
476 for part in output_str.split(',') {
477 let trimmed = part.trim();
478
479 if let Some(after_inserted) = trimmed
481 .strip_prefix("INSERTED.")
482 .or_else(|| trimmed.strip_prefix("inserted."))
483 .or_else(|| trimmed.strip_prefix("INSERTED"))
484 .or_else(|| trimmed.strip_prefix("inserted"))
485 {
486 let col_name = after_inserted.trim().to_string();
487 if !col_name.is_empty() {
488 cols.push(col_name);
489 }
490 }
491 }
492
493 cols.join(", ")
494}
495
496fn strip_returning_into(sql: &str) -> String {
498 let upper = sql.to_uppercase();
500 if let Some(ret_pos) = upper.rfind("RETURNING") {
501 let after_returning = &upper[ret_pos + "RETURNING".len()..];
502 if let Some(into_offset) = after_returning.find("INTO") {
503 let into_pos = ret_pos + "RETURNING".len() + into_offset;
504 let trimmed = sql[..into_pos].trim_end();
506 return trimmed.to_string();
507 }
508 }
509 sql.to_string()
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::errors::ErrorCode;
516
517 fn parse(sql: &str) -> Result<Query, ScytheError> {
518 parse_query(sql)
519 }
520
521 #[test]
522 fn test_basic_parse() {
523 let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
524 let q = parse(input).unwrap();
525 assert_eq!(q.name, "GetUsers");
526 assert_eq!(q.command, QueryCommand::Many);
527 assert!(q.sql.contains("SELECT"));
528 }
529
530 #[test]
531 fn test_all_command_types() {
532 let cases = vec![
533 (":one", QueryCommand::One),
534 (":many", QueryCommand::Many),
535 (":exec", QueryCommand::Exec),
536 (":exec_result", QueryCommand::ExecResult),
537 (":exec_rows", QueryCommand::ExecRows),
538 ];
539 for (tag, expected) in cases {
540 let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
541 let q = parse(&input).unwrap();
542 assert_eq!(q.command, expected, "failed for {}", tag);
543 }
544 }
545
546 #[test]
547 fn test_case_insensitive_keywords() {
548 let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
549 let q = parse(input).unwrap();
550 assert_eq!(q.name, "GetUsers");
551 assert_eq!(q.command, QueryCommand::Many);
552 }
553
554 #[test]
555 fn test_missing_name_errors() {
556 let input = "-- @returns :many\nSELECT 1";
557 let err = parse(input).unwrap_err();
558 assert_eq!(err.code, ErrorCode::MissingAnnotation);
559 assert!(err.message.contains("name"));
560 }
561
562 #[test]
563 fn test_missing_returns_errors() {
564 let input = "-- @name Foo\nSELECT 1";
565 let err = parse(input).unwrap_err();
566 assert_eq!(err.code, ErrorCode::MissingAnnotation);
567 assert!(err.message.contains("returns"));
568 }
569
570 #[test]
571 fn test_invalid_returns_value() {
572 let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
573 let err = parse(input).unwrap_err();
574 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
575 }
576
577 #[test]
578 fn test_empty_name_value() {
579 let input = "-- @name\n-- @returns :one\nSELECT 1";
581 let q = parse(input).unwrap();
582 assert_eq!(q.name, "");
583 }
584
585 #[test]
586 fn test_param_annotation() {
587 let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
588 let q = parse(input).unwrap();
589 assert_eq!(q.annotations.param_docs.len(), 1);
590 assert_eq!(q.annotations.param_docs[0].name, "id");
591 assert_eq!(q.annotations.param_docs[0].description, "the user ID");
592 }
593
594 #[test]
595 fn test_param_no_description() {
596 let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
597 let q = parse(input).unwrap();
598 assert_eq!(q.annotations.param_docs.len(), 1);
599 assert_eq!(q.annotations.param_docs[0].name, "id");
600 assert_eq!(q.annotations.param_docs[0].description, "");
601 }
602
603 #[test]
604 fn test_nullable_annotation() {
605 let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
606 let q = parse(input).unwrap();
607 assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
608 }
609
610 #[test]
611 fn test_nonnull_annotation() {
612 let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
613 let q = parse(input).unwrap();
614 assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
615 }
616
617 #[test]
618 fn test_json_annotation() {
619 let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
620 let q = parse(input).unwrap();
621 assert_eq!(q.annotations.json_mappings.len(), 1);
622 assert_eq!(q.annotations.json_mappings[0].column, "data");
623 assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
624 }
625
626 #[test]
627 fn test_deprecated_annotation() {
628 let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
629 let q = parse(input).unwrap();
630 assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
631 }
632
633 #[test]
634 fn test_sql_syntax_error() {
635 let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
636 let err = parse(input).unwrap_err();
637 assert_eq!(err.code, ErrorCode::SyntaxError);
638 }
639
640 #[test]
641 fn test_trailing_semicolon() {
642 let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
643 let q = parse(input).unwrap();
644 assert_eq!(q.name, "Foo");
645 }
646
647 #[test]
648 fn test_multiple_statements_error() {
649 let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
650 let err = parse(input).unwrap_err();
651 assert_eq!(err.code, ErrorCode::SyntaxError);
652 }
653
654 #[test]
655 fn test_sql_preserved_without_annotations() {
656 let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
657 let q = parse(input).unwrap();
658 assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
659 }
660
661 #[test]
662 fn test_returns_without_colon_prefix() {
663 let input = "-- @name Foo\n-- @returns many\nSELECT 1";
664 let q = parse(input).unwrap();
665 assert_eq!(q.command, QueryCommand::Many);
666 }
667
668 #[test]
669 fn test_batch_command() {
670 let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
671 let q = parse(input).unwrap();
672 assert_eq!(q.command, QueryCommand::Batch);
673 }
674
675 #[test]
676 fn test_grouped_command_with_group_by() {
677 let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
678 let q = parse(input).unwrap();
679 assert_eq!(q.command, QueryCommand::Grouped);
680 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
681 }
682
683 #[test]
684 fn test_grouped_command_without_group_by_errors() {
685 let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
686 let err = parse(input).unwrap_err();
687 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
688 assert!(err.message.contains("@group_by"));
689 }
690
691 #[test]
692 fn test_group_by_without_grouped_is_ignored() {
693 let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
694 let q = parse(input).unwrap();
695 assert_eq!(q.command, QueryCommand::Many);
696 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
697 }
698
699 #[test]
700 fn test_preprocess_oracle_colon_placeholders() {
701 assert_eq!(
702 preprocess_oracle_sql("SELECT * FROM users WHERE id = :1"),
703 "SELECT * FROM users WHERE id = ?"
704 );
705 assert_eq!(
706 preprocess_oracle_sql("INSERT INTO users (name, email) VALUES (:1, :2)"),
707 "INSERT INTO users (name, email) VALUES (?, ?)"
708 );
709 }
710
711 #[test]
712 fn test_preprocess_oracle_preserves_string_literals() {
713 assert_eq!(
714 preprocess_oracle_sql("SELECT * FROM users WHERE name = ':1' AND id = :1"),
715 "SELECT * FROM users WHERE name = ':1' AND id = ?"
716 );
717 }
718
719 #[test]
720 fn test_preprocess_oracle_strips_returning_into() {
721 assert_eq!(
722 preprocess_oracle_sql(
723 "INSERT INTO users (name) VALUES (:1) RETURNING id, name INTO :2, :3"
724 ),
725 "INSERT INTO users (name) VALUES (?) RETURNING id, name"
726 );
727 }
728
729 #[test]
730 fn test_preprocess_oracle_full_insert_returning_into() {
731 let sql = "INSERT INTO users (name, email, active) VALUES (:1, :2, :3) RETURNING id, name, email, active, created_at INTO :4, :5, :6, :7, :8";
732 let result = preprocess_oracle_sql(sql);
733 assert_eq!(
734 result,
735 "INSERT INTO users (name, email, active) VALUES (?, ?, ?) RETURNING id, name, email, active, created_at"
736 );
737 }
738
739 #[test]
740 fn test_preprocess_oracle_no_returning_into_unchanged() {
741 assert_eq!(
742 preprocess_oracle_sql("DELETE FROM users WHERE id = :1"),
743 "DELETE FROM users WHERE id = ?"
744 );
745 }
746
747 #[test]
748 fn test_preprocess_mssql_single_placeholder() {
749 assert_eq!(
750 preprocess_mssql_sql("SELECT * FROM users WHERE id = @p1"),
751 "SELECT * FROM users WHERE id = ?"
752 );
753 }
754
755 #[test]
756 fn test_preprocess_mssql_multiple_placeholders() {
757 assert_eq!(
758 preprocess_mssql_sql("INSERT INTO users (name, email) VALUES (@p1, @p2)"),
759 "INSERT INTO users (name, email) VALUES (?, ?)"
760 );
761 }
762
763 #[test]
764 fn test_preprocess_mssql_preserves_string_literals() {
765 assert_eq!(
766 preprocess_mssql_sql("SELECT * FROM users WHERE name = '@p1' AND id = @p1"),
767 "SELECT * FROM users WHERE name = '@p1' AND id = ?"
768 );
769 }
770
771 #[test]
772 fn test_preprocess_mssql_case_insensitive_p() {
773 assert_eq!(
774 preprocess_mssql_sql("SELECT * FROM users WHERE id = @P1"),
775 "SELECT * FROM users WHERE id = ?"
776 );
777 }
778
779 #[test]
780 fn test_preprocess_mssql_non_placeholder_at_variable_unchanged() {
781 assert_eq!(preprocess_mssql_sql("SELECT @myvar"), "SELECT @myvar");
783 }
784
785 #[test]
786 fn test_preprocess_mssql_multi_digit_placeholder() {
787 assert_eq!(preprocess_mssql_sql("SELECT @p10, @p2"), "SELECT ?, ?");
788 }
789
790 #[test]
791 fn test_preprocess_mssql_output_inserted_simple() {
792 let sql =
793 "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, @p2)";
794 let result = preprocess_mssql_sql(sql);
795 assert!(result.contains("RETURNING id, name"), "got: {}", result);
797 assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
798 assert!(!result.contains("OUTPUT"), "got: {}", result);
799 }
800
801 #[test]
802 fn test_preprocess_mssql_output_inserted_full_example() {
803 let sql = "INSERT INTO users (id, name, email, active) OUTPUT INSERTED.id, INSERTED.name, INSERTED.email, INSERTED.active, INSERTED.created_at VALUES (@p1, @p2, @p3, @p4)";
804 let result = preprocess_mssql_sql(sql);
805 assert!(
806 result.contains("RETURNING id, name, email, active, created_at"),
807 "got: {}",
808 result
809 );
810 assert!(result.contains("VALUES (?, ?, ?, ?)"), "got: {}", result);
811 }
812
813 #[test]
814 fn test_preprocess_mssql_output_case_insensitive() {
815 let sql = "INSERT INTO users (id) output inserted.id values (@p1)";
816 let result = preprocess_mssql_sql(sql);
817 assert!(result.contains("RETURNING id"), "got: {}", result);
818 assert!(
820 result.contains("values (?)") || result.contains("VALUES (?)"),
821 "got: {}",
822 result
823 );
824 }
825
826 #[test]
827 fn test_preprocess_mssql_no_output_unchanged() {
828 let sql = "INSERT INTO users (id, name) VALUES (@p1, @p2)";
829 let result = preprocess_mssql_sql(sql);
830 assert_eq!(result, "INSERT INTO users (id, name) VALUES (?, ?)");
831 }
832
833 #[test]
834 fn test_preprocess_mssql_output_with_string_literal() {
835 let sql =
837 "INSERT INTO users (id, name) OUTPUT INSERTED.id, INSERTED.name VALUES (@p1, '@p2')";
838 let result = preprocess_mssql_sql(sql);
839 assert!(result.contains("RETURNING id, name"), "got: {}", result);
840 assert!(result.contains("(?, '@p2')"), "got: {}", result);
841 }
842
843 #[test]
844 fn test_preprocess_mssql_output_with_whitespace() {
845 let sql =
846 "INSERT INTO users (id, name)\nOUTPUT INSERTED.id,\n INSERTED.name\nVALUES (@p1, @p2)";
847 let result = preprocess_mssql_sql(sql);
848 assert!(result.contains("RETURNING id, name"), "got: {}", result);
849 assert!(result.contains("VALUES (?, ?)"), "got: {}", result);
850 }
851}