1use crate::config::{MssqlRetryOptions, MssqlSlowQueryOptions, MssqlTracingOptions};
2use crate::connection::{MssqlConnection, run_with_timeout};
3use crate::error::{TiberiusErrorContext, is_transient_tiberius_error, map_tiberius_error};
4use crate::parameter::PreparedQuery;
5use crate::row::MssqlRow;
6use crate::telemetry::{QueryTrace, classify_sql, trace_query};
7use crate::transaction::MssqlTransaction;
8use async_trait::async_trait;
9use futures_io::{AsyncRead, AsyncWrite};
10use sql_orm_core::{FromRow, OrmError};
11use sql_orm_query::{CompiledQuery, QueryExecution};
12use std::time::Duration;
13use tiberius::Client;
14use tiberius::QueryStream;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct ExecuteResult {
18 rows_affected: Vec<u64>,
19}
20
21#[derive(Clone, Copy)]
22pub(crate) struct QueryExecutionOptions<'a> {
23 pub(crate) tracing: MssqlTracingOptions,
24 pub(crate) slow_query: MssqlSlowQueryOptions,
25 pub(crate) retry: MssqlRetryOptions,
26 pub(crate) server_addr: &'a str,
27 pub(crate) timeout: Option<Duration>,
28}
29
30impl ExecuteResult {
31 pub fn new(rows_affected: Vec<u64>) -> Self {
32 Self { rows_affected }
33 }
34
35 pub fn rows_affected(&self) -> &[u64] {
36 &self.rows_affected
37 }
38
39 pub fn total(&self) -> u64 {
40 self.rows_affected.iter().sum()
41 }
42}
43
44#[async_trait]
45pub trait Executor {
46 async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError>;
47 async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
48 where
49 T: FromRow + Send;
50 async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
51 where
52 T: FromRow + Send;
53}
54
55#[async_trait]
56impl<S> Executor for MssqlConnection<S>
57where
58 S: AsyncRead + AsyncWrite + Unpin + Send,
59{
60 async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
61 MssqlConnection::execute(self, query).await
62 }
63
64 async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
65 where
66 T: FromRow + Send,
67 {
68 MssqlConnection::fetch_one(self, query).await
69 }
70
71 async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
72 where
73 T: FromRow + Send,
74 {
75 MssqlConnection::fetch_all(self, query).await
76 }
77}
78
79#[async_trait]
80impl<S> Executor for MssqlTransaction<'_, S>
81where
82 S: AsyncRead + AsyncWrite + Unpin + Send,
83{
84 async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
85 MssqlTransaction::execute(self, query).await
86 }
87
88 async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
89 where
90 T: FromRow + Send,
91 {
92 MssqlTransaction::fetch_one(self, query).await
93 }
94
95 async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
96 where
97 T: FromRow + Send,
98 {
99 MssqlTransaction::fetch_all(self, query).await
100 }
101}
102
103impl<S> MssqlConnection<S>
104where
105 S: AsyncRead + AsyncWrite + Unpin + Send,
106{
107 pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
108 let tracing_options = self.tracing_options();
109 let slow_query_options = self.slow_query_options();
110 let server_addr = self.server_addr();
111 let query_timeout = self.query_timeout();
112 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
113 execute_compiled(
114 self.client_mut(),
115 query,
116 tracing_options,
117 slow_query_options,
118 &server_addr,
119 query_timeout,
120 )
121 .await
122 })
123 .await
124 }
125
126 pub async fn query_raw<'a>(
127 &'a mut self,
128 query: CompiledQuery,
129 ) -> Result<QueryStream<'a>, OrmError> {
130 let tracing_options = self.tracing_options();
131 let slow_query_options = self.slow_query_options();
132 let server_addr = self.server_addr();
133 let query_timeout = self.query_timeout();
134 run_with_timeout(query_timeout, "SQL Server query timed out", async {
135 query_raw_compiled(
136 self.client_mut(),
137 query,
138 tracing_options,
139 slow_query_options,
140 &server_addr,
141 query_timeout,
142 )
143 .await
144 })
145 .await
146 }
147
148 pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
149 where
150 T: FromRow + Send,
151 {
152 let tracing_options = self.tracing_options();
153 let slow_query_options = self.slow_query_options();
154 let retry_options = self.retry_options();
155 let server_addr = self.server_addr();
156 let query_timeout = self.query_timeout();
157 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
158 fetch_one_compiled(
159 self.client_mut(),
160 query,
161 QueryExecutionOptions {
162 tracing: tracing_options,
163 slow_query: slow_query_options,
164 retry: retry_options,
165 server_addr: &server_addr,
166 timeout: query_timeout,
167 },
168 )
169 .await
170 })
171 .await
172 }
173
174 pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
175 where
176 T: FromRow + Send,
177 {
178 let tracing_options = self.tracing_options();
179 let slow_query_options = self.slow_query_options();
180 let retry_options = self.retry_options();
181 let server_addr = self.server_addr();
182 let query_timeout = self.query_timeout();
183 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
184 fetch_all_compiled(
185 self.client_mut(),
186 query,
187 QueryExecutionOptions {
188 tracing: tracing_options,
189 slow_query: slow_query_options,
190 retry: retry_options,
191 server_addr: &server_addr,
192 timeout: query_timeout,
193 },
194 )
195 .await
196 })
197 .await
198 }
199
200 pub async fn fetch_one_with<T, F>(
201 &mut self,
202 query: CompiledQuery,
203 map: F,
204 ) -> Result<Option<T>, OrmError>
205 where
206 F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
207 {
208 let tracing_options = self.tracing_options();
209 let slow_query_options = self.slow_query_options();
210 let retry_options = self.retry_options();
211 let server_addr = self.server_addr();
212 let query_timeout = self.query_timeout();
213 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
214 fetch_one_compiled_with(
215 self.client_mut(),
216 query,
217 QueryExecutionOptions {
218 tracing: tracing_options,
219 slow_query: slow_query_options,
220 retry: retry_options,
221 server_addr: &server_addr,
222 timeout: query_timeout,
223 },
224 map,
225 )
226 .await
227 })
228 .await
229 }
230
231 pub async fn fetch_all_with<T, F>(
232 &mut self,
233 query: CompiledQuery,
234 map: F,
235 ) -> Result<Vec<T>, OrmError>
236 where
237 F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
238 {
239 let tracing_options = self.tracing_options();
240 let slow_query_options = self.slow_query_options();
241 let retry_options = self.retry_options();
242 let server_addr = self.server_addr();
243 let query_timeout = self.query_timeout();
244 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
245 fetch_all_compiled_with(
246 self.client_mut(),
247 query,
248 QueryExecutionOptions {
249 tracing: tracing_options,
250 slow_query: slow_query_options,
251 retry: retry_options,
252 server_addr: &server_addr,
253 timeout: query_timeout,
254 },
255 map,
256 )
257 .await
258 })
259 .await
260 }
261}
262
263pub(crate) async fn execute_compiled<S>(
264 client: &mut Client<S>,
265 query: CompiledQuery,
266 tracing_options: MssqlTracingOptions,
267 slow_query_options: MssqlSlowQueryOptions,
268 server_addr: &str,
269 query_timeout: Option<std::time::Duration>,
270) -> Result<ExecuteResult, OrmError>
271where
272 S: AsyncRead + AsyncWrite + Unpin + Send,
273{
274 let prepared = PreparedQuery::from_compiled(query);
275 let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
276 let result = trace_query(tracing_options, slow_query_options, trace, async {
277 prepared.validate_parameter_count()?;
278 prepared.execute(client).await
279 })
280 .await?;
281
282 Ok(ExecuteResult::new(result.rows_affected().to_vec()))
283}
284
285pub(crate) async fn query_raw_compiled<'a, S>(
286 client: &'a mut Client<S>,
287 query: CompiledQuery,
288 tracing_options: MssqlTracingOptions,
289 slow_query_options: MssqlSlowQueryOptions,
290 server_addr: &str,
291 query_timeout: Option<std::time::Duration>,
292) -> Result<QueryStream<'a>, OrmError>
293where
294 S: AsyncRead + AsyncWrite + Unpin + Send,
295{
296 let prepared = PreparedQuery::from_compiled(query);
297 let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
298 trace_query(tracing_options, slow_query_options, trace, async {
299 prepared.validate_parameter_count()?;
300 prepared.query(client).await
301 })
302 .await
303}
304
305pub(crate) async fn fetch_one_compiled<S, T>(
306 client: &mut Client<S>,
307 query: CompiledQuery,
308 options: QueryExecutionOptions<'_>,
309) -> Result<Option<T>, OrmError>
310where
311 S: AsyncRead + AsyncWrite + Unpin + Send,
312 T: FromRow + Send,
313{
314 let retryable_query = is_retryable_read_query(&query, options.retry);
315 let mut attempt = 0;
316
317 let row = loop {
318 let prepared = PreparedQuery::from_compiled(query.clone());
319 prepared.validate_parameter_count()?;
320 let trace = QueryTrace::new(
321 options.server_addr,
322 options.timeout,
323 options.tracing,
324 &prepared,
325 );
326
327 match trace_query(options.tracing, options.slow_query, trace, async {
328 prepared.query_driver(client).await?.into_row().await
329 })
330 .await
331 {
332 Ok(row) => break row,
333 Err(error)
334 if retryable_query
335 && attempt < options.retry.max_retries
336 && is_transient_tiberius_error(&error) =>
337 {
338 attempt += 1;
339 let delay = retry_delay(options.retry, attempt);
340
341 tracing::warn!(
342 target: "orm.query.retry",
343 server_addr = %options.server_addr,
344 operation = %classify_sql(&query.sql),
345 attempt,
346 max_retries = options.retry.max_retries,
347 delay_ms = delay.as_millis(),
348 error_code = ?error.code(),
349 error = %error,
350 );
351
352 tokio::time::sleep(delay).await;
353 }
354 Err(error) => {
355 return Err(map_tiberius_error(
356 &error,
357 TiberiusErrorContext::ExecuteQuery,
358 ));
359 }
360 }
361 };
362
363 row.as_ref()
364 .map(|row| T::from_row(&MssqlRow::new(row)))
365 .transpose()
366}
367
368async fn fetch_one_compiled_with<S, T, F>(
369 client: &mut Client<S>,
370 query: CompiledQuery,
371 options: QueryExecutionOptions<'_>,
372 mut map: F,
373) -> Result<Option<T>, OrmError>
374where
375 S: AsyncRead + AsyncWrite + Unpin + Send,
376 F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
377{
378 let retryable_query = is_retryable_read_query(&query, options.retry);
379 let mut attempt = 0;
380
381 let row = loop {
382 let prepared = PreparedQuery::from_compiled(query.clone());
383 prepared.validate_parameter_count()?;
384 let trace = QueryTrace::new(
385 options.server_addr,
386 options.timeout,
387 options.tracing,
388 &prepared,
389 );
390
391 match trace_query(options.tracing, options.slow_query, trace, async {
392 prepared.query_driver(client).await?.into_row().await
393 })
394 .await
395 {
396 Ok(row) => break row,
397 Err(error)
398 if retryable_query
399 && attempt < options.retry.max_retries
400 && is_transient_tiberius_error(&error) =>
401 {
402 attempt += 1;
403 let delay = retry_delay(options.retry, attempt);
404
405 tracing::warn!(
406 target: "orm.query.retry",
407 server_addr = %options.server_addr,
408 operation = %classify_sql(&query.sql),
409 attempt,
410 max_retries = options.retry.max_retries,
411 delay_ms = delay.as_millis(),
412 error_code = ?error.code(),
413 error = %error,
414 );
415
416 tokio::time::sleep(delay).await;
417 }
418 Err(error) => {
419 return Err(map_tiberius_error(
420 &error,
421 TiberiusErrorContext::ExecuteQuery,
422 ));
423 }
424 }
425 };
426
427 row.as_ref().map(|row| map(MssqlRow::new(row))).transpose()
428}
429
430pub(crate) async fn fetch_all_compiled<S, T>(
431 client: &mut Client<S>,
432 query: CompiledQuery,
433 options: QueryExecutionOptions<'_>,
434) -> Result<Vec<T>, OrmError>
435where
436 S: AsyncRead + AsyncWrite + Unpin + Send,
437 T: FromRow + Send,
438{
439 let retryable_query = is_retryable_read_query(&query, options.retry);
440 let mut attempt = 0;
441
442 let rows = loop {
443 let prepared = PreparedQuery::from_compiled(query.clone());
444 prepared.validate_parameter_count()?;
445 let trace = QueryTrace::new(
446 options.server_addr,
447 options.timeout,
448 options.tracing,
449 &prepared,
450 );
451
452 match trace_query(options.tracing, options.slow_query, trace, async {
453 prepared
454 .query_driver(client)
455 .await?
456 .into_first_result()
457 .await
458 })
459 .await
460 {
461 Ok(rows) => break rows,
462 Err(error)
463 if retryable_query
464 && attempt < options.retry.max_retries
465 && is_transient_tiberius_error(&error) =>
466 {
467 attempt += 1;
468 let delay = retry_delay(options.retry, attempt);
469
470 tracing::warn!(
471 target: "orm.query.retry",
472 server_addr = %options.server_addr,
473 operation = %classify_sql(&query.sql),
474 attempt,
475 max_retries = options.retry.max_retries,
476 delay_ms = delay.as_millis(),
477 error_code = ?error.code(),
478 error = %error,
479 );
480
481 tokio::time::sleep(delay).await;
482 }
483 Err(error) => {
484 return Err(map_tiberius_error(
485 &error,
486 TiberiusErrorContext::ExecuteQuery,
487 ));
488 }
489 }
490 };
491
492 rows.iter()
493 .map(|row| T::from_row(&MssqlRow::new(row)))
494 .collect()
495}
496
497async fn fetch_all_compiled_with<S, T, F>(
498 client: &mut Client<S>,
499 query: CompiledQuery,
500 options: QueryExecutionOptions<'_>,
501 mut map: F,
502) -> Result<Vec<T>, OrmError>
503where
504 S: AsyncRead + AsyncWrite + Unpin + Send,
505 F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
506{
507 let retryable_query = is_retryable_read_query(&query, options.retry);
508 let mut attempt = 0;
509
510 let rows = loop {
511 let prepared = PreparedQuery::from_compiled(query.clone());
512 prepared.validate_parameter_count()?;
513 let trace = QueryTrace::new(
514 options.server_addr,
515 options.timeout,
516 options.tracing,
517 &prepared,
518 );
519
520 match trace_query(options.tracing, options.slow_query, trace, async {
521 prepared
522 .query_driver(client)
523 .await?
524 .into_first_result()
525 .await
526 })
527 .await
528 {
529 Ok(rows) => break rows,
530 Err(error)
531 if retryable_query
532 && attempt < options.retry.max_retries
533 && is_transient_tiberius_error(&error) =>
534 {
535 attempt += 1;
536 let delay = retry_delay(options.retry, attempt);
537
538 tracing::warn!(
539 target: "orm.query.retry",
540 server_addr = %options.server_addr,
541 operation = %classify_sql(&query.sql),
542 attempt,
543 max_retries = options.retry.max_retries,
544 delay_ms = delay.as_millis(),
545 error_code = ?error.code(),
546 error = %error,
547 );
548
549 tokio::time::sleep(delay).await;
550 }
551 Err(error) => {
552 return Err(map_tiberius_error(
553 &error,
554 TiberiusErrorContext::ExecuteQuery,
555 ));
556 }
557 }
558 };
559
560 rows.iter().map(|row| map(MssqlRow::new(row))).collect()
561}
562
563fn is_retryable_read_query(query: &CompiledQuery, retry_options: MssqlRetryOptions) -> bool {
564 retry_options.enabled
565 && retry_options.max_retries > 0
566 && query.execution == QueryExecution::ReadOnly
567}
568
569fn retry_delay(retry_options: MssqlRetryOptions, attempt: u32) -> Duration {
570 let multiplier = 1u32
571 .checked_shl(attempt.saturating_sub(1))
572 .unwrap_or(u32::MAX);
573 let base_millis = retry_options.base_delay.as_millis();
574 let max_millis = retry_options.max_delay.as_millis();
575 let scaled = base_millis.saturating_mul(u128::from(multiplier));
576
577 Duration::from_millis(scaled.min(max_millis) as u64)
578}
579
580#[cfg(test)]
581mod tests {
582 use super::{
583 ExecuteResult, fetch_all_compiled, fetch_one_compiled, is_retryable_read_query,
584 query_raw_compiled, retry_delay,
585 };
586 use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
587 use sql_orm_core::{FromRow, OrmError, Row};
588 use sql_orm_query::{CompiledQuery, QueryExecution};
589 use std::time::Duration;
590
591 struct TestRowModel;
592
593 impl FromRow for TestRowModel {
594 fn from_row<R: Row>(_row: &R) -> Result<Self, OrmError> {
595 Ok(Self)
596 }
597 }
598
599 #[test]
600 fn execute_result_exposes_rows_affected_and_total() {
601 let result = ExecuteResult::new(vec![1, 2, 3]);
602
603 assert_eq!(result.rows_affected(), &[1, 2, 3]);
604 assert_eq!(result.total(), 6);
605 }
606
607 #[test]
608 fn reuses_shared_execution_helpers_from_transaction_boundary() {
609 let query_raw = query_raw_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
610 let fetch_one =
611 fetch_one_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>, TestRowModel>;
612 let fetch_all =
613 fetch_all_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>, TestRowModel>;
614
615 let _ = (query_raw, fetch_one, fetch_all);
616 }
617
618 #[test]
619 fn compiled_query_helpers_accept_tracing_context_shape() {
620 let tracing = MssqlTracingOptions::enabled();
621 let slow_query = MssqlSlowQueryOptions::enabled(std::time::Duration::from_millis(250));
622
623 assert!(tracing.enabled);
624 assert!(slow_query.enabled);
625 }
626
627 #[test]
628 fn retry_policy_only_targets_explicit_read_only_queries() {
629 let retry = crate::config::MssqlRetryOptions::enabled(
630 2,
631 Duration::from_millis(50),
632 Duration::from_secs(1),
633 );
634
635 assert!(is_retryable_read_query(
636 &CompiledQuery::read_only("EXEC dbo.read_only_proc", vec![]),
637 retry
638 ));
639 assert!(!is_retryable_read_query(
640 &CompiledQuery::write(
641 "SELECT * INTO [dbo].[users_copy] FROM [dbo].[users]",
642 vec![]
643 ),
644 retry
645 ));
646 assert!(!is_retryable_read_query(
647 &CompiledQuery::with_execution("SELECT 1", vec![], QueryExecution::RawNoRetry),
648 retry
649 ));
650 }
651
652 #[test]
653 fn retry_delay_caps_at_max_delay() {
654 let retry = crate::config::MssqlRetryOptions::enabled(
655 4,
656 Duration::from_millis(100),
657 Duration::from_millis(250),
658 );
659
660 assert_eq!(retry_delay(retry, 1), Duration::from_millis(100));
661 assert_eq!(retry_delay(retry, 2), Duration::from_millis(200));
662 assert_eq!(retry_delay(retry, 3), Duration::from_millis(250));
663 }
664}