Skip to main content

sql_orm_tiberius/
config.rs

1use sql_orm_core::OrmError;
2use std::time::Duration;
3use tiberius::Config;
4
5#[derive(Debug, Clone)]
6pub struct MssqlConnectionConfig {
7    connection_string: String,
8    inner: Config,
9    options: MssqlOperationalOptions,
10}
11
12impl MssqlConnectionConfig {
13    pub fn from_connection_string(connection_string: &str) -> Result<Self, OrmError> {
14        Self::from_connection_string_with_options(
15            connection_string,
16            MssqlOperationalOptions::default(),
17        )
18    }
19
20    pub fn from_connection_string_with_options(
21        connection_string: &str,
22        options: MssqlOperationalOptions,
23    ) -> Result<Self, OrmError> {
24        if connection_string.trim().is_empty() {
25            return Err(OrmError::new("invalid SQL Server connection string"));
26        }
27
28        let inner = Config::from_ado_string(connection_string)
29            .map_err(|_| OrmError::new("invalid SQL Server connection string"))?;
30        validate_config(&inner)?;
31
32        Ok(Self {
33            connection_string: connection_string.to_string(),
34            inner,
35            options,
36        })
37    }
38
39    pub fn with_options(mut self, options: MssqlOperationalOptions) -> Self {
40        self.options = options;
41        self
42    }
43
44    pub fn connection_string(&self) -> &str {
45        &self.connection_string
46    }
47
48    pub fn addr(&self) -> String {
49        self.inner.get_addr()
50    }
51
52    pub fn options(&self) -> &MssqlOperationalOptions {
53        &self.options
54    }
55
56    pub(crate) fn tiberius_config(&self) -> &Config {
57        &self.inner
58    }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Default)]
62pub struct MssqlOperationalOptions {
63    pub timeouts: MssqlTimeoutOptions,
64    pub retry: MssqlRetryOptions,
65    pub tracing: MssqlTracingOptions,
66    pub slow_query: MssqlSlowQueryOptions,
67    pub health: MssqlHealthCheckOptions,
68    pub pool: MssqlPoolOptions,
69}
70
71impl MssqlOperationalOptions {
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    pub fn with_timeouts(mut self, timeouts: MssqlTimeoutOptions) -> Self {
77        self.timeouts = timeouts;
78        self
79    }
80
81    pub fn with_retry(mut self, retry: MssqlRetryOptions) -> Self {
82        self.retry = retry;
83        self
84    }
85
86    pub fn with_tracing(mut self, tracing: MssqlTracingOptions) -> Self {
87        self.tracing = tracing;
88        self
89    }
90
91    pub fn with_slow_query(mut self, slow_query: MssqlSlowQueryOptions) -> Self {
92        self.slow_query = slow_query;
93        self
94    }
95
96    pub fn with_health(mut self, health: MssqlHealthCheckOptions) -> Self {
97        self.health = health;
98        self
99    }
100
101    pub fn with_pool(mut self, pool: MssqlPoolOptions) -> Self {
102        self.pool = pool;
103        self
104    }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
108pub struct MssqlTimeoutOptions {
109    pub connect_timeout: Option<Duration>,
110    pub query_timeout: Option<Duration>,
111    pub acquire_timeout: Option<Duration>,
112}
113
114impl MssqlTimeoutOptions {
115    pub fn new() -> Self {
116        Self::default()
117    }
118
119    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
120        self.connect_timeout = Some(timeout);
121        self
122    }
123
124    pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
125        self.query_timeout = Some(timeout);
126        self
127    }
128
129    pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
130        self.acquire_timeout = Some(timeout);
131        self
132    }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct MssqlRetryOptions {
137    pub enabled: bool,
138    pub max_retries: u32,
139    pub base_delay: Duration,
140    pub max_delay: Duration,
141}
142
143impl Default for MssqlRetryOptions {
144    fn default() -> Self {
145        Self {
146            enabled: false,
147            max_retries: 0,
148            base_delay: Duration::from_millis(100),
149            max_delay: Duration::from_secs(2),
150        }
151    }
152}
153
154impl MssqlRetryOptions {
155    pub fn disabled() -> Self {
156        Self::default()
157    }
158
159    pub fn enabled(max_retries: u32, base_delay: Duration, max_delay: Duration) -> Self {
160        Self {
161            enabled: true,
162            max_retries,
163            base_delay,
164            max_delay,
165        }
166    }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub struct MssqlTracingOptions {
171    pub enabled: bool,
172    pub parameter_logging: MssqlParameterLogMode,
173    pub emit_start_event: bool,
174    pub emit_finish_event: bool,
175    pub emit_error_event: bool,
176}
177
178impl Default for MssqlTracingOptions {
179    fn default() -> Self {
180        Self {
181            enabled: false,
182            parameter_logging: MssqlParameterLogMode::Redacted,
183            emit_start_event: true,
184            emit_finish_event: true,
185            emit_error_event: true,
186        }
187    }
188}
189
190impl MssqlTracingOptions {
191    pub fn disabled() -> Self {
192        Self::default()
193    }
194
195    pub fn enabled() -> Self {
196        Self {
197            enabled: true,
198            ..Self::default()
199        }
200    }
201
202    pub fn with_parameter_logging(mut self, parameter_logging: MssqlParameterLogMode) -> Self {
203        self.parameter_logging = parameter_logging;
204        self
205    }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub enum MssqlParameterLogMode {
210    Disabled,
211    Redacted,
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub struct MssqlSlowQueryOptions {
216    pub enabled: bool,
217    pub threshold: Duration,
218    pub parameter_logging: MssqlParameterLogMode,
219}
220
221impl Default for MssqlSlowQueryOptions {
222    fn default() -> Self {
223        Self {
224            enabled: false,
225            threshold: Duration::from_millis(500),
226            parameter_logging: MssqlParameterLogMode::Redacted,
227        }
228    }
229}
230
231impl MssqlSlowQueryOptions {
232    pub fn disabled() -> Self {
233        Self::default()
234    }
235
236    pub fn enabled(threshold: Duration) -> Self {
237        Self {
238            enabled: true,
239            threshold,
240            ..Self::default()
241        }
242    }
243
244    pub fn with_parameter_logging(mut self, parameter_logging: MssqlParameterLogMode) -> Self {
245        self.parameter_logging = parameter_logging;
246        self
247    }
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
251pub struct MssqlHealthCheckOptions {
252    pub enabled: bool,
253    pub query: MssqlHealthCheckQuery,
254    pub timeout: Option<Duration>,
255}
256
257impl Default for MssqlHealthCheckOptions {
258    fn default() -> Self {
259        Self {
260            enabled: false,
261            query: MssqlHealthCheckQuery::SelectOne,
262            timeout: None,
263        }
264    }
265}
266
267impl MssqlHealthCheckOptions {
268    pub fn disabled() -> Self {
269        Self::default()
270    }
271
272    pub fn enabled(query: MssqlHealthCheckQuery) -> Self {
273        Self {
274            enabled: true,
275            query,
276            timeout: None,
277        }
278    }
279
280    pub fn with_timeout(mut self, timeout: Duration) -> Self {
281        self.timeout = Some(timeout);
282        self
283    }
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub enum MssqlHealthCheckQuery {
288    SelectOne,
289}
290
291impl MssqlHealthCheckQuery {
292    pub(crate) fn sql(self) -> &'static str {
293        match self {
294            Self::SelectOne => "SELECT 1 AS [health_check]",
295        }
296    }
297}
298
299#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub struct MssqlPoolOptions {
301    pub enabled: bool,
302    pub backend: MssqlPoolBackend,
303    pub max_size: u32,
304    pub min_idle: Option<u32>,
305    pub acquire_timeout: Option<Duration>,
306    pub idle_timeout: Option<Duration>,
307    pub max_lifetime: Option<Duration>,
308}
309
310impl Default for MssqlPoolOptions {
311    fn default() -> Self {
312        Self {
313            enabled: false,
314            backend: MssqlPoolBackend::Bb8,
315            max_size: 10,
316            min_idle: None,
317            acquire_timeout: None,
318            idle_timeout: None,
319            max_lifetime: None,
320        }
321    }
322}
323
324impl MssqlPoolOptions {
325    pub fn disabled() -> Self {
326        Self::default()
327    }
328
329    pub fn bb8(max_size: u32) -> Self {
330        Self {
331            enabled: true,
332            backend: MssqlPoolBackend::Bb8,
333            max_size,
334            ..Self::default()
335        }
336    }
337
338    pub fn with_min_idle(mut self, min_idle: u32) -> Self {
339        self.min_idle = Some(min_idle);
340        self
341    }
342
343    pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
344        self.acquire_timeout = Some(timeout);
345        self
346    }
347
348    pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
349        self.idle_timeout = Some(timeout);
350        self
351    }
352
353    pub fn with_max_lifetime(mut self, timeout: Duration) -> Self {
354        self.max_lifetime = Some(timeout);
355        self
356    }
357}
358
359#[derive(Debug, Clone, Copy, PartialEq, Eq)]
360pub enum MssqlPoolBackend {
361    Bb8,
362}
363
364fn validate_config(config: &Config) -> Result<(), OrmError> {
365    let addr = config.get_addr();
366
367    if addr.is_empty() || addr.starts_with(':') {
368        return Err(OrmError::new("invalid SQL Server connection string"));
369    }
370
371    Ok(())
372}
373
374#[cfg(test)]
375mod tests {
376    use super::{
377        MssqlConnectionConfig, MssqlHealthCheckOptions, MssqlHealthCheckQuery,
378        MssqlOperationalOptions, MssqlParameterLogMode, MssqlPoolBackend, MssqlPoolOptions,
379        MssqlRetryOptions, MssqlSlowQueryOptions, MssqlTimeoutOptions, MssqlTracingOptions,
380    };
381    use std::time::Duration;
382
383    #[test]
384    fn parses_valid_ado_connection_string() {
385        let config = MssqlConnectionConfig::from_connection_string(
386            "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=sql-orm-tests",
387        )
388        .unwrap();
389
390        assert_eq!(
391            config.connection_string(),
392            "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true;Application Name=sql-orm-tests"
393        );
394        assert_eq!(config.addr(), "localhost:1433");
395        assert_eq!(config.options(), &MssqlOperationalOptions::default());
396    }
397
398    #[test]
399    fn preserves_explicit_operational_options() {
400        let options = MssqlOperationalOptions::new()
401            .with_timeouts(
402                MssqlTimeoutOptions::new()
403                    .with_connect_timeout(Duration::from_secs(5))
404                    .with_query_timeout(Duration::from_secs(30))
405                    .with_acquire_timeout(Duration::from_secs(2)),
406            )
407            .with_retry(MssqlRetryOptions::enabled(
408                3,
409                Duration::from_millis(50),
410                Duration::from_secs(1),
411            ))
412            .with_tracing(
413                MssqlTracingOptions::enabled()
414                    .with_parameter_logging(MssqlParameterLogMode::Disabled),
415            )
416            .with_slow_query(
417                MssqlSlowQueryOptions::enabled(Duration::from_millis(250))
418                    .with_parameter_logging(MssqlParameterLogMode::Disabled),
419            )
420            .with_health(
421                MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
422                    .with_timeout(Duration::from_secs(3)),
423            )
424            .with_pool(
425                MssqlPoolOptions::bb8(16)
426                    .with_min_idle(4)
427                    .with_acquire_timeout(Duration::from_secs(2))
428                    .with_idle_timeout(Duration::from_secs(30))
429                    .with_max_lifetime(Duration::from_secs(300)),
430            );
431
432        let config = MssqlConnectionConfig::from_connection_string_with_options(
433            "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true",
434            options.clone(),
435        )
436        .unwrap();
437
438        assert_eq!(config.options(), &options);
439        assert_eq!(config.options().pool.backend, MssqlPoolBackend::Bb8);
440        assert!(config.options().retry.enabled);
441        assert!(config.options().tracing.enabled);
442        assert!(config.options().slow_query.enabled);
443        assert!(config.options().health.enabled);
444        assert!(config.options().pool.enabled);
445    }
446
447    #[test]
448    fn can_replace_options_on_existing_config() {
449        let config = MssqlConnectionConfig::from_connection_string(
450            "server=tcp:localhost,1433;database=AppDb;user=sa;password=Password123;TrustServerCertificate=true",
451        )
452        .unwrap()
453        .with_options(MssqlOperationalOptions::new().with_tracing(MssqlTracingOptions::enabled()));
454
455        assert!(config.options().tracing.enabled);
456    }
457
458    #[test]
459    fn rejects_invalid_connection_string() {
460        let error = MssqlConnectionConfig::from_connection_string("server=").unwrap_err();
461
462        assert_eq!(error.message(), "invalid SQL Server connection string");
463    }
464
465    #[test]
466    fn health_check_query_uses_stable_sql() {
467        assert_eq!(
468            MssqlHealthCheckQuery::SelectOne.sql(),
469            "SELECT 1 AS [health_check]"
470        );
471    }
472}