1use crate::config::{MssqlRetryOptions, MssqlSlowQueryOptions, MssqlTracingOptions};
2use crate::connection::{MssqlConnection, run_with_timeout};
3use crate::error::{TiberiusErrorContext, is_transient_tiberius_error, map_tiberius_error};
4use crate::parameter::PreparedQuery;
5use crate::row::MssqlRow;
6use crate::telemetry::{QueryTrace, classify_sql, trace_query};
7use crate::transaction::MssqlTransaction;
8use async_trait::async_trait;
9use futures_io::{AsyncRead, AsyncWrite};
10use sql_orm_core::{FromRow, OrmError};
11use sql_orm_query::CompiledQuery;
12use std::time::Duration;
13use tiberius::Client;
14use tiberius::QueryStream;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct ExecuteResult {
18 rows_affected: Vec<u64>,
19}
20
21impl ExecuteResult {
22 pub fn new(rows_affected: Vec<u64>) -> Self {
23 Self { rows_affected }
24 }
25
26 pub fn rows_affected(&self) -> &[u64] {
27 &self.rows_affected
28 }
29
30 pub fn total(&self) -> u64 {
31 self.rows_affected.iter().sum()
32 }
33}
34
35#[async_trait]
36pub trait Executor {
37 async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError>;
38 async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
39 where
40 T: FromRow + Send;
41 async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
42 where
43 T: FromRow + Send;
44}
45
46#[async_trait]
47impl<S> Executor for MssqlConnection<S>
48where
49 S: AsyncRead + AsyncWrite + Unpin + Send,
50{
51 async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
52 MssqlConnection::execute(self, query).await
53 }
54
55 async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
56 where
57 T: FromRow + Send,
58 {
59 MssqlConnection::fetch_one(self, query).await
60 }
61
62 async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
63 where
64 T: FromRow + Send,
65 {
66 MssqlConnection::fetch_all(self, query).await
67 }
68}
69
70#[async_trait]
71impl<S> Executor for MssqlTransaction<'_, S>
72where
73 S: AsyncRead + AsyncWrite + Unpin + Send,
74{
75 async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
76 MssqlTransaction::execute(self, query).await
77 }
78
79 async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
80 where
81 T: FromRow + Send,
82 {
83 MssqlTransaction::fetch_one(self, query).await
84 }
85
86 async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
87 where
88 T: FromRow + Send,
89 {
90 MssqlTransaction::fetch_all(self, query).await
91 }
92}
93
94impl<S> MssqlConnection<S>
95where
96 S: AsyncRead + AsyncWrite + Unpin + Send,
97{
98 pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
99 let tracing_options = self.tracing_options();
100 let slow_query_options = self.slow_query_options();
101 let server_addr = self.server_addr();
102 let query_timeout = self.query_timeout();
103 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
104 execute_compiled(
105 self.client_mut(),
106 query,
107 tracing_options,
108 slow_query_options,
109 &server_addr,
110 query_timeout,
111 )
112 .await
113 })
114 .await
115 }
116
117 pub async fn query_raw<'a>(
118 &'a mut self,
119 query: CompiledQuery,
120 ) -> Result<QueryStream<'a>, OrmError> {
121 let tracing_options = self.tracing_options();
122 let slow_query_options = self.slow_query_options();
123 let server_addr = self.server_addr();
124 let query_timeout = self.query_timeout();
125 run_with_timeout(query_timeout, "SQL Server query timed out", async {
126 query_raw_compiled(
127 self.client_mut(),
128 query,
129 tracing_options,
130 slow_query_options,
131 &server_addr,
132 query_timeout,
133 )
134 .await
135 })
136 .await
137 }
138
139 pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
140 where
141 T: FromRow + Send,
142 {
143 let tracing_options = self.tracing_options();
144 let slow_query_options = self.slow_query_options();
145 let retry_options = self.retry_options();
146 let server_addr = self.server_addr();
147 let query_timeout = self.query_timeout();
148 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
149 fetch_one_compiled(
150 self.client_mut(),
151 query,
152 tracing_options,
153 slow_query_options,
154 retry_options,
155 &server_addr,
156 query_timeout,
157 )
158 .await
159 })
160 .await
161 }
162
163 pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
164 where
165 T: FromRow + Send,
166 {
167 let tracing_options = self.tracing_options();
168 let slow_query_options = self.slow_query_options();
169 let retry_options = self.retry_options();
170 let server_addr = self.server_addr();
171 let query_timeout = self.query_timeout();
172 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
173 fetch_all_compiled(
174 self.client_mut(),
175 query,
176 tracing_options,
177 slow_query_options,
178 retry_options,
179 &server_addr,
180 query_timeout,
181 )
182 .await
183 })
184 .await
185 }
186
187 pub async fn fetch_one_with<T, F>(
188 &mut self,
189 query: CompiledQuery,
190 map: F,
191 ) -> Result<Option<T>, OrmError>
192 where
193 F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
194 {
195 let tracing_options = self.tracing_options();
196 let slow_query_options = self.slow_query_options();
197 let retry_options = self.retry_options();
198 let server_addr = self.server_addr();
199 let query_timeout = self.query_timeout();
200 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
201 fetch_one_compiled_with(
202 self.client_mut(),
203 query,
204 tracing_options,
205 slow_query_options,
206 retry_options,
207 &server_addr,
208 query_timeout,
209 map,
210 )
211 .await
212 })
213 .await
214 }
215
216 pub async fn fetch_all_with<T, F>(
217 &mut self,
218 query: CompiledQuery,
219 map: F,
220 ) -> Result<Vec<T>, OrmError>
221 where
222 F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
223 {
224 let tracing_options = self.tracing_options();
225 let slow_query_options = self.slow_query_options();
226 let retry_options = self.retry_options();
227 let server_addr = self.server_addr();
228 let query_timeout = self.query_timeout();
229 run_with_timeout(self.query_timeout(), "SQL Server query timed out", async {
230 fetch_all_compiled_with(
231 self.client_mut(),
232 query,
233 tracing_options,
234 slow_query_options,
235 retry_options,
236 &server_addr,
237 query_timeout,
238 map,
239 )
240 .await
241 })
242 .await
243 }
244}
245
246pub(crate) async fn execute_compiled<S>(
247 client: &mut Client<S>,
248 query: CompiledQuery,
249 tracing_options: MssqlTracingOptions,
250 slow_query_options: MssqlSlowQueryOptions,
251 server_addr: &str,
252 query_timeout: Option<std::time::Duration>,
253) -> Result<ExecuteResult, OrmError>
254where
255 S: AsyncRead + AsyncWrite + Unpin + Send,
256{
257 let prepared = PreparedQuery::from_compiled(query);
258 let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
259 let result = trace_query(tracing_options, slow_query_options, trace, async {
260 prepared.validate_parameter_count()?;
261 prepared.execute(client).await
262 })
263 .await?;
264
265 Ok(ExecuteResult::new(result.rows_affected().to_vec()))
266}
267
268pub(crate) async fn query_raw_compiled<'a, S>(
269 client: &'a mut Client<S>,
270 query: CompiledQuery,
271 tracing_options: MssqlTracingOptions,
272 slow_query_options: MssqlSlowQueryOptions,
273 server_addr: &str,
274 query_timeout: Option<std::time::Duration>,
275) -> Result<QueryStream<'a>, OrmError>
276where
277 S: AsyncRead + AsyncWrite + Unpin + Send,
278{
279 let prepared = PreparedQuery::from_compiled(query);
280 let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
281 trace_query(tracing_options, slow_query_options, trace, async {
282 prepared.validate_parameter_count()?;
283 prepared.query(client).await
284 })
285 .await
286}
287
288pub(crate) async fn fetch_one_compiled<S, T>(
289 client: &mut Client<S>,
290 query: CompiledQuery,
291 tracing_options: MssqlTracingOptions,
292 slow_query_options: MssqlSlowQueryOptions,
293 retry_options: MssqlRetryOptions,
294 server_addr: &str,
295 query_timeout: Option<std::time::Duration>,
296) -> Result<Option<T>, OrmError>
297where
298 S: AsyncRead + AsyncWrite + Unpin + Send,
299 T: FromRow + Send,
300{
301 let retryable_query = is_retryable_read_query(&query, retry_options);
302 let mut attempt = 0;
303
304 let row = loop {
305 let prepared = PreparedQuery::from_compiled(query.clone());
306 prepared.validate_parameter_count()?;
307 let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
308
309 match trace_query(tracing_options, slow_query_options, trace, async {
310 prepared.query_driver(client).await?.into_row().await
311 })
312 .await
313 {
314 Ok(row) => break row,
315 Err(error)
316 if retryable_query
317 && attempt < retry_options.max_retries
318 && is_transient_tiberius_error(&error) =>
319 {
320 attempt += 1;
321 let delay = retry_delay(retry_options, attempt);
322
323 tracing::warn!(
324 target: "orm.query.retry",
325 server_addr = %server_addr,
326 operation = %classify_sql(&query.sql),
327 attempt,
328 max_retries = retry_options.max_retries,
329 delay_ms = delay.as_millis(),
330 error_code = ?error.code(),
331 error = %error,
332 );
333
334 tokio::time::sleep(delay).await;
335 }
336 Err(error) => {
337 return Err(map_tiberius_error(
338 &error,
339 TiberiusErrorContext::ExecuteQuery,
340 ));
341 }
342 }
343 };
344
345 row.as_ref()
346 .map(|row| T::from_row(&MssqlRow::new(row)))
347 .transpose()
348}
349
350async fn fetch_one_compiled_with<S, T, F>(
351 client: &mut Client<S>,
352 query: CompiledQuery,
353 tracing_options: MssqlTracingOptions,
354 slow_query_options: MssqlSlowQueryOptions,
355 retry_options: MssqlRetryOptions,
356 server_addr: &str,
357 query_timeout: Option<std::time::Duration>,
358 mut map: F,
359) -> Result<Option<T>, OrmError>
360where
361 S: AsyncRead + AsyncWrite + Unpin + Send,
362 F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
363{
364 let retryable_query = is_retryable_read_query(&query, retry_options);
365 let mut attempt = 0;
366
367 let row = loop {
368 let prepared = PreparedQuery::from_compiled(query.clone());
369 prepared.validate_parameter_count()?;
370 let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
371
372 match trace_query(tracing_options, slow_query_options, trace, async {
373 prepared.query_driver(client).await?.into_row().await
374 })
375 .await
376 {
377 Ok(row) => break row,
378 Err(error)
379 if retryable_query
380 && attempt < retry_options.max_retries
381 && is_transient_tiberius_error(&error) =>
382 {
383 attempt += 1;
384 let delay = retry_delay(retry_options, attempt);
385
386 tracing::warn!(
387 target: "orm.query.retry",
388 server_addr = %server_addr,
389 operation = %classify_sql(&query.sql),
390 attempt,
391 max_retries = retry_options.max_retries,
392 delay_ms = delay.as_millis(),
393 error_code = ?error.code(),
394 error = %error,
395 );
396
397 tokio::time::sleep(delay).await;
398 }
399 Err(error) => {
400 return Err(map_tiberius_error(
401 &error,
402 TiberiusErrorContext::ExecuteQuery,
403 ));
404 }
405 }
406 };
407
408 row.as_ref().map(|row| map(MssqlRow::new(row))).transpose()
409}
410
411pub(crate) async fn fetch_all_compiled<S, T>(
412 client: &mut Client<S>,
413 query: CompiledQuery,
414 tracing_options: MssqlTracingOptions,
415 slow_query_options: MssqlSlowQueryOptions,
416 retry_options: MssqlRetryOptions,
417 server_addr: &str,
418 query_timeout: Option<std::time::Duration>,
419) -> Result<Vec<T>, OrmError>
420where
421 S: AsyncRead + AsyncWrite + Unpin + Send,
422 T: FromRow + Send,
423{
424 let retryable_query = is_retryable_read_query(&query, retry_options);
425 let mut attempt = 0;
426
427 let rows = loop {
428 let prepared = PreparedQuery::from_compiled(query.clone());
429 prepared.validate_parameter_count()?;
430 let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
431
432 match trace_query(tracing_options, slow_query_options, trace, async {
433 prepared
434 .query_driver(client)
435 .await?
436 .into_first_result()
437 .await
438 })
439 .await
440 {
441 Ok(rows) => break rows,
442 Err(error)
443 if retryable_query
444 && attempt < retry_options.max_retries
445 && is_transient_tiberius_error(&error) =>
446 {
447 attempt += 1;
448 let delay = retry_delay(retry_options, attempt);
449
450 tracing::warn!(
451 target: "orm.query.retry",
452 server_addr = %server_addr,
453 operation = %classify_sql(&query.sql),
454 attempt,
455 max_retries = retry_options.max_retries,
456 delay_ms = delay.as_millis(),
457 error_code = ?error.code(),
458 error = %error,
459 );
460
461 tokio::time::sleep(delay).await;
462 }
463 Err(error) => {
464 return Err(map_tiberius_error(
465 &error,
466 TiberiusErrorContext::ExecuteQuery,
467 ));
468 }
469 }
470 };
471
472 rows.iter()
473 .map(|row| T::from_row(&MssqlRow::new(row)))
474 .collect()
475}
476
477async fn fetch_all_compiled_with<S, T, F>(
478 client: &mut Client<S>,
479 query: CompiledQuery,
480 tracing_options: MssqlTracingOptions,
481 slow_query_options: MssqlSlowQueryOptions,
482 retry_options: MssqlRetryOptions,
483 server_addr: &str,
484 query_timeout: Option<std::time::Duration>,
485 mut map: F,
486) -> Result<Vec<T>, OrmError>
487where
488 S: AsyncRead + AsyncWrite + Unpin + Send,
489 F: FnMut(MssqlRow<'_>) -> Result<T, OrmError> + Send,
490{
491 let retryable_query = is_retryable_read_query(&query, retry_options);
492 let mut attempt = 0;
493
494 let rows = loop {
495 let prepared = PreparedQuery::from_compiled(query.clone());
496 prepared.validate_parameter_count()?;
497 let trace = QueryTrace::new(server_addr, query_timeout, tracing_options, &prepared);
498
499 match trace_query(tracing_options, slow_query_options, trace, async {
500 prepared
501 .query_driver(client)
502 .await?
503 .into_first_result()
504 .await
505 })
506 .await
507 {
508 Ok(rows) => break rows,
509 Err(error)
510 if retryable_query
511 && attempt < retry_options.max_retries
512 && is_transient_tiberius_error(&error) =>
513 {
514 attempt += 1;
515 let delay = retry_delay(retry_options, attempt);
516
517 tracing::warn!(
518 target: "orm.query.retry",
519 server_addr = %server_addr,
520 operation = %classify_sql(&query.sql),
521 attempt,
522 max_retries = retry_options.max_retries,
523 delay_ms = delay.as_millis(),
524 error_code = ?error.code(),
525 error = %error,
526 );
527
528 tokio::time::sleep(delay).await;
529 }
530 Err(error) => {
531 return Err(map_tiberius_error(
532 &error,
533 TiberiusErrorContext::ExecuteQuery,
534 ));
535 }
536 }
537 };
538
539 rows.iter().map(|row| map(MssqlRow::new(row))).collect()
540}
541
542fn is_retryable_read_query(query: &CompiledQuery, retry_options: MssqlRetryOptions) -> bool {
543 retry_options.enabled && retry_options.max_retries > 0 && classify_sql(&query.sql) == "select"
544}
545
546fn retry_delay(retry_options: MssqlRetryOptions, attempt: u32) -> Duration {
547 let multiplier = 1u32
548 .checked_shl(attempt.saturating_sub(1))
549 .unwrap_or(u32::MAX);
550 let base_millis = retry_options.base_delay.as_millis();
551 let max_millis = retry_options.max_delay.as_millis();
552 let scaled = base_millis.saturating_mul(u128::from(multiplier));
553
554 Duration::from_millis(scaled.min(max_millis) as u64)
555}
556
557#[cfg(test)]
558mod tests {
559 use super::{
560 ExecuteResult, fetch_all_compiled, fetch_one_compiled, is_retryable_read_query,
561 query_raw_compiled, retry_delay,
562 };
563 use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
564 use sql_orm_core::{FromRow, OrmError, Row};
565 use sql_orm_query::CompiledQuery;
566 use std::time::Duration;
567
568 struct TestRowModel;
569
570 impl FromRow for TestRowModel {
571 fn from_row<R: Row>(_row: &R) -> Result<Self, OrmError> {
572 Ok(Self)
573 }
574 }
575
576 #[test]
577 fn execute_result_exposes_rows_affected_and_total() {
578 let result = ExecuteResult::new(vec![1, 2, 3]);
579
580 assert_eq!(result.rows_affected(), &[1, 2, 3]);
581 assert_eq!(result.total(), 6);
582 }
583
584 #[test]
585 fn reuses_shared_execution_helpers_from_transaction_boundary() {
586 let query_raw = query_raw_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
587 let fetch_one =
588 fetch_one_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>, TestRowModel>;
589 let fetch_all =
590 fetch_all_compiled::<tokio_util::compat::Compat<tokio::net::TcpStream>, TestRowModel>;
591
592 let _ = (query_raw, fetch_one, fetch_all);
593 }
594
595 #[test]
596 fn compiled_query_helpers_accept_tracing_context_shape() {
597 let tracing = MssqlTracingOptions::enabled();
598 let slow_query = MssqlSlowQueryOptions::enabled(std::time::Duration::from_millis(250));
599
600 assert!(tracing.enabled);
601 assert!(slow_query.enabled);
602 }
603
604 #[test]
605 fn retry_policy_only_targets_select_queries() {
606 let retry = crate::config::MssqlRetryOptions::enabled(
607 2,
608 Duration::from_millis(50),
609 Duration::from_secs(1),
610 );
611
612 assert!(is_retryable_read_query(
613 &CompiledQuery::new("SELECT 1", vec![]),
614 retry
615 ));
616 assert!(!is_retryable_read_query(
617 &CompiledQuery::new("INSERT INTO [dbo].[users] DEFAULT VALUES", vec![]),
618 retry
619 ));
620 }
621
622 #[test]
623 fn retry_delay_caps_at_max_delay() {
624 let retry = crate::config::MssqlRetryOptions::enabled(
625 4,
626 Duration::from_millis(100),
627 Duration::from_millis(250),
628 );
629
630 assert_eq!(retry_delay(retry, 1), Duration::from_millis(100));
631 assert_eq!(retry_delay(retry, 2), Duration::from_millis(200));
632 assert_eq!(retry_delay(retry, 3), Duration::from_millis(250));
633 }
634}