Skip to main content

sql_orm_tiberius/
config.rs

1use sql_orm_core::OrmError;
2use std::fmt;
3use std::time::Duration;
4use tiberius::Config;
5
6#[derive(Clone)]
7pub struct MssqlConnectionConfig {
8    connection_string: String,
9    inner: Config,
10    options: MssqlOperationalOptions,
11}
12
13impl fmt::Debug for MssqlConnectionConfig {
14    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
15        formatter
16            .debug_struct("MssqlConnectionConfig")
17            .field("connection_string", &"<redacted>")
18            .field("addr", &self.addr())
19            .field("options", &self.options)
20            .finish()
21    }
22}
23
24impl MssqlConnectionConfig {
25    pub fn from_connection_string(connection_string: &str) -> Result<Self, OrmError> {
26        Self::from_connection_string_with_options(
27            connection_string,
28            MssqlOperationalOptions::default(),
29        )
30    }
31
32    pub fn from_connection_string_with_options(
33        connection_string: &str,
34        options: MssqlOperationalOptions,
35    ) -> Result<Self, OrmError> {
36        if connection_string.trim().is_empty() {
37            return Err(OrmError::connection("invalid SQL Server connection string"));
38        }
39
40        let inner = Config::from_ado_string(connection_string)
41            .map_err(|_| OrmError::connection("invalid SQL Server connection string"))?;
42        validate_config(&inner)?;
43
44        Ok(Self {
45            connection_string: connection_string.to_string(),
46            inner,
47            options,
48        })
49    }
50
51    pub fn with_options(mut self, options: MssqlOperationalOptions) -> Self {
52        self.options = options;
53        self
54    }
55
56    pub fn connection_string(&self) -> &str {
57        &self.connection_string
58    }
59
60    pub fn addr(&self) -> String {
61        self.inner.get_addr()
62    }
63
64    pub fn options(&self) -> &MssqlOperationalOptions {
65        &self.options
66    }
67
68    pub(crate) fn tiberius_config(&self) -> &Config {
69        &self.inner
70    }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Default)]
74pub struct MssqlOperationalOptions {
75    pub timeouts: MssqlTimeoutOptions,
76    pub retry: MssqlRetryOptions,
77    pub tracing: MssqlTracingOptions,
78    pub slow_query: MssqlSlowQueryOptions,
79    pub health: MssqlHealthCheckOptions,
80    pub pool: MssqlPoolOptions,
81}
82
83impl MssqlOperationalOptions {
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    pub fn with_timeouts(mut self, timeouts: MssqlTimeoutOptions) -> Self {
89        self.timeouts = timeouts;
90        self
91    }
92
93    pub fn with_retry(mut self, retry: MssqlRetryOptions) -> Self {
94        self.retry = retry;
95        self
96    }
97
98    pub fn with_tracing(mut self, tracing: MssqlTracingOptions) -> Self {
99        self.tracing = tracing;
100        self
101    }
102
103    pub fn with_slow_query(mut self, slow_query: MssqlSlowQueryOptions) -> Self {
104        self.slow_query = slow_query;
105        self
106    }
107
108    pub fn with_health(mut self, health: MssqlHealthCheckOptions) -> Self {
109        self.health = health;
110        self
111    }
112
113    pub fn with_pool(mut self, pool: MssqlPoolOptions) -> Self {
114        self.pool = pool;
115        self
116    }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
120pub struct MssqlTimeoutOptions {
121    pub connect_timeout: Option<Duration>,
122    pub query_timeout: Option<Duration>,
123    pub acquire_timeout: Option<Duration>,
124}
125
126impl MssqlTimeoutOptions {
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
132        self.connect_timeout = Some(timeout);
133        self
134    }
135
136    pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
137        self.query_timeout = Some(timeout);
138        self
139    }
140
141    pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
142        self.acquire_timeout = Some(timeout);
143        self
144    }
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub struct MssqlRetryOptions {
149    pub enabled: bool,
150    pub max_retries: u32,
151    pub base_delay: Duration,
152    pub max_delay: Duration,
153}
154
155impl Default for MssqlRetryOptions {
156    fn default() -> Self {
157        Self {
158            enabled: false,
159            max_retries: 0,
160            base_delay: Duration::from_millis(100),
161            max_delay: Duration::from_secs(2),
162        }
163    }
164}
165
166impl MssqlRetryOptions {
167    pub fn disabled() -> Self {
168        Self::default()
169    }
170
171    pub fn enabled(max_retries: u32, base_delay: Duration, max_delay: Duration) -> Self {
172        Self {
173            enabled: true,
174            max_retries,
175            base_delay,
176            max_delay,
177        }
178    }
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub struct MssqlTracingOptions {
183    pub enabled: bool,
184    pub parameter_logging: MssqlParameterLogMode,
185    pub emit_start_event: bool,
186    pub emit_finish_event: bool,
187    pub emit_error_event: bool,
188}
189
190impl Default for MssqlTracingOptions {
191    fn default() -> Self {
192        Self {
193            enabled: false,
194            parameter_logging: MssqlParameterLogMode::Redacted,
195            emit_start_event: true,
196            emit_finish_event: true,
197            emit_error_event: true,
198        }
199    }
200}
201
202impl MssqlTracingOptions {
203    pub fn disabled() -> Self {
204        Self::default()
205    }
206
207    pub fn enabled() -> Self {
208        Self {
209            enabled: true,
210            ..Self::default()
211        }
212    }
213
214    pub fn with_parameter_logging(mut self, parameter_logging: MssqlParameterLogMode) -> Self {
215        self.parameter_logging = parameter_logging;
216        self
217    }
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum MssqlParameterLogMode {
222    Disabled,
223    Redacted,
224}
225
226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
227pub struct MssqlSlowQueryOptions {
228    pub enabled: bool,
229    pub threshold: Duration,
230    pub parameter_logging: MssqlParameterLogMode,
231}
232
233impl Default for MssqlSlowQueryOptions {
234    fn default() -> Self {
235        Self {
236            enabled: false,
237            threshold: Duration::from_millis(500),
238            parameter_logging: MssqlParameterLogMode::Redacted,
239        }
240    }
241}
242
243impl MssqlSlowQueryOptions {
244    pub fn disabled() -> Self {
245        Self::default()
246    }
247
248    pub fn enabled(threshold: Duration) -> Self {
249        Self {
250            enabled: true,
251            threshold,
252            ..Self::default()
253        }
254    }
255
256    pub fn with_parameter_logging(mut self, parameter_logging: MssqlParameterLogMode) -> Self {
257        self.parameter_logging = parameter_logging;
258        self
259    }
260}
261
262#[derive(Debug, Clone, Copy, PartialEq, Eq)]
263pub struct MssqlHealthCheckOptions {
264    pub enabled: bool,
265    pub query: MssqlHealthCheckQuery,
266    pub timeout: Option<Duration>,
267}
268
269impl Default for MssqlHealthCheckOptions {
270    fn default() -> Self {
271        Self {
272            enabled: false,
273            query: MssqlHealthCheckQuery::SelectOne,
274            timeout: None,
275        }
276    }
277}
278
279impl MssqlHealthCheckOptions {
280    pub fn disabled() -> Self {
281        Self::default()
282    }
283
284    pub fn enabled(query: MssqlHealthCheckQuery) -> Self {
285        Self {
286            enabled: true,
287            query,
288            timeout: None,
289        }
290    }
291
292    pub fn with_timeout(mut self, timeout: Duration) -> Self {
293        self.timeout = Some(timeout);
294        self
295    }
296}
297
298#[derive(Debug, Clone, Copy, PartialEq, Eq)]
299pub enum MssqlHealthCheckQuery {
300    SelectOne,
301}
302
303impl MssqlHealthCheckQuery {
304    pub(crate) fn sql(self) -> &'static str {
305        match self {
306            Self::SelectOne => "SELECT 1 AS [health_check]",
307        }
308    }
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq)]
312pub struct MssqlPoolOptions {
313    pub enabled: bool,
314    pub backend: MssqlPoolBackend,
315    pub max_size: u32,
316    pub min_idle: Option<u32>,
317    pub acquire_timeout: Option<Duration>,
318    pub idle_timeout: Option<Duration>,
319    pub max_lifetime: Option<Duration>,
320}
321
322impl Default for MssqlPoolOptions {
323    fn default() -> Self {
324        Self {
325            enabled: false,
326            backend: MssqlPoolBackend::Bb8,
327            max_size: 10,
328            min_idle: None,
329            acquire_timeout: None,
330            idle_timeout: None,
331            max_lifetime: None,
332        }
333    }
334}
335
336impl MssqlPoolOptions {
337    pub fn disabled() -> Self {
338        Self::default()
339    }
340
341    pub fn bb8(max_size: u32) -> Self {
342        Self {
343            enabled: true,
344            backend: MssqlPoolBackend::Bb8,
345            max_size,
346            ..Self::default()
347        }
348    }
349
350    pub fn with_min_idle(mut self, min_idle: u32) -> Self {
351        self.min_idle = Some(min_idle);
352        self
353    }
354
355    pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
356        self.acquire_timeout = Some(timeout);
357        self
358    }
359
360    pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
361        self.idle_timeout = Some(timeout);
362        self
363    }
364
365    pub fn with_max_lifetime(mut self, timeout: Duration) -> Self {
366        self.max_lifetime = Some(timeout);
367        self
368    }
369}
370
371#[derive(Debug, Clone, Copy, PartialEq, Eq)]
372pub enum MssqlPoolBackend {
373    Bb8,
374}
375
376fn validate_config(config: &Config) -> Result<(), OrmError> {
377    let addr = config.get_addr();
378
379    if addr.is_empty() || addr.starts_with(':') {
380        return Err(OrmError::connection("invalid SQL Server connection string"));
381    }
382
383    Ok(())
384}
385
386#[cfg(test)]
387mod tests {
388    use super::{
389        MssqlConnectionConfig, MssqlHealthCheckOptions, MssqlHealthCheckQuery,
390        MssqlOperationalOptions, MssqlParameterLogMode, MssqlPoolBackend, MssqlPoolOptions,
391        MssqlRetryOptions, MssqlSlowQueryOptions, MssqlTimeoutOptions, MssqlTracingOptions,
392    };
393    use std::time::Duration;
394
395    #[test]
396    fn parses_valid_ado_connection_string() {
397        let config = MssqlConnectionConfig::from_connection_string(
398            "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=sql-orm-tests",
399        )
400        .unwrap();
401
402        assert_eq!(
403            config.connection_string(),
404            "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=sql-orm-tests"
405        );
406        assert_eq!(config.addr(), "localhost:1433");
407        assert_eq!(config.options(), &MssqlOperationalOptions::default());
408    }
409
410    #[test]
411    fn preserves_explicit_operational_options() {
412        let options = MssqlOperationalOptions::new()
413            .with_timeouts(
414                MssqlTimeoutOptions::new()
415                    .with_connect_timeout(Duration::from_secs(5))
416                    .with_query_timeout(Duration::from_secs(30))
417                    .with_acquire_timeout(Duration::from_secs(2)),
418            )
419            .with_retry(MssqlRetryOptions::enabled(
420                3,
421                Duration::from_millis(50),
422                Duration::from_secs(1),
423            ))
424            .with_tracing(
425                MssqlTracingOptions::enabled()
426                    .with_parameter_logging(MssqlParameterLogMode::Disabled),
427            )
428            .with_slow_query(
429                MssqlSlowQueryOptions::enabled(Duration::from_millis(250))
430                    .with_parameter_logging(MssqlParameterLogMode::Disabled),
431            )
432            .with_health(
433                MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
434                    .with_timeout(Duration::from_secs(3)),
435            )
436            .with_pool(
437                MssqlPoolOptions::bb8(16)
438                    .with_min_idle(4)
439                    .with_acquire_timeout(Duration::from_secs(2))
440                    .with_idle_timeout(Duration::from_secs(30))
441                    .with_max_lifetime(Duration::from_secs(300)),
442            );
443
444        let config = MssqlConnectionConfig::from_connection_string_with_options(
445            "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true",
446            options.clone(),
447        )
448        .unwrap();
449
450        assert_eq!(config.options(), &options);
451        assert_eq!(config.options().pool.backend, MssqlPoolBackend::Bb8);
452        assert!(config.options().retry.enabled);
453        assert!(config.options().tracing.enabled);
454        assert!(config.options().slow_query.enabled);
455        assert!(config.options().health.enabled);
456        assert!(config.options().pool.enabled);
457    }
458
459    #[test]
460    fn can_replace_options_on_existing_config() {
461        let config = MssqlConnectionConfig::from_connection_string(
462            "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true",
463        )
464        .unwrap()
465        .with_options(MssqlOperationalOptions::new().with_tracing(MssqlTracingOptions::enabled()));
466
467        assert!(config.options().tracing.enabled);
468    }
469
470    #[test]
471    fn debug_redacts_connection_string() {
472        let config = MssqlConnectionConfig::from_connection_string(
473            "server=tcp:localhost,1433;database=SecretDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=secret-app",
474        )
475        .unwrap();
476
477        let debug = format!("{config:?}");
478
479        assert!(debug.contains("MssqlConnectionConfig"));
480        assert!(debug.contains("connection_string: \"<redacted>\""));
481        assert!(debug.contains("addr: \"localhost:1433\""));
482        assert!(!debug.contains("SecretDb"));
483        assert!(!debug.contains("Password123"));
484        assert!(!debug.contains("secret-app"));
485        assert!(!debug.contains("user=sa"));
486    }
487
488    #[test]
489    fn rejects_invalid_connection_string() {
490        let error = MssqlConnectionConfig::from_connection_string("server=").unwrap_err();
491
492        assert_eq!(error.message(), "invalid SQL Server connection string");
493    }
494
495    #[test]
496    fn health_check_query_uses_stable_sql() {
497        assert_eq!(
498            MssqlHealthCheckQuery::SelectOne.sql(),
499            "SELECT 1 AS [health_check]"
500        );
501    }
502}