Skip to main content

sql_orm_tiberius/
connection.rs

1use crate::config::MssqlConnectionConfig;
2use crate::error::{TiberiusErrorContext, map_tiberius_error};
3use crate::executor::fetch_one_compiled;
4use crate::telemetry::trace_connection;
5use crate::transaction::{
6    MssqlTransaction, begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
7};
8use futures_io::{AsyncRead, AsyncWrite};
9use sql_orm_core::OrmError;
10use sql_orm_query::CompiledQuery;
11use std::time::Duration;
12use tiberius::Client;
13use tokio::net::TcpStream;
14use tokio::time::timeout;
15use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
16
17pub type TokioConnectionStream = Compat<TcpStream>;
18
19pub struct MssqlConnection<S: AsyncRead + AsyncWrite + Unpin + Send = TokioConnectionStream> {
20    client: Client<S>,
21    config: MssqlConnectionConfig,
22}
23
24impl MssqlConnection<TokioConnectionStream> {
25    pub async fn connect(connection_string: &str) -> Result<Self, OrmError> {
26        let config = MssqlConnectionConfig::from_connection_string(connection_string)?;
27        Self::connect_with_config(config).await
28    }
29
30    pub async fn connect_with_config(config: MssqlConnectionConfig) -> Result<Self, OrmError> {
31        let tracing_options = config.options().tracing;
32        let connect_timeout = config.options().timeouts.connect_timeout;
33        let addr = config.addr();
34        let trace_addr = addr.clone();
35        let tiberius_config = config.tiberius_config().clone();
36
37        let client = trace_connection(tracing_options, &trace_addr, connect_timeout, async {
38            run_with_timeout(connect_timeout, "SQL Server connection timed out", async {
39                let tcp = TcpStream::connect(addr).await.map_err(|error| {
40                    map_tiberius_error(&error.into(), TiberiusErrorContext::ConnectTcp)
41                })?;
42                tcp.set_nodelay(true).map_err(|error| {
43                    map_tiberius_error(&error.into(), TiberiusErrorContext::ConfigureTcp)
44                })?;
45
46                Client::connect(tiberius_config, tcp.compat_write())
47                    .await
48                    .map_err(|error| {
49                        map_tiberius_error(&error, TiberiusErrorContext::InitializeClient)
50                    })
51            })
52            .await
53        })
54        .await?;
55
56        Ok(Self { client, config })
57    }
58}
59
60impl<S: AsyncRead + AsyncWrite + Unpin + Send> MssqlConnection<S> {
61    pub fn new(client: Client<S>, config: MssqlConnectionConfig) -> Self {
62        Self { client, config }
63    }
64
65    pub fn config(&self) -> &MssqlConnectionConfig {
66        &self.config
67    }
68
69    pub fn client(&self) -> &Client<S> {
70        &self.client
71    }
72
73    pub fn client_mut(&mut self) -> &mut Client<S> {
74        &mut self.client
75    }
76
77    pub(crate) fn query_timeout(&self) -> Option<Duration> {
78        self.config.options().timeouts.query_timeout
79    }
80
81    pub(crate) fn tracing_options(&self) -> crate::config::MssqlTracingOptions {
82        self.config.options().tracing
83    }
84
85    pub(crate) fn slow_query_options(&self) -> crate::config::MssqlSlowQueryOptions {
86        self.config.options().slow_query
87    }
88
89    pub(crate) fn retry_options(&self) -> crate::config::MssqlRetryOptions {
90        self.config.options().retry
91    }
92
93    pub(crate) fn health_options(&self) -> crate::config::MssqlHealthCheckOptions {
94        self.config.options().health
95    }
96
97    pub(crate) fn server_addr(&self) -> String {
98        self.config.addr()
99    }
100
101    pub async fn begin_transaction<'a>(&'a mut self) -> Result<MssqlTransaction<'a, S>, OrmError> {
102        let query_timeout = self.query_timeout();
103        let tracing_options = self.tracing_options();
104        let slow_query_options = self.slow_query_options();
105        let server_addr = self.server_addr();
106        MssqlTransaction::begin(
107            self.client_mut(),
108            query_timeout,
109            tracing_options,
110            slow_query_options,
111            server_addr,
112        )
113        .await
114    }
115
116    pub async fn begin_transaction_scope(&mut self) -> Result<(), OrmError> {
117        let query_timeout = self.query_timeout();
118        let tracing_options = self.tracing_options();
119        let server_addr = self.server_addr();
120        begin_transaction_scope(
121            self.client_mut(),
122            query_timeout,
123            tracing_options,
124            &server_addr,
125        )
126        .await
127    }
128
129    pub async fn commit_transaction(&mut self) -> Result<(), OrmError> {
130        let query_timeout = self.query_timeout();
131        let tracing_options = self.tracing_options();
132        let server_addr = self.server_addr();
133        commit_transaction_scope(
134            self.client_mut(),
135            query_timeout,
136            tracing_options,
137            &server_addr,
138        )
139        .await
140    }
141
142    pub async fn rollback_transaction(&mut self) -> Result<(), OrmError> {
143        let query_timeout = self.query_timeout();
144        let tracing_options = self.tracing_options();
145        let server_addr = self.server_addr();
146        rollback_transaction_scope(
147            self.client_mut(),
148            query_timeout,
149            tracing_options,
150            &server_addr,
151        )
152        .await
153    }
154
155    pub async fn health_check(&mut self) -> Result<(), OrmError> {
156        let tracing_options = self.tracing_options();
157        let slow_query_options = self.slow_query_options();
158        let retry_options = self.retry_options();
159        let server_addr = self.server_addr();
160        let health_options = self.health_options();
161        let health_timeout = resolve_health_timeout(health_options, self.query_timeout());
162        let query = build_health_check_query(health_options);
163
164        let row = run_with_timeout(health_timeout, "SQL Server health check timed out", async {
165            fetch_one_compiled::<_, HealthCheckRow>(
166                self.client_mut(),
167                query,
168                tracing_options,
169                slow_query_options,
170                retry_options,
171                &server_addr,
172                health_timeout,
173            )
174            .await
175        })
176        .await?;
177
178        match row {
179            Some(HealthCheckRow { value: 1 }) => Ok(()),
180            Some(_) => Err(OrmError::new(
181                "SQL Server health check returned an unexpected value",
182            )),
183            None => Err(OrmError::new(
184                "SQL Server health check did not return a row",
185            )),
186        }
187    }
188
189    pub fn into_inner(self) -> Client<S> {
190        self.client
191    }
192}
193
194struct HealthCheckRow {
195    value: i32,
196}
197
198impl sql_orm_core::FromRow for HealthCheckRow {
199    fn from_row<R: sql_orm_core::Row>(row: &R) -> Result<Self, OrmError> {
200        Ok(Self {
201            value: row.get_required_typed::<i32>("health_check")?,
202        })
203    }
204}
205
206fn resolve_health_timeout(
207    health_options: crate::config::MssqlHealthCheckOptions,
208    query_timeout: Option<Duration>,
209) -> Option<Duration> {
210    health_options.timeout.or(query_timeout)
211}
212
213fn build_health_check_query(
214    health_options: crate::config::MssqlHealthCheckOptions,
215) -> CompiledQuery {
216    CompiledQuery::new(health_options.query.sql().to_string(), vec![])
217}
218
219pub(crate) async fn run_with_timeout<F, T>(
220    duration: Option<Duration>,
221    timeout_message: &'static str,
222    future: F,
223) -> Result<T, OrmError>
224where
225    F: core::future::Future<Output = Result<T, OrmError>>,
226{
227    match duration {
228        Some(duration) => timeout(duration, future)
229            .await
230            .map_err(|_| OrmError::new(timeout_message))?,
231        None => future.await,
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::{build_health_check_query, resolve_health_timeout, run_with_timeout};
238    use crate::config::{MssqlHealthCheckOptions, MssqlHealthCheckQuery};
239    use std::time::Duration;
240
241    #[test]
242    fn health_check_prefers_explicit_health_timeout_over_query_timeout() {
243        let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
244            .with_timeout(Duration::from_secs(3));
245
246        assert_eq!(
247            resolve_health_timeout(health, Some(Duration::from_secs(30))),
248            Some(Duration::from_secs(3))
249        );
250    }
251
252    #[test]
253    fn health_check_falls_back_to_query_timeout_when_no_dedicated_timeout_exists() {
254        let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne);
255
256        assert_eq!(
257            resolve_health_timeout(health, Some(Duration::from_secs(30))),
258            Some(Duration::from_secs(30))
259        );
260        assert_eq!(resolve_health_timeout(health, None), None);
261    }
262
263    #[test]
264    fn health_check_builds_expected_compiled_query() {
265        let query = build_health_check_query(MssqlHealthCheckOptions::enabled(
266            MssqlHealthCheckQuery::SelectOne,
267        ));
268
269        assert_eq!(query.sql, "SELECT 1 AS [health_check]");
270        assert!(query.params.is_empty());
271    }
272
273    #[tokio::test]
274    async fn run_with_timeout_returns_future_result_without_timeout() {
275        let value = run_with_timeout(None, "timeout", async {
276            Ok::<_, sql_orm_core::OrmError>(7)
277        })
278        .await
279        .unwrap();
280
281        assert_eq!(value, 7);
282    }
283
284    #[tokio::test]
285    async fn run_with_timeout_fails_when_future_exceeds_deadline() {
286        let error = run_with_timeout(
287            Some(Duration::from_millis(5)),
288            "SQL Server connection timed out",
289            async {
290                tokio::time::sleep(Duration::from_millis(25)).await;
291                Ok::<_, sql_orm_core::OrmError>(())
292            },
293        )
294        .await
295        .unwrap_err();
296
297        assert_eq!(error.message(), "SQL Server connection timed out");
298    }
299}