Skip to main content

sql_orm_tiberius/
executor.rs

1use crate::config::{MssqlRetryOptions, MssqlSlowQueryOptions, MssqlTracingOptions};
2use crate::connection::{MssqlConnection, run_with_timeout};
3use crate::error::{TiberiusErrorContext, is_transient_tiberius_error, map_tiberius_error};
4use crate::parameter::PreparedQuery;
5use crate::row::MssqlRow;
6use crate::telemetry::{QueryTrace, classify_sql, trace_query};
7use crate::transaction::MssqlTransaction;
8use async_trait::async_trait;
9use futures_io::{AsyncRead, AsyncWrite};
10use sql_orm_core::{FromRow, OrmError};
11use sql_orm_query::{CompiledQuery, QueryExecution};
12use std::time::Duration;
13use tiberius::Client;
14use tiberius::QueryStream;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct ExecuteResult {
18    rows_affected: Vec<u64>,
19}
20
21#[derive(Clone, Copy)]
22pub(crate) struct QueryExecutionOptions<'a> {
23    pub(crate) tracing: MssqlTracingOptions,
24    pub(crate) slow_query: MssqlSlowQueryOptions,
25    pub(crate) retry: MssqlRetryOptions,
26    pub(crate) server_addr: &'a str,
27    pub(crate) timeout: Option<Duration>,
28}
29
30impl ExecuteResult {
31    pub fn new(rows_affected: Vec<u64>) -> Self {
32        Self { rows_affected }
33    }
34
35    pub fn rows_affected(&self) -> &[u64] {
36        &self.rows_affected
37    }
38
39    pub fn total(&self) -> u64 {
40        self.rows_affected.iter().sum()
41    }
42}
43
44#[async_trait]
45pub trait Executor {
46    async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError>;
47    async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
48    where
49        T: FromRow + Send;
50    async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
51    where
52        T: FromRow + Send;
53}
54
55#[async_trait]
56impl<S> Executor for MssqlConnection<S>
57where
58    S: AsyncRead + AsyncWrite + Unpin + Send,
59{
60    async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
61        MssqlConnection::execute(self, query).await
62    }
63
64    async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
65    where
66        T: FromRow + Send,
67    {
68        MssqlConnection::fetch_one(self, query).await
69    }
70
71    async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
72    where
73        T: FromRow + Send,
74    {
75        MssqlConnection::fetch_all(self, query).await
76    }
77}
78
79#[async_trait]
80impl<S> Executor for MssqlTransaction<'_, S>
81where
82    S: AsyncRead + AsyncWrite + Unpin + Send,
83{
84    async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
85        MssqlTransaction::execute(self, query).await
86    }
87
88    async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
89    where
90        T: FromRow + Send,
91    {
92        MssqlTransaction::fetch_one(self, query).await
93    }
94
95    async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
96    where
97        T: FromRow + Send,
98    {
99        MssqlTransaction::fetch_all(self, query).await
100    }
101}
102
103impl<S> MssqlConnection<S>
104where
105    S: AsyncRead + AsyncWrite + Unpin + Send,
106{
107    pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
108        let tracing_options = self.tracing_options();
109        let slow_query_options = self.slow_query_options();
110        let server_addr = self.server_addr();
111        let query_timeout = self.query_timeout();
112        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
113            execute_compiled(
114                self.client_mut(),
115                query,
116                tracing_options,
117                slow_query_options,
118                &server_addr,
119                query_timeout,
120            )
121            .await
122        })
123        .await
124    }
125
126    pub async fn query_raw<'a>(
127        &'a mut self,
128        query: CompiledQuery,
129    ) -> Result<QueryStream<'a>, OrmError> {
130        let tracing_options = self.tracing_options();
131        let slow_query_options = self.slow_query_options();
132        let server_addr = self.server_addr();
133        let query_timeout = self.query_timeout();
134        run_with_timeout(query_timeout, "SQL Server query timed out", async {
135            query_raw_compiled(
136                self.client_mut(),
137                query,
138                tracing_options,
139                slow_query_options,
140                &server_addr,
141                query_timeout,
142            )
143            .await
144        })
145        .await
146    }
147
148    pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
149    where
150        T: FromRow + Send,
151    {
152        let tracing_options = self.tracing_options();
153        let slow_query_options = self.slow_query_options();
154        let retry_options = self.retry_options();
155        let server_addr = self.server_addr();
156        let query_timeout = self.query_timeout();
157        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
158            fetch_one_compiled(
159                self.client_mut(),
160                query,
161                QueryExecutionOptions {
162                    tracing: tracing_options,
163                    slow_query: slow_query_options,
164                    retry: retry_options,
165                    server_addr: &server_addr,
166                    timeout: query_timeout,
167                },
168            )
169            .await
170        })
171        .await
172    }
173
174    pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
175    where
176        T: FromRow + Send,
177    {
178        let tracing_options = self.tracing_options();
179        let slow_query_options = self.slow_query_options();
180        let retry_options = self.retry_options();
181        let server_addr = self.server_addr();
182        let query_timeout = self.query_timeout();
183        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
184            fetch_all_compiled(
185                self.client_mut(),
186                query,
187                QueryExecutionOptions {
188                    tracing: tracing_options,
189                    slow_query: slow_query_options,
190                    retry: retry_options,
191                    server_addr: &server_addr,
192                    timeout: query_timeout,
193                },
194            )
195            .await
196        })
197        .await
198    }
199
200    pub async fn fetch_one_with<T, F>(
201        &mut self,
202        query: CompiledQuery,
203        map: F,
204    ) -> Result<Option<T>, OrmError>
205    where
206        F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
207    {
208        let tracing_options = self.tracing_options();
209        let slow_query_options = self.slow_query_options();
210        let retry_options = self.retry_options();
211        let server_addr = self.server_addr();
212        let query_timeout = self.query_timeout();
213        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
214            fetch_one_compiled_with(
215                self.client_mut(),
216                query,
217                QueryExecutionOptions {
218                    tracing: tracing_options,
219                    slow_query: slow_query_options,
220                    retry: retry_options,
221                    server_addr: &server_addr,
222                    timeout: query_timeout,
223                },
224                map,
225            )
226            .await
227        })
228        .await
229    }
230
231    pub async fn fetch_all_with<T, F>(
232        &mut self,
233        query: CompiledQuery,
234        map: F,
235    ) -> Result<Vec<T>, OrmError>
236    where
237        F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
238    {
239        let tracing_options = self.tracing_options();
240        let slow_query_options = self.slow_query_options();
241        let retry_options = self.retry_options();
242        let server_addr = self.server_addr();
243        let query_timeout = self.query_timeout();
244        run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
245            fetch_all_compiled_with(
246                self.client_mut(),
247                query,
248                QueryExecutionOptions {
249                    tracing: tracing_options,
250                    slow_query: slow_query_options,
251                    retry: retry_options,
252                    server_addr: &server_addr,
253                    timeout: query_timeout,
254                },
255                map,
256            )
257            .await
258        })
259        .await
260    }
261}
262
263pub(crate) async fn execute_compiled<S>(
264    client: &mut Client<S>,
265    query: CompiledQuery,
266    tracing_options: MssqlTracingOptions,
267    slow_query_options: MssqlSlowQueryOptions,
268    server_addr: &str,
269    query_timeout: Option<std::time::Duration>,
270) -> Result<ExecuteResult, OrmError>
271where
272    S: AsyncRead + AsyncWrite + Unpin + Send,
273{
274    let prepared = PreparedQuery::from_compiled(query);
275    let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
276    let result = trace_query(tracing_options, slow_query_options, trace, async {
277        prepared.validate_parameter_count()?;
278        prepared.execute(client).await
279    })
280    .await?;
281
282    Ok(ExecuteResult::new(result.rows_affected().to_vec()))
283}
284
285pub(crate) async fn query_raw_compiled<'a, S>(
286    client: &'a mut Client<S>,
287    query: CompiledQuery,
288    tracing_options: MssqlTracingOptions,
289    slow_query_options: MssqlSlowQueryOptions,
290    server_addr: &str,
291    query_timeout: Option<std::time::Duration>,
292) -> Result<QueryStream<'a>, OrmError>
293where
294    S: AsyncRead + AsyncWrite + Unpin + Send,
295{
296    let prepared = PreparedQuery::from_compiled(query);
297    let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
298    trace_query(tracing_options, slow_query_options, trace, async {
299        prepared.validate_parameter_count()?;
300        prepared.query(client).await
301    })
302    .await
303}
304
305pub(crate) async fn fetch_one_compiled<S, T>(
306    client: &mut Client<S>,
307    query: CompiledQuery,
308    options: QueryExecutionOptions<'_>,
309) -> Result<Option<T>, OrmError>
310where
311    S: AsyncRead + AsyncWrite + Unpin + Send,
312    T: FromRow + Send,
313{
314    let retryable_query = is_retryable_read_query(&query, options.retry);
315    let mut attempt = 0;
316
317    let row = loop {
318        let prepared = PreparedQuery::from_compiled(query.clone());
319        prepared.validate_parameter_count()?;
320        let trace = QueryTrace::new(
321            options.server_addr,
322            options.timeout,
323            options.tracing,
324            &prepared,
325        );
326
327        match trace_query(options.tracing, options.slow_query, trace, async {
328            prepared.query_driver(client).await?.into_row().await
329        })
330        .await
331        {
332            Ok(row) => break row,
333            Err(error)
334                if retryable_query
335                    && attempt < options.retry.max_retries
336                    && is_transient_tiberius_error(&error) =>
337            {
338                attempt += 1;
339                let delay = retry_delay(options.retry, attempt);
340
341                tracing::warn!(
342                    target: "orm.query.retry",
343                    server_addr = %options.server_addr,
344                    operation = %classify_sql(&query.sql),
345                    attempt,
346                    max_retries = options.retry.max_retries,
347                    delay_ms = delay.as_millis(),
348                    error_code = ?error.code(),
349                    error = %error,
350                );
351
352                tokio::time::sleep(delay).await;
353            }
354            Err(error) => {
355                return Err(map_tiberius_error(
356                    &error,
357                    TiberiusErrorContext::ExecuteQuery,
358                ));
359            }
360        }
361    };
362
363    row.as_ref()
364        .map(|row| T::from_row(&MssqlRow::new(row)))
365        .transpose()
366}
367
368async fn fetch_one_compiled_with<S, T, F>(
369    client: &mut Client<S>,
370    query: CompiledQuery,
371    options: QueryExecutionOptions<'_>,
372    mut map: F,
373) -> Result<Option<T>, OrmError>
374where
375    S: AsyncRead + AsyncWrite + Unpin + Send,
376    F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
377{
378    let retryable_query = is_retryable_read_query(&query, options.retry);
379    let mut attempt = 0;
380
381    let row = loop {
382        let prepared = PreparedQuery::from_compiled(query.clone());
383        prepared.validate_parameter_count()?;
384        let trace = QueryTrace::new(
385            options.server_addr,
386            options.timeout,
387            options.tracing,
388            &prepared,
389        );
390
391        match trace_query(options.tracing, options.slow_query, trace, async {
392            prepared.query_driver(client).await?.into_row().await
393        })
394        .await
395        {
396            Ok(row) => break row,
397            Err(error)
398                if retryable_query
399                    && attempt < options.retry.max_retries
400                    && is_transient_tiberius_error(&error) =>
401            {
402                attempt += 1;
403                let delay = retry_delay(options.retry, attempt);
404
405                tracing::warn!(
406                    target: "orm.query.retry",
407                    server_addr = %options.server_addr,
408                    operation = %classify_sql(&query.sql),
409                    attempt,
410                    max_retries = options.retry.max_retries,
411                    delay_ms = delay.as_millis(),
412                    error_code = ?error.code(),
413                    error = %error,
414                );
415
416                tokio::time::sleep(delay).await;
417            }
418            Err(error) => {
419                return Err(map_tiberius_error(
420                    &error,
421                    TiberiusErrorContext::ExecuteQuery,
422                ));
423            }
424        }
425    };
426
427    row.as_ref().map(|row| map(MssqlRow::new(row))).transpose()
428}
429
430pub(crate) async fn fetch_all_compiled<S, T>(
431    client: &mut Client<S>,
432    query: CompiledQuery,
433    options: QueryExecutionOptions<'_>,
434) -> Result<Vec<T>, OrmError>
435where
436    S: AsyncRead + AsyncWrite + Unpin + Send,
437    T: FromRow + Send,
438{
439    let retryable_query = is_retryable_read_query(&query, options.retry);
440    let mut attempt = 0;
441
442    let rows = loop {
443        let prepared = PreparedQuery::from_compiled(query.clone());
444        prepared.validate_parameter_count()?;
445        let trace = QueryTrace::new(
446            options.server_addr,
447            options.timeout,
448            options.tracing,
449            &prepared,
450        );
451
452        match trace_query(options.tracing, options.slow_query, trace, async {
453            prepared
454                .query_driver(client)
455                .await?
456                .into_first_result()
457                .await
458        })
459        .await
460        {
461            Ok(rows) => break rows,
462            Err(error)
463                if retryable_query
464                    && attempt < options.retry.max_retries
465                    && is_transient_tiberius_error(&error) =>
466            {
467                attempt += 1;
468                let delay = retry_delay(options.retry, attempt);
469
470                tracing::warn!(
471                    target: "orm.query.retry",
472                    server_addr = %options.server_addr,
473                    operation = %classify_sql(&query.sql),
474                    attempt,
475                    max_retries = options.retry.max_retries,
476                    delay_ms = delay.as_millis(),
477                    error_code = ?error.code(),
478                    error = %error,
479                );
480
481                tokio::time::sleep(delay).await;
482            }
483            Err(error) => {
484                return Err(map_tiberius_error(
485                    &error,
486                    TiberiusErrorContext::ExecuteQuery,
487                ));
488            }
489        }
490    };
491
492    rows.iter()
493        .map(|row| T::from_row(&MssqlRow::new(row)))
494        .collect()
495}
496
497async fn fetch_all_compiled_with<S, T, F>(
498    client: &mut Client<S>,
499    query: CompiledQuery,
500    options: QueryExecutionOptions<'_>,
501    mut map: F,
502) -> Result<Vec<T>, OrmError>
503where
504    S: AsyncRead + AsyncWrite + Unpin + Send,
505    F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
506{
507    let retryable_query = is_retryable_read_query(&query, options.retry);
508    let mut attempt = 0;
509
510    let rows = loop {
511        let prepared = PreparedQuery::from_compiled(query.clone());
512        prepared.validate_parameter_count()?;
513        let trace = QueryTrace::new(
514            options.server_addr,
515            options.timeout,
516            options.tracing,
517            &prepared,
518        );
519
520        match trace_query(options.tracing, options.slow_query, trace, async {
521            prepared
522                .query_driver(client)
523                .await?
524                .into_first_result()
525                .await
526        })
527        .await
528        {
529            Ok(rows) => break rows,
530            Err(error)
531                if retryable_query
532                    && attempt < options.retry.max_retries
533                    && is_transient_tiberius_error(&error) =>
534            {
535                attempt += 1;
536                let delay = retry_delay(options.retry, attempt);
537
538                tracing::warn!(
539                    target: "orm.query.retry",
540                    server_addr = %options.server_addr,
541                    operation = %classify_sql(&query.sql),
542                    attempt,
543                    max_retries = options.retry.max_retries,
544                    delay_ms = delay.as_millis(),
545                    error_code = ?error.code(),
546                    error = %error,
547                );
548
549                tokio::time::sleep(delay).await;
550            }
551            Err(error) => {
552                return Err(map_tiberius_error(
553                    &error,
554                    TiberiusErrorContext::ExecuteQuery,
555                ));
556            }
557        }
558    };
559
560    rows.iter().map(|row| map(MssqlRow::new(row))).collect()
561}
562
563fn is_retryable_read_query(query: &CompiledQuery, retry_options: MssqlRetryOptions) -> bool {
564    retry_options.enabled
565        && retry_options.max_retries > 0
566        && query.execution == QueryExecution::ReadOnly
567}
568
569fn retry_delay(retry_options: MssqlRetryOptions, attempt: u32) -> Duration {
570    let multiplier = 1u32
571        .checked_shl(attempt.saturating_sub(1))
572        .unwrap_or(u32::MAX);
573    let base_millis = retry_options.base_delay.as_millis();
574    let max_millis = retry_options.max_delay.as_millis();
575    let scaled = base_millis.saturating_mul(u128::from(multiplier));
576
577    Duration::from_millis(scaled.min(max_millis) as u64)
578}
579
580#[cfg(test)]
581mod tests {
582    use super::{
583        ExecuteResult, fetch_all_compiled, fetch_one_compiled, is_retryable_read_query,
584        query_raw_compiled, retry_delay,
585    };
586    use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
587    use sql_orm_core::{FromRow, OrmError, Row};
588    use sql_orm_query::{CompiledQuery, QueryExecution};
589    use std::time::Duration;
590
591    struct TestRowModel;
592
593    impl FromRow for TestRowModel {
594        fn from_row<R: Row>(_row: &R) -> Result<Self, OrmError> {
595            Ok(Self)
596        }
597    }
598
599    #[test]
600    fn execute_result_exposes_rows_affected_and_total() {
601        let result = ExecuteResult::new(vec![1, 2, 3]);
602
603        assert_eq!(result.rows_affected(), &[1, 2, 3]);
604        assert_eq!(result.total(), 6);
605    }
606
607    #[test]
608    fn reuses_shared_execution_helpers_from_transaction_boundary() {
609        let query_raw = query_raw_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
610        let fetch_one =
611            fetch_one_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>, TestRowModel>;
612        let fetch_all =
613            fetch_all_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>, TestRowModel>;
614
615        let _ = (query_raw, fetch_one, fetch_all);
616    }
617
618    #[test]
619    fn compiled_query_helpers_accept_tracing_context_shape() {
620        let tracing = MssqlTracingOptions::enabled();
621        let slow_query = MssqlSlowQueryOptions::enabled(std::time::Duration::from_millis(250));
622
623        assert!(tracing.enabled);
624        assert!(slow_query.enabled);
625    }
626
627    #[test]
628    fn retry_policy_only_targets_explicit_read_only_queries() {
629        let retry = crate::config::MssqlRetryOptions::enabled(
630            2,
631            Duration::from_millis(50),
632            Duration::from_secs(1),
633        );
634
635        assert!(is_retryable_read_query(
636            &CompiledQuery::read_only("EXEC dbo.read_only_proc", vec![]),
637            retry
638        ));
639        assert!(!is_retryable_read_query(
640            &CompiledQuery::write(
641                "SELECT * INTO [dbo].[users_copy] FROM [dbo].[users]",
642                vec![]
643            ),
644            retry
645        ));
646        assert!(!is_retryable_read_query(
647            &CompiledQuery::with_execution("SELECT 1", vec![], QueryExecution::RawNoRetry),
648            retry
649        ));
650    }
651
652    #[test]
653    fn retry_delay_caps_at_max_delay() {
654        let retry = crate::config::MssqlRetryOptions::enabled(
655            4,
656            Duration::from_millis(100),
657            Duration::from_millis(250),
658        );
659
660        assert_eq!(retry_delay(retry, 1), Duration::from_millis(100));
661        assert_eq!(retry_delay(retry, 2), Duration::from_millis(200));
662        assert_eq!(retry_delay(retry, 3), Duration::from_millis(250));
663    }
664}