1use crate::config::MssqlConnectionConfig;
2use crate::error::{TiberiusErrorContext, map_tiberius_error};
3use crate::executor::fetch_one_compiled;
4use crate::telemetry::trace_connection;
5use crate::transaction::{
6 MssqlTransaction, begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
7};
8use futures_io::{AsyncRead, AsyncWrite};
9use sql_orm_core::OrmError;
10use sql_orm_query::CompiledQuery;
11use std::time::Duration;
12use tiberius::Client;
13use tokio::net::TcpStream;
14use tokio::time::timeout;
15use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
16
17pub type TokioConnectionStream = Compat<TcpStream>;
18
19pub struct MssqlConnection<S: AsyncRead + AsyncWrite + Unpin + Send = TokioConnectionStream> {
20 client: Client<S>,
21 config: MssqlConnectionConfig,
22}
23
24impl MssqlConnection<TokioConnectionStream> {
25 pub async fn connect(connection_string: &str) -> Result<Self, OrmError> {
26 let config = MssqlConnectionConfig::from_connection_string(connection_string)?;
27 Self::connect_with_config(config).await
28 }
29
30 pub async fn connect_with_config(config: MssqlConnectionConfig) -> Result<Self, OrmError> {
31 let tracing_options = config.options().tracing;
32 let connect_timeout = config.options().timeouts.connect_timeout;
33 let addr = config.addr();
34 let trace_addr = addr.clone();
35 let tiberius_config = config.tiberius_config().clone();
36
37 let client = trace_connection(tracing_options, &trace_addr, connect_timeout, async {
38 run_with_timeout(connect_timeout, "SQL Server connection timed out", async {
39 let tcp = TcpStream::connect(addr).await.map_err(|error| {
40 map_tiberius_error(&error.into(), TiberiusErrorContext::ConnectTcp)
41 })?;
42 tcp.set_nodelay(true).map_err(|error| {
43 map_tiberius_error(&error.into(), TiberiusErrorContext::ConfigureTcp)
44 })?;
45
46 Client::connect(tiberius_config, tcp.compat_write())
47 .await
48 .map_err(|error| {
49 map_tiberius_error(&error, TiberiusErrorContext::InitializeClient)
50 })
51 })
52 .await
53 })
54 .await?;
55
56 Ok(Self { client, config })
57 }
58}
59
60impl<S: AsyncRead + AsyncWrite + Unpin + Send> MssqlConnection<S> {
61 pub fn new(client: Client<S>, config: MssqlConnectionConfig) -> Self {
62 Self { client, config }
63 }
64
65 pub fn config(&self) -> &MssqlConnectionConfig {
66 &self.config
67 }
68
69 pub fn client(&self) -> &Client<S> {
70 &self.client
71 }
72
73 pub fn client_mut(&mut self) -> &mut Client<S> {
74 &mut self.client
75 }
76
77 pub(crate) fn query_timeout(&self) -> Option<Duration> {
78 self.config.options().timeouts.query_timeout
79 }
80
81 pub(crate) fn tracing_options(&self) -> crate::config::MssqlTracingOptions {
82 self.config.options().tracing
83 }
84
85 pub(crate) fn slow_query_options(&self) -> crate::config::MssqlSlowQueryOptions {
86 self.config.options().slow_query
87 }
88
89 pub(crate) fn retry_options(&self) -> crate::config::MssqlRetryOptions {
90 self.config.options().retry
91 }
92
93 pub(crate) fn health_options(&self) -> crate::config::MssqlHealthCheckOptions {
94 self.config.options().health
95 }
96
97 pub(crate) fn server_addr(&self) -> String {
98 self.config.addr()
99 }
100
101 pub async fn begin_transaction<'a>(&'a mut self) -> Result<MssqlTransaction<'a, S>, OrmError> {
102 let query_timeout = self.query_timeout();
103 let tracing_options = self.tracing_options();
104 let slow_query_options = self.slow_query_options();
105 let server_addr = self.server_addr();
106 MssqlTransaction::begin(
107 self.client_mut(),
108 query_timeout,
109 tracing_options,
110 slow_query_options,
111 server_addr,
112 )
113 .await
114 }
115
116 pub async fn begin_transaction_scope(&mut self) -> Result<(), OrmError> {
117 let query_timeout = self.query_timeout();
118 let tracing_options = self.tracing_options();
119 let server_addr = self.server_addr();
120 begin_transaction_scope(
121 self.client_mut(),
122 query_timeout,
123 tracing_options,
124 &server_addr,
125 )
126 .await
127 }
128
129 pub async fn commit_transaction(&mut self) -> Result<(), OrmError> {
130 let query_timeout = self.query_timeout();
131 let tracing_options = self.tracing_options();
132 let server_addr = self.server_addr();
133 commit_transaction_scope(
134 self.client_mut(),
135 query_timeout,
136 tracing_options,
137 &server_addr,
138 )
139 .await
140 }
141
142 pub async fn rollback_transaction(&mut self) -> Result<(), OrmError> {
143 let query_timeout = self.query_timeout();
144 let tracing_options = self.tracing_options();
145 let server_addr = self.server_addr();
146 rollback_transaction_scope(
147 self.client_mut(),
148 query_timeout,
149 tracing_options,
150 &server_addr,
151 )
152 .await
153 }
154
155 pub async fn health_check(&mut self) -> Result<(), OrmError> {
156 let tracing_options = self.tracing_options();
157 let slow_query_options = self.slow_query_options();
158 let retry_options = self.retry_options();
159 let server_addr = self.server_addr();
160 let health_options = self.health_options();
161 let health_timeout = resolve_health_timeout(health_options, self.query_timeout());
162 let query = build_health_check_query(health_options);
163
164 let row = run_with_timeout(health_timeout, "SQL Server health check timed out", async {
165 fetch_one_compiled::<_, HealthCheckRow>(
166 self.client_mut(),
167 query,
168 tracing_options,
169 slow_query_options,
170 retry_options,
171 &server_addr,
172 health_timeout,
173 )
174 .await
175 })
176 .await?;
177
178 match row {
179 Some(HealthCheckRow { value: 1 }) => Ok(()),
180 Some(_) => Err(OrmError::new(
181 "SQL Server health check returned an unexpected value",
182 )),
183 None => Err(OrmError::new(
184 "SQL Server health check did not return a row",
185 )),
186 }
187 }
188
189 pub fn into_inner(self) -> Client<S> {
190 self.client
191 }
192}
193
194struct HealthCheckRow {
195 value: i32,
196}
197
198impl sql_orm_core::FromRow for HealthCheckRow {
199 fn from_row<R: sql_orm_core::Row>(row: &R) -> Result<Self, OrmError> {
200 Ok(Self {
201 value: row.get_required_typed::<i32>("health_check")?,
202 })
203 }
204}
205
206fn resolve_health_timeout(
207 health_options: crate::config::MssqlHealthCheckOptions,
208 query_timeout: Option<Duration>,
209) -> Option<Duration> {
210 health_options.timeout.or(query_timeout)
211}
212
213fn build_health_check_query(
214 health_options: crate::config::MssqlHealthCheckOptions,
215) -> CompiledQuery {
216 CompiledQuery::new(health_options.query.sql().to_string(), vec![])
217}
218
219pub(crate) async fn run_with_timeout<F, T>(
220 duration: Option<Duration>,
221 timeout_message: &'static str,
222 future: F,
223) -> Result<T, OrmError>
224where
225 F: core::future::Future<Output = Result<T, OrmError>>,
226{
227 match duration {
228 Some(duration) => timeout(duration, future)
229 .await
230 .map_err(|_| OrmError::new(timeout_message))?,
231 None => future.await,
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::{build_health_check_query, resolve_health_timeout, run_with_timeout};
238 use crate::config::{MssqlHealthCheckOptions, MssqlHealthCheckQuery};
239 use std::time::Duration;
240
241 #[test]
242 fn health_check_prefers_explicit_health_timeout_over_query_timeout() {
243 let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
244 .with_timeout(Duration::from_secs(3));
245
246 assert_eq!(
247 resolve_health_timeout(health, Some(Duration::from_secs(30))),
248 Some(Duration::from_secs(3))
249 );
250 }
251
252 #[test]
253 fn health_check_falls_back_to_query_timeout_when_no_dedicated_timeout_exists() {
254 let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne);
255
256 assert_eq!(
257 resolve_health_timeout(health, Some(Duration::from_secs(30))),
258 Some(Duration::from_secs(30))
259 );
260 assert_eq!(resolve_health_timeout(health, None), None);
261 }
262
263 #[test]
264 fn health_check_builds_expected_compiled_query() {
265 let query = build_health_check_query(MssqlHealthCheckOptions::enabled(
266 MssqlHealthCheckQuery::SelectOne,
267 ));
268
269 assert_eq!(query.sql, "SELECT 1 AS [health_check]");
270 assert!(query.params.is_empty());
271 }
272
273 #[tokio::test]
274 async fn run_with_timeout_returns_future_result_without_timeout() {
275 let value = run_with_timeout(None, "timeout", async {
276 Ok::<_, sql_orm_core::OrmError>(7)
277 })
278 .await
279 .unwrap();
280
281 assert_eq!(value, 7);
282 }
283
284 #[tokio::test]
285 async fn run_with_timeout_fails_when_future_exceeds_deadline() {
286 let error = run_with_timeout(
287 Some(Duration::from_millis(5)),
288 "SQL Server connection timed out",
289 async {
290 tokio::time::sleep(Duration::from_millis(25)).await;
291 Ok::<_, sql_orm_core::OrmError>(())
292 },
293 )
294 .await
295 .unwrap_err();
296
297 assert_eq!(error.message(), "SQL Server connection timed out");
298 }
299}