1use sql_orm_core::OrmError;
2use std::time::Duration;
3use tiberius::Config;
4
5#[derive(Debug, Clone)]
6pub struct MssqlConnectionConfig {
7 connection_string: String,
8 inner: Config,
9 options: MssqlOperationalOptions,
10}
11
12impl MssqlConnectionConfig {
13 pub fn from_connection_string(connection_string: &str) -> Result<Self, OrmError> {
14 Self::from_connection_string_with_options(
15 connection_string,
16 MssqlOperationalOptions::default(),
17 )
18 }
19
20 pub fn from_connection_string_with_options(
21 connection_string: &str,
22 options: MssqlOperationalOptions,
23 ) -> Result<Self, OrmError> {
24 if connection_string.trim().is_empty() {
25 return Err(OrmError::new("invalid SQL Server connection string"));
26 }
27
28 let inner = Config::from_ado_string(connection_string)
29 .map_err(|_| OrmError::new("invalid SQL Server connection string"))?;
30 validate_config(&inner)?;
31
32 Ok(Self {
33 connection_string: connection_string.to_string(),
34 inner,
35 options,
36 })
37 }
38
39 pub fn with_options(mut self, options: MssqlOperationalOptions) -> Self {
40 self.options = options;
41 self
42 }
43
44 pub fn connection_string(&self) -> &str {
45 &self.connection_string
46 }
47
48 pub fn addr(&self) -> String {
49 self.inner.get_addr()
50 }
51
52 pub fn options(&self) -> &MssqlOperationalOptions {
53 &self.options
54 }
55
56 pub(crate) fn tiberius_config(&self) -> &Config {
57 &self.inner
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Default)]
62pub struct MssqlOperationalOptions {
63 pub timeouts: MssqlTimeoutOptions,
64 pub retry: MssqlRetryOptions,
65 pub tracing: MssqlTracingOptions,
66 pub slow_query: MssqlSlowQueryOptions,
67 pub health: MssqlHealthCheckOptions,
68 pub pool: MssqlPoolOptions,
69}
70
71impl MssqlOperationalOptions {
72 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn with_timeouts(mut self, timeouts: MssqlTimeoutOptions) -> Self {
77 self.timeouts = timeouts;
78 self
79 }
80
81 pub fn with_retry(mut self, retry: MssqlRetryOptions) -> Self {
82 self.retry = retry;
83 self
84 }
85
86 pub fn with_tracing(mut self, tracing: MssqlTracingOptions) -> Self {
87 self.tracing = tracing;
88 self
89 }
90
91 pub fn with_slow_query(mut self, slow_query: MssqlSlowQueryOptions) -> Self {
92 self.slow_query = slow_query;
93 self
94 }
95
96 pub fn with_health(mut self, health: MssqlHealthCheckOptions) -> Self {
97 self.health = health;
98 self
99 }
100
101 pub fn with_pool(mut self, pool: MssqlPoolOptions) -> Self {
102 self.pool = pool;
103 self
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
108pub struct MssqlTimeoutOptions {
109 pub connect_timeout: Option<Duration>,
110 pub query_timeout: Option<Duration>,
111 pub acquire_timeout: Option<Duration>,
112}
113
114impl MssqlTimeoutOptions {
115 pub fn new() -> Self {
116 Self::default()
117 }
118
119 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
120 self.connect_timeout = Some(timeout);
121 self
122 }
123
124 pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
125 self.query_timeout = Some(timeout);
126 self
127 }
128
129 pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
130 self.acquire_timeout = Some(timeout);
131 self
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct MssqlRetryOptions {
137 pub enabled: bool,
138 pub max_retries: u32,
139 pub base_delay: Duration,
140 pub max_delay: Duration,
141}
142
143impl Default for MssqlRetryOptions {
144 fn default() -> Self {
145 Self {
146 enabled: false,
147 max_retries: 0,
148 base_delay: Duration::from_millis(100),
149 max_delay: Duration::from_secs(2),
150 }
151 }
152}
153
154impl MssqlRetryOptions {
155 pub fn disabled() -> Self {
156 Self::default()
157 }
158
159 pub fn enabled(max_retries: u32, base_delay: Duration, max_delay: Duration) -> Self {
160 Self {
161 enabled: true,
162 max_retries,
163 base_delay,
164 max_delay,
165 }
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub struct MssqlTracingOptions {
171 pub enabled: bool,
172 pub parameter_logging: MssqlParameterLogMode,
173 pub emit_start_event: bool,
174 pub emit_finish_event: bool,
175 pub emit_error_event: bool,
176}
177
178impl Default for MssqlTracingOptions {
179 fn default() -> Self {
180 Self {
181 enabled: false,
182 parameter_logging: MssqlParameterLogMode::Redacted,
183 emit_start_event: true,
184 emit_finish_event: true,
185 emit_error_event: true,
186 }
187 }
188}
189
190impl MssqlTracingOptions {
191 pub fn disabled() -> Self {
192 Self::default()
193 }
194
195 pub fn enabled() -> Self {
196 Self {
197 enabled: true,
198 ..Self::default()
199 }
200 }
201
202 pub fn with_parameter_logging(mut self, parameter_logging: MssqlParameterLogMode) -> Self {
203 self.parameter_logging = parameter_logging;
204 self
205 }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub enum MssqlParameterLogMode {
210 Disabled,
211 Redacted,
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub struct MssqlSlowQueryOptions {
216 pub enabled: bool,
217 pub threshold: Duration,
218 pub parameter_logging: MssqlParameterLogMode,
219}
220
221impl Default for MssqlSlowQueryOptions {
222 fn default() -> Self {
223 Self {
224 enabled: false,
225 threshold: Duration::from_millis(500),
226 parameter_logging: MssqlParameterLogMode::Redacted,
227 }
228 }
229}
230
231impl MssqlSlowQueryOptions {
232 pub fn disabled() -> Self {
233 Self::default()
234 }
235
236 pub fn enabled(threshold: Duration) -> Self {
237 Self {
238 enabled: true,
239 threshold,
240 ..Self::default()
241 }
242 }
243
244 pub fn with_parameter_logging(mut self, parameter_logging: MssqlParameterLogMode) -> Self {
245 self.parameter_logging = parameter_logging;
246 self
247 }
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
251pub struct MssqlHealthCheckOptions {
252 pub enabled: bool,
253 pub query: MssqlHealthCheckQuery,
254 pub timeout: Option<Duration>,
255}
256
257impl Default for MssqlHealthCheckOptions {
258 fn default() -> Self {
259 Self {
260 enabled: false,
261 query: MssqlHealthCheckQuery::SelectOne,
262 timeout: None,
263 }
264 }
265}
266
267impl MssqlHealthCheckOptions {
268 pub fn disabled() -> Self {
269 Self::default()
270 }
271
272 pub fn enabled(query: MssqlHealthCheckQuery) -> Self {
273 Self {
274 enabled: true,
275 query,
276 timeout: None,
277 }
278 }
279
280 pub fn with_timeout(mut self, timeout: Duration) -> Self {
281 self.timeout = Some(timeout);
282 self
283 }
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub enum MssqlHealthCheckQuery {
288 SelectOne,
289}
290
291impl MssqlHealthCheckQuery {
292 pub(crate) fn sql(self) -> &'static str {
293 match self {
294 Self::SelectOne => "SELECT 1 AS [health_check]",
295 }
296 }
297}
298
299#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub struct MssqlPoolOptions {
301 pub enabled: bool,
302 pub backend: MssqlPoolBackend,
303 pub max_size: u32,
304 pub min_idle: Option<u32>,
305 pub acquire_timeout: Option<Duration>,
306 pub idle_timeout: Option<Duration>,
307 pub max_lifetime: Option<Duration>,
308}
309
310impl Default for MssqlPoolOptions {
311 fn default() -> Self {
312 Self {
313 enabled: false,
314 backend: MssqlPoolBackend::Bb8,
315 max_size: 10,
316 min_idle: None,
317 acquire_timeout: None,
318 idle_timeout: None,
319 max_lifetime: None,
320 }
321 }
322}
323
324impl MssqlPoolOptions {
325 pub fn disabled() -> Self {
326 Self::default()
327 }
328
329 pub fn bb8(max_size: u32) -> Self {
330 Self {
331 enabled: true,
332 backend: MssqlPoolBackend::Bb8,
333 max_size,
334 ..Self::default()
335 }
336 }
337
338 pub fn with_min_idle(mut self, min_idle: u32) -> Self {
339 self.min_idle = Some(min_idle);
340 self
341 }
342
343 pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
344 self.acquire_timeout = Some(timeout);
345 self
346 }
347
348 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
349 self.idle_timeout = Some(timeout);
350 self
351 }
352
353 pub fn with_max_lifetime(mut self, timeout: Duration) -> Self {
354 self.max_lifetime = Some(timeout);
355 self
356 }
357}
358
359#[derive(Debug, Clone, Copy, PartialEq, Eq)]
360pub enum MssqlPoolBackend {
361 Bb8,
362}
363
364fn validate_config(config: &Config) -> Result<(), OrmError> {
365 let addr = config.get_addr();
366
367 if addr.is_empty() || addr.starts_with(':') {
368 return Err(OrmError::new("invalid SQL Server connection string"));
369 }
370
371 Ok(())
372}
373
374#[cfg(test)]
375mod tests {
376 use super::{
377 MssqlConnectionConfig, MssqlHealthCheckOptions, MssqlHealthCheckQuery,
378 MssqlOperationalOptions, MssqlParameterLogMode, MssqlPoolBackend, MssqlPoolOptions,
379 MssqlRetryOptions, MssqlSlowQueryOptions, MssqlTimeoutOptions, MssqlTracingOptions,
380 };
381 use std::time::Duration;
382
383 #[test]
384 fn parses_valid_ado_connection_string() {
385 let config = MssqlConnectionConfig::from_connection_string(
386 "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=sql-orm-tests",
387 )
388 .unwrap();
389
390 assert_eq!(
391 config.connection_string(),
392 "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=sql-orm-tests"
393 );
394 assert_eq!(config.addr(), "localhost:1433");
395 assert_eq!(config.options(), &MssqlOperationalOptions::default());
396 }
397
398 #[test]
399 fn preserves_explicit_operational_options() {
400 let options = MssqlOperationalOptions::new()
401 .with_timeouts(
402 MssqlTimeoutOptions::new()
403 .with_connect_timeout(Duration::from_secs(5))
404 .with_query_timeout(Duration::from_secs(30))
405 .with_acquire_timeout(Duration::from_secs(2)),
406 )
407 .with_retry(MssqlRetryOptions::enabled(
408 3,
409 Duration::from_millis(50),
410 Duration::from_secs(1),
411 ))
412 .with_tracing(
413 MssqlTracingOptions::enabled()
414 .with_parameter_logging(MssqlParameterLogMode::Disabled),
415 )
416 .with_slow_query(
417 MssqlSlowQueryOptions::enabled(Duration::from_millis(250))
418 .with_parameter_logging(MssqlParameterLogMode::Disabled),
419 )
420 .with_health(
421 MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
422 .with_timeout(Duration::from_secs(3)),
423 )
424 .with_pool(
425 MssqlPoolOptions::bb8(16)
426 .with_min_idle(4)
427 .with_acquire_timeout(Duration::from_secs(2))
428 .with_idle_timeout(Duration::from_secs(30))
429 .with_max_lifetime(Duration::from_secs(300)),
430 );
431
432 let config = MssqlConnectionConfig::from_connection_string_with_options(
433 "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true",
434 options.clone(),
435 )
436 .unwrap();
437
438 assert_eq!(config.options(), &options);
439 assert_eq!(config.options().pool.backend, MssqlPoolBackend::Bb8);
440 assert!(config.options().retry.enabled);
441 assert!(config.options().tracing.enabled);
442 assert!(config.options().slow_query.enabled);
443 assert!(config.options().health.enabled);
444 assert!(config.options().pool.enabled);
445 }
446
447 #[test]
448 fn can_replace_options_on_existing_config() {
449 let config = MssqlConnectionConfig::from_connection_string(
450 "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true",
451 )
452 .unwrap()
453 .with_options(MssqlOperationalOptions::new().with_tracing(MssqlTracingOptions::enabled()));
454
455 assert!(config.options().tracing.enabled);
456 }
457
458 #[test]
459 fn rejects_invalid_connection_string() {
460 let error = MssqlConnectionConfig::from_connection_string("server=").unwrap_err();
461
462 assert_eq!(error.message(), "invalid SQL Server connection string");
463 }
464
465 #[test]
466 fn health_check_query_uses_stable_sql() {
467 assert_eq!(
468 MssqlHealthCheckQuery::SelectOne.sql(),
469 "SELECT 1 AS [health_check]"
470 );
471 }
472}