1use sql_orm_core::OrmError;
2use std::fmt;
3use std::time::Duration;
4use tiberius::Config;
5
6#[derive(Clone)]
7pub struct MssqlConnectionConfig {
8 connection_string: String,
9 inner: Config,
10 options: MssqlOperationalOptions,
11}
12
13impl fmt::Debug for MssqlConnectionConfig {
14 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
15 formatter
16 .debug_struct("MssqlConnectionConfig")
17 .field("connection_string", &"<redacted>")
18 .field("addr", &self.addr())
19 .field("options", &self.options)
20 .finish()
21 }
22}
23
24impl MssqlConnectionConfig {
25 pub fn from_connection_string(connection_string: &str) -> Result<Self, OrmError> {
26 Self::from_connection_string_with_options(
27 connection_string,
28 MssqlOperationalOptions::default(),
29 )
30 }
31
32 pub fn from_connection_string_with_options(
33 connection_string: &str,
34 options: MssqlOperationalOptions,
35 ) -> Result<Self, OrmError> {
36 if connection_string.trim().is_empty() {
37 return Err(OrmError::connection("invalid SQL Server connection string"));
38 }
39
40 let inner = Config::from_ado_string(connection_string)
41 .map_err(|_| OrmError::connection("invalid SQL Server connection string"))?;
42 validate_config(&inner)?;
43
44 Ok(Self {
45 connection_string: connection_string.to_string(),
46 inner,
47 options,
48 })
49 }
50
51 pub fn with_options(mut self, options: MssqlOperationalOptions) -> Self {
52 self.options = options;
53 self
54 }
55
56 pub fn connection_string(&self) -> &str {
57 &self.connection_string
58 }
59
60 pub fn addr(&self) -> String {
61 self.inner.get_addr()
62 }
63
64 pub fn options(&self) -> &MssqlOperationalOptions {
65 &self.options
66 }
67
68 pub(crate) fn tiberius_config(&self) -> &Config {
69 &self.inner
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Default)]
74pub struct MssqlOperationalOptions {
75 pub timeouts: MssqlTimeoutOptions,
76 pub retry: MssqlRetryOptions,
77 pub tracing: MssqlTracingOptions,
78 pub slow_query: MssqlSlowQueryOptions,
79 pub health: MssqlHealthCheckOptions,
80 pub pool: MssqlPoolOptions,
81}
82
83impl MssqlOperationalOptions {
84 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn with_timeouts(mut self, timeouts: MssqlTimeoutOptions) -> Self {
89 self.timeouts = timeouts;
90 self
91 }
92
93 pub fn with_retry(mut self, retry: MssqlRetryOptions) -> Self {
94 self.retry = retry;
95 self
96 }
97
98 pub fn with_tracing(mut self, tracing: MssqlTracingOptions) -> Self {
99 self.tracing = tracing;
100 self
101 }
102
103 pub fn with_slow_query(mut self, slow_query: MssqlSlowQueryOptions) -> Self {
104 self.slow_query = slow_query;
105 self
106 }
107
108 pub fn with_health(mut self, health: MssqlHealthCheckOptions) -> Self {
109 self.health = health;
110 self
111 }
112
113 pub fn with_pool(mut self, pool: MssqlPoolOptions) -> Self {
114 self.pool = pool;
115 self
116 }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
120pub struct MssqlTimeoutOptions {
121 pub connect_timeout: Option<Duration>,
122 pub query_timeout: Option<Duration>,
123 pub acquire_timeout: Option<Duration>,
124}
125
126impl MssqlTimeoutOptions {
127 pub fn new() -> Self {
128 Self::default()
129 }
130
131 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
132 self.connect_timeout = Some(timeout);
133 self
134 }
135
136 pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
137 self.query_timeout = Some(timeout);
138 self
139 }
140
141 pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
142 self.acquire_timeout = Some(timeout);
143 self
144 }
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub struct MssqlRetryOptions {
149 pub enabled: bool,
150 pub max_retries: u32,
151 pub base_delay: Duration,
152 pub max_delay: Duration,
153}
154
155impl Default for MssqlRetryOptions {
156 fn default() -> Self {
157 Self {
158 enabled: false,
159 max_retries: 0,
160 base_delay: Duration::from_millis(100),
161 max_delay: Duration::from_secs(2),
162 }
163 }
164}
165
166impl MssqlRetryOptions {
167 pub fn disabled() -> Self {
168 Self::default()
169 }
170
171 pub fn enabled(max_retries: u32, base_delay: Duration, max_delay: Duration) -> Self {
172 Self {
173 enabled: true,
174 max_retries,
175 base_delay,
176 max_delay,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub struct MssqlTracingOptions {
183 pub enabled: bool,
184 pub parameter_logging: MssqlParameterLogMode,
185 pub emit_start_event: bool,
186 pub emit_finish_event: bool,
187 pub emit_error_event: bool,
188}
189
190impl Default for MssqlTracingOptions {
191 fn default() -> Self {
192 Self {
193 enabled: false,
194 parameter_logging: MssqlParameterLogMode::Redacted,
195 emit_start_event: true,
196 emit_finish_event: true,
197 emit_error_event: true,
198 }
199 }
200}
201
202impl MssqlTracingOptions {
203 pub fn disabled() -> Self {
204 Self::default()
205 }
206
207 pub fn enabled() -> Self {
208 Self {
209 enabled: true,
210 ..Self::default()
211 }
212 }
213
214 pub fn with_parameter_logging(mut self, parameter_logging: MssqlParameterLogMode) -> Self {
215 self.parameter_logging = parameter_logging;
216 self
217 }
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum MssqlParameterLogMode {
222 Disabled,
223 Redacted,
224}
225
226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
227pub struct MssqlSlowQueryOptions {
228 pub enabled: bool,
229 pub threshold: Duration,
230 pub parameter_logging: MssqlParameterLogMode,
231}
232
233impl Default for MssqlSlowQueryOptions {
234 fn default() -> Self {
235 Self {
236 enabled: false,
237 threshold: Duration::from_millis(500),
238 parameter_logging: MssqlParameterLogMode::Redacted,
239 }
240 }
241}
242
243impl MssqlSlowQueryOptions {
244 pub fn disabled() -> Self {
245 Self::default()
246 }
247
248 pub fn enabled(threshold: Duration) -> Self {
249 Self {
250 enabled: true,
251 threshold,
252 ..Self::default()
253 }
254 }
255
256 pub fn with_parameter_logging(mut self, parameter_logging: MssqlParameterLogMode) -> Self {
257 self.parameter_logging = parameter_logging;
258 self
259 }
260}
261
262#[derive(Debug, Clone, Copy, PartialEq, Eq)]
263pub struct MssqlHealthCheckOptions {
264 pub enabled: bool,
265 pub query: MssqlHealthCheckQuery,
266 pub timeout: Option<Duration>,
267}
268
269impl Default for MssqlHealthCheckOptions {
270 fn default() -> Self {
271 Self {
272 enabled: false,
273 query: MssqlHealthCheckQuery::SelectOne,
274 timeout: None,
275 }
276 }
277}
278
279impl MssqlHealthCheckOptions {
280 pub fn disabled() -> Self {
281 Self::default()
282 }
283
284 pub fn enabled(query: MssqlHealthCheckQuery) -> Self {
285 Self {
286 enabled: true,
287 query,
288 timeout: None,
289 }
290 }
291
292 pub fn with_timeout(mut self, timeout: Duration) -> Self {
293 self.timeout = Some(timeout);
294 self
295 }
296}
297
298#[derive(Debug, Clone, Copy, PartialEq, Eq)]
299pub enum MssqlHealthCheckQuery {
300 SelectOne,
301}
302
303impl MssqlHealthCheckQuery {
304 pub(crate) fn sql(self) -> &'static str {
305 match self {
306 Self::SelectOne => "SELECT 1 AS [health_check]",
307 }
308 }
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq)]
312pub struct MssqlPoolOptions {
313 pub enabled: bool,
314 pub backend: MssqlPoolBackend,
315 pub max_size: u32,
316 pub min_idle: Option<u32>,
317 pub acquire_timeout: Option<Duration>,
318 pub idle_timeout: Option<Duration>,
319 pub max_lifetime: Option<Duration>,
320}
321
322impl Default for MssqlPoolOptions {
323 fn default() -> Self {
324 Self {
325 enabled: false,
326 backend: MssqlPoolBackend::Bb8,
327 max_size: 10,
328 min_idle: None,
329 acquire_timeout: None,
330 idle_timeout: None,
331 max_lifetime: None,
332 }
333 }
334}
335
336impl MssqlPoolOptions {
337 pub fn disabled() -> Self {
338 Self::default()
339 }
340
341 pub fn bb8(max_size: u32) -> Self {
342 Self {
343 enabled: true,
344 backend: MssqlPoolBackend::Bb8,
345 max_size,
346 ..Self::default()
347 }
348 }
349
350 pub fn with_min_idle(mut self, min_idle: u32) -> Self {
351 self.min_idle = Some(min_idle);
352 self
353 }
354
355 pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
356 self.acquire_timeout = Some(timeout);
357 self
358 }
359
360 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
361 self.idle_timeout = Some(timeout);
362 self
363 }
364
365 pub fn with_max_lifetime(mut self, timeout: Duration) -> Self {
366 self.max_lifetime = Some(timeout);
367 self
368 }
369}
370
371#[derive(Debug, Clone, Copy, PartialEq, Eq)]
372pub enum MssqlPoolBackend {
373 Bb8,
374}
375
376fn validate_config(config: &Config) -> Result<(), OrmError> {
377 let addr = config.get_addr();
378
379 if addr.is_empty() || addr.starts_with(':') {
380 return Err(OrmError::connection("invalid SQL Server connection string"));
381 }
382
383 Ok(())
384}
385
386#[cfg(test)]
387mod tests {
388 use super::{
389 MssqlConnectionConfig, MssqlHealthCheckOptions, MssqlHealthCheckQuery,
390 MssqlOperationalOptions, MssqlParameterLogMode, MssqlPoolBackend, MssqlPoolOptions,
391 MssqlRetryOptions, MssqlSlowQueryOptions, MssqlTimeoutOptions, MssqlTracingOptions,
392 };
393 use std::time::Duration;
394
395 #[test]
396 fn parses_valid_ado_connection_string() {
397 let config = MssqlConnectionConfig::from_connection_string(
398 "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=sql-orm-tests",
399 )
400 .unwrap();
401
402 assert_eq!(
403 config.connection_string(),
404 "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=sql-orm-tests"
405 );
406 assert_eq!(config.addr(), "localhost:1433");
407 assert_eq!(config.options(), &MssqlOperationalOptions::default());
408 }
409
410 #[test]
411 fn preserves_explicit_operational_options() {
412 let options = MssqlOperationalOptions::new()
413 .with_timeouts(
414 MssqlTimeoutOptions::new()
415 .with_connect_timeout(Duration::from_secs(5))
416 .with_query_timeout(Duration::from_secs(30))
417 .with_acquire_timeout(Duration::from_secs(2)),
418 )
419 .with_retry(MssqlRetryOptions::enabled(
420 3,
421 Duration::from_millis(50),
422 Duration::from_secs(1),
423 ))
424 .with_tracing(
425 MssqlTracingOptions::enabled()
426 .with_parameter_logging(MssqlParameterLogMode::Disabled),
427 )
428 .with_slow_query(
429 MssqlSlowQueryOptions::enabled(Duration::from_millis(250))
430 .with_parameter_logging(MssqlParameterLogMode::Disabled),
431 )
432 .with_health(
433 MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
434 .with_timeout(Duration::from_secs(3)),
435 )
436 .with_pool(
437 MssqlPoolOptions::bb8(16)
438 .with_min_idle(4)
439 .with_acquire_timeout(Duration::from_secs(2))
440 .with_idle_timeout(Duration::from_secs(30))
441 .with_max_lifetime(Duration::from_secs(300)),
442 );
443
444 let config = MssqlConnectionConfig::from_connection_string_with_options(
445 "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true",
446 options.clone(),
447 )
448 .unwrap();
449
450 assert_eq!(config.options(), &options);
451 assert_eq!(config.options().pool.backend, MssqlPoolBackend::Bb8);
452 assert!(config.options().retry.enabled);
453 assert!(config.options().tracing.enabled);
454 assert!(config.options().slow_query.enabled);
455 assert!(config.options().health.enabled);
456 assert!(config.options().pool.enabled);
457 }
458
459 #[test]
460 fn can_replace_options_on_existing_config() {
461 let config = MssqlConnectionConfig::from_connection_string(
462 "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true",
463 )
464 .unwrap()
465 .with_options(MssqlOperationalOptions::new().with_tracing(MssqlTracingOptions::enabled()));
466
467 assert!(config.options().tracing.enabled);
468 }
469
470 #[test]
471 fn debug_redacts_connection_string() {
472 let config = MssqlConnectionConfig::from_connection_string(
473 "server=tcp:localhost,1433;database=SecretDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=secret-app",
474 )
475 .unwrap();
476
477 let debug = format!("{config:?}");
478
479 assert!(debug.contains("MssqlConnectionConfig"));
480 assert!(debug.contains("connection_string: \"<redacted>\""));
481 assert!(debug.contains("addr: \"localhost:1433\""));
482 assert!(!debug.contains("SecretDb"));
483 assert!(!debug.contains("Password123"));
484 assert!(!debug.contains("secret-app"));
485 assert!(!debug.contains("user=sa"));
486 }
487
488 #[test]
489 fn rejects_invalid_connection_string() {
490 let error = MssqlConnectionConfig::from_connection_string("server=").unwrap_err();
491
492 assert_eq!(error.message(), "invalid SQL Server connection string");
493 }
494
495 #[test]
496 fn health_check_query_uses_stable_sql() {
497 assert_eq!(
498 MssqlHealthCheckQuery::SelectOne.sql(),
499 "SELECT 1 AS [health_check]"
500 );
501 }
502}