Skip to main content

sqlx_mssql_odbc_core/
options.rs

1use crate::{MssqlConnection, Result};
2use log::LevelFilter;
3use std::fmt::{self, Debug, Formatter};
4use std::str::FromStr;
5use std::time::Duration;
6use url::Url;
7
8/// Fetch-buffer settings used by the MSSQL ODBC driver.
9///
10/// `max_column_size = Some(_)` enables buffered fetching and can truncate long text or binary
11/// fields to the configured size. `max_column_size = None` keeps fetching unbuffered so variable
12/// sized values are not truncated by this crate's buffer allocation.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct MssqlBufferSettings {
15    /// Number of rows fetched in each batch.
16    pub batch_size: usize,
17    /// Maximum text or binary column size in buffered mode, or `None` for unbuffered mode.
18    pub max_column_size: Option<usize>,
19}
20
21impl Default for MssqlBufferSettings {
22    fn default() -> Self {
23        Self {
24            batch_size: 64,
25            max_column_size: None,
26        }
27    }
28}
29
30/// Connection options for an MSSQL ODBC data source.
31#[derive(Clone)]
32pub struct MssqlConnectOptions {
33    pub(crate) conn_str: String,
34    pub(crate) buffer_settings: MssqlBufferSettings,
35    pub(crate) statement_cache_capacity: usize,
36    pub(crate) log_statements: LevelFilter,
37    pub(crate) log_slow_statements: LevelFilter,
38    pub(crate) log_slow_statement_duration: Duration,
39}
40
41impl MssqlConnectOptions {
42    /// Returns the normalized ODBC connection string.
43    pub fn connection_string(&self) -> &str {
44        &self.conn_str
45    }
46
47    /// Sets the buffer configuration for this connection.
48    pub fn buffer_settings(&mut self, settings: MssqlBufferSettings) -> &mut Self {
49        assert!(settings.batch_size > 0, "batch_size must be greater than 0");
50        if let Some(size) = settings.max_column_size {
51            assert!(size > 0, "max_column_size must be greater than 0");
52        }
53
54        self.buffer_settings = settings;
55        self
56    }
57
58    /// Returns the current buffer settings.
59    pub fn buffer_settings_ref(&self) -> &MssqlBufferSettings {
60        &self.buffer_settings
61    }
62
63    /// Sets the number of rows fetched in each batch.
64    pub fn batch_size(&mut self, batch_size: usize) -> &mut Self {
65        assert!(batch_size > 0, "batch_size must be greater than 0");
66        self.buffer_settings.batch_size = batch_size;
67        self
68    }
69
70    /// Sets the maximum buffered column size, or `None` for unbuffered fetching.
71    pub fn max_column_size(&mut self, max_column_size: Option<usize>) -> &mut Self {
72        if let Some(size) = max_column_size {
73            assert!(size > 0, "max_column_size must be greater than 0");
74        }
75
76        self.buffer_settings.max_column_size = max_column_size;
77        self
78    }
79
80    /// Sets the maximum number of prepared statements kept in this connection's cache.
81    pub fn statement_cache_capacity(&mut self, capacity: usize) -> &mut Self {
82        self.statement_cache_capacity = capacity;
83        self
84    }
85
86    /// Sets regular statement logging level.
87    pub fn log_statements(&mut self, level: LevelFilter) -> &mut Self {
88        self.log_statements = level;
89        self
90    }
91
92    /// Sets slow statement logging level and threshold.
93    pub fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) -> &mut Self {
94        self.log_slow_statements = level;
95        self.log_slow_statement_duration = duration;
96        self
97    }
98
99    /// Enables or disables TLS encryption for the connection.
100    ///
101    /// When enabled, adds `Encrypt=yes` to the connection string.
102    pub fn encrypt(&mut self, enable: bool) -> &mut Self {
103        if enable && !self.conn_str.contains("Encrypt=") {
104            self.conn_str.push_str(";Encrypt=yes");
105        }
106        self
107    }
108
109    /// Enables or disables server certificate validation.
110    ///
111    /// When enabled alongside `encrypt(true)`, adds `TrustServerCertificate=yes`
112    /// to the connection string. Useful for development environments with
113    /// self-signed certificates.
114    pub fn trust_certificate(&mut self, enable: bool) -> &mut Self {
115        if enable && !self.conn_str.contains("TrustServerCertificate=") {
116            self.conn_str.push_str(";TrustServerCertificate=yes");
117        }
118        self
119    }
120
121    /// Opens a blocking MSSQL ODBC connection.
122    pub fn connect_blocking(&self) -> Result<MssqlConnection> {
123        MssqlConnection::connect_blocking(self)
124    }
125
126    /// Returns a copy of these options with the database name replaced.
127    ///
128    /// The connection string is searched case-insensitively for a `Database=` key;
129    /// if found it is replaced, otherwise the new database is appended.
130    #[cfg(feature = "migrate")]
131    pub(crate) fn with_database(&self, database: &str) -> Self {
132        let mut new = self.clone();
133        let escaped = escape_odbc_value(database);
134        let upper = new.conn_str.to_uppercase();
135        let search = "DATABASE=";
136
137        if let Some(pos) = upper.find(search) {
138            let start = pos;
139            let end = new.conn_str[start..]
140                .find(';')
141                .map(|i| start + i)
142                .unwrap_or(new.conn_str.len());
143            let before = new.conn_str[..start].to_owned();
144            let after = new.conn_str[end..].to_owned();
145            new.conn_str = format!("{before}Database={escaped}{after}");
146        } else {
147            new.conn_str.push_str(&format!(";Database={escaped}"));
148        }
149        new
150    }
151}
152
153impl Debug for MssqlConnectOptions {
154    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
155        f.debug_struct("MssqlConnectOptions")
156            .field("conn_str", &"<redacted>")
157            .field("buffer_settings", &self.buffer_settings)
158            .field("statement_cache_capacity", &self.statement_cache_capacity)
159            .field("log_statements", &self.log_statements)
160            .field("log_slow_statements", &self.log_slow_statements)
161            .field(
162                "log_slow_statement_duration",
163                &self.log_slow_statement_duration,
164            )
165            .finish()
166    }
167}
168
169/// Escapes a value for use in an ODBC connection string.
170///
171/// Values containing `;`, `{`, `}`, or `=` are wrapped in braces, with any
172/// literal `}` doubled per the ODBC connection-string escaping convention.
173fn escape_odbc_value(value: &str) -> String {
174    if value.contains(';') || value.contains('{') || value.contains('}') || value.contains('=') {
175        format!("{{{}}}", value.replace('}', "}}"))
176    } else {
177        value.to_owned()
178    }
179}
180
181/// Builds an ODBC connection string from a `mssql://` URL.
182///
183/// Supported URL format:
184/// `mssql://user:password@host:port/database?param=value`
185///
186/// Supported query parameters:
187/// - `trust_certificate=true` — adds `TrustServerCertificate=yes`
188/// - `encrypt=true` — adds `Encrypt=yes`
189/// - `driver=...` — custom ODBC driver name
190fn mssql_url_to_connection_string(url: &Url) -> String {
191    let scheme = url.scheme();
192    let is_mssql = scheme.eq_ignore_ascii_case("mssql");
193
194    // Only handle mssql:// URLs; odbc:// or other schemes pass through
195    if !is_mssql && !scheme.eq_ignore_ascii_case("odbc") {
196        return url.as_str().to_owned();
197    }
198
199    let host = url.host_str().unwrap_or("localhost");
200    let port = url.port().unwrap_or(1433);
201    let database = url.path().trim_start_matches('/');
202    let username = url.username();
203    let password = url.password().unwrap_or_default();
204
205    let mut conn_str = format!(
206        "Driver={{ODBC Driver 18 for SQL Server}};Server={host},{port}"
207    );
208
209    if !database.is_empty() {
210        conn_str.push_str(&format!(";Database={}", escape_odbc_value(database)));
211    }
212    if !username.is_empty() {
213        conn_str.push_str(&format!(";UID={}", escape_odbc_value(username)));
214    }
215    if !password.is_empty() {
216        conn_str.push_str(&format!(";PWD={}", escape_odbc_value(password)));
217    }
218
219    // Parse query parameters
220    for (key, value) in url.query_pairs() {
221        match key.as_ref() {
222            "trust_certificate" if value == "true" => {
223                if !conn_str.contains("TrustServerCertificate=") {
224                    conn_str.push_str(";TrustServerCertificate=yes");
225                }
226            }
227            "encrypt" if value == "true" => {
228                if !conn_str.contains("Encrypt=") {
229                    conn_str.push_str(";Encrypt=yes");
230                }
231            }
232            "driver" => {
233                let driver_val = format!("Driver={value}");
234                if let Some(pos) = conn_str.find("Driver=") {
235                    let end = conn_str[pos..].find(';').map(|i| pos + i).unwrap_or(conn_str.len());
236                    conn_str.replace_range(pos..end, &driver_val);
237                }
238            }
239            _ => {}
240        }
241    }
242
243    conn_str
244}
245
246impl FromStr for MssqlConnectOptions {
247    type Err = sqlx_core::Error;
248
249    fn from_str(input: &str) -> std::result::Result<Self, Self::Err> {
250        let trimmed = input.trim();
251
252        // Legacy support: strip odbc: prefix before URL parsing
253        let (trimmed, _had_odbc_prefix) = if let Some(rest) = trimmed.strip_prefix("odbc:") {
254            (rest, true)
255        } else {
256            (trimmed, false)
257        };
258
259        // Try to parse as a mssql:// URL (only for actual mssql:// scheme)
260        if trimmed.starts_with("mssql://") || trimmed.starts_with("mssql:") {
261            if let Ok(url) = Url::parse(trimmed) {
262                let scheme = url.scheme();
263                if scheme.eq_ignore_ascii_case("mssql") {
264                    let conn_str = mssql_url_to_connection_string(&url);
265
266                    return Ok(Self {
267                        conn_str,
268                        buffer_settings: MssqlBufferSettings::default(),
269                        statement_cache_capacity: 100,
270                        log_statements: LevelFilter::Debug,
271                        log_slow_statements: LevelFilter::Warn,
272                        log_slow_statement_duration: Duration::from_secs(1),
273                    });
274                }
275            }
276        }
277
278        // Treat as raw ODBC connection string (or bare DSN)
279        let conn_str = if trimmed.contains('=') {
280            trimmed.to_owned()
281        } else {
282            format!("DSN={trimmed}")
283        };
284
285        Ok(Self {
286            conn_str,
287            buffer_settings: MssqlBufferSettings::default(),
288            statement_cache_capacity: 100,
289            log_statements: LevelFilter::Debug,
290            log_slow_statements: LevelFilter::Warn,
291            log_slow_statement_duration: Duration::from_secs(1),
292        })
293    }
294}
295
296impl sqlx_core::connection::ConnectOptions for MssqlConnectOptions {
297    type Connection = MssqlConnection;
298
299    fn from_url(url: &Url) -> std::result::Result<Self, sqlx_core::Error> {
300        Self::from_str(url.as_str())
301    }
302
303    async fn connect(&self) -> std::result::Result<Self::Connection, sqlx_core::Error> {
304        self.connect_blocking().map_err(Into::into)
305    }
306
307    fn log_statements(mut self, level: LevelFilter) -> Self {
308        self.log_statements = level;
309        self
310    }
311
312    fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
313        self.log_slow_statements = level;
314        self.log_slow_statement_duration = duration;
315        self
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn parses_mssql_url_with_all_components() {
325        let url = "mssql://sa:Password1!@server.example.com:1433/testdb";
326        let options = MssqlConnectOptions::from_str(url).unwrap();
327        let cs = options.connection_string();
328        assert!(cs.contains("Driver={ODBC Driver 18 for SQL Server}"));
329        assert!(cs.contains("Server=server.example.com,1433"));
330        assert!(cs.contains("Database=testdb"));
331        assert!(cs.contains("UID=sa"));
332        assert!(cs.contains("PWD=Password1!"));
333    }
334
335    #[test]
336    fn parses_mssql_url_with_default_port() {
337        let url = "mssql://user:pass@localhost/mydb";
338        let options = MssqlConnectOptions::from_str(url).unwrap();
339        let cs = options.connection_string();
340        assert!(cs.contains("Server=localhost,1433"));
341    }
342
343    #[test]
344    fn parses_mssql_url_without_credentials() {
345        let url = "mssql://localhost/mydb";
346        let options = MssqlConnectOptions::from_str(url).unwrap();
347        let cs = options.connection_string();
348        assert!(cs.contains("Server=localhost,1433"));
349        assert!(cs.contains("Database=mydb"));
350        assert!(!cs.contains("UID="));
351        assert!(!cs.contains("PWD="));
352    }
353
354    #[test]
355    fn parses_mssql_url_with_trust_certificate() {
356        let url = "mssql://localhost/mydb?trust_certificate=true";
357        let options = MssqlConnectOptions::from_str(url).unwrap();
358        let cs = options.connection_string();
359        assert!(cs.contains("TrustServerCertificate=yes"));
360    }
361
362    #[test]
363    fn parses_mssql_url_with_encrypt() {
364        let url = "mssql://localhost/mydb?encrypt=true";
365        let options = MssqlConnectOptions::from_str(url).unwrap();
366        let cs = options.connection_string();
367        assert!(cs.contains("Encrypt=yes"));
368    }
369
370    #[test]
371    fn parses_mssql_url_with_custom_driver() {
372        let url = "mssql://localhost/mydb?driver={ODBC Driver 17 for SQL Server}";
373        let options = MssqlConnectOptions::from_str(url).unwrap();
374        let cs = options.connection_string();
375        assert!(cs.contains("Driver={ODBC Driver 17 for SQL Server}"));
376    }
377
378    #[test]
379    fn preserves_raw_odbc_connection_strings() {
380        let input = "Driver={ODBC Driver 17 for SQL Server};Server=localhost;Database=test";
381        let options = MssqlConnectOptions::from_str(input).unwrap();
382        assert_eq!(options.connection_string(), input);
383    }
384
385    #[test]
386    fn supports_dsn_format() {
387        let options = MssqlConnectOptions::from_str("MyMssqlDSN").unwrap();
388        assert_eq!(options.connection_string(), "DSN=MyMssqlDSN");
389    }
390
391    #[test]
392    fn strips_legacy_odbc_prefix() {
393        let options = MssqlConnectOptions::from_str("odbc:DSN=Warehouse").unwrap();
394        assert_eq!(options.connection_string(), "DSN=Warehouse");
395    }
396
397    #[test]
398    fn encrypt_method_adds_encrypt() {
399        let mut options = MssqlConnectOptions::from_str("DSN=Test").unwrap();
400        options.encrypt(true);
401        assert!(options.connection_string().contains("Encrypt=yes"));
402    }
403
404    #[test]
405    fn trust_certificate_method_adds_flag() {
406        let mut options = MssqlConnectOptions::from_str("DSN=Test").unwrap();
407        options.trust_certificate(true);
408        assert!(options.connection_string().contains("TrustServerCertificate=yes"));
409    }
410
411    #[test]
412    fn updates_buffer_settings_incrementally() {
413        let mut options = MssqlConnectOptions::from_str("DSN=Test").unwrap();
414        options.batch_size(128).max_column_size(Some(2048));
415        assert_eq!(options.buffer_settings.batch_size, 128);
416        assert_eq!(options.buffer_settings.max_column_size, Some(2048));
417    }
418
419    #[test]
420    fn escape_odbc_value_preserves_safe_values() {
421        assert_eq!(escape_odbc_value("simple"), "simple");
422        assert_eq!(escape_odbc_value(""), "");
423        assert_eq!(escape_odbc_value("abc123"), "abc123");
424    }
425
426    #[test]
427    fn escape_odbc_value_wraps_values_with_special_chars() {
428        assert_eq!(escape_odbc_value("pass;word"), "{pass;word}");
429        assert_eq!(escape_odbc_value("pass=word"), "{pass=word}");
430        assert_eq!(escape_odbc_value("pass{word"), "{pass{word}");
431        assert_eq!(escape_odbc_value("pass}word"), "{pass}}word}");
432        assert_eq!(escape_odbc_value("a}b}c"), "{a}}b}}c}");
433    }
434
435    #[test]
436    fn parses_mssql_url_with_special_chars_in_password() {
437        // The url crate preserves percent-encoded characters in the password.
438        // escape_odbc_value sees the encoded form (no ODBC-special chars)
439        // and passes it through as-is, which is correct since the ODBC driver
440        // receives the already-encoded value.
441        let url = "mssql://user:a%3Bb%3Dc%7Dd@localhost/mydb";
442        let options = MssqlConnectOptions::from_str(url).unwrap();
443        let cs = options.connection_string();
444        assert!(
445            cs.contains("PWD=a%3Bb%3Dc%7Dd"),
446            "password not included correctly; got: {cs}"
447        );
448    }
449}