Skip to main content

sql_orm_tiberius/
executor.rs

1use crate::config::{MssqlRetryOptions, MssqlSlowQueryOptions, MssqlTracingOptions};
2use crate::connection::{MssqlConnection, run_with_timeout};
3use crate::error::{TiberiusErrorContext, is_transient_tiberius_error, map_tiberius_error};
4use crate::parameter::PreparedQuery;
5use crate::row::MssqlRow;
6use crate::telemetry::{QueryTrace, classify_sql, trace_query};
7use crate::transaction::MssqlTransaction;
8use async_trait::async_trait;
9use futures_io::{AsyncRead, AsyncWrite};
10use sql_orm_core::{FromRow, OrmError};
11use sql_orm_query::CompiledQuery;
12use std::time::Duration;
13use tiberius::Client;
14use tiberius::QueryStream;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct ExecuteResult {
18    rows_affected: Vec<u64>,
19}
20
21impl ExecuteResult {
22    pub fn new(rows_affected: Vec<u64>) -> Self {
23        Self { rows_affected }
24    }
25
26    pub fn rows_affected(&self) -> &[u64] {
27        &self.rows_affected
28    }
29
30    pub fn total(&self) -> u64 {
31        self.rows_affected.iter().sum()
32    }
33}
34
35#[async_trait]
36pub trait Executor {
37    async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError>;
38    async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
39    where
40        T: FromRow + Send;
41    async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
42    where
43        T: FromRow + Send;
44}
45
46#[async_trait]
47impl<S> Executor for MssqlConnection<S>
48where
49    S: AsyncRead + AsyncWrite + Unpin + Send,
50{
51    async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
52        MssqlConnection::execute(self, query).await
53    }
54
55    async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
56    where
57        T: FromRow + Send,
58    {
59        MssqlConnection::fetch_one(self, query).await
60    }
61
62    async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
63    where
64        T: FromRow + Send,
65    {
66        MssqlConnection::fetch_all(self, query).await
67    }
68}
69
70#[async_trait]
71impl<S> Executor for MssqlTransaction<'_, S>
72where
73    S: AsyncRead + AsyncWrite + Unpin + Send,
74{
75    async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
76        MssqlTransaction::execute(self, query).await
77    }
78
79    async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
80    where
81        T: FromRow + Send,
82    {
83        MssqlTransaction::fetch_one(self, query).await
84    }
85
86    async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
87    where
88        T: FromRow + Send,
89    {
90        MssqlTransaction::fetch_all(self, query).await
91    }
92}
93
94impl<S> MssqlConnection<S>
95where
96    S: AsyncRead + AsyncWrite + Unpin + Send,
97{
98    pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
99        let tracing_options = self.tracing_options();
100        let slow_query_options = self.slow_query_options();
101        let server_addr = self.server_addr();
102        let query_timeout = self.query_timeout();
103        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
104            execute_compiled(
105                self.client_mut(),
106                query,
107                tracing_options,
108                slow_query_options,
109                &server_addr,
110                query_timeout,
111            )
112            .await
113        })
114        .await
115    }
116
117    pub async fn query_raw<'a>(
118        &'a mut self,
119        query: CompiledQuery,
120    ) -> Result<QueryStream<'a>, OrmError> {
121        let tracing_options = self.tracing_options();
122        let slow_query_options = self.slow_query_options();
123        let server_addr = self.server_addr();
124        let query_timeout = self.query_timeout();
125        run_with_timeout(query_timeout, "SQL Server query timed out", async {
126            query_raw_compiled(
127                self.client_mut(),
128                query,
129                tracing_options,
130                slow_query_options,
131                &server_addr,
132                query_timeout,
133            )
134            .await
135        })
136        .await
137    }
138
139    pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
140    where
141        T: FromRow + Send,
142    {
143        let tracing_options = self.tracing_options();
144        let slow_query_options = self.slow_query_options();
145        let retry_options = self.retry_options();
146        let server_addr = self.server_addr();
147        let query_timeout = self.query_timeout();
148        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
149            fetch_one_compiled(
150                self.client_mut(),
151                query,
152                tracing_options,
153                slow_query_options,
154                retry_options,
155                &server_addr,
156                query_timeout,
157            )
158            .await
159        })
160        .await
161    }
162
163    pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
164    where
165        T: FromRow + Send,
166    {
167        let tracing_options = self.tracing_options();
168        let slow_query_options = self.slow_query_options();
169        let retry_options = self.retry_options();
170        let server_addr = self.server_addr();
171        let query_timeout = self.query_timeout();
172        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
173            fetch_all_compiled(
174                self.client_mut(),
175                query,
176                tracing_options,
177                slow_query_options,
178                retry_options,
179                &server_addr,
180                query_timeout,
181            )
182            .await
183        })
184        .await
185    }
186
187    pub async fn fetch_one_with<T, F>(
188        &mut self,
189        query: CompiledQuery,
190        map: F,
191    ) -> Result<Option<T>, OrmError>
192    where
193        F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
194    {
195        let tracing_options = self.tracing_options();
196        let slow_query_options = self.slow_query_options();
197        let retry_options = self.retry_options();
198        let server_addr = self.server_addr();
199        let query_timeout = self.query_timeout();
200        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
201            fetch_one_compiled_with(
202                self.client_mut(),
203                query,
204                tracing_options,
205                slow_query_options,
206                retry_options,
207                &server_addr,
208                query_timeout,
209                map,
210            )
211            .await
212        })
213        .await
214    }
215
216    pub async fn fetch_all_with<T, F>(
217        &mut self,
218        query: CompiledQuery,
219        map: F,
220    ) -> Result<Vec<T>, OrmError>
221    where
222        F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
223    {
224        let tracing_options = self.tracing_options();
225        let slow_query_options = self.slow_query_options();
226        let retry_options = self.retry_options();
227        let server_addr = self.server_addr();
228        let query_timeout = self.query_timeout();
229        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
230            fetch_all_compiled_with(
231                self.client_mut(),
232                query,
233                tracing_options,
234                slow_query_options,
235                retry_options,
236                &server_addr,
237                query_timeout,
238                map,
239            )
240            .await
241        })
242        .await
243    }
244}
245
246pub(crate) async fn execute_compiled<S>(
247    client: &mut Client<S>,
248    query: CompiledQuery,
249    tracing_options: MssqlTracingOptions,
250    slow_query_options: MssqlSlowQueryOptions,
251    server_addr: &str,
252    query_timeout: Option<std::time::Duration>,
253) -> Result<ExecuteResult, OrmError>
254where
255    S: AsyncRead + AsyncWrite + Unpin + Send,
256{
257    let prepared = PreparedQuery::from_compiled(query);
258    let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
259    let result = trace_query(tracing_options, slow_query_options, trace, async {
260        prepared.validate_parameter_count()?;
261        prepared.execute(client).await
262    })
263    .await?;
264
265    Ok(ExecuteResult::new(result.rows_affected().to_vec()))
266}
267
268pub(crate) async fn query_raw_compiled<'a, S>(
269    client: &'a mut Client<S>,
270    query: CompiledQuery,
271    tracing_options: MssqlTracingOptions,
272    slow_query_options: MssqlSlowQueryOptions,
273    server_addr: &str,
274    query_timeout: Option<std::time::Duration>,
275) -> Result<QueryStream<'a>, OrmError>
276where
277    S: AsyncRead + AsyncWrite + Unpin + Send,
278{
279    let prepared = PreparedQuery::from_compiled(query);
280    let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
281    trace_query(tracing_options, slow_query_options, trace, async {
282        prepared.validate_parameter_count()?;
283        prepared.query(client).await
284    })
285    .await
286}
287
288pub(crate) async fn fetch_one_compiled<S, T>(
289    client: &mut Client<S>,
290    query: CompiledQuery,
291    tracing_options: MssqlTracingOptions,
292    slow_query_options: MssqlSlowQueryOptions,
293    retry_options: MssqlRetryOptions,
294    server_addr: &str,
295    query_timeout: Option<std::time::Duration>,
296) -> Result<Option<T>, OrmError>
297where
298    S: AsyncRead + AsyncWrite + Unpin + Send,
299    T: FromRow + Send,
300{
301    let retryable_query = is_retryable_read_query(&query, retry_options);
302    let mut attempt = 0;
303
304    let row = loop {
305        let prepared = PreparedQuery::from_compiled(query.clone());
306        prepared.validate_parameter_count()?;
307        let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
308
309        match trace_query(tracing_options, slow_query_options, trace, async {
310            prepared.query_driver(client).await?.into_row().await
311        })
312        .await
313        {
314            Ok(row) => break row,
315            Err(error)
316                if retryable_query
317                    && attempt < retry_options.max_retries
318                    && is_transient_tiberius_error(&error) =>
319            {
320                attempt += 1;
321                let delay = retry_delay(retry_options, attempt);
322
323                tracing::warn!(
324                    target: "orm.query.retry",
325                    server_addr = %server_addr,
326                    operation = %classify_sql(&query.sql),
327                    attempt,
328                    max_retries = retry_options.max_retries,
329                    delay_ms = delay.as_millis(),
330                    error_code = ?error.code(),
331                    error = %error,
332                );
333
334                tokio::time::sleep(delay).await;
335            }
336            Err(error) => {
337                return Err(map_tiberius_error(
338                    &error,
339                    TiberiusErrorContext::ExecuteQuery,
340                ));
341            }
342        }
343    };
344
345    row.as_ref()
346        .map(|row| T::from_row(&MssqlRow::new(row)))
347        .transpose()
348}
349
350async fn fetch_one_compiled_with<S, T, F>(
351    client: &mut Client<S>,
352    query: CompiledQuery,
353    tracing_options: MssqlTracingOptions,
354    slow_query_options: MssqlSlowQueryOptions,
355    retry_options: MssqlRetryOptions,
356    server_addr: &str,
357    query_timeout: Option<std::time::Duration>,
358    mut map: F,
359) -> Result<Option<T>, OrmError>
360where
361    S: AsyncRead + AsyncWrite + Unpin + Send,
362    F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
363{
364    let retryable_query = is_retryable_read_query(&query, retry_options);
365    let mut attempt = 0;
366
367    let row = loop {
368        let prepared = PreparedQuery::from_compiled(query.clone());
369        prepared.validate_parameter_count()?;
370        let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
371
372        match trace_query(tracing_options, slow_query_options, trace, async {
373            prepared.query_driver(client).await?.into_row().await
374        })
375        .await
376        {
377            Ok(row) => break row,
378            Err(error)
379                if retryable_query
380                    && attempt < retry_options.max_retries
381                    && is_transient_tiberius_error(&error) =>
382            {
383                attempt += 1;
384                let delay = retry_delay(retry_options, attempt);
385
386                tracing::warn!(
387                    target: "orm.query.retry",
388                    server_addr = %server_addr,
389                    operation = %classify_sql(&query.sql),
390                    attempt,
391                    max_retries = retry_options.max_retries,
392                    delay_ms = delay.as_millis(),
393                    error_code = ?error.code(),
394                    error = %error,
395                );
396
397                tokio::time::sleep(delay).await;
398            }
399            Err(error) => {
400                return Err(map_tiberius_error(
401                    &error,
402                    TiberiusErrorContext::ExecuteQuery,
403                ));
404            }
405        }
406    };
407
408    row.as_ref().map(|row| map(MssqlRow::new(row))).transpose()
409}
410
411pub(crate) async fn fetch_all_compiled<S, T>(
412    client: &mut Client<S>,
413    query: CompiledQuery,
414    tracing_options: MssqlTracingOptions,
415    slow_query_options: MssqlSlowQueryOptions,
416    retry_options: MssqlRetryOptions,
417    server_addr: &str,
418    query_timeout: Option<std::time::Duration>,
419) -> Result<Vec<T>, OrmError>
420where
421    S: AsyncRead + AsyncWrite + Unpin + Send,
422    T: FromRow + Send,
423{
424    let retryable_query = is_retryable_read_query(&query, retry_options);
425    let mut attempt = 0;
426
427    let rows = loop {
428        let prepared = PreparedQuery::from_compiled(query.clone());
429        prepared.validate_parameter_count()?;
430        let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
431
432        match trace_query(tracing_options, slow_query_options, trace, async {
433            prepared
434                .query_driver(client)
435                .await?
436                .into_first_result()
437                .await
438        })
439        .await
440        {
441            Ok(rows) => break rows,
442            Err(error)
443                if retryable_query
444                    && attempt < retry_options.max_retries
445                    && is_transient_tiberius_error(&error) =>
446            {
447                attempt += 1;
448                let delay = retry_delay(retry_options, attempt);
449
450                tracing::warn!(
451                    target: "orm.query.retry",
452                    server_addr = %server_addr,
453                    operation = %classify_sql(&query.sql),
454                    attempt,
455                    max_retries = retry_options.max_retries,
456                    delay_ms = delay.as_millis(),
457                    error_code = ?error.code(),
458                    error = %error,
459                );
460
461                tokio::time::sleep(delay).await;
462            }
463            Err(error) => {
464                return Err(map_tiberius_error(
465                    &error,
466                    TiberiusErrorContext::ExecuteQuery,
467                ));
468            }
469        }
470    };
471
472    rows.iter()
473        .map(|row| T::from_row(&MssqlRow::new(row)))
474        .collect()
475}
476
477async fn fetch_all_compiled_with<S, T, F>(
478    client: &mut Client<S>,
479    query: CompiledQuery,
480    tracing_options: MssqlTracingOptions,
481    slow_query_options: MssqlSlowQueryOptions,
482    retry_options: MssqlRetryOptions,
483    server_addr: &str,
484    query_timeout: Option<std::time::Duration>,
485    mut map: F,
486) -> Result<Vec<T>, OrmError>
487where
488    S: AsyncRead + AsyncWrite + Unpin + Send,
489    F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
490{
491    let retryable_query = is_retryable_read_query(&query, retry_options);
492    let mut attempt = 0;
493
494    let rows = loop {
495        let prepared = PreparedQuery::from_compiled(query.clone());
496        prepared.validate_parameter_count()?;
497        let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
498
499        match trace_query(tracing_options, slow_query_options, trace, async {
500            prepared
501                .query_driver(client)
502                .await?
503                .into_first_result()
504                .await
505        })
506        .await
507        {
508            Ok(rows) => break rows,
509            Err(error)
510                if retryable_query
511                    && attempt < retry_options.max_retries
512                    && is_transient_tiberius_error(&error) =>
513            {
514                attempt += 1;
515                let delay = retry_delay(retry_options, attempt);
516
517                tracing::warn!(
518                    target: "orm.query.retry",
519                    server_addr = %server_addr,
520                    operation = %classify_sql(&query.sql),
521                    attempt,
522                    max_retries = retry_options.max_retries,
523                    delay_ms = delay.as_millis(),
524                    error_code = ?error.code(),
525                    error = %error,
526                );
527
528                tokio::time::sleep(delay).await;
529            }
530            Err(error) => {
531                return Err(map_tiberius_error(
532                    &error,
533                    TiberiusErrorContext::ExecuteQuery,
534                ));
535            }
536        }
537    };
538
539    rows.iter().map(|row| map(MssqlRow::new(row))).collect()
540}
541
542fn is_retryable_read_query(query: &CompiledQuery, retry_options: MssqlRetryOptions) -> bool {
543    retry_options.enabled && retry_options.max_retries > 0 && classify_sql(&query.sql) == "select"
544}
545
546fn retry_delay(retry_options: MssqlRetryOptions, attempt: u32) -> Duration {
547    let multiplier = 1u32
548        .checked_shl(attempt.saturating_sub(1))
549        .unwrap_or(u32::MAX);
550    let base_millis = retry_options.base_delay.as_millis();
551    let max_millis = retry_options.max_delay.as_millis();
552    let scaled = base_millis.saturating_mul(u128::from(multiplier));
553
554    Duration::from_millis(scaled.min(max_millis) as u64)
555}
556
557#[cfg(test)]
558mod tests {
559    use super::{
560        ExecuteResult, fetch_all_compiled, fetch_one_compiled, is_retryable_read_query,
561        query_raw_compiled, retry_delay,
562    };
563    use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
564    use sql_orm_core::{FromRow, OrmError, Row};
565    use sql_orm_query::CompiledQuery;
566    use std::time::Duration;
567
568    struct TestRowModel;
569
570    impl FromRow for TestRowModel {
571        fn from_row<R: Row>(_row: &R) -> Result<Self, OrmError> {
572            Ok(Self)
573        }
574    }
575
576    #[test]
577    fn execute_result_exposes_rows_affected_and_total() {
578        let result = ExecuteResult::new(vec![1, 2, 3]);
579
580        assert_eq!(result.rows_affected(), &[1, 2, 3]);
581        assert_eq!(result.total(), 6);
582    }
583
584    #[test]
585    fn reuses_shared_execution_helpers_from_transaction_boundary() {
586        let query_raw = query_raw_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
587        let fetch_one =
588            fetch_one_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>, TestRowModel>;
589        let fetch_all =
590            fetch_all_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>, TestRowModel>;
591
592        let _ = (query_raw, fetch_one, fetch_all);
593    }
594
595    #[test]
596    fn compiled_query_helpers_accept_tracing_context_shape() {
597        let tracing = MssqlTracingOptions::enabled();
598        let slow_query = MssqlSlowQueryOptions::enabled(std::time::Duration::from_millis(250));
599
600        assert!(tracing.enabled);
601        assert!(slow_query.enabled);
602    }
603
604    #[test]
605    fn retry_policy_only_targets_select_queries() {
606        let retry = crate::config::MssqlRetryOptions::enabled(
607            2,
608            Duration::from_millis(50),
609            Duration::from_secs(1),
610        );
611
612        assert!(is_retryable_read_query(
613            &CompiledQuery::new("SELECT 1", vec![]),
614            retry
615        ));
616        assert!(!is_retryable_read_query(
617            &CompiledQuery::new("INSERT INTO [dbo].[users] DEFAULT VALUES", vec![]),
618            retry
619        ));
620    }
621
622    #[test]
623    fn retry_delay_caps_at_max_delay() {
624        let retry = crate::config::MssqlRetryOptions::enabled(
625            4,
626            Duration::from_millis(100),
627            Duration::from_millis(250),
628        );
629
630        assert_eq!(retry_delay(retry, 1), Duration::from_millis(100));
631        assert_eq!(retry_delay(retry, 2), Duration::from_millis(200));
632        assert_eq!(retry_delay(retry, 3), Duration::from_millis(250));
633    }
634}