1use percent_encoding::percent_decode_str;
2use sqlx_core::connection::ConnectOptions;
3use sqlx_core::error::Error;
4use std::path::{Path, PathBuf};
5use std::str::FromStr;
6use std::time::Duration;
7use thiserror::Error;
8use url::Url;
9
10use crate::MssqlConnection;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Encrypt {
15 NotSupported,
17 Off,
19 On,
21 Required,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct MssqlConnectOptions {
28 host: String,
29 port: Option<u16>,
30 username: String,
31 password: Option<String>,
32 database: String,
33 instance: Option<String>,
34 encrypt: Encrypt,
35 trust_server_certificate: bool,
36 hostname_in_certificate: Option<String>,
37 ssl_root_cert: Option<PathBuf>,
38 requested_packet_size: u32,
39 client_program_version: u32,
40 client_pid: u32,
41 hostname: String,
42 app_name: String,
43 server_name: String,
44 client_interface_name: String,
45 language: String,
46}
47
48impl Default for MssqlConnectOptions {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl MssqlConnectOptions {
55 pub fn new() -> Self {
57 Self {
58 host: "localhost".to_owned(),
59 port: None,
60 username: "sa".to_owned(),
61 password: None,
62 database: "master".to_owned(),
63 instance: None,
64 encrypt: Encrypt::On,
65 trust_server_certificate: true,
66 hostname_in_certificate: None,
67 ssl_root_cert: None,
68 requested_packet_size: 4096,
69 client_program_version: 0,
70 client_pid: 0,
71 hostname: String::new(),
72 app_name: String::new(),
73 server_name: String::new(),
74 client_interface_name: String::new(),
75 language: String::new(),
76 }
77 }
78
79 pub fn parse_url(input: &str) -> Result<Self, MssqlInvalidOption> {
81 parse_url(input)
82 }
83
84 pub fn host(&self) -> &str {
86 &self.host
87 }
88
89 pub fn port(&self) -> Option<u16> {
91 self.port
92 }
93
94 pub fn username(&self) -> &str {
96 &self.username
97 }
98
99 pub fn password(&self) -> Option<&str> {
101 self.password.as_deref()
102 }
103
104 pub fn database(&self) -> &str {
106 &self.database
107 }
108
109 pub fn instance(&self) -> Option<&str> {
111 self.instance.as_deref()
112 }
113
114 pub fn encrypt(&self) -> Encrypt {
116 self.encrypt
117 }
118
119 pub fn trust_server_certificate(&self) -> bool {
121 self.trust_server_certificate
122 }
123
124 pub fn hostname_in_certificate(&self) -> Option<&str> {
126 self.hostname_in_certificate.as_deref()
127 }
128
129 pub fn ssl_root_cert(&self) -> Option<&Path> {
131 self.ssl_root_cert.as_deref()
132 }
133
134 pub fn requested_packet_size(&self) -> u32 {
136 self.requested_packet_size
137 }
138
139 pub fn client_program_version(&self) -> u32 {
141 self.client_program_version
142 }
143
144 pub fn client_pid(&self) -> u32 {
146 self.client_pid
147 }
148
149 pub fn hostname(&self) -> &str {
151 &self.hostname
152 }
153
154 pub fn app_name(&self) -> &str {
156 &self.app_name
157 }
158
159 pub fn server_name(&self) -> &str {
161 &self.server_name
162 }
163
164 pub fn client_interface_name(&self) -> &str {
166 &self.client_interface_name
167 }
168
169 pub fn language(&self) -> &str {
171 &self.language
172 }
173
174 fn set_requested_packet_size(&mut self, size: u32) -> Result<(), MssqlInvalidOption> {
175 if size < 512 {
176 return Err(MssqlInvalidOption::InvalidValue {
177 key: "packet_size".to_owned(),
178 value: size.to_string(),
179 message: "packet_size must be at least 512 bytes".to_owned(),
180 });
181 }
182
183 self.requested_packet_size = size;
184 Ok(())
185 }
186
187 #[cfg(test)]
188 pub(crate) fn set_hostname_for_test(&mut self, hostname: String) {
189 self.hostname = hostname;
190 }
191
192 #[cfg(feature = "migrate")]
193 pub(crate) fn set_database_for_maintenance(&mut self) {
194 self.database = "master".to_owned();
195 }
196}
197
198impl FromStr for MssqlConnectOptions {
199 type Err = Error;
200
201 fn from_str(input: &str) -> Result<Self, Self::Err> {
202 Self::parse_url(input).map_err(Error::config)
203 }
204}
205
206impl ConnectOptions for MssqlConnectOptions {
207 type Connection = MssqlConnection;
208
209 fn from_url(url: &Url) -> Result<Self, Error> {
210 Self::parse_url(url.as_str()).map_err(Error::config)
211 }
212
213 async fn connect(&self) -> Result<Self::Connection, Error>
214 where
215 Self::Connection: Sized,
216 {
217 MssqlConnection::establish(self).await
218 }
219
220 fn log_statements(self, _level: log::LevelFilter) -> Self {
221 self
222 }
223
224 fn log_slow_statements(self, _level: log::LevelFilter, _duration: Duration) -> Self {
225 self
226 }
227}
228
229fn parse_url(input: &str) -> Result<MssqlConnectOptions, MssqlInvalidOption> {
230 let url = Url::parse(input).map_err(MssqlInvalidOption::Url)?;
231 match url.scheme() {
232 "mssql" | "sqlserver" => {}
233 scheme => return Err(MssqlInvalidOption::UnsupportedScheme(scheme.to_owned())),
234 }
235
236 let mut options = MssqlConnectOptions::new();
237
238 if let Some(host) = url.host_str() {
239 options.host = host.to_owned();
240 }
241
242 options.port = url.port();
243
244 let username = url.username();
245 if !username.is_empty() {
246 options.username = percent_decode_str(username)
247 .decode_utf8()
248 .map_err(MssqlInvalidOption::Utf8)?
249 .into_owned();
250 }
251
252 if let Some(password) = url.password() {
253 options.password = Some(
254 percent_decode_str(password)
255 .decode_utf8()
256 .map_err(MssqlInvalidOption::Utf8)?
257 .into_owned(),
258 );
259 }
260
261 let path = url.path().trim_start_matches('/');
262 if !path.is_empty() {
263 options.database = percent_decode_str(path)
264 .decode_utf8()
265 .map_err(MssqlInvalidOption::Utf8)?
266 .into_owned();
267 }
268
269 for (key, value) in url.query_pairs() {
270 match key.as_ref() {
271 "instance" => options.instance = Some(value.into_owned()),
272 "encrypt" => {
273 options.encrypt =
274 parse_encrypt(&value).ok_or_else(|| MssqlInvalidOption::InvalidValue {
275 key: "encrypt".to_owned(),
276 value: value.into_owned(),
277 message: "expected strict, mandatory, optional, not_supported, true, false, yes, or no"
278 .to_owned(),
279 })?;
280 }
281 "sslrootcert" | "ssl-root-cert" | "ssl-ca" => {
282 options.ssl_root_cert = Some(PathBuf::from(value.as_ref()));
283 }
284 "trust_server_certificate" => {
285 options.trust_server_certificate =
286 parse_bool(&value).ok_or_else(|| MssqlInvalidOption::InvalidValue {
287 key: key.into_owned(),
288 value: value.into_owned(),
289 message: "expected true, false, yes, or no".to_owned(),
290 })?;
291 }
292 "hostname_in_certificate" => {
293 options.hostname_in_certificate = Some(value.into_owned());
294 }
295 "packet_size" => {
296 let size = value
297 .parse()
298 .map_err(|_| MssqlInvalidOption::InvalidValue {
299 key: "packet_size".to_owned(),
300 value: value.to_string(),
301 message: "expected an integer".to_owned(),
302 })?;
303 options.set_requested_packet_size(size)?;
304 }
305 "client_program_version" => options.client_program_version = parse_u32(&key, &value)?,
306 "client_pid" => options.client_pid = parse_u32(&key, &value)?,
307 "hostname" => options.hostname = value.into_owned(),
308 "app_name" => options.app_name = value.into_owned(),
309 "server_name" => options.server_name = value.into_owned(),
310 "client_interface_name" => options.client_interface_name = value.into_owned(),
311 "language" => options.language = value.into_owned(),
312 _ => return Err(MssqlInvalidOption::UnknownOption(key.into_owned())),
313 }
314 }
315
316 Ok(options)
317}
318
319fn parse_encrypt(value: &str) -> Option<Encrypt> {
320 match value.to_ascii_lowercase().as_str() {
321 "strict" => Some(Encrypt::Required),
322 "mandatory" | "true" | "yes" => Some(Encrypt::On),
323 "optional" | "false" | "no" => Some(Encrypt::Off),
324 "not_supported" => Some(Encrypt::NotSupported),
325 _ => None,
326 }
327}
328
329fn parse_bool(value: &str) -> Option<bool> {
330 match value.to_ascii_lowercase().as_str() {
331 "true" | "yes" => Some(true),
332 "false" | "no" => Some(false),
333 _ => None,
334 }
335}
336
337fn parse_u32(key: &str, value: &str) -> Result<u32, MssqlInvalidOption> {
338 value.parse().map_err(|_| MssqlInvalidOption::InvalidValue {
339 key: key.to_owned(),
340 value: value.to_owned(),
341 message: "expected an integer".to_owned(),
342 })
343}
344
345#[derive(Debug, Error)]
347pub enum MssqlInvalidOption {
348 #[error("invalid SQL Server URL: {0}")]
350 Url(#[from] url::ParseError),
351 #[error("invalid UTF-8 in SQL Server URL component: {0}")]
353 Utf8(#[from] std::str::Utf8Error),
354 #[error("unsupported SQL Server URL scheme `{0}`")]
356 UnsupportedScheme(String),
357 #[error("unknown SQL Server connection option `{0}`")]
359 UnknownOption(String),
360 #[error("invalid value `{value}` for SQL Server connection option `{key}`: {message}")]
362 InvalidValue {
363 key: String,
365 value: String,
367 message: String,
369 },
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn parses_username_with_at_sign() {
378 let opts =
379 MssqlConnectOptions::parse_url("mssql://user%40domain:secret@example.com/database")
380 .unwrap();
381
382 assert_eq!("user@domain", opts.username());
383 assert_eq!(Some("secret"), opts.password());
384 }
385
386 #[test]
387 fn parses_password_with_at_sign() {
388 let opts =
389 MssqlConnectOptions::parse_url("mssql://username:p%40ssw0rd@example.com/database")
390 .unwrap();
391
392 assert_eq!(Some("p@ssw0rd"), opts.password());
393 }
394
395 #[test]
396 fn parses_named_instance_without_resolving_port() {
397 let opts = MssqlConnectOptions::parse_url(
398 "mssql://sa:secret@example.com/master?instance=SQLEXPRESS",
399 )
400 .unwrap();
401
402 assert_eq!("example.com", opts.host());
403 assert_eq!(None, opts.port());
404 assert_eq!(Some("SQLEXPRESS"), opts.instance());
405 }
406
407 #[test]
408 fn keeps_explicit_port_with_named_instance() {
409 let opts = MssqlConnectOptions::parse_url(
410 "mssql://sa:secret@example.com:1434/master?instance=SQLEXPRESS",
411 )
412 .unwrap();
413
414 assert_eq!(Some(1434), opts.port());
415 assert_eq!(Some("SQLEXPRESS"), opts.instance());
416 }
417
418 #[test]
419 fn parses_encryption_options() {
420 let strict =
421 MssqlConnectOptions::parse_url("mssql://localhost/master?encrypt=strict").unwrap();
422 let optional =
423 MssqlConnectOptions::parse_url("mssql://localhost/master?encrypt=optional").unwrap();
424 let disabled =
425 MssqlConnectOptions::parse_url("mssql://localhost/master?encrypt=not_supported")
426 .unwrap();
427
428 assert_eq!(Encrypt::Required, strict.encrypt());
429 assert_eq!(Encrypt::Off, optional.encrypt());
430 assert_eq!(Encrypt::NotSupported, disabled.encrypt());
431 }
432
433 #[test]
434 fn rejects_invalid_packet_size() {
435 let err = MssqlConnectOptions::parse_url("mssql://localhost/master?packet_size=128")
436 .expect_err("packet_size below 512 should be rejected");
437
438 assert!(err.to_string().contains("packet_size"));
439 }
440
441 #[test]
442 fn rejects_unknown_options() {
443 let err = MssqlConnectOptions::parse_url("mssql://localhost/master?mars=true")
444 .expect_err("unsupported options should fail loudly");
445
446 assert!(matches!(err, MssqlInvalidOption::UnknownOption(_)));
447 }
448}