1use crate::config::MssqlConnectionConfig;
2use crate::error::{TiberiusErrorContext, map_tiberius_error};
3use crate::executor::{QueryExecutionOptions, fetch_one_compiled};
4use crate::telemetry::trace_connection;
5use crate::transaction::{
6 MssqlTransaction, begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
7};
8use futures_io::{AsyncRead, AsyncWrite};
9use sql_orm_core::OrmError;
10use sql_orm_query::CompiledQuery;
11use std::time::Duration;
12use tiberius::Client;
13use tokio::net::TcpStream;
14use tokio::time::timeout;
15use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
16
17pub type TokioConnectionStream = Compat<TcpStream>;
18
19pub struct MssqlConnection<S: AsyncRead + AsyncWrite + Unpin + Send = TokioConnectionStream> {
20 client: Client<S>,
21 config: MssqlConnectionConfig,
22}
23
24impl MssqlConnection<TokioConnectionStream> {
25 pub async fn connect(connection_string: &str) -> Result<Self, OrmError> {
26 let config = MssqlConnectionConfig::from_connection_string(connection_string)?;
27 Self::connect_with_config(config).await
28 }
29
30 pub async fn connect_with_config(config: MssqlConnectionConfig) -> Result<Self, OrmError> {
31 let tracing_options = config.options().tracing;
32 let connect_timeout = config.options().timeouts.connect_timeout;
33 let addr = config.addr();
34 let trace_addr = addr.clone();
35 let tiberius_config = config.tiberius_config().clone();
36
37 let client = trace_connection(tracing_options, &trace_addr, connect_timeout, async {
38 run_with_timeout(connect_timeout, "SQL Server connection timed out", async {
39 let tcp = TcpStream::connect(addr).await.map_err(|error| {
40 map_tiberius_error(&error.into(), TiberiusErrorContext::ConnectTcp)
41 })?;
42 tcp.set_nodelay(true).map_err(|error| {
43 map_tiberius_error(&error.into(), TiberiusErrorContext::ConfigureTcp)
44 })?;
45
46 Client::connect(tiberius_config, tcp.compat_write())
47 .await
48 .map_err(|error| {
49 map_tiberius_error(&error, TiberiusErrorContext::InitializeClient)
50 })
51 })
52 .await
53 })
54 .await?;
55
56 Ok(Self { client, config })
57 }
58}
59
60impl<S: AsyncRead + AsyncWrite + Unpin + Send> MssqlConnection<S> {
61 pub fn new(client: Client<S>, config: MssqlConnectionConfig) -> Self {
62 Self { client, config }
63 }
64
65 pub fn config(&self) -> &MssqlConnectionConfig {
66 &self.config
67 }
68
69 pub fn client(&self) -> &Client<S> {
70 &self.client
71 }
72
73 pub fn client_mut(&mut self) -> &mut Client<S> {
74 &mut self.client
75 }
76
77 pub(crate) fn query_timeout(&self) -> Option<Duration> {
78 self.config.options().timeouts.query_timeout
79 }
80
81 pub(crate) fn tracing_options(&self) -> crate::config::MssqlTracingOptions {
82 self.config.options().tracing
83 }
84
85 pub(crate) fn slow_query_options(&self) -> crate::config::MssqlSlowQueryOptions {
86 self.config.options().slow_query
87 }
88
89 pub(crate) fn retry_options(&self) -> crate::config::MssqlRetryOptions {
90 self.config.options().retry
91 }
92
93 #[doc(hidden)]
94 pub fn replace_retry_options(
95 &mut self,
96 retry: crate::config::MssqlRetryOptions,
97 ) -> crate::config::MssqlRetryOptions {
98 let previous = self.config.options().retry;
99 let options = self.config.options().clone().with_retry(retry);
100 self.config = self.config.clone().with_options(options);
101 previous
102 }
103
104 pub(crate) fn health_options(&self) -> crate::config::MssqlHealthCheckOptions {
105 self.config.options().health
106 }
107
108 pub(crate) fn server_addr(&self) -> String {
109 self.config.addr()
110 }
111
112 pub async fn begin_transaction<'a>(&'a mut self) -> Result<MssqlTransaction<'a, S>, OrmError> {
113 let query_timeout = self.query_timeout();
114 let tracing_options = self.tracing_options();
115 let slow_query_options = self.slow_query_options();
116 let server_addr = self.server_addr();
117 MssqlTransaction::begin(
118 self.client_mut(),
119 query_timeout,
120 tracing_options,
121 slow_query_options,
122 server_addr,
123 )
124 .await
125 }
126
127 pub async fn begin_transaction_scope(&mut self) -> Result<(), OrmError> {
128 let query_timeout = self.query_timeout();
129 let tracing_options = self.tracing_options();
130 let server_addr = self.server_addr();
131 begin_transaction_scope(
132 self.client_mut(),
133 query_timeout,
134 tracing_options,
135 &server_addr,
136 )
137 .await
138 }
139
140 pub async fn commit_transaction(&mut self) -> Result<(), OrmError> {
141 let query_timeout = self.query_timeout();
142 let tracing_options = self.tracing_options();
143 let server_addr = self.server_addr();
144 commit_transaction_scope(
145 self.client_mut(),
146 query_timeout,
147 tracing_options,
148 &server_addr,
149 )
150 .await
151 }
152
153 pub async fn rollback_transaction(&mut self) -> Result<(), OrmError> {
154 let query_timeout = self.query_timeout();
155 let tracing_options = self.tracing_options();
156 let server_addr = self.server_addr();
157 rollback_transaction_scope(
158 self.client_mut(),
159 query_timeout,
160 tracing_options,
161 &server_addr,
162 )
163 .await
164 }
165
166 pub async fn health_check(&mut self) -> Result<(), OrmError> {
167 let tracing_options = self.tracing_options();
168 let slow_query_options = self.slow_query_options();
169 let retry_options = self.retry_options();
170 let server_addr = self.server_addr();
171 let health_options = self.health_options();
172 let health_timeout = resolve_health_timeout(health_options, self.query_timeout());
173 let query = build_health_check_query(health_options);
174
175 let row = run_with_timeout(health_timeout, "SQL Server health check timed out", async {
176 fetch_one_compiled::<_, HealthCheckRow>(
177 self.client_mut(),
178 query,
179 QueryExecutionOptions {
180 tracing: tracing_options,
181 slow_query: slow_query_options,
182 retry: retry_options,
183 server_addr: &server_addr,
184 timeout: health_timeout,
185 },
186 )
187 .await
188 })
189 .await?;
190
191 match row {
192 Some(HealthCheckRow { value: 1 }) => Ok(()),
193 Some(_) => Err(OrmError::connection(
194 "SQL Server health check returned an unexpected value",
195 )),
196 None => Err(OrmError::connection(
197 "SQL Server health check did not return a row",
198 )),
199 }
200 }
201
202 pub fn into_inner(self) -> Client<S> {
203 self.client
204 }
205}
206
207struct HealthCheckRow {
208 value: i32,
209}
210
211impl sql_orm_core::FromRow for HealthCheckRow {
212 fn from_row<R: sql_orm_core::Row>(row: &R) -> Result<Self, OrmError> {
213 Ok(Self {
214 value: row.get_required_typed::<i32>("health_check")?,
215 })
216 }
217}
218
219fn resolve_health_timeout(
220 health_options: crate::config::MssqlHealthCheckOptions,
221 query_timeout: Option<Duration>,
222) -> Option<Duration> {
223 health_options.timeout.or(query_timeout)
224}
225
226fn build_health_check_query(
227 health_options: crate::config::MssqlHealthCheckOptions,
228) -> CompiledQuery {
229 CompiledQuery::read_only(health_options.query.sql().to_string(), vec![])
230}
231
232pub(crate) async fn run_with_timeout<F, T>(
233 duration: Option<Duration>,
234 timeout_message: &'static str,
235 future: F,
236) -> Result<T, OrmError>
237where
238 F: core::future::Future<Output = Result<T, OrmError>>,
239{
240 match duration {
241 Some(duration) => timeout(duration, future)
242 .await
243 .map_err(|_| timeout_error(timeout_message))?,
244 None => future.await,
245 }
246}
247
248fn timeout_error(message: &'static str) -> OrmError {
249 if message.contains("connection") || message.contains("health check") {
250 OrmError::connection(message)
251 } else {
252 OrmError::execution(message)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::{build_health_check_query, resolve_health_timeout, run_with_timeout};
259 use crate::config::{MssqlHealthCheckOptions, MssqlHealthCheckQuery};
260 use sql_orm_core::OrmErrorKind;
261 use std::time::Duration;
262
263 #[test]
264 fn health_check_prefers_explicit_health_timeout_over_query_timeout() {
265 let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
266 .with_timeout(Duration::from_secs(3));
267
268 assert_eq!(
269 resolve_health_timeout(health, Some(Duration::from_secs(30))),
270 Some(Duration::from_secs(3))
271 );
272 }
273
274 #[test]
275 fn health_check_falls_back_to_query_timeout_when_no_dedicated_timeout_exists() {
276 let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne);
277
278 assert_eq!(
279 resolve_health_timeout(health, Some(Duration::from_secs(30))),
280 Some(Duration::from_secs(30))
281 );
282 assert_eq!(resolve_health_timeout(health, None), None);
283 }
284
285 #[test]
286 fn health_check_builds_expected_compiled_query() {
287 let query = build_health_check_query(MssqlHealthCheckOptions::enabled(
288 MssqlHealthCheckQuery::SelectOne,
289 ));
290
291 assert_eq!(query.sql, "SELECT 1 AS [health_check]");
292 assert!(query.params.is_empty());
293 }
294
295 #[tokio::test]
296 async fn run_with_timeout_returns_future_result_without_timeout() {
297 let value = run_with_timeout(None, "timeout", async {
298 Ok::<_, sql_orm_core::OrmError>(7)
299 })
300 .await
301 .unwrap();
302
303 assert_eq!(value, 7);
304 }
305
306 #[tokio::test]
307 async fn run_with_timeout_fails_when_future_exceeds_deadline() {
308 let error = run_with_timeout(
309 Some(Duration::from_millis(5)),
310 "SQL Server connection timed out",
311 async {
312 tokio::time::sleep(Duration::from_millis(25)).await;
313 Ok::<_, sql_orm_core::OrmError>(())
314 },
315 )
316 .await
317 .unwrap_err();
318
319 assert_eq!(error.message(), "SQL Server connection timed out");
320 assert_eq!(error.kind(), OrmErrorKind::Connection);
321 }
322
323 #[tokio::test]
324 async fn run_with_timeout_classifies_query_timeout_as_execution() {
325 let error = run_with_timeout(
326 Some(Duration::from_millis(5)),
327 "SQL Server query timed out",
328 async {
329 tokio::time::sleep(Duration::from_millis(25)).await;
330 Ok::<_, sql_orm_core::OrmError>(())
331 },
332 )
333 .await
334 .unwrap_err();
335
336 assert_eq!(error.message(), "SQL Server query timed out");
337 assert_eq!(error.kind(), OrmErrorKind::Execution);
338 }
339}