1use crate::config::MssqlConnectionConfig;
2use crate::error::{TiberiusErrorContext, map_tiberius_error};
3use crate::executor::fetch_one_compiled;
4use crate::telemetry::trace_connection;
5use crate::transaction::{
6 MssqlTransaction, begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
7};
8use futures_io::{AsyncRead, AsyncWrite};
9use sql_orm_core::OrmError;
10use sql_orm_query::CompiledQuery;
11use std::time::Duration;
12use tiberius::Client;
13use tokio::net::TcpStream;
14use tokio::time::timeout;
15use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
16
17pub type TokioConnectionStream = Compat<TcpStream>;
18
19pub struct MssqlConnection<S: AsyncRead + AsyncWrite + Unpin + Send = TokioConnectionStream> {
20 client: Client<S>,
21 config: MssqlConnectionConfig,
22}
23
24impl MssqlConnection<TokioConnectionStream> {
25 pub async fn connect(connection_string: &str) -> Result<Self, OrmError> {
26 let config = MssqlConnectionConfig::from_connection_string(connection_string)?;
27 Self::connect_with_config(config).await
28 }
29
30 pub async fn connect_with_config(config: MssqlConnectionConfig) -> Result<Self, OrmError> {
31 let tracing_options = config.options().tracing;
32 let connect_timeout = config.options().timeouts.connect_timeout;
33 let addr = config.addr();
34 let trace_addr = addr.clone();
35 let tiberius_config = config.tiberius_config().clone();
36
37 let client = trace_connection(tracing_options, &trace_addr, connect_timeout, async {
38 run_with_timeout(connect_timeout, "SQL Server connection timed out", async {
39 let tcp = TcpStream::connect(addr).await.map_err(|error| {
40 map_tiberius_error(&error.into(), TiberiusErrorContext::ConnectTcp)
41 })?;
42 tcp.set_nodelay(true).map_err(|error| {
43 map_tiberius_error(&error.into(), TiberiusErrorContext::ConfigureTcp)
44 })?;
45
46 Client::connect(tiberius_config, tcp.compat_write())
47 .await
48 .map_err(|error| {
49 map_tiberius_error(&error, TiberiusErrorContext::InitializeClient)
50 })
51 })
52 .await
53 })
54 .await?;
55
56 Ok(Self { client, config })
57 }
58}
59
60impl<S: AsyncRead + AsyncWrite + Unpin + Send> MssqlConnection<S> {
61 pub fn new(client: Client<S>, config: MssqlConnectionConfig) -> Self {
62 Self { client, config }
63 }
64
65 pub fn config(&self) -> &MssqlConnectionConfig {
66 &self.config
67 }
68
69 pub fn client(&self) -> &Client<S> {
70 &self.client
71 }
72
73 pub fn client_mut(&mut self) -> &mut Client<S> {
74 &mut self.client
75 }
76
77 pub(crate) fn query_timeout(&self) -> Option<Duration> {
78 self.config.options().timeouts.query_timeout
79 }
80
81 pub(crate) fn tracing_options(&self) -> crate::config::MssqlTracingOptions {
82 self.config.options().tracing
83 }
84
85 pub(crate) fn slow_query_options(&self) -> crate::config::MssqlSlowQueryOptions {
86 self.config.options().slow_query
87 }
88
89 pub(crate) fn retry_options(&self) -> crate::config::MssqlRetryOptions {
90 self.config.options().retry
91 }
92
93 #[doc(hidden)]
94 pub fn replace_retry_options(
95 &mut self,
96 retry: crate::config::MssqlRetryOptions,
97 ) -> crate::config::MssqlRetryOptions {
98 let previous = self.config.options().retry;
99 let options = self.config.options().clone().with_retry(retry);
100 self.config = self.config.clone().with_options(options);
101 previous
102 }
103
104 pub(crate) fn health_options(&self) -> crate::config::MssqlHealthCheckOptions {
105 self.config.options().health
106 }
107
108 pub(crate) fn server_addr(&self) -> String {
109 self.config.addr()
110 }
111
112 pub async fn begin_transaction<'a>(&'a mut self) -> Result<MssqlTransaction<'a, S>, OrmError> {
113 let query_timeout = self.query_timeout();
114 let tracing_options = self.tracing_options();
115 let slow_query_options = self.slow_query_options();
116 let server_addr = self.server_addr();
117 MssqlTransaction::begin(
118 self.client_mut(),
119 query_timeout,
120 tracing_options,
121 slow_query_options,
122 server_addr,
123 )
124 .await
125 }
126
127 pub async fn begin_transaction_scope(&mut self) -> Result<(), OrmError> {
128 let query_timeout = self.query_timeout();
129 let tracing_options = self.tracing_options();
130 let server_addr = self.server_addr();
131 begin_transaction_scope(
132 self.client_mut(),
133 query_timeout,
134 tracing_options,
135 &server_addr,
136 )
137 .await
138 }
139
140 pub async fn commit_transaction(&mut self) -> Result<(), OrmError> {
141 let query_timeout = self.query_timeout();
142 let tracing_options = self.tracing_options();
143 let server_addr = self.server_addr();
144 commit_transaction_scope(
145 self.client_mut(),
146 query_timeout,
147 tracing_options,
148 &server_addr,
149 )
150 .await
151 }
152
153 pub async fn rollback_transaction(&mut self) -> Result<(), OrmError> {
154 let query_timeout = self.query_timeout();
155 let tracing_options = self.tracing_options();
156 let server_addr = self.server_addr();
157 rollback_transaction_scope(
158 self.client_mut(),
159 query_timeout,
160 tracing_options,
161 &server_addr,
162 )
163 .await
164 }
165
166 pub async fn health_check(&mut self) -> Result<(), OrmError> {
167 let tracing_options = self.tracing_options();
168 let slow_query_options = self.slow_query_options();
169 let retry_options = self.retry_options();
170 let server_addr = self.server_addr();
171 let health_options = self.health_options();
172 let health_timeout = resolve_health_timeout(health_options, self.query_timeout());
173 let query = build_health_check_query(health_options);
174
175 let row = run_with_timeout(health_timeout, "SQL Server health check timed out", async {
176 fetch_one_compiled::<_, HealthCheckRow>(
177 self.client_mut(),
178 query,
179 tracing_options,
180 slow_query_options,
181 retry_options,
182 &server_addr,
183 health_timeout,
184 )
185 .await
186 })
187 .await?;
188
189 match row {
190 Some(HealthCheckRow { value: 1 }) => Ok(()),
191 Some(_) => Err(OrmError::new(
192 "SQL Server health check returned an unexpected value",
193 )),
194 None => Err(OrmError::new(
195 "SQL Server health check did not return a row",
196 )),
197 }
198 }
199
200 pub fn into_inner(self) -> Client<S> {
201 self.client
202 }
203}
204
205struct HealthCheckRow {
206 value: i32,
207}
208
209impl sql_orm_core::FromRow for HealthCheckRow {
210 fn from_row<R: sql_orm_core::Row>(row: &R) -> Result<Self, OrmError> {
211 Ok(Self {
212 value: row.get_required_typed::<i32>("health_check")?,
213 })
214 }
215}
216
217fn resolve_health_timeout(
218 health_options: crate::config::MssqlHealthCheckOptions,
219 query_timeout: Option<Duration>,
220) -> Option<Duration> {
221 health_options.timeout.or(query_timeout)
222}
223
224fn build_health_check_query(
225 health_options: crate::config::MssqlHealthCheckOptions,
226) -> CompiledQuery {
227 CompiledQuery::new(health_options.query.sql().to_string(), vec![])
228}
229
230pub(crate) async fn run_with_timeout<F, T>(
231 duration: Option<Duration>,
232 timeout_message: &'static str,
233 future: F,
234) -> Result<T, OrmError>
235where
236 F: core::future::Future<Output = Result<T, OrmError>>,
237{
238 match duration {
239 Some(duration) => timeout(duration, future)
240 .await
241 .map_err(|_| OrmError::new(timeout_message))?,
242 None => future.await,
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::{build_health_check_query, resolve_health_timeout, run_with_timeout};
249 use crate::config::{MssqlHealthCheckOptions, MssqlHealthCheckQuery};
250 use std::time::Duration;
251
252 #[test]
253 fn health_check_prefers_explicit_health_timeout_over_query_timeout() {
254 let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
255 .with_timeout(Duration::from_secs(3));
256
257 assert_eq!(
258 resolve_health_timeout(health, Some(Duration::from_secs(30))),
259 Some(Duration::from_secs(3))
260 );
261 }
262
263 #[test]
264 fn health_check_falls_back_to_query_timeout_when_no_dedicated_timeout_exists() {
265 let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne);
266
267 assert_eq!(
268 resolve_health_timeout(health, Some(Duration::from_secs(30))),
269 Some(Duration::from_secs(30))
270 );
271 assert_eq!(resolve_health_timeout(health, None), None);
272 }
273
274 #[test]
275 fn health_check_builds_expected_compiled_query() {
276 let query = build_health_check_query(MssqlHealthCheckOptions::enabled(
277 MssqlHealthCheckQuery::SelectOne,
278 ));
279
280 assert_eq!(query.sql, "SELECT 1 AS [health_check]");
281 assert!(query.params.is_empty());
282 }
283
284 #[tokio::test]
285 async fn run_with_timeout_returns_future_result_without_timeout() {
286 let value = run_with_timeout(None, "timeout", async {
287 Ok::<_, sql_orm_core::OrmError>(7)
288 })
289 .await
290 .unwrap();
291
292 assert_eq!(value, 7);
293 }
294
295 #[tokio::test]
296 async fn run_with_timeout_fails_when_future_exceeds_deadline() {
297 let error = run_with_timeout(
298 Some(Duration::from_millis(5)),
299 "SQL Server connection timed out",
300 async {
301 tokio::time::sleep(Duration::from_millis(25)).await;
302 Ok::<_, sql_orm_core::OrmError>(())
303 },
304 )
305 .await
306 .unwrap_err();
307
308 assert_eq!(error.message(), "SQL Server connection timed out");
309 }
310}