Skip to main content

sql_orm/
raw_sql.rs

1use crate::context::SharedConnection;
2use sql_orm_core::{FromRow, OrmError, SqlTypeMapping, SqlValue};
3use sql_orm_query::CompiledQuery;
4use sql_orm_tiberius::ExecuteResult;
5use std::collections::BTreeSet;
6use std::marker::PhantomData;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
9/// SQL Server query hints supported by typed raw queries.
10///
11/// Hints are rendered by the ORM at the end of the raw SQL as an
12/// `OPTION (...)` clause. Do not also write a top-level `OPTION (...)` clause
13/// manually in the SQL string.
14pub enum QueryHint {
15    /// Adds `OPTION (RECOMPILE)` to the query.
16    Recompile,
17}
18
19impl QueryHint {
20    const fn sql(self) -> &'static str {
21        match self {
22            Self::Recompile => "RECOMPILE",
23        }
24    }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28struct RawPlaceholderPlan {
29    max_index: usize,
30}
31
32impl RawPlaceholderPlan {
33    const fn expected_param_count(&self) -> usize {
34        self.max_index
35    }
36}
37
38/// Converts one raw SQL parameter into an ORM `SqlValue`.
39///
40/// Raw SQL placeholders are positional and must be written as `@P1`, `@P2`,
41/// and so on. Reusing the same placeholder index reuses the same parameter
42/// value.
43pub trait RawParam {
44    /// Converts this value into a SQL Server parameter value.
45    fn into_sql_value(self) -> SqlValue;
46}
47
48macro_rules! impl_raw_param_via_sql_type_mapping {
49    ($($ty:ty),+ $(,)?) => {
50        $(
51            impl RawParam for $ty {
52                fn into_sql_value(self) -> SqlValue {
53                    <Self as SqlTypeMapping>::to_sql_value(self)
54                }
55            }
56        )+
57    };
58}
59
60impl_raw_param_via_sql_type_mapping!(
61    bool,
62    i32,
63    i64,
64    f64,
65    String,
66    Vec<u8>,
67    uuid::Uuid,
68    rust_decimal::Decimal,
69    chrono::NaiveDate,
70    chrono::NaiveDateTime,
71);
72
73impl RawParam for SqlValue {
74    fn into_sql_value(self) -> SqlValue {
75        self
76    }
77}
78
79impl RawParam for &str {
80    fn into_sql_value(self) -> SqlValue {
81        SqlValue::String(self.to_string())
82    }
83}
84
85impl<T> RawParam for Option<T>
86where
87    T: RawParam,
88{
89    fn into_sql_value(self) -> SqlValue {
90        self.map(RawParam::into_sql_value).unwrap_or(SqlValue::Null)
91    }
92}
93
94/// Converts a parameter collection into raw SQL parameter values.
95///
96/// Implemented for `()`, `Vec<T>` where `T: RawParam`, and tuples up to 12
97/// elements.
98pub trait RawParams {
99    /// Converts this collection into positional SQL values.
100    fn into_sql_values(self) -> Vec<SqlValue>;
101}
102
103impl RawParams for () {
104    fn into_sql_values(self) -> Vec<SqlValue> {
105        Vec::new()
106    }
107}
108
109impl<T> RawParams for Vec<T>
110where
111    T: RawParam,
112{
113    fn into_sql_values(self) -> Vec<SqlValue> {
114        self.into_iter().map(RawParam::into_sql_value).collect()
115    }
116}
117
118macro_rules! impl_raw_params_tuple {
119    ($($name:ident),+ $(,)?) => {
120        impl<$($name),+> RawParams for ($($name,)+)
121        where
122            $($name: RawParam),+
123        {
124            #[allow(non_snake_case)]
125            fn into_sql_values(self) -> Vec<SqlValue> {
126                let ($($name,)+) = self;
127                vec![$($name.into_sql_value()),+]
128            }
129        }
130    };
131}
132
133impl_raw_params_tuple!(A);
134impl_raw_params_tuple!(A, B);
135impl_raw_params_tuple!(A, B, C);
136impl_raw_params_tuple!(A, B, C, D);
137impl_raw_params_tuple!(A, B, C, D, E);
138impl_raw_params_tuple!(A, B, C, D, E, F);
139impl_raw_params_tuple!(A, B, C, D, E, F, G);
140impl_raw_params_tuple!(A, B, C, D, E, F, G, H);
141impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I);
142impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J);
143impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J, K);
144impl_raw_params_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
145
146#[derive(Clone)]
147/// Typed raw SQL query that materializes rows as `T`.
148///
149/// `T` must implement `FromRow`. Raw SQL is executed exactly as written after
150/// parameter validation and optional query hint rendering; it does not apply
151/// ORM tenant or soft-delete filters automatically.
152pub struct RawQuery<T> {
153    connection: SharedConnection,
154    sql: String,
155    params: Vec<SqlValue>,
156    query_hints: BTreeSet<QueryHint>,
157    _row: PhantomData<fn() -> T>,
158}
159
160impl<T> RawQuery<T>
161where
162    T: FromRow + Send,
163{
164    pub(crate) fn new(connection: SharedConnection, sql: impl Into<String>) -> Self {
165        Self {
166            connection,
167            sql: sql.into(),
168            params: Vec::new(),
169            query_hints: BTreeSet::new(),
170            _row: PhantomData,
171        }
172    }
173
174    /// Appends one positional parameter.
175    pub fn param<P>(mut self, value: P) -> Self
176    where
177        P: RawParam,
178    {
179        self.params.push(value.into_sql_value());
180        self
181    }
182
183    /// Appends multiple positional parameters.
184    pub fn params<P>(mut self, values: P) -> Self
185    where
186        P: RawParams,
187    {
188        self.params.extend(values.into_sql_values());
189        self
190    }
191
192    /// Adds a SQL Server query hint to render at the end of the raw query.
193    ///
194    /// Repeated hints are deduplicated with stable ordering.
195    pub fn query_hint(mut self, hint: QueryHint) -> Self {
196        self.query_hints.insert(hint);
197        self
198    }
199
200    /// Executes the query and materializes all rows.
201    pub async fn all(self) -> Result<Vec<T>, OrmError> {
202        let compiled = self.compiled_query()?;
203        let mut connection = self.connection.lock().await?;
204        connection.fetch_all(compiled).await
205    }
206
207    /// Executes the query and materializes the first row, if any.
208    pub async fn first(self) -> Result<Option<T>, OrmError> {
209        let compiled = self.compiled_query()?;
210        let mut connection = self.connection.lock().await?;
211        connection.fetch_one(compiled).await
212    }
213
214    fn compiled_query(&self) -> Result<CompiledQuery, OrmError> {
215        compiled_raw_query_with_hints(&self.sql, self.params.clone(), &self.query_hints)
216    }
217}
218
219#[derive(Clone)]
220/// Raw SQL command for statements that do not materialize rows.
221///
222/// Use this for `INSERT`, `UPDATE`, `DELETE`, DDL, or stored procedure calls
223/// where the caller only needs the execution result. Query hints are supported
224/// only on `RawQuery<T>`.
225pub struct RawCommand {
226    connection: SharedConnection,
227    sql: String,
228    params: Vec<SqlValue>,
229}
230
231impl RawCommand {
232    pub(crate) fn new(connection: SharedConnection, sql: impl Into<String>) -> Self {
233        Self {
234            connection,
235            sql: sql.into(),
236            params: Vec::new(),
237        }
238    }
239
240    /// Appends one positional parameter.
241    pub fn param<P>(mut self, value: P) -> Self
242    where
243        P: RawParam,
244    {
245        self.params.push(value.into_sql_value());
246        self
247    }
248
249    /// Appends multiple positional parameters.
250    pub fn params<P>(mut self, values: P) -> Self
251    where
252        P: RawParams,
253    {
254        self.params.extend(values.into_sql_values());
255        self
256    }
257
258    /// Executes the command and returns affected-row information.
259    pub async fn execute(self) -> Result<ExecuteResult, OrmError> {
260        let compiled = self.compiled_query()?;
261        let mut connection = self.connection.lock().await?;
262        connection.execute(compiled).await
263    }
264
265    fn compiled_query(&self) -> Result<CompiledQuery, OrmError> {
266        compiled_raw_query(&self.sql, self.params.clone())
267    }
268}
269
270fn compiled_raw_query(sql: &str, params: Vec<SqlValue>) -> Result<CompiledQuery, OrmError> {
271    compiled_raw_query_with_hints(sql, params, &BTreeSet::new())
272}
273
274fn compiled_raw_query_with_hints(
275    sql: &str,
276    params: Vec<SqlValue>,
277    query_hints: &BTreeSet<QueryHint>,
278) -> Result<CompiledQuery, OrmError> {
279    validate_raw_sql_parameters(sql, params.len())?;
280
281    let sql = render_raw_sql_with_hints(sql, query_hints)?;
282
283    Ok(CompiledQuery::new(sql, params))
284}
285
286fn render_raw_sql_with_hints(
287    sql: &str,
288    query_hints: &BTreeSet<QueryHint>,
289) -> Result<String, OrmError> {
290    if query_hints.is_empty() {
291        return Ok(sql.to_string());
292    }
293
294    if contains_top_level_option_clause(sql) {
295        return Err(OrmError::new(
296            "raw SQL already contains OPTION (...); remove it before using query_hint(...)",
297        ));
298    }
299
300    let mut sql = sql.trim_end().trim_end_matches(';').trim_end().to_string();
301    let hints = query_hints
302        .iter()
303        .copied()
304        .map(QueryHint::sql)
305        .collect::<Vec<_>>()
306        .join(", ");
307
308    sql.push_str(" OPTION (");
309    sql.push_str(&hints);
310    sql.push(')');
311
312    Ok(sql)
313}
314
315pub(crate) fn validate_raw_sql_parameters(sql: &str, param_count: usize) -> Result<(), OrmError> {
316    let plan = analyze_placeholders(sql)?;
317
318    if plan.expected_param_count() != param_count {
319        return Err(OrmError::new(format!(
320            "raw SQL parameter count mismatch: SQL expects {} parameter(s), received {}",
321            plan.expected_param_count(),
322            param_count
323        )));
324    }
325
326    Ok(())
327}
328
329fn analyze_placeholders(sql: &str) -> Result<RawPlaceholderPlan, OrmError> {
330    let bytes = sql.as_bytes();
331    let mut index = 0;
332    let mut placeholders = BTreeSet::new();
333
334    while index + 2 < bytes.len() {
335        if bytes[index] == b'@' && bytes[index + 1] == b'P' && bytes[index + 2].is_ascii_digit() {
336            index += 2;
337            let start = index;
338
339            while index < bytes.len() && bytes[index].is_ascii_digit() {
340                index += 1;
341            }
342
343            let raw_index = sql[start..index]
344                .parse::<usize>()
345                .map_err(|_| OrmError::new("raw SQL placeholder index is larger than supported"))?;
346
347            if raw_index == 0 {
348                return Err(OrmError::new("raw SQL placeholders must start at @P1"));
349            }
350
351            placeholders.insert(raw_index);
352            continue;
353        }
354
355        index += 1;
356    }
357
358    let max_index = placeholders.iter().next_back().copied().unwrap_or(0);
359    for expected in 1..=max_index {
360        if !placeholders.contains(&expected) {
361            return Err(OrmError::new(format!(
362                "raw SQL placeholders must be continuous from @P1 to @P{}",
363                max_index
364            )));
365        }
366    }
367
368    Ok(RawPlaceholderPlan { max_index })
369}
370
371fn contains_top_level_option_clause(sql: &str) -> bool {
372    let bytes = sql.as_bytes();
373    let mut index = 0;
374    let mut depth = 0_i32;
375
376    while index < bytes.len() {
377        match bytes[index] {
378            b'\'' => {
379                index += 1;
380                while index < bytes.len() {
381                    if bytes[index] == b'\'' {
382                        index += 1;
383                        if index < bytes.len() && bytes[index] == b'\'' {
384                            index += 1;
385                            continue;
386                        }
387                        break;
388                    }
389                    index += 1;
390                }
391            }
392            b'[' => {
393                index += 1;
394                while index < bytes.len() {
395                    if bytes[index] == b']' {
396                        index += 1;
397                        break;
398                    }
399                    index += 1;
400                }
401            }
402            b'-' if index + 1 < bytes.len() && bytes[index + 1] == b'-' => {
403                index += 2;
404                while index < bytes.len() && !matches!(bytes[index], b'\n' | b'\r') {
405                    index += 1;
406                }
407            }
408            b'/' if index + 1 < bytes.len() && bytes[index + 1] == b'*' => {
409                index += 2;
410                while index + 1 < bytes.len() {
411                    if bytes[index] == b'*' && bytes[index + 1] == b'/' {
412                        index += 2;
413                        break;
414                    }
415                    index += 1;
416                }
417            }
418            b'(' => {
419                depth += 1;
420                index += 1;
421            }
422            b')' => {
423                depth = depth.saturating_sub(1);
424                index += 1;
425            }
426            _ if depth == 0 && starts_with_keyword(sql, index, "OPTION") => {
427                let after_keyword = index + "OPTION".len();
428                let mut cursor = after_keyword;
429                while cursor < bytes.len() && bytes[cursor].is_ascii_whitespace() {
430                    cursor += 1;
431                }
432                return cursor < bytes.len() && bytes[cursor] == b'(';
433            }
434            _ => index += 1,
435        }
436    }
437
438    false
439}
440
441fn starts_with_keyword(sql: &str, index: usize, keyword: &str) -> bool {
442    let bytes = sql.as_bytes();
443    let keyword_bytes = keyword.as_bytes();
444
445    if index + keyword_bytes.len() > bytes.len() {
446        return false;
447    }
448
449    if !bytes[index..index + keyword_bytes.len()].eq_ignore_ascii_case(keyword_bytes) {
450        return false;
451    }
452
453    let before_is_boundary = index == 0 || !is_identifier_byte(bytes[index - 1]);
454    let after = index + keyword_bytes.len();
455    let after_is_boundary = after == bytes.len() || !is_identifier_byte(bytes[after]);
456
457    before_is_boundary && after_is_boundary
458}
459
460fn is_identifier_byte(byte: u8) -> bool {
461    byte.is_ascii_alphanumeric() || byte == b'_'
462}
463
464#[cfg(test)]
465mod tests {
466    use super::{
467        QueryHint, RawParam, RawParams, compiled_raw_query, compiled_raw_query_with_hints,
468        contains_top_level_option_clause, validate_raw_sql_parameters,
469    };
470    use chrono::NaiveDate;
471    use rust_decimal::Decimal;
472    use sql_orm_core::SqlValue;
473    use std::collections::BTreeSet;
474    use uuid::Uuid;
475
476    #[test]
477    fn validates_continuous_placeholders_by_max_index() {
478        validate_raw_sql_parameters("SELECT @P1, @P2, @P3", 3).unwrap();
479    }
480
481    #[test]
482    fn validates_continuous_placeholders_through_highest_index() {
483        validate_raw_sql_parameters(
484            "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11, @P12",
485            12,
486        )
487        .unwrap();
488    }
489
490    #[test]
491    fn allows_repeated_placeholder_to_reuse_one_param() {
492        validate_raw_sql_parameters("SELECT @P1 WHERE owner_id = @P1", 1).unwrap();
493    }
494
495    #[test]
496    fn rejects_extra_params_without_placeholders() {
497        let error = validate_raw_sql_parameters("SELECT 1", 1).unwrap_err();
498
499        assert!(error.message().contains("expects 0 parameter"));
500    }
501
502    #[test]
503    fn rejects_missing_params() {
504        let error = validate_raw_sql_parameters("SELECT @P1, @P2", 1).unwrap_err();
505
506        assert!(error.message().contains("expects 2 parameter"));
507    }
508
509    #[test]
510    fn rejects_non_continuous_placeholders() {
511        let error = validate_raw_sql_parameters("SELECT @P1, @P3", 2).unwrap_err();
512
513        assert!(error.message().contains("continuous from @P1 to @P3"));
514    }
515
516    #[test]
517    fn rejects_zero_index_placeholder() {
518        let error = validate_raw_sql_parameters("SELECT @P0", 0).unwrap_err();
519
520        assert!(error.message().contains("start at @P1"));
521    }
522
523    #[test]
524    fn raw_params_tuple_preserves_order_and_values() {
525        let values = (
526            true,
527            7_i32,
528            9_i64,
529            3.5_f64,
530            "draft",
531            String::from("owned"),
532            vec![1_u8, 2],
533            Uuid::nil(),
534            Decimal::new(1234, 2),
535            NaiveDate::from_ymd_opt(2026, 4, 26).unwrap(),
536            NaiveDate::from_ymd_opt(2026, 4, 26)
537                .unwrap()
538                .and_hms_opt(10, 20, 30)
539                .unwrap(),
540            SqlValue::Null,
541        )
542            .into_sql_values();
543
544        assert_eq!(
545            values,
546            vec![
547                SqlValue::Bool(true),
548                SqlValue::I32(7),
549                SqlValue::I64(9),
550                SqlValue::F64(3.5),
551                SqlValue::String("draft".to_string()),
552                SqlValue::String("owned".to_string()),
553                SqlValue::Bytes(vec![1, 2]),
554                SqlValue::Uuid(Uuid::nil()),
555                SqlValue::Decimal(Decimal::new(1234, 2)),
556                SqlValue::Date(NaiveDate::from_ymd_opt(2026, 4, 26).unwrap()),
557                SqlValue::DateTime(
558                    NaiveDate::from_ymd_opt(2026, 4, 26)
559                        .unwrap()
560                        .and_hms_opt(10, 20, 30)
561                        .unwrap()
562                ),
563                SqlValue::Null,
564            ]
565        );
566    }
567
568    #[test]
569    fn raw_param_option_none_maps_to_null() {
570        assert_eq!(Option::<i64>::None.into_sql_value(), SqlValue::Null);
571    }
572
573    #[test]
574    fn raw_param_option_some_maps_inner_value() {
575        assert_eq!(Some(42_i64).into_sql_value(), SqlValue::I64(42));
576    }
577
578    #[test]
579    fn raw_params_vec_preserves_order() {
580        let values = vec![1_i64, 2_i64, 3_i64].into_sql_values();
581
582        assert_eq!(
583            values,
584            vec![SqlValue::I64(1), SqlValue::I64(2), SqlValue::I64(3)]
585        );
586    }
587
588    #[test]
589    fn raw_params_unit_maps_to_empty_params() {
590        assert_eq!(().into_sql_values(), Vec::<SqlValue>::new());
591    }
592
593    #[test]
594    fn compiled_raw_query_preserves_sql_and_parameter_order() {
595        let params = (
596            SqlValue::Null,
597            true,
598            7_i32,
599            9_i64,
600            3.5_f64,
601            "draft",
602            vec![1_u8, 2],
603            Uuid::nil(),
604            Decimal::new(1234, 2),
605            NaiveDate::from_ymd_opt(2026, 4, 26).unwrap(),
606            NaiveDate::from_ymd_opt(2026, 4, 26)
607                .unwrap()
608                .and_hms_opt(10, 20, 30)
609                .unwrap(),
610        )
611            .into_sql_values();
612
613        let compiled = compiled_raw_query(
614            "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11",
615            params,
616        )
617        .unwrap();
618
619        assert_eq!(
620            compiled.sql,
621            "SELECT @P1, @P2, @P3, @P4, @P5, @P6, @P7, @P8, @P9, @P10, @P11"
622        );
623        assert_eq!(
624            compiled.params,
625            vec![
626                SqlValue::Null,
627                SqlValue::Bool(true),
628                SqlValue::I32(7),
629                SqlValue::I64(9),
630                SqlValue::F64(3.5),
631                SqlValue::String("draft".to_string()),
632                SqlValue::Bytes(vec![1, 2]),
633                SqlValue::Uuid(Uuid::nil()),
634                SqlValue::Decimal(Decimal::new(1234, 2)),
635                SqlValue::Date(NaiveDate::from_ymd_opt(2026, 4, 26).unwrap()),
636                SqlValue::DateTime(
637                    NaiveDate::from_ymd_opt(2026, 4, 26)
638                        .unwrap()
639                        .and_hms_opt(10, 20, 30)
640                        .unwrap()
641                ),
642            ]
643        );
644    }
645
646    #[test]
647    fn compiled_raw_query_allows_repeated_placeholder_with_single_param() {
648        let compiled = compiled_raw_query(
649            "SELECT * FROM users WHERE owner_id = @P1 OR reviewer_id = @P1",
650            vec![SqlValue::I64(42)],
651        )
652        .unwrap();
653
654        assert_eq!(compiled.params, vec![SqlValue::I64(42)]);
655    }
656
657    #[test]
658    fn compiled_raw_query_appends_recompile_hint_after_parameters() {
659        let hints = BTreeSet::from([QueryHint::Recompile]);
660        let compiled = compiled_raw_query_with_hints(
661            "SELECT * FROM users WHERE owner_id = @P1",
662            vec![SqlValue::I64(42)],
663            &hints,
664        )
665        .unwrap();
666
667        assert_eq!(
668            compiled.sql,
669            "SELECT * FROM users WHERE owner_id = @P1 OPTION (RECOMPILE)"
670        );
671        assert_eq!(compiled.params, vec![SqlValue::I64(42)]);
672    }
673
674    #[test]
675    fn compiled_raw_query_deduplicates_repeated_query_hints() {
676        let hints = BTreeSet::from([QueryHint::Recompile, QueryHint::Recompile]);
677        let compiled = compiled_raw_query_with_hints("SELECT 1", vec![], &hints).unwrap();
678
679        assert_eq!(compiled.sql, "SELECT 1 OPTION (RECOMPILE)");
680    }
681
682    #[test]
683    fn compiled_raw_query_places_hint_before_trailing_semicolon() {
684        let hints = BTreeSet::from([QueryHint::Recompile]);
685        let compiled = compiled_raw_query_with_hints("SELECT 1;   ", vec![], &hints).unwrap();
686
687        assert_eq!(compiled.sql, "SELECT 1 OPTION (RECOMPILE)");
688    }
689
690    #[test]
691    fn compiled_raw_query_rejects_existing_top_level_option_clause_with_hints() {
692        let hints = BTreeSet::from([QueryHint::Recompile]);
693        let error = compiled_raw_query_with_hints("SELECT 1 OPTION (MAXDOP 1)", vec![], &hints)
694            .unwrap_err();
695
696        assert!(error.message().contains("already contains OPTION"));
697    }
698
699    #[test]
700    fn detects_top_level_option_clause_without_matching_strings_or_nested_queries() {
701        assert!(contains_top_level_option_clause(
702            "SELECT 1 OPTION (RECOMPILE)"
703        ));
704        assert!(!contains_top_level_option_clause(
705            "SELECT 'OPTION (RECOMPILE)' AS text_value"
706        ));
707        assert!(!contains_top_level_option_clause(
708            "SELECT * FROM (SELECT 1 OPTION (RECOMPILE)) AS nested"
709        ));
710    }
711
712    #[test]
713    fn compiled_raw_query_rejects_non_continuous_placeholders() {
714        let error = compiled_raw_query(
715            "SELECT * FROM users WHERE owner_id = @P1 OR reviewer_id = @P3",
716            vec![SqlValue::I64(42), SqlValue::I64(7)],
717        )
718        .unwrap_err();
719
720        assert!(error.message().contains("continuous from @P1 to @P3"));
721    }
722}