Skip to main content

sqlx_sqlserver/
options.rs

1use percent_encoding::percent_decode_str;
2use sqlx_core::connection::ConnectOptions;
3use sqlx_core::error::Error;
4use std::path::{Path, PathBuf};
5use std::str::FromStr;
6use std::time::Duration;
7use thiserror::Error;
8use url::Url;
9
10use crate::MssqlConnection;
11
12/// SQL Server connection encryption preference.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Encrypt {
15    /// Encryption is not supported by the client.
16    NotSupported,
17    /// Use encryption when the server supports it.
18    Off,
19    /// Require encryption.
20    On,
21    /// Require encryption and certificate validation.
22    Required,
23}
24
25/// SQL Server connection options.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct MssqlConnectOptions {
28    host: String,
29    port: Option<u16>,
30    username: String,
31    password: Option<String>,
32    database: String,
33    instance: Option<String>,
34    encrypt: Encrypt,
35    trust_server_certificate: bool,
36    hostname_in_certificate: Option<String>,
37    ssl_root_cert: Option<PathBuf>,
38    requested_packet_size: u32,
39    client_program_version: u32,
40    client_pid: u32,
41    hostname: String,
42    app_name: String,
43    server_name: String,
44    client_interface_name: String,
45    language: String,
46}
47
48impl Default for MssqlConnectOptions {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl MssqlConnectOptions {
55    /// Creates options with SQL Server defaults.
56    pub fn new() -> Self {
57        Self {
58            host: "localhost".to_owned(),
59            port: None,
60            username: "sa".to_owned(),
61            password: None,
62            database: "master".to_owned(),
63            instance: None,
64            encrypt: Encrypt::On,
65            trust_server_certificate: true,
66            hostname_in_certificate: None,
67            ssl_root_cert: None,
68            requested_packet_size: 4096,
69            client_program_version: 0,
70            client_pid: 0,
71            hostname: String::new(),
72            app_name: String::new(),
73            server_name: String::new(),
74            client_interface_name: String::new(),
75            language: String::new(),
76        }
77    }
78
79    /// Parses SQL Server connection options while preserving detailed parser errors.
80    pub fn parse_url(input: &str) -> Result<Self, MssqlInvalidOption> {
81        parse_url(input)
82    }
83
84    /// Returns the configured host.
85    pub fn host(&self) -> &str {
86        &self.host
87    }
88
89    /// Returns the configured port, if one was explicitly set.
90    pub fn port(&self) -> Option<u16> {
91        self.port
92    }
93
94    /// Returns the configured username.
95    pub fn username(&self) -> &str {
96        &self.username
97    }
98
99    /// Returns the configured password.
100    pub fn password(&self) -> Option<&str> {
101        self.password.as_deref()
102    }
103
104    /// Returns the configured database.
105    pub fn database(&self) -> &str {
106        &self.database
107    }
108
109    /// Returns the configured named instance, if any.
110    pub fn instance(&self) -> Option<&str> {
111        self.instance.as_deref()
112    }
113
114    /// Returns the encryption preference.
115    pub fn encrypt(&self) -> Encrypt {
116        self.encrypt
117    }
118
119    /// Returns whether server certificate validation is bypassed.
120    pub fn trust_server_certificate(&self) -> bool {
121        self.trust_server_certificate
122    }
123
124    /// Returns the hostname expected in the server certificate.
125    pub fn hostname_in_certificate(&self) -> Option<&str> {
126        self.hostname_in_certificate.as_deref()
127    }
128
129    /// Returns the configured root certificate path.
130    pub fn ssl_root_cert(&self) -> Option<&Path> {
131        self.ssl_root_cert.as_deref()
132    }
133
134    /// Returns the requested TDS packet size.
135    pub fn requested_packet_size(&self) -> u32 {
136        self.requested_packet_size
137    }
138
139    /// Returns the client program version sent during login.
140    pub fn client_program_version(&self) -> u32 {
141        self.client_program_version
142    }
143
144    /// Returns the client process ID sent during login.
145    pub fn client_pid(&self) -> u32 {
146        self.client_pid
147    }
148
149    /// Returns the client host name sent during login.
150    pub fn hostname(&self) -> &str {
151        &self.hostname
152    }
153
154    /// Returns the application name sent during login.
155    pub fn app_name(&self) -> &str {
156        &self.app_name
157    }
158
159    /// Returns the server name sent during login.
160    pub fn server_name(&self) -> &str {
161        &self.server_name
162    }
163
164    /// Returns the client interface name sent during login.
165    pub fn client_interface_name(&self) -> &str {
166        &self.client_interface_name
167    }
168
169    /// Returns the language sent during login.
170    pub fn language(&self) -> &str {
171        &self.language
172    }
173
174    fn set_requested_packet_size(&mut self, size: u32) -> Result<(), MssqlInvalidOption> {
175        if size < 512 {
176            return Err(MssqlInvalidOption::InvalidValue {
177                key: "packet_size".to_owned(),
178                value: size.to_string(),
179                message: "packet_size must be at least 512 bytes".to_owned(),
180            });
181        }
182
183        self.requested_packet_size = size;
184        Ok(())
185    }
186
187    #[cfg(test)]
188    pub(crate) fn set_hostname_for_test(&mut self, hostname: String) {
189        self.hostname = hostname;
190    }
191
192    #[cfg(feature = "migrate")]
193    pub(crate) fn set_database_for_maintenance(&mut self) {
194        self.database = "master".to_owned();
195    }
196}
197
198impl FromStr for MssqlConnectOptions {
199    type Err = Error;
200
201    fn from_str(input: &str) -> Result<Self, Self::Err> {
202        Self::parse_url(input).map_err(Error::config)
203    }
204}
205
206impl ConnectOptions for MssqlConnectOptions {
207    type Connection = MssqlConnection;
208
209    fn from_url(url: &Url) -> Result<Self, Error> {
210        Self::parse_url(url.as_str()).map_err(Error::config)
211    }
212
213    async fn connect(&self) -> Result<Self::Connection, Error>
214    where
215        Self::Connection: Sized,
216    {
217        MssqlConnection::establish(self).await
218    }
219
220    fn log_statements(self, _level: log::LevelFilter) -> Self {
221        self
222    }
223
224    fn log_slow_statements(self, _level: log::LevelFilter, _duration: Duration) -> Self {
225        self
226    }
227}
228
229fn parse_url(input: &str) -> Result<MssqlConnectOptions, MssqlInvalidOption> {
230    let url = Url::parse(input).map_err(MssqlInvalidOption::Url)?;
231    match url.scheme() {
232        "mssql" | "sqlserver" => {}
233        scheme => return Err(MssqlInvalidOption::UnsupportedScheme(scheme.to_owned())),
234    }
235
236    let mut options = MssqlConnectOptions::new();
237
238    if let Some(host) = url.host_str() {
239        options.host = host.to_owned();
240    }
241
242    options.port = url.port();
243
244    let username = url.username();
245    if !username.is_empty() {
246        options.username = percent_decode_str(username)
247            .decode_utf8()
248            .map_err(MssqlInvalidOption::Utf8)?
249            .into_owned();
250    }
251
252    if let Some(password) = url.password() {
253        options.password = Some(
254            percent_decode_str(password)
255                .decode_utf8()
256                .map_err(MssqlInvalidOption::Utf8)?
257                .into_owned(),
258        );
259    }
260
261    let path = url.path().trim_start_matches('/');
262    if !path.is_empty() {
263        options.database = percent_decode_str(path)
264            .decode_utf8()
265            .map_err(MssqlInvalidOption::Utf8)?
266            .into_owned();
267    }
268
269    for (key, value) in url.query_pairs() {
270        match key.as_ref() {
271            "instance" => options.instance = Some(value.into_owned()),
272            "encrypt" => {
273                options.encrypt =
274                    parse_encrypt(&value).ok_or_else(|| MssqlInvalidOption::InvalidValue {
275                        key: "encrypt".to_owned(),
276                        value: value.into_owned(),
277                        message: "expected strict, mandatory, optional, not_supported, true, false, yes, or no"
278                            .to_owned(),
279                    })?;
280            }
281            "sslrootcert" | "ssl-root-cert" | "ssl-ca" => {
282                options.ssl_root_cert = Some(PathBuf::from(value.as_ref()));
283            }
284            "trust_server_certificate" => {
285                options.trust_server_certificate =
286                    parse_bool(&value).ok_or_else(|| MssqlInvalidOption::InvalidValue {
287                        key: key.into_owned(),
288                        value: value.into_owned(),
289                        message: "expected true, false, yes, or no".to_owned(),
290                    })?;
291            }
292            "hostname_in_certificate" => {
293                options.hostname_in_certificate = Some(value.into_owned());
294            }
295            "packet_size" => {
296                let size = value
297                    .parse()
298                    .map_err(|_| MssqlInvalidOption::InvalidValue {
299                        key: "packet_size".to_owned(),
300                        value: value.to_string(),
301                        message: "expected an integer".to_owned(),
302                    })?;
303                options.set_requested_packet_size(size)?;
304            }
305            "client_program_version" => options.client_program_version = parse_u32(&key, &value)?,
306            "client_pid" => options.client_pid = parse_u32(&key, &value)?,
307            "hostname" => options.hostname = value.into_owned(),
308            "app_name" => options.app_name = value.into_owned(),
309            "server_name" => options.server_name = value.into_owned(),
310            "client_interface_name" => options.client_interface_name = value.into_owned(),
311            "language" => options.language = value.into_owned(),
312            _ => return Err(MssqlInvalidOption::UnknownOption(key.into_owned())),
313        }
314    }
315
316    Ok(options)
317}
318
319fn parse_encrypt(value: &str) -> Option<Encrypt> {
320    match value.to_ascii_lowercase().as_str() {
321        "strict" => Some(Encrypt::Required),
322        "mandatory" | "true" | "yes" => Some(Encrypt::On),
323        "optional" | "false" | "no" => Some(Encrypt::Off),
324        "not_supported" => Some(Encrypt::NotSupported),
325        _ => None,
326    }
327}
328
329fn parse_bool(value: &str) -> Option<bool> {
330    match value.to_ascii_lowercase().as_str() {
331        "true" | "yes" => Some(true),
332        "false" | "no" => Some(false),
333        _ => None,
334    }
335}
336
337fn parse_u32(key: &str, value: &str) -> Result<u32, MssqlInvalidOption> {
338    value.parse().map_err(|_| MssqlInvalidOption::InvalidValue {
339        key: key.to_owned(),
340        value: value.to_owned(),
341        message: "expected an integer".to_owned(),
342    })
343}
344
345/// Error returned while parsing SQL Server connection options.
346#[derive(Debug, Error)]
347pub enum MssqlInvalidOption {
348    /// URL syntax was invalid.
349    #[error("invalid SQL Server URL: {0}")]
350    Url(#[from] url::ParseError),
351    /// Percent-decoded URL component was not valid UTF-8.
352    #[error("invalid UTF-8 in SQL Server URL component: {0}")]
353    Utf8(#[from] std::str::Utf8Error),
354    /// The URL scheme is not supported by this driver.
355    #[error("unsupported SQL Server URL scheme `{0}`")]
356    UnsupportedScheme(String),
357    /// A query parameter is not recognized.
358    #[error("unknown SQL Server connection option `{0}`")]
359    UnknownOption(String),
360    /// A query parameter value is invalid.
361    #[error("invalid value `{value}` for SQL Server connection option `{key}`: {message}")]
362    InvalidValue {
363        /// Option name.
364        key: String,
365        /// Option value.
366        value: String,
367        /// Validation message.
368        message: String,
369    },
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn parses_username_with_at_sign() {
378        let opts =
379            MssqlConnectOptions::parse_url("mssql://user%40domain:secret@example.com/database")
380                .unwrap();
381
382        assert_eq!("user@domain", opts.username());
383        assert_eq!(Some("secret"), opts.password());
384    }
385
386    #[test]
387    fn parses_password_with_at_sign() {
388        let opts =
389            MssqlConnectOptions::parse_url("mssql://username:p%40ssw0rd@example.com/database")
390                .unwrap();
391
392        assert_eq!(Some("p@ssw0rd"), opts.password());
393    }
394
395    #[test]
396    fn parses_named_instance_without_resolving_port() {
397        let opts = MssqlConnectOptions::parse_url(
398            "mssql://sa:secret@example.com/master?instance=SQLEXPRESS",
399        )
400        .unwrap();
401
402        assert_eq!("example.com", opts.host());
403        assert_eq!(None, opts.port());
404        assert_eq!(Some("SQLEXPRESS"), opts.instance());
405    }
406
407    #[test]
408    fn keeps_explicit_port_with_named_instance() {
409        let opts = MssqlConnectOptions::parse_url(
410            "mssql://sa:secret@example.com:1434/master?instance=SQLEXPRESS",
411        )
412        .unwrap();
413
414        assert_eq!(Some(1434), opts.port());
415        assert_eq!(Some("SQLEXPRESS"), opts.instance());
416    }
417
418    #[test]
419    fn parses_encryption_options() {
420        let strict =
421            MssqlConnectOptions::parse_url("mssql://localhost/master?encrypt=strict").unwrap();
422        let optional =
423            MssqlConnectOptions::parse_url("mssql://localhost/master?encrypt=optional").unwrap();
424        let disabled =
425            MssqlConnectOptions::parse_url("mssql://localhost/master?encrypt=not_supported")
426                .unwrap();
427
428        assert_eq!(Encrypt::Required, strict.encrypt());
429        assert_eq!(Encrypt::Off, optional.encrypt());
430        assert_eq!(Encrypt::NotSupported, disabled.encrypt());
431    }
432
433    #[test]
434    fn rejects_invalid_packet_size() {
435        let err = MssqlConnectOptions::parse_url("mssql://localhost/master?packet_size=128")
436            .expect_err("packet_size below 512 should be rejected");
437
438        assert!(err.to_string().contains("packet_size"));
439    }
440
441    #[test]
442    fn rejects_unknown_options() {
443        let err = MssqlConnectOptions::parse_url("mssql://localhost/master?mars=true")
444            .expect_err("unsupported options should fail loudly");
445
446        assert!(matches!(err, MssqlInvalidOption::UnknownOption(_)));
447    }
448}