Skip to main content

sql_orm_tiberius/
connection.rs

1use crate::config::MssqlConnectionConfig;
2use crate::error::{TiberiusErrorContext, map_tiberius_error};
3use crate::executor::fetch_one_compiled;
4use crate::telemetry::trace_connection;
5use crate::transaction::{
6    MssqlTransaction, begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
7};
8use futures_io::{AsyncRead, AsyncWrite};
9use sql_orm_core::OrmError;
10use sql_orm_query::CompiledQuery;
11use std::time::Duration;
12use tiberius::Client;
13use tokio::net::TcpStream;
14use tokio::time::timeout;
15use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
16
17pub type TokioConnectionStream = Compat<TcpStream>;
18
19pub struct MssqlConnection<S: AsyncRead + AsyncWrite + Unpin + Send = TokioConnectionStream> {
20    client: Client<S>,
21    config: MssqlConnectionConfig,
22}
23
24impl MssqlConnection<TokioConnectionStream> {
25    pub async fn connect(connection_string: &str) -> Result<Self, OrmError> {
26        let config = MssqlConnectionConfig::from_connection_string(connection_string)?;
27        Self::connect_with_config(config).await
28    }
29
30    pub async fn connect_with_config(config: MssqlConnectionConfig) -> Result<Self, OrmError> {
31        let tracing_options = config.options().tracing;
32        let connect_timeout = config.options().timeouts.connect_timeout;
33        let addr = config.addr();
34        let trace_addr = addr.clone();
35        let tiberius_config = config.tiberius_config().clone();
36
37        let client = trace_connection(tracing_options, &trace_addr, connect_timeout, async {
38            run_with_timeout(connect_timeout, "SQL Server connection timed out", async {
39                let tcp = TcpStream::connect(addr).await.map_err(|error| {
40                    map_tiberius_error(&error.into(), TiberiusErrorContext::ConnectTcp)
41                })?;
42                tcp.set_nodelay(true).map_err(|error| {
43                    map_tiberius_error(&error.into(), TiberiusErrorContext::ConfigureTcp)
44                })?;
45
46                Client::connect(tiberius_config, tcp.compat_write())
47                    .await
48                    .map_err(|error| {
49                        map_tiberius_error(&error, TiberiusErrorContext::InitializeClient)
50                    })
51            })
52            .await
53        })
54        .await?;
55
56        Ok(Self { client, config })
57    }
58}
59
60impl<S: AsyncRead + AsyncWrite + Unpin + Send> MssqlConnection<S> {
61    pub fn new(client: Client<S>, config: MssqlConnectionConfig) -> Self {
62        Self { client, config }
63    }
64
65    pub fn config(&self) -> &MssqlConnectionConfig {
66        &self.config
67    }
68
69    pub fn client(&self) -> &Client<S> {
70        &self.client
71    }
72
73    pub fn client_mut(&mut self) -> &mut Client<S> {
74        &mut self.client
75    }
76
77    pub(crate) fn query_timeout(&self) -> Option<Duration> {
78        self.config.options().timeouts.query_timeout
79    }
80
81    pub(crate) fn tracing_options(&self) -> crate::config::MssqlTracingOptions {
82        self.config.options().tracing
83    }
84
85    pub(crate) fn slow_query_options(&self) -> crate::config::MssqlSlowQueryOptions {
86        self.config.options().slow_query
87    }
88
89    pub(crate) fn retry_options(&self) -> crate::config::MssqlRetryOptions {
90        self.config.options().retry
91    }
92
93    #[doc(hidden)]
94    pub fn replace_retry_options(
95        &mut self,
96        retry: crate::config::MssqlRetryOptions,
97    ) -> crate::config::MssqlRetryOptions {
98        let previous = self.config.options().retry;
99        let options = self.config.options().clone().with_retry(retry);
100        self.config = self.config.clone().with_options(options);
101        previous
102    }
103
104    pub(crate) fn health_options(&self) -> crate::config::MssqlHealthCheckOptions {
105        self.config.options().health
106    }
107
108    pub(crate) fn server_addr(&self) -> String {
109        self.config.addr()
110    }
111
112    pub async fn begin_transaction<'a>(&'a mut self) -> Result<MssqlTransaction<'a, S>, OrmError> {
113        let query_timeout = self.query_timeout();
114        let tracing_options = self.tracing_options();
115        let slow_query_options = self.slow_query_options();
116        let server_addr = self.server_addr();
117        MssqlTransaction::begin(
118            self.client_mut(),
119            query_timeout,
120            tracing_options,
121            slow_query_options,
122            server_addr,
123        )
124        .await
125    }
126
127    pub async fn begin_transaction_scope(&mut self) -> Result<(), OrmError> {
128        let query_timeout = self.query_timeout();
129        let tracing_options = self.tracing_options();
130        let server_addr = self.server_addr();
131        begin_transaction_scope(
132            self.client_mut(),
133            query_timeout,
134            tracing_options,
135            &server_addr,
136        )
137        .await
138    }
139
140    pub async fn commit_transaction(&mut self) -> Result<(), OrmError> {
141        let query_timeout = self.query_timeout();
142        let tracing_options = self.tracing_options();
143        let server_addr = self.server_addr();
144        commit_transaction_scope(
145            self.client_mut(),
146            query_timeout,
147            tracing_options,
148            &server_addr,
149        )
150        .await
151    }
152
153    pub async fn rollback_transaction(&mut self) -> Result<(), OrmError> {
154        let query_timeout = self.query_timeout();
155        let tracing_options = self.tracing_options();
156        let server_addr = self.server_addr();
157        rollback_transaction_scope(
158            self.client_mut(),
159            query_timeout,
160            tracing_options,
161            &server_addr,
162        )
163        .await
164    }
165
166    pub async fn health_check(&mut self) -> Result<(), OrmError> {
167        let tracing_options = self.tracing_options();
168        let slow_query_options = self.slow_query_options();
169        let retry_options = self.retry_options();
170        let server_addr = self.server_addr();
171        let health_options = self.health_options();
172        let health_timeout = resolve_health_timeout(health_options, self.query_timeout());
173        let query = build_health_check_query(health_options);
174
175        let row = run_with_timeout(health_timeout, "SQL Server health check timed out", async {
176            fetch_one_compiled::<_, HealthCheckRow>(
177                self.client_mut(),
178                query,
179                tracing_options,
180                slow_query_options,
181                retry_options,
182                &server_addr,
183                health_timeout,
184            )
185            .await
186        })
187        .await?;
188
189        match row {
190            Some(HealthCheckRow { value: 1 }) => Ok(()),
191            Some(_) => Err(OrmError::new(
192                "SQL Server health check returned an unexpected value",
193            )),
194            None => Err(OrmError::new(
195                "SQL Server health check did not return a row",
196            )),
197        }
198    }
199
200    pub fn into_inner(self) -> Client<S> {
201        self.client
202    }
203}
204
205struct HealthCheckRow {
206    value: i32,
207}
208
209impl sql_orm_core::FromRow for HealthCheckRow {
210    fn from_row<R: sql_orm_core::Row>(row: &R) -> Result<Self, OrmError> {
211        Ok(Self {
212            value: row.get_required_typed::<i32>("health_check")?,
213        })
214    }
215}
216
217fn resolve_health_timeout(
218    health_options: crate::config::MssqlHealthCheckOptions,
219    query_timeout: Option<Duration>,
220) -> Option<Duration> {
221    health_options.timeout.or(query_timeout)
222}
223
224fn build_health_check_query(
225    health_options: crate::config::MssqlHealthCheckOptions,
226) -> CompiledQuery {
227    CompiledQuery::new(health_options.query.sql().to_string(), vec![])
228}
229
230pub(crate) async fn run_with_timeout<F, T>(
231    duration: Option<Duration>,
232    timeout_message: &'static str,
233    future: F,
234) -> Result<T, OrmError>
235where
236    F: core::future::Future<Output = Result<T, OrmError>>,
237{
238    match duration {
239        Some(duration) => timeout(duration, future)
240            .await
241            .map_err(|_| OrmError::new(timeout_message))?,
242        None => future.await,
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::{build_health_check_query, resolve_health_timeout, run_with_timeout};
249    use crate::config::{MssqlHealthCheckOptions, MssqlHealthCheckQuery};
250    use std::time::Duration;
251
252    #[test]
253    fn health_check_prefers_explicit_health_timeout_over_query_timeout() {
254        let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
255            .with_timeout(Duration::from_secs(3));
256
257        assert_eq!(
258            resolve_health_timeout(health, Some(Duration::from_secs(30))),
259            Some(Duration::from_secs(3))
260        );
261    }
262
263    #[test]
264    fn health_check_falls_back_to_query_timeout_when_no_dedicated_timeout_exists() {
265        let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne);
266
267        assert_eq!(
268            resolve_health_timeout(health, Some(Duration::from_secs(30))),
269            Some(Duration::from_secs(30))
270        );
271        assert_eq!(resolve_health_timeout(health, None), None);
272    }
273
274    #[test]
275    fn health_check_builds_expected_compiled_query() {
276        let query = build_health_check_query(MssqlHealthCheckOptions::enabled(
277            MssqlHealthCheckQuery::SelectOne,
278        ));
279
280        assert_eq!(query.sql, "SELECT 1 AS [health_check]");
281        assert!(query.params.is_empty());
282    }
283
284    #[tokio::test]
285    async fn run_with_timeout_returns_future_result_without_timeout() {
286        let value = run_with_timeout(None, "timeout", async {
287            Ok::<_, sql_orm_core::OrmError>(7)
288        })
289        .await
290        .unwrap();
291
292        assert_eq!(value, 7);
293    }
294
295    #[tokio::test]
296    async fn run_with_timeout_fails_when_future_exceeds_deadline() {
297        let error = run_with_timeout(
298            Some(Duration::from_millis(5)),
299            "SQL Server connection timed out",
300            async {
301                tokio::time::sleep(Duration::from_millis(25)).await;
302                Ok::<_, sql_orm_core::OrmError>(())
303            },
304        )
305        .await
306        .unwrap_err();
307
308        assert_eq!(error.message(), "SQL Server connection timed out");
309    }
310}