Skip to main content

sql_orm_tiberius/
connection.rs

1use crate::config::MssqlConnectionConfig;
2use crate::error::{TiberiusErrorContext, map_tiberius_error};
3use crate::executor::{QueryExecutionOptions, fetch_one_compiled};
4use crate::telemetry::trace_connection;
5use crate::transaction::{
6    MssqlTransaction, begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
7};
8use futures_io::{AsyncRead, AsyncWrite};
9use sql_orm_core::OrmError;
10use sql_orm_query::CompiledQuery;
11use std::time::Duration;
12use tiberius::Client;
13use tokio::net::TcpStream;
14use tokio::time::timeout;
15use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
16
17pub type TokioConnectionStream = Compat<TcpStream>;
18
19pub struct MssqlConnection<S: AsyncRead + AsyncWrite + Unpin + Send = TokioConnectionStream> {
20    client: Client<S>,
21    config: MssqlConnectionConfig,
22}
23
24impl MssqlConnection<TokioConnectionStream> {
25    pub async fn connect(connection_string: &str) -> Result<Self, OrmError> {
26        let config = MssqlConnectionConfig::from_connection_string(connection_string)?;
27        Self::connect_with_config(config).await
28    }
29
30    pub async fn connect_with_config(config: MssqlConnectionConfig) -> Result<Self, OrmError> {
31        let tracing_options = config.options().tracing;
32        let connect_timeout = config.options().timeouts.connect_timeout;
33        let addr = config.addr();
34        let trace_addr = addr.clone();
35        let tiberius_config = config.tiberius_config().clone();
36
37        let client = trace_connection(tracing_options, &trace_addr, connect_timeout, async {
38            run_with_timeout(connect_timeout, "SQL Server connection timed out", async {
39                let tcp = TcpStream::connect(addr).await.map_err(|error| {
40                    map_tiberius_error(&error.into(), TiberiusErrorContext::ConnectTcp)
41                })?;
42                tcp.set_nodelay(true).map_err(|error| {
43                    map_tiberius_error(&error.into(), TiberiusErrorContext::ConfigureTcp)
44                })?;
45
46                Client::connect(tiberius_config, tcp.compat_write())
47                    .await
48                    .map_err(|error| {
49                        map_tiberius_error(&error, TiberiusErrorContext::InitializeClient)
50                    })
51            })
52            .await
53        })
54        .await?;
55
56        Ok(Self { client, config })
57    }
58}
59
60impl<S: AsyncRead + AsyncWrite + Unpin + Send> MssqlConnection<S> {
61    pub fn new(client: Client<S>, config: MssqlConnectionConfig) -> Self {
62        Self { client, config }
63    }
64
65    pub fn config(&self) -> &MssqlConnectionConfig {
66        &self.config
67    }
68
69    pub fn client(&self) -> &Client<S> {
70        &self.client
71    }
72
73    pub fn client_mut(&mut self) -> &mut Client<S> {
74        &mut self.client
75    }
76
77    pub(crate) fn query_timeout(&self) -> Option<Duration> {
78        self.config.options().timeouts.query_timeout
79    }
80
81    pub(crate) fn tracing_options(&self) -> crate::config::MssqlTracingOptions {
82        self.config.options().tracing
83    }
84
85    pub(crate) fn slow_query_options(&self) -> crate::config::MssqlSlowQueryOptions {
86        self.config.options().slow_query
87    }
88
89    pub(crate) fn retry_options(&self) -> crate::config::MssqlRetryOptions {
90        self.config.options().retry
91    }
92
93    #[doc(hidden)]
94    pub fn replace_retry_options(
95        &mut self,
96        retry: crate::config::MssqlRetryOptions,
97    ) -> crate::config::MssqlRetryOptions {
98        let previous = self.config.options().retry;
99        let options = self.config.options().clone().with_retry(retry);
100        self.config = self.config.clone().with_options(options);
101        previous
102    }
103
104    pub(crate) fn health_options(&self) -> crate::config::MssqlHealthCheckOptions {
105        self.config.options().health
106    }
107
108    pub(crate) fn server_addr(&self) -> String {
109        self.config.addr()
110    }
111
112    pub async fn begin_transaction<'a>(&'a mut self) -> Result<MssqlTransaction<'a, S>, OrmError> {
113        let query_timeout = self.query_timeout();
114        let tracing_options = self.tracing_options();
115        let slow_query_options = self.slow_query_options();
116        let server_addr = self.server_addr();
117        MssqlTransaction::begin(
118            self.client_mut(),
119            query_timeout,
120            tracing_options,
121            slow_query_options,
122            server_addr,
123        )
124        .await
125    }
126
127    pub async fn begin_transaction_scope(&mut self) -> Result<(), OrmError> {
128        let query_timeout = self.query_timeout();
129        let tracing_options = self.tracing_options();
130        let server_addr = self.server_addr();
131        begin_transaction_scope(
132            self.client_mut(),
133            query_timeout,
134            tracing_options,
135            &server_addr,
136        )
137        .await
138    }
139
140    pub async fn commit_transaction(&mut self) -> Result<(), OrmError> {
141        let query_timeout = self.query_timeout();
142        let tracing_options = self.tracing_options();
143        let server_addr = self.server_addr();
144        commit_transaction_scope(
145            self.client_mut(),
146            query_timeout,
147            tracing_options,
148            &server_addr,
149        )
150        .await
151    }
152
153    pub async fn rollback_transaction(&mut self) -> Result<(), OrmError> {
154        let query_timeout = self.query_timeout();
155        let tracing_options = self.tracing_options();
156        let server_addr = self.server_addr();
157        rollback_transaction_scope(
158            self.client_mut(),
159            query_timeout,
160            tracing_options,
161            &server_addr,
162        )
163        .await
164    }
165
166    pub async fn health_check(&mut self) -> Result<(), OrmError> {
167        let tracing_options = self.tracing_options();
168        let slow_query_options = self.slow_query_options();
169        let retry_options = self.retry_options();
170        let server_addr = self.server_addr();
171        let health_options = self.health_options();
172        let health_timeout = resolve_health_timeout(health_options, self.query_timeout());
173        let query = build_health_check_query(health_options);
174
175        let row = run_with_timeout(health_timeout, "SQL Server health check timed out", async {
176            fetch_one_compiled::<_, HealthCheckRow>(
177                self.client_mut(),
178                query,
179                QueryExecutionOptions {
180                    tracing: tracing_options,
181                    slow_query: slow_query_options,
182                    retry: retry_options,
183                    server_addr: &server_addr,
184                    timeout: health_timeout,
185                },
186            )
187            .await
188        })
189        .await?;
190
191        match row {
192            Some(HealthCheckRow { value: 1 }) => Ok(()),
193            Some(_) => Err(OrmError::connection(
194                "SQL Server health check returned an unexpected value",
195            )),
196            None => Err(OrmError::connection(
197                "SQL Server health check did not return a row",
198            )),
199        }
200    }
201
202    pub fn into_inner(self) -> Client<S> {
203        self.client
204    }
205}
206
207struct HealthCheckRow {
208    value: i32,
209}
210
211impl sql_orm_core::FromRow for HealthCheckRow {
212    fn from_row<R: sql_orm_core::Row>(row: &R) -> Result<Self, OrmError> {
213        Ok(Self {
214            value: row.get_required_typed::<i32>("health_check")?,
215        })
216    }
217}
218
219fn resolve_health_timeout(
220    health_options: crate::config::MssqlHealthCheckOptions,
221    query_timeout: Option<Duration>,
222) -> Option<Duration> {
223    health_options.timeout.or(query_timeout)
224}
225
226fn build_health_check_query(
227    health_options: crate::config::MssqlHealthCheckOptions,
228) -> CompiledQuery {
229    CompiledQuery::read_only(health_options.query.sql().to_string(), vec![])
230}
231
232pub(crate) async fn run_with_timeout<F, T>(
233    duration: Option<Duration>,
234    timeout_message: &'static str,
235    future: F,
236) -> Result<T, OrmError>
237where
238    F: core::future::Future<Output = Result<T, OrmError>>,
239{
240    match duration {
241        Some(duration) => timeout(duration, future)
242            .await
243            .map_err(|_| timeout_error(timeout_message))?,
244        None => future.await,
245    }
246}
247
248fn timeout_error(message: &'static str) -> OrmError {
249    if message.contains("connection") || message.contains("health check") {
250        OrmError::connection(message)
251    } else {
252        OrmError::execution(message)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::{build_health_check_query, resolve_health_timeout, run_with_timeout};
259    use crate::config::{MssqlHealthCheckOptions, MssqlHealthCheckQuery};
260    use sql_orm_core::OrmErrorKind;
261    use std::time::Duration;
262
263    #[test]
264    fn health_check_prefers_explicit_health_timeout_over_query_timeout() {
265        let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
266            .with_timeout(Duration::from_secs(3));
267
268        assert_eq!(
269            resolve_health_timeout(health, Some(Duration::from_secs(30))),
270            Some(Duration::from_secs(3))
271        );
272    }
273
274    #[test]
275    fn health_check_falls_back_to_query_timeout_when_no_dedicated_timeout_exists() {
276        let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne);
277
278        assert_eq!(
279            resolve_health_timeout(health, Some(Duration::from_secs(30))),
280            Some(Duration::from_secs(30))
281        );
282        assert_eq!(resolve_health_timeout(health, None), None);
283    }
284
285    #[test]
286    fn health_check_builds_expected_compiled_query() {
287        let query = build_health_check_query(MssqlHealthCheckOptions::enabled(
288            MssqlHealthCheckQuery::SelectOne,
289        ));
290
291        assert_eq!(query.sql, "SELECT 1 AS [health_check]");
292        assert!(query.params.is_empty());
293    }
294
295    #[tokio::test]
296    async fn run_with_timeout_returns_future_result_without_timeout() {
297        let value = run_with_timeout(None, "timeout", async {
298            Ok::<_, sql_orm_core::OrmError>(7)
299        })
300        .await
301        .unwrap();
302
303        assert_eq!(value, 7);
304    }
305
306    #[tokio::test]
307    async fn run_with_timeout_fails_when_future_exceeds_deadline() {
308        let error = run_with_timeout(
309            Some(Duration::from_millis(5)),
310            "SQL Server connection timed out",
311            async {
312                tokio::time::sleep(Duration::from_millis(25)).await;
313                Ok::<_, sql_orm_core::OrmError>(())
314            },
315        )
316        .await
317        .unwrap_err();
318
319        assert_eq!(error.message(), "SQL Server connection timed out");
320        assert_eq!(error.kind(), OrmErrorKind::Connection);
321    }
322
323    #[tokio::test]
324    async fn run_with_timeout_classifies_query_timeout_as_execution() {
325        let error = run_with_timeout(
326            Some(Duration::from_millis(5)),
327            "SQL Server query timed out",
328            async {
329                tokio::time::sleep(Duration::from_millis(25)).await;
330                Ok::<_, sql_orm_core::OrmError>(())
331            },
332        )
333        .await
334        .unwrap_err();
335
336        assert_eq!(error.message(), "SQL Server query timed out");
337        assert_eq!(error.kind(), OrmErrorKind::Execution);
338    }
339}