1use crate::context::SharedConnection;
2use sql_orm_core::{FromRow, OrmError, SqlTypeMapping, SqlValue};
3use sql_orm_query::CompiledQuery;
4use sql_orm_tiberius::ExecuteResult;
5use std::collections::BTreeSet;
6use std::marker::PhantomData;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub enum RawSqlExecution {
15 ReadOnly,
17 Write,
19 Migration,
21 RawNoRetry,
23}
24
25impl RawSqlExecution {
26 const fn query_execution(self) -> sql_orm_query::QueryExecution {
27 match self {
28 Self::ReadOnly => sql_orm_query::QueryExecution::ReadOnly,
29 Self::Write => sql_orm_query::QueryExecution::Write,
30 Self::Migration => sql_orm_query::QueryExecution::Migration,
31 Self::RawNoRetry => sql_orm_query::QueryExecution::RawNoRetry,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
37pub enum QueryHint {
43 Recompile,
45}
46
47impl QueryHint {
48 const fn sql(self) -> &'static str {
49 match self {
50 Self::Recompile => "RECOMPILE",
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56struct RawPlaceholderPlan {
57 max_index: usize,
58}
59
60impl RawPlaceholderPlan {
61 const fn expected_param_count(&self) -> usize {
62 self.max_index
63 }
64}
65
66pub trait RawParam {
72 fn into_sql_value(self) -> SqlValue;
74}
75
76macro_rules! impl_raw_param_via_sql_type_mapping {
77 ($($ty:ty),+ $(,)?) => {
78 $(
79 impl RawParam for $ty {
80 fn into_sql_value(self) -> SqlValue {
81 <Self as SqlTypeMapping>::to_sql_value(self)
82 }
83 }
84 )+
85 };
86}
87
88impl_raw_param_via_sql_type_mapping!(
89 bool,
90 i32,
91 i64,
92 f64,
93 String,
94 Vec<u8>,
95 uuid::Uuid,
96 rust_decimal::Decimal,
97 chrono::NaiveDate,
98 chrono::NaiveDateTime,
99);
100
101impl RawParam for SqlValue {
102 fn into_sql_value(self) -> SqlValue {
103 self
104 }
105}
106
107impl RawParam for &str {
108 fn into_sql_value(self) -> SqlValue {
109 SqlValue::String(self.to_string())
110 }
111}
112
113impl<T> RawParam for Option<T>
114where
115 T: RawParam,
116{
117 fn into_sql_value(self) -> SqlValue {
118 self.map(RawParam::into_sql_value).unwrap_or(SqlValue::Null)
119 }
120}
121
122pub trait RawParams {
127 fn into_sql_values(self) -> Vec<SqlValue>;
129}
130
131impl RawParams for () {
132 fn into_sql_values(self) -> Vec<SqlValue> {
133 Vec::new()
134 }
135}
136
137impl<T> RawParams for Vec<T>
138where
139 T: RawParam,
140{
141 fn into_sql_values(self) -> Vec<SqlValue> {
142 self.into_iter().map(RawParam::into_sql_value).collect()
143 }
144}
145
146macro_rules! impl_raw_params_tuple {
147 ($($name:ident),+ $(,)?) => {
148 impl<$($name),+> RawParams for ($($name,)+)
149 where
150 $($name: RawParam),+
151 {
152 #[allow(non_snake_case)]
153 fn into_sql_values(self) -> Vec<SqlValue> {
154 let ($($name,)+) = self;
155 vec![$($name.into_sql_value()),+]
156 }
157 }
158 };
159}
160
161impl_raw_params_tuple!(A);
162impl_raw_params_tuple!(A, B);
163impl_raw_params_tuple!(A, B, C);
164impl_raw_params_tuple!(A, B, C, D);
165impl_raw_params_tuple!(A, B, C, D, E);
166impl_raw_params_tuple!(A, B, C, D, E, F);
167impl_raw_params_tuple!(A, B, C, D, E, F, G);
168impl_raw_params_tuple!(A, B, C, D, E, F, G, H);
169impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I);
170impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J);
171impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J, K);
172impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
173
174#[derive(Clone)]
175pub struct RawQuery<T> {
181 connection: SharedConnection,
182 sql: String,
183 params: Vec<SqlValue>,
184 query_hints: BTreeSet<QueryHint>,
185 execution: RawSqlExecution,
186 _row: PhantomData<fn() -> T>,
187}
188
189impl<T> RawQuery<T>
190where
191 T: FromRow + Send,
192{
193 pub(crate) fn new(connection: SharedConnection, sql: impl Into<String>) -> Self {
194 Self {
195 connection,
196 sql: sql.into(),
197 params: Vec::new(),
198 query_hints: BTreeSet::new(),
199 execution: RawSqlExecution::RawNoRetry,
200 _row: PhantomData,
201 }
202 }
203
204 pub fn param<P>(mut self, value: P) -> Self
206 where
207 P: RawParam,
208 {
209 self.params.push(value.into_sql_value());
210 self
211 }
212
213 pub fn params<P>(mut self, values: P) -> Self
215 where
216 P: RawParams,
217 {
218 self.params.extend(values.into_sql_values());
219 self
220 }
221
222 pub fn query_hint(mut self, hint: QueryHint) -> Self {
226 self.query_hints.insert(hint);
227 self
228 }
229
230 pub fn read_only(mut self) -> Self {
236 self.execution = RawSqlExecution::ReadOnly;
237 self
238 }
239
240 pub fn no_retry(mut self) -> Self {
242 self.execution = RawSqlExecution::RawNoRetry;
243 self
244 }
245
246 pub async fn all(self) -> Result<Vec<T>, OrmError> {
248 let compiled = self.compiled_query()?;
249 let mut connection = self.connection.lock().await?;
250 connection.fetch_all(compiled).await
251 }
252
253 pub async fn first(self) -> Result<Option<T>, OrmError> {
255 let compiled = self.compiled_query()?;
256 let mut connection = self.connection.lock().await?;
257 connection.fetch_one(compiled).await
258 }
259
260 fn compiled_query(&self) -> Result<CompiledQuery, OrmError> {
261 compiled_raw_query_with_hints(
262 &self.sql,
263 self.params.clone(),
264 &self.query_hints,
265 self.execution,
266 )
267 }
268}
269
270#[derive(Clone)]
271pub struct RawCommand {
277 connection: SharedConnection,
278 sql: String,
279 params: Vec<SqlValue>,
280 execution: RawSqlExecution,
281}
282
283impl RawCommand {
284 pub(crate) fn new(connection: SharedConnection, sql: impl Into<String>) -> Self {
285 Self {
286 connection,
287 sql: sql.into(),
288 params: Vec::new(),
289 execution: RawSqlExecution::Write,
290 }
291 }
292
293 pub fn param<P>(mut self, value: P) -> Self
295 where
296 P: RawParam,
297 {
298 self.params.push(value.into_sql_value());
299 self
300 }
301
302 pub fn params<P>(mut self, values: P) -> Self
304 where
305 P: RawParams,
306 {
307 self.params.extend(values.into_sql_values());
308 self
309 }
310
311 pub fn migration(mut self) -> Self {
313 self.execution = RawSqlExecution::Migration;
314 self
315 }
316
317 pub fn no_retry(mut self) -> Self {
319 self.execution = RawSqlExecution::RawNoRetry;
320 self
321 }
322
323 pub async fn execute(self) -> Result<ExecuteResult, OrmError> {
325 let compiled = self.compiled_query()?;
326 let mut connection = self.connection.lock().await?;
327 connection.execute(compiled).await
328 }
329
330 fn compiled_query(&self) -> Result<CompiledQuery, OrmError> {
331 compiled_raw_query_with_execution(&self.sql, self.params.clone(), self.execution)
332 }
333}
334
335#[cfg(test)]
336fn compiled_raw_query(sql: &str, params: Vec<SqlValue>) -> Result<CompiledQuery, OrmError> {
337 compiled_raw_query_with_execution(sql, params, RawSqlExecution::RawNoRetry)
338}
339
340fn compiled_raw_query_with_execution(
341 sql: &str,
342 params: Vec<SqlValue>,
343 execution: RawSqlExecution,
344) -> Result<CompiledQuery, OrmError> {
345 compiled_raw_query_with_hints(sql, params, &BTreeSet::new(), execution)
346}
347
348fn compiled_raw_query_with_hints(
349 sql: &str,
350 params: Vec<SqlValue>,
351 query_hints: &BTreeSet<QueryHint>,
352 execution: RawSqlExecution,
353) -> Result<CompiledQuery, OrmError> {
354 validate_raw_sql_parameters(sql, params.len())?;
355
356 let sql = render_raw_sql_with_hints(sql, query_hints)?;
357
358 Ok(CompiledQuery::with_execution(
359 sql,
360 params,
361 execution.query_execution(),
362 ))
363}
364
365fn render_raw_sql_with_hints(
366 sql: &str,
367 query_hints: &BTreeSet<QueryHint>,
368) -> Result<String, OrmError> {
369 if query_hints.is_empty() {
370 return Ok(sql.to_string());
371 }
372
373 if contains_top_level_option_clause(sql) {
374 return Err(OrmError::compile(
375 "raw SQL already contains OPTION (...); remove it before using query_hint(...)",
376 ));
377 }
378
379 let mut sql = sql.trim_end().trim_end_matches(';').trim_end().to_string();
380 let hints = query_hints
381 .iter()
382 .copied()
383 .map(QueryHint::sql)
384 .collect::<Vec<_>>()
385 .join(", ");
386
387 sql.push_str(" OPTION (");
388 sql.push_str(&hints);
389 sql.push(')');
390
391 Ok(sql)
392}
393
394pub(crate) fn validate_raw_sql_parameters(sql: &str, param_count: usize) -> Result<(), OrmError> {
395 let plan = analyze_placeholders(sql)?;
396
397 if plan.expected_param_count() != param_count {
398 return Err(OrmError::compile(format!(
399 "raw SQL parameter count mismatch: SQL expects {} parameter(s), received {}",
400 plan.expected_param_count(),
401 param_count
402 )));
403 }
404
405 Ok(())
406}
407
408fn analyze_placeholders(sql: &str) -> Result<RawPlaceholderPlan, OrmError> {
409 let bytes = sql.as_bytes();
410 let mut index = 0;
411 let mut placeholders = BTreeSet::new();
412
413 while index + 2 < bytes.len() {
414 if let Some(next_index) = skip_sql_non_code(bytes, index) {
415 index = next_index;
416 continue;
417 }
418
419 if bytes[index] == b'@' && bytes[index + 1] == b'P' && bytes[index + 2].is_ascii_digit() {
420 index += 2;
421 let start = index;
422
423 while index < bytes.len() && bytes[index].is_ascii_digit() {
424 index += 1;
425 }
426
427 let raw_index = sql[start..index].parse::<usize>().map_err(|_| {
428 OrmError::compile("raw SQL placeholder index is larger than supported")
429 })?;
430
431 if raw_index == 0 {
432 return Err(OrmError::compile("raw SQL placeholders must start at @P1"));
433 }
434
435 placeholders.insert(raw_index);
436 continue;
437 }
438
439 index += 1;
440 }
441
442 let max_index = placeholders.iter().next_back().copied().unwrap_or(0);
443 for expected in 1..=max_index {
444 if !placeholders.contains(&expected) {
445 return Err(OrmError::compile(format!(
446 "raw SQL placeholders must be continuous from @P1 to @P{}",
447 max_index
448 )));
449 }
450 }
451
452 Ok(RawPlaceholderPlan { max_index })
453}
454
455fn skip_sql_non_code(bytes: &[u8], index: usize) -> Option<usize> {
456 match bytes[index] {
457 b'\'' => Some(skip_quoted_string(bytes, index)),
458 b'[' => Some(skip_bracket_identifier(bytes, index)),
459 b'"' => Some(skip_double_quoted_identifier(bytes, index)),
460 b'-' if index + 1 < bytes.len() && bytes[index + 1] == b'-' => {
461 Some(skip_line_comment(bytes, index))
462 }
463 b'/' if index + 1 < bytes.len() && bytes[index + 1] == b'*' => {
464 Some(skip_block_comment(bytes, index))
465 }
466 _ => None,
467 }
468}
469
470fn skip_quoted_string(bytes: &[u8], mut index: usize) -> usize {
471 index += 1;
472 while index < bytes.len() {
473 if bytes[index] == b'\'' {
474 index += 1;
475 if index < bytes.len() && bytes[index] == b'\'' {
476 index += 1;
477 continue;
478 }
479 break;
480 }
481 index += 1;
482 }
483 index
484}
485
486fn skip_bracket_identifier(bytes: &[u8], mut index: usize) -> usize {
487 index += 1;
488 while index < bytes.len() {
489 if bytes[index] == b']' {
490 index += 1;
491 if index < bytes.len() && bytes[index] == b']' {
492 index += 1;
493 continue;
494 }
495 break;
496 }
497 index += 1;
498 }
499 index
500}
501
502fn skip_double_quoted_identifier(bytes: &[u8], mut index: usize) -> usize {
503 index += 1;
504 while index < bytes.len() {
505 if bytes[index] == b'"' {
506 index += 1;
507 if index < bytes.len() && bytes[index] == b'"' {
508 index += 1;
509 continue;
510 }
511 break;
512 }
513 index += 1;
514 }
515 index
516}
517
518fn skip_line_comment(bytes: &[u8], mut index: usize) -> usize {
519 index += 2;
520 while index < bytes.len() && !matches!(bytes[index], b'\n' | b'\r') {
521 index += 1;
522 }
523 index
524}
525
526fn skip_block_comment(bytes: &[u8], mut index: usize) -> usize {
527 index += 2;
528 while index + 1 < bytes.len() {
529 if bytes[index] == b'*' && bytes[index + 1] == b'/' {
530 return index + 2;
531 }
532 index += 1;
533 }
534 bytes.len()
535}
536
537fn contains_top_level_option_clause(sql: &str) -> bool {
538 let bytes = sql.as_bytes();
539 let mut index = 0;
540 let mut depth = 0_i32;
541
542 while index < bytes.len() {
543 match bytes[index] {
544 b'\'' => {
545 index += 1;
546 while index < bytes.len() {
547 if bytes[index] == b'\'' {
548 index += 1;
549 if index < bytes.len() && bytes[index] == b'\'' {
550 index += 1;
551 continue;
552 }
553 break;
554 }
555 index += 1;
556 }
557 }
558 b'[' => {
559 index += 1;
560 while index < bytes.len() {
561 if bytes[index] == b']' {
562 index += 1;
563 break;
564 }
565 index += 1;
566 }
567 }
568 b'-' if index + 1 < bytes.len() && bytes[index + 1] == b'-' => {
569 index += 2;
570 while index < bytes.len() && !matches!(bytes[index], b'\n' | b'\r') {
571 index += 1;
572 }
573 }
574 b'/' if index + 1 < bytes.len() && bytes[index + 1] == b'*' => {
575 index += 2;
576 while index + 1 < bytes.len() {
577 if bytes[index] == b'*' && bytes[index + 1] == b'/' {
578 index += 2;
579 break;
580 }
581 index += 1;
582 }
583 }
584 b'(' => {
585 depth += 1;
586 index += 1;
587 }
588 b')' => {
589 depth = depth.saturating_sub(1);
590 index += 1;
591 }
592 _ if depth == 0 && starts_with_keyword(sql, index, "OPTION") => {
593 let after_keyword = index + "OPTION".len();
594 let mut cursor = after_keyword;
595 while cursor < bytes.len() && bytes[cursor].is_ascii_whitespace() {
596 cursor += 1;
597 }
598 return cursor < bytes.len() && bytes[cursor] == b'(';
599 }
600 _ => index += 1,
601 }
602 }
603
604 false
605}
606
607fn starts_with_keyword(sql: &str, index: usize, keyword: &str) -> bool {
608 let bytes = sql.as_bytes();
609 let keyword_bytes = keyword.as_bytes();
610
611 if index + keyword_bytes.len() > bytes.len() {
612 return false;
613 }
614
615 if !bytes[index..index + keyword_bytes.len()].eq_ignore_ascii_case(keyword_bytes) {
616 return false;
617 }
618
619 let before_is_boundary = index == 0 || !is_identifier_byte(bytes[index - 1]);
620 let after = index + keyword_bytes.len();
621 let after_is_boundary = after == bytes.len() || !is_identifier_byte(bytes[after]);
622
623 before_is_boundary && after_is_boundary
624}
625
626fn is_identifier_byte(byte: u8) -> bool {
627 byte.is_ascii_alphanumeric() || byte == b'_'
628}
629
630#[cfg(test)]
631mod tests {
632 use super::RawSqlExecution;
633 use super::{
634 QueryHint, RawParam, RawParams, compiled_raw_query, compiled_raw_query_with_execution,
635 compiled_raw_query_with_hints, contains_top_level_option_clause,
636 validate_raw_sql_parameters,
637 };
638 use chrono::NaiveDate;
639 use rust_decimal::Decimal;
640 use sql_orm_core::{OrmErrorKind, SqlValue};
641 use sql_orm_query::QueryExecution;
642 use std::collections::BTreeSet;
643 use uuid::Uuid;
644
645 #[test]
646 fn validates_continuous_placeholders_by_max_index() {
647 validate_raw_sql_parameters("SELECT @P1, @P2, @P3", 3).unwrap();
648 }
649
650 #[test]
651 fn validates_continuous_placeholders_through_highest_index() {
652 validate_raw_sql_parameters(
653 "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11, @P12",
654 12,
655 )
656 .unwrap();
657 }
658
659 #[test]
660 fn allows_repeated_placeholder_to_reuse_one_param() {
661 validate_raw_sql_parameters("SELECT @P1 WHERE owner_id = @P1", 1).unwrap();
662 }
663
664 #[test]
665 fn rejects_extra_params_without_placeholders() {
666 let error = validate_raw_sql_parameters("SELECT 1", 1).unwrap_err();
667
668 assert_eq!(error.kind(), OrmErrorKind::Compile);
669 assert!(error.message().contains("expects 0 parameter"));
670 }
671
672 #[test]
673 fn rejects_missing_params() {
674 let error = validate_raw_sql_parameters("SELECT @P1, @P2", 1).unwrap_err();
675
676 assert!(error.message().contains("expects 2 parameter"));
677 }
678
679 #[test]
680 fn rejects_non_continuous_placeholders() {
681 let error = validate_raw_sql_parameters("SELECT @P1, @P3", 2).unwrap_err();
682
683 assert!(error.message().contains("continuous from @P1 to @P3"));
684 }
685
686 #[test]
687 fn rejects_zero_index_placeholder() {
688 let error = validate_raw_sql_parameters("SELECT @P0", 0).unwrap_err();
689
690 assert!(error.message().contains("start at @P1"));
691 }
692
693 #[test]
694 fn ignores_placeholder_text_inside_sql_non_code_regions() {
695 let sql = r#"
696 SELECT @P1 AS value,
697 '@P2 literal '' @P3 escaped quote' AS string_value,
698 [@P4 identifier] AS bracket_identifier,
699 "@P5 quoted identifier" AS quoted_identifier
700 -- @P6 line comment
701 /* @P7 block comment */
702 WHERE label = @P1
703 "#;
704
705 validate_raw_sql_parameters(sql, 1).unwrap();
706 }
707
708 #[test]
709 fn ignores_placeholder_text_in_raw_sql_without_parameters() {
710 let sql = r#"
711 SELECT '@P1 is documentation' AS literal,
712 [@P2 is an identifier] AS identifier
713 -- @P3 is a comment
714 /* @P4 is also a comment */
715 "#;
716
717 validate_raw_sql_parameters(sql, 0).unwrap();
718 }
719
720 #[test]
721 fn counts_placeholders_after_ignored_sql_regions() {
722 let sql = "SELECT '@P1 ignored' AS label -- @P2 ignored\nWHERE id = @P1 AND code = @P2";
723
724 validate_raw_sql_parameters(sql, 2).unwrap();
725 }
726
727 #[test]
728 fn raw_params_tuple_preserves_order_and_values() {
729 let values = (
730 true,
731 7_i32,
732 9_i64,
733 3.5_f64,
734 "draft",
735 String::from("owned"),
736 vec![1_u8, 2],
737 Uuid::nil(),
738 Decimal::new(1234, 2),
739 NaiveDate::from_ymd_opt(2026, 4, 26).unwrap(),
740 NaiveDate::from_ymd_opt(2026, 4, 26)
741 .unwrap()
742 .and_hms_opt(10, 20, 30)
743 .unwrap(),
744 SqlValue::Null,
745 )
746 .into_sql_values();
747
748 assert_eq!(
749 values,
750 vec![
751 SqlValue::Bool(true),
752 SqlValue::I32(7),
753 SqlValue::I64(9),
754 SqlValue::F64(3.5),
755 SqlValue::String("draft".to_string()),
756 SqlValue::String("owned".to_string()),
757 SqlValue::Bytes(vec![1, 2]),
758 SqlValue::Uuid(Uuid::nil()),
759 SqlValue::Decimal(Decimal::new(1234, 2)),
760 SqlValue::Date(NaiveDate::from_ymd_opt(2026, 4, 26).unwrap()),
761 SqlValue::DateTime(
762 NaiveDate::from_ymd_opt(2026, 4, 26)
763 .unwrap()
764 .and_hms_opt(10, 20, 30)
765 .unwrap()
766 ),
767 SqlValue::Null,
768 ]
769 );
770 }
771
772 #[test]
773 fn raw_param_option_none_maps_to_null() {
774 assert_eq!(Option::<i64>::None.into_sql_value(), SqlValue::Null);
775 }
776
777 #[test]
778 fn raw_param_option_some_maps_inner_value() {
779 assert_eq!(Some(42_i64).into_sql_value(), SqlValue::I64(42));
780 }
781
782 #[test]
783 fn raw_params_vec_preserves_order() {
784 let values = vec![1_i64, 2_i64, 3_i64].into_sql_values();
785
786 assert_eq!(
787 values,
788 vec![SqlValue::I64(1), SqlValue::I64(2), SqlValue::I64(3)]
789 );
790 }
791
792 #[test]
793 fn raw_params_unit_maps_to_empty_params() {
794 assert_eq!(().into_sql_values(), Vec::<SqlValue>::new());
795 }
796
797 #[test]
798 fn compiled_raw_query_preserves_sql_and_parameter_order() {
799 let params = (
800 SqlValue::Null,
801 true,
802 7_i32,
803 9_i64,
804 3.5_f64,
805 "draft",
806 vec![1_u8, 2],
807 Uuid::nil(),
808 Decimal::new(1234, 2),
809 NaiveDate::from_ymd_opt(2026, 4, 26).unwrap(),
810 NaiveDate::from_ymd_opt(2026, 4, 26)
811 .unwrap()
812 .and_hms_opt(10, 20, 30)
813 .unwrap(),
814 )
815 .into_sql_values();
816
817 let compiled = compiled_raw_query(
818 "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11",
819 params,
820 )
821 .unwrap();
822
823 assert_eq!(
824 compiled.sql,
825 "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11"
826 );
827 assert_eq!(
828 compiled.params,
829 vec![
830 SqlValue::Null,
831 SqlValue::Bool(true),
832 SqlValue::I32(7),
833 SqlValue::I64(9),
834 SqlValue::F64(3.5),
835 SqlValue::String("draft".to_string()),
836 SqlValue::Bytes(vec![1, 2]),
837 SqlValue::Uuid(Uuid::nil()),
838 SqlValue::Decimal(Decimal::new(1234, 2)),
839 SqlValue::Date(NaiveDate::from_ymd_opt(2026, 4, 26).unwrap()),
840 SqlValue::DateTime(
841 NaiveDate::from_ymd_opt(2026, 4, 26)
842 .unwrap()
843 .and_hms_opt(10, 20, 30)
844 .unwrap()
845 ),
846 ]
847 );
848 }
849
850 #[test]
851 fn compiled_raw_query_allows_repeated_placeholder_with_single_param() {
852 let compiled = compiled_raw_query(
853 "SELECT * FROM users WHERE owner_id = @P1 OR reviewer_id = @P1",
854 vec![SqlValue::I64(42)],
855 )
856 .unwrap();
857
858 assert_eq!(compiled.params, vec![SqlValue::I64(42)]);
859 }
860
861 #[test]
862 fn compiled_raw_query_appends_recompile_hint_after_parameters() {
863 let hints = BTreeSet::from([QueryHint::Recompile]);
864 let compiled = compiled_raw_query_with_hints(
865 "SELECT * FROM users WHERE owner_id = @P1",
866 vec![SqlValue::I64(42)],
867 &hints,
868 RawSqlExecution::ReadOnly,
869 )
870 .unwrap();
871
872 assert_eq!(
873 compiled.sql,
874 "SELECT * FROM users WHERE owner_id = @P1 OPTION (RECOMPILE)"
875 );
876 assert_eq!(compiled.params, vec![SqlValue::I64(42)]);
877 assert_eq!(compiled.execution, QueryExecution::ReadOnly);
878 }
879
880 #[test]
881 fn compiled_raw_query_deduplicates_repeated_query_hints() {
882 let hints = BTreeSet::from([QueryHint::Recompile, QueryHint::Recompile]);
883 let compiled =
884 compiled_raw_query_with_hints("SELECT 1", vec![], &hints, RawSqlExecution::ReadOnly)
885 .unwrap();
886
887 assert_eq!(compiled.sql, "SELECT 1 OPTION (RECOMPILE)");
888 }
889
890 #[test]
891 fn compiled_raw_query_places_hint_before_trailing_semicolon() {
892 let hints = BTreeSet::from([QueryHint::Recompile]);
893 let compiled = compiled_raw_query_with_hints(
894 "SELECT 1; ",
895 vec![],
896 &hints,
897 RawSqlExecution::ReadOnly,
898 )
899 .unwrap();
900
901 assert_eq!(compiled.sql, "SELECT 1 OPTION (RECOMPILE)");
902 }
903
904 #[test]
905 fn compiled_raw_query_rejects_existing_top_level_option_clause_with_hints() {
906 let hints = BTreeSet::from([QueryHint::Recompile]);
907 let error = compiled_raw_query_with_hints(
908 "SELECT 1 OPTION (MAXDOP 1)",
909 vec![],
910 &hints,
911 RawSqlExecution::ReadOnly,
912 )
913 .unwrap_err();
914
915 assert_eq!(error.kind(), OrmErrorKind::Compile);
916 assert!(error.message().contains("already contains OPTION"));
917 }
918
919 #[test]
920 fn raw_query_defaults_to_no_retry_execution_classification() {
921 let compiled = compiled_raw_query("SELECT 1", vec![]).unwrap();
922
923 assert_eq!(compiled.execution, QueryExecution::RawNoRetry);
924 }
925
926 #[test]
927 fn raw_sql_execution_classification_is_explicit() {
928 let read_only =
929 compiled_raw_query_with_execution("SELECT 1", vec![], RawSqlExecution::ReadOnly)
930 .unwrap();
931 let write = compiled_raw_query_with_execution(
932 "UPDATE users SET active = 1",
933 vec![],
934 RawSqlExecution::Write,
935 )
936 .unwrap();
937 let migration = compiled_raw_query_with_execution(
938 "ALTER TABLE users ADD active bit NOT NULL DEFAULT 1",
939 vec![],
940 RawSqlExecution::Migration,
941 )
942 .unwrap();
943
944 assert_eq!(read_only.execution, QueryExecution::ReadOnly);
945 assert_eq!(write.execution, QueryExecution::Write);
946 assert_eq!(migration.execution, QueryExecution::Migration);
947 }
948
949 #[test]
950 fn detects_top_level_option_clause_without_matching_strings_or_nested_queries() {
951 assert!(contains_top_level_option_clause(
952 "SELECT 1 OPTION (RECOMPILE)"
953 ));
954 assert!(!contains_top_level_option_clause(
955 "SELECT 'OPTION (RECOMPILE)' AS text_value"
956 ));
957 assert!(!contains_top_level_option_clause(
958 "SELECT * FROM (SELECT 1 OPTION (RECOMPILE)) AS nested"
959 ));
960 }
961
962 #[test]
963 fn compiled_raw_query_rejects_non_continuous_placeholders() {
964 let error = compiled_raw_query(
965 "SELECT * FROM users WHERE owner_id = @P1 OR reviewer_id = @P3",
966 vec![SqlValue::I64(42), SqlValue::I64(7)],
967 )
968 .unwrap_err();
969
970 assert!(error.message().contains("continuous from @P1 to @P3"));
971 }
972}