1use crate::context::SharedConnection;
2use sql_orm_core::{FromRow, OrmError, SqlTypeMapping, SqlValue};
3use sql_orm_query::CompiledQuery;
4use sql_orm_tiberius::ExecuteResult;
5use std::collections::BTreeSet;
6use std::marker::PhantomData;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub enum QueryHint {
15 Recompile,
17}
18
19impl QueryHint {
20 const fn sql(self) -> &'static str {
21 match self {
22 Self::Recompile => "RECOMPILE",
23 }
24 }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28struct RawPlaceholderPlan {
29 max_index: usize,
30}
31
32impl RawPlaceholderPlan {
33 const fn expected_param_count(&self) -> usize {
34 self.max_index
35 }
36}
37
38pub trait RawParam {
44 fn into_sql_value(self) -> SqlValue;
46}
47
48macro_rules! impl_raw_param_via_sql_type_mapping {
49 ($($ty:ty),+ $(,)?) => {
50 $(
51 impl RawParam for $ty {
52 fn into_sql_value(self) -> SqlValue {
53 <Self as SqlTypeMapping>::to_sql_value(self)
54 }
55 }
56 )+
57 };
58}
59
60impl_raw_param_via_sql_type_mapping!(
61 bool,
62 i32,
63 i64,
64 f64,
65 String,
66 Vec<u8>,
67 uuid::Uuid,
68 rust_decimal::Decimal,
69 chrono::NaiveDate,
70 chrono::NaiveDateTime,
71);
72
73impl RawParam for SqlValue {
74 fn into_sql_value(self) -> SqlValue {
75 self
76 }
77}
78
79impl RawParam for &str {
80 fn into_sql_value(self) -> SqlValue {
81 SqlValue::String(self.to_string())
82 }
83}
84
85impl<T> RawParam for Option<T>
86where
87 T: RawParam,
88{
89 fn into_sql_value(self) -> SqlValue {
90 self.map(RawParam::into_sql_value).unwrap_or(SqlValue::Null)
91 }
92}
93
94pub trait RawParams {
99 fn into_sql_values(self) -> Vec<SqlValue>;
101}
102
103impl RawParams for () {
104 fn into_sql_values(self) -> Vec<SqlValue> {
105 Vec::new()
106 }
107}
108
109impl<T> RawParams for Vec<T>
110where
111 T: RawParam,
112{
113 fn into_sql_values(self) -> Vec<SqlValue> {
114 self.into_iter().map(RawParam::into_sql_value).collect()
115 }
116}
117
118macro_rules! impl_raw_params_tuple {
119 ($($name:ident),+ $(,)?) => {
120 impl<$($name),+> RawParams for ($($name,)+)
121 where
122 $($name: RawParam),+
123 {
124 #[allow(non_snake_case)]
125 fn into_sql_values(self) -> Vec<SqlValue> {
126 let ($($name,)+) = self;
127 vec![$($name.into_sql_value()),+]
128 }
129 }
130 };
131}
132
133impl_raw_params_tuple!(A);
134impl_raw_params_tuple!(A, B);
135impl_raw_params_tuple!(A, B, C);
136impl_raw_params_tuple!(A, B, C, D);
137impl_raw_params_tuple!(A, B, C, D, E);
138impl_raw_params_tuple!(A, B, C, D, E, F);
139impl_raw_params_tuple!(A, B, C, D, E, F, G);
140impl_raw_params_tuple!(A, B, C, D, E, F, G, H);
141impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I);
142impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J);
143impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J, K);
144impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
145
146#[derive(Clone)]
147pub struct RawQuery<T> {
153 connection: SharedConnection,
154 sql: String,
155 params: Vec<SqlValue>,
156 query_hints: BTreeSet<QueryHint>,
157 _row: PhantomData<fn() -> T>,
158}
159
160impl<T> RawQuery<T>
161where
162 T: FromRow + Send,
163{
164 pub(crate) fn new(connection: SharedConnection, sql: impl Into<String>) -> Self {
165 Self {
166 connection,
167 sql: sql.into(),
168 params: Vec::new(),
169 query_hints: BTreeSet::new(),
170 _row: PhantomData,
171 }
172 }
173
174 pub fn param<P>(mut self, value: P) -> Self
176 where
177 P: RawParam,
178 {
179 self.params.push(value.into_sql_value());
180 self
181 }
182
183 pub fn params<P>(mut self, values: P) -> Self
185 where
186 P: RawParams,
187 {
188 self.params.extend(values.into_sql_values());
189 self
190 }
191
192 pub fn query_hint(mut self, hint: QueryHint) -> Self {
196 self.query_hints.insert(hint);
197 self
198 }
199
200 pub async fn all(self) -> Result<Vec<T>, OrmError> {
202 let compiled = self.compiled_query()?;
203 let mut connection = self.connection.lock().await?;
204 connection.fetch_all(compiled).await
205 }
206
207 pub async fn first(self) -> Result<Option<T>, OrmError> {
209 let compiled = self.compiled_query()?;
210 let mut connection = self.connection.lock().await?;
211 connection.fetch_one(compiled).await
212 }
213
214 fn compiled_query(&self) -> Result<CompiledQuery, OrmError> {
215 compiled_raw_query_with_hints(&self.sql, self.params.clone(), &self.query_hints)
216 }
217}
218
219#[derive(Clone)]
220pub struct RawCommand {
226 connection: SharedConnection,
227 sql: String,
228 params: Vec<SqlValue>,
229}
230
231impl RawCommand {
232 pub(crate) fn new(connection: SharedConnection, sql: impl Into<String>) -> Self {
233 Self {
234 connection,
235 sql: sql.into(),
236 params: Vec::new(),
237 }
238 }
239
240 pub fn param<P>(mut self, value: P) -> Self
242 where
243 P: RawParam,
244 {
245 self.params.push(value.into_sql_value());
246 self
247 }
248
249 pub fn params<P>(mut self, values: P) -> Self
251 where
252 P: RawParams,
253 {
254 self.params.extend(values.into_sql_values());
255 self
256 }
257
258 pub async fn execute(self) -> Result<ExecuteResult, OrmError> {
260 let compiled = self.compiled_query()?;
261 let mut connection = self.connection.lock().await?;
262 connection.execute(compiled).await
263 }
264
265 fn compiled_query(&self) -> Result<CompiledQuery, OrmError> {
266 compiled_raw_query(&self.sql, self.params.clone())
267 }
268}
269
270fn compiled_raw_query(sql: &str, params: Vec<SqlValue>) -> Result<CompiledQuery, OrmError> {
271 compiled_raw_query_with_hints(sql, params, &BTreeSet::new())
272}
273
274fn compiled_raw_query_with_hints(
275 sql: &str,
276 params: Vec<SqlValue>,
277 query_hints: &BTreeSet<QueryHint>,
278) -> Result<CompiledQuery, OrmError> {
279 validate_raw_sql_parameters(sql, params.len())?;
280
281 let sql = render_raw_sql_with_hints(sql, query_hints)?;
282
283 Ok(CompiledQuery::new(sql, params))
284}
285
286fn render_raw_sql_with_hints(
287 sql: &str,
288 query_hints: &BTreeSet<QueryHint>,
289) -> Result<String, OrmError> {
290 if query_hints.is_empty() {
291 return Ok(sql.to_string());
292 }
293
294 if contains_top_level_option_clause(sql) {
295 return Err(OrmError::new(
296 "raw SQL already contains OPTION (...); remove it before using query_hint(...)",
297 ));
298 }
299
300 let mut sql = sql.trim_end().trim_end_matches(';').trim_end().to_string();
301 let hints = query_hints
302 .iter()
303 .copied()
304 .map(QueryHint::sql)
305 .collect::<Vec<_>>()
306 .join(", ");
307
308 sql.push_str(" OPTION (");
309 sql.push_str(&hints);
310 sql.push(')');
311
312 Ok(sql)
313}
314
315pub(crate) fn validate_raw_sql_parameters(sql: &str, param_count: usize) -> Result<(), OrmError> {
316 let plan = analyze_placeholders(sql)?;
317
318 if plan.expected_param_count() != param_count {
319 return Err(OrmError::new(format!(
320 "raw SQL parameter count mismatch: SQL expects {} parameter(s), received {}",
321 plan.expected_param_count(),
322 param_count
323 )));
324 }
325
326 Ok(())
327}
328
329fn analyze_placeholders(sql: &str) -> Result<RawPlaceholderPlan, OrmError> {
330 let bytes = sql.as_bytes();
331 let mut index = 0;
332 let mut placeholders = BTreeSet::new();
333
334 while index + 2 < bytes.len() {
335 if bytes[index] == b'@' && bytes[index + 1] == b'P' && bytes[index + 2].is_ascii_digit() {
336 index += 2;
337 let start = index;
338
339 while index < bytes.len() && bytes[index].is_ascii_digit() {
340 index += 1;
341 }
342
343 let raw_index = sql[start..index]
344 .parse::<usize>()
345 .map_err(|_| OrmError::new("raw SQL placeholder index is larger than supported"))?;
346
347 if raw_index == 0 {
348 return Err(OrmError::new("raw SQL placeholders must start at @P1"));
349 }
350
351 placeholders.insert(raw_index);
352 continue;
353 }
354
355 index += 1;
356 }
357
358 let max_index = placeholders.iter().next_back().copied().unwrap_or(0);
359 for expected in 1..=max_index {
360 if !placeholders.contains(&expected) {
361 return Err(OrmError::new(format!(
362 "raw SQL placeholders must be continuous from @P1 to @P{}",
363 max_index
364 )));
365 }
366 }
367
368 Ok(RawPlaceholderPlan { max_index })
369}
370
371fn contains_top_level_option_clause(sql: &str) -> bool {
372 let bytes = sql.as_bytes();
373 let mut index = 0;
374 let mut depth = 0_i32;
375
376 while index < bytes.len() {
377 match bytes[index] {
378 b'\'' => {
379 index += 1;
380 while index < bytes.len() {
381 if bytes[index] == b'\'' {
382 index += 1;
383 if index < bytes.len() && bytes[index] == b'\'' {
384 index += 1;
385 continue;
386 }
387 break;
388 }
389 index += 1;
390 }
391 }
392 b'[' => {
393 index += 1;
394 while index < bytes.len() {
395 if bytes[index] == b']' {
396 index += 1;
397 break;
398 }
399 index += 1;
400 }
401 }
402 b'-' if index + 1 < bytes.len() && bytes[index + 1] == b'-' => {
403 index += 2;
404 while index < bytes.len() && !matches!(bytes[index], b'\n' | b'\r') {
405 index += 1;
406 }
407 }
408 b'/' if index + 1 < bytes.len() && bytes[index + 1] == b'*' => {
409 index += 2;
410 while index + 1 < bytes.len() {
411 if bytes[index] == b'*' && bytes[index + 1] == b'/' {
412 index += 2;
413 break;
414 }
415 index += 1;
416 }
417 }
418 b'(' => {
419 depth += 1;
420 index += 1;
421 }
422 b')' => {
423 depth = depth.saturating_sub(1);
424 index += 1;
425 }
426 _ if depth == 0 && starts_with_keyword(sql, index, "OPTION") => {
427 let after_keyword = index + "OPTION".len();
428 let mut cursor = after_keyword;
429 while cursor < bytes.len() && bytes[cursor].is_ascii_whitespace() {
430 cursor += 1;
431 }
432 return cursor < bytes.len() && bytes[cursor] == b'(';
433 }
434 _ => index += 1,
435 }
436 }
437
438 false
439}
440
441fn starts_with_keyword(sql: &str, index: usize, keyword: &str) -> bool {
442 let bytes = sql.as_bytes();
443 let keyword_bytes = keyword.as_bytes();
444
445 if index + keyword_bytes.len() > bytes.len() {
446 return false;
447 }
448
449 if !bytes[index..index + keyword_bytes.len()].eq_ignore_ascii_case(keyword_bytes) {
450 return false;
451 }
452
453 let before_is_boundary = index == 0 || !is_identifier_byte(bytes[index - 1]);
454 let after = index + keyword_bytes.len();
455 let after_is_boundary = after == bytes.len() || !is_identifier_byte(bytes[after]);
456
457 before_is_boundary && after_is_boundary
458}
459
460fn is_identifier_byte(byte: u8) -> bool {
461 byte.is_ascii_alphanumeric() || byte == b'_'
462}
463
464#[cfg(test)]
465mod tests {
466 use super::{
467 QueryHint, RawParam, RawParams, compiled_raw_query, compiled_raw_query_with_hints,
468 contains_top_level_option_clause, validate_raw_sql_parameters,
469 };
470 use chrono::NaiveDate;
471 use rust_decimal::Decimal;
472 use sql_orm_core::SqlValue;
473 use std::collections::BTreeSet;
474 use uuid::Uuid;
475
476 #[test]
477 fn validates_continuous_placeholders_by_max_index() {
478 validate_raw_sql_parameters("SELECT @P1, @P2, @P3", 3).unwrap();
479 }
480
481 #[test]
482 fn validates_continuous_placeholders_through_highest_index() {
483 validate_raw_sql_parameters(
484 "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11, @P12",
485 12,
486 )
487 .unwrap();
488 }
489
490 #[test]
491 fn allows_repeated_placeholder_to_reuse_one_param() {
492 validate_raw_sql_parameters("SELECT @P1 WHERE owner_id = @P1", 1).unwrap();
493 }
494
495 #[test]
496 fn rejects_extra_params_without_placeholders() {
497 let error = validate_raw_sql_parameters("SELECT 1", 1).unwrap_err();
498
499 assert!(error.message().contains("expects 0 parameter"));
500 }
501
502 #[test]
503 fn rejects_missing_params() {
504 let error = validate_raw_sql_parameters("SELECT @P1, @P2", 1).unwrap_err();
505
506 assert!(error.message().contains("expects 2 parameter"));
507 }
508
509 #[test]
510 fn rejects_non_continuous_placeholders() {
511 let error = validate_raw_sql_parameters("SELECT @P1, @P3", 2).unwrap_err();
512
513 assert!(error.message().contains("continuous from @P1 to @P3"));
514 }
515
516 #[test]
517 fn rejects_zero_index_placeholder() {
518 let error = validate_raw_sql_parameters("SELECT @P0", 0).unwrap_err();
519
520 assert!(error.message().contains("start at @P1"));
521 }
522
523 #[test]
524 fn raw_params_tuple_preserves_order_and_values() {
525 let values = (
526 true,
527 7_i32,
528 9_i64,
529 3.5_f64,
530 "draft",
531 String::from("owned"),
532 vec![1_u8, 2],
533 Uuid::nil(),
534 Decimal::new(1234, 2),
535 NaiveDate::from_ymd_opt(2026, 4, 26).unwrap(),
536 NaiveDate::from_ymd_opt(2026, 4, 26)
537 .unwrap()
538 .and_hms_opt(10, 20, 30)
539 .unwrap(),
540 SqlValue::Null,
541 )
542 .into_sql_values();
543
544 assert_eq!(
545 values,
546 vec![
547 SqlValue::Bool(true),
548 SqlValue::I32(7),
549 SqlValue::I64(9),
550 SqlValue::F64(3.5),
551 SqlValue::String("draft".to_string()),
552 SqlValue::String("owned".to_string()),
553 SqlValue::Bytes(vec![1, 2]),
554 SqlValue::Uuid(Uuid::nil()),
555 SqlValue::Decimal(Decimal::new(1234, 2)),
556 SqlValue::Date(NaiveDate::from_ymd_opt(2026, 4, 26).unwrap()),
557 SqlValue::DateTime(
558 NaiveDate::from_ymd_opt(2026, 4, 26)
559 .unwrap()
560 .and_hms_opt(10, 20, 30)
561 .unwrap()
562 ),
563 SqlValue::Null,
564 ]
565 );
566 }
567
568 #[test]
569 fn raw_param_option_none_maps_to_null() {
570 assert_eq!(Option::<i64>::None.into_sql_value(), SqlValue::Null);
571 }
572
573 #[test]
574 fn raw_param_option_some_maps_inner_value() {
575 assert_eq!(Some(42_i64).into_sql_value(), SqlValue::I64(42));
576 }
577
578 #[test]
579 fn raw_params_vec_preserves_order() {
580 let values = vec![1_i64, 2_i64, 3_i64].into_sql_values();
581
582 assert_eq!(
583 values,
584 vec![SqlValue::I64(1), SqlValue::I64(2), SqlValue::I64(3)]
585 );
586 }
587
588 #[test]
589 fn raw_params_unit_maps_to_empty_params() {
590 assert_eq!(().into_sql_values(), Vec::<SqlValue>::new());
591 }
592
593 #[test]
594 fn compiled_raw_query_preserves_sql_and_parameter_order() {
595 let params = (
596 SqlValue::Null,
597 true,
598 7_i32,
599 9_i64,
600 3.5_f64,
601 "draft",
602 vec![1_u8, 2],
603 Uuid::nil(),
604 Decimal::new(1234, 2),
605 NaiveDate::from_ymd_opt(2026, 4, 26).unwrap(),
606 NaiveDate::from_ymd_opt(2026, 4, 26)
607 .unwrap()
608 .and_hms_opt(10, 20, 30)
609 .unwrap(),
610 )
611 .into_sql_values();
612
613 let compiled = compiled_raw_query(
614 "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11",
615 params,
616 )
617 .unwrap();
618
619 assert_eq!(
620 compiled.sql,
621 "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11"
622 );
623 assert_eq!(
624 compiled.params,
625 vec![
626 SqlValue::Null,
627 SqlValue::Bool(true),
628 SqlValue::I32(7),
629 SqlValue::I64(9),
630 SqlValue::F64(3.5),
631 SqlValue::String("draft".to_string()),
632 SqlValue::Bytes(vec![1, 2]),
633 SqlValue::Uuid(Uuid::nil()),
634 SqlValue::Decimal(Decimal::new(1234, 2)),
635 SqlValue::Date(NaiveDate::from_ymd_opt(2026, 4, 26).unwrap()),
636 SqlValue::DateTime(
637 NaiveDate::from_ymd_opt(2026, 4, 26)
638 .unwrap()
639 .and_hms_opt(10, 20, 30)
640 .unwrap()
641 ),
642 ]
643 );
644 }
645
646 #[test]
647 fn compiled_raw_query_allows_repeated_placeholder_with_single_param() {
648 let compiled = compiled_raw_query(
649 "SELECT * FROM users WHERE owner_id = @P1 OR reviewer_id = @P1",
650 vec![SqlValue::I64(42)],
651 )
652 .unwrap();
653
654 assert_eq!(compiled.params, vec![SqlValue::I64(42)]);
655 }
656
657 #[test]
658 fn compiled_raw_query_appends_recompile_hint_after_parameters() {
659 let hints = BTreeSet::from([QueryHint::Recompile]);
660 let compiled = compiled_raw_query_with_hints(
661 "SELECT * FROM users WHERE owner_id = @P1",
662 vec![SqlValue::I64(42)],
663 &hints,
664 )
665 .unwrap();
666
667 assert_eq!(
668 compiled.sql,
669 "SELECT * FROM users WHERE owner_id = @P1 OPTION (RECOMPILE)"
670 );
671 assert_eq!(compiled.params, vec![SqlValue::I64(42)]);
672 }
673
674 #[test]
675 fn compiled_raw_query_deduplicates_repeated_query_hints() {
676 let hints = BTreeSet::from([QueryHint::Recompile, QueryHint::Recompile]);
677 let compiled = compiled_raw_query_with_hints("SELECT 1", vec![], &hints).unwrap();
678
679 assert_eq!(compiled.sql, "SELECT 1 OPTION (RECOMPILE)");
680 }
681
682 #[test]
683 fn compiled_raw_query_places_hint_before_trailing_semicolon() {
684 let hints = BTreeSet::from([QueryHint::Recompile]);
685 let compiled = compiled_raw_query_with_hints("SELECT 1; ", vec![], &hints).unwrap();
686
687 assert_eq!(compiled.sql, "SELECT 1 OPTION (RECOMPILE)");
688 }
689
690 #[test]
691 fn compiled_raw_query_rejects_existing_top_level_option_clause_with_hints() {
692 let hints = BTreeSet::from([QueryHint::Recompile]);
693 let error = compiled_raw_query_with_hints("SELECT 1 OPTION (MAXDOP 1)", vec![], &hints)
694 .unwrap_err();
695
696 assert!(error.message().contains("already contains OPTION"));
697 }
698
699 #[test]
700 fn detects_top_level_option_clause_without_matching_strings_or_nested_queries() {
701 assert!(contains_top_level_option_clause(
702 "SELECT 1 OPTION (RECOMPILE)"
703 ));
704 assert!(!contains_top_level_option_clause(
705 "SELECT 'OPTION (RECOMPILE)' AS text_value"
706 ));
707 assert!(!contains_top_level_option_clause(
708 "SELECT * FROM (SELECT 1 OPTION (RECOMPILE)) AS nested"
709 ));
710 }
711
712 #[test]
713 fn compiled_raw_query_rejects_non_continuous_placeholders() {
714 let error = compiled_raw_query(
715 "SELECT * FROM users WHERE owner_id = @P1 OR reviewer_id = @P3",
716 vec![SqlValue::I64(42), SqlValue::I64(7)],
717 )
718 .unwrap_err();
719
720 assert!(error.message().contains("continuous from @P1 to @P3"));
721 }
722}