Skip to main content

sqlmodel_mysql/
config.rs

1//! MySQL connection configuration.
2//!
3//! Provides connection parameters for establishing MySQL connections
4//! including authentication, SSL, and connection options.
5
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::Duration;
9
10/// TLS/SSL configuration for MySQL connections.
11///
12/// This struct holds the certificate and key paths for TLS connections.
13/// The actual TLS implementation requires the `tls` feature to be enabled.
14#[derive(Debug, Clone, Default)]
15pub struct TlsConfig {
16    /// Path to CA certificate file (PEM format) for server verification.
17    /// Required for `SslMode::VerifyCa` and `SslMode::VerifyIdentity`.
18    pub ca_cert_path: Option<PathBuf>,
19
20    /// Path to client certificate file (PEM format) for mutual TLS.
21    /// Optional - only needed if server requires client certificate.
22    pub client_cert_path: Option<PathBuf>,
23
24    /// Path to client private key file (PEM format) for mutual TLS.
25    /// Required if `client_cert_path` is set.
26    pub client_key_path: Option<PathBuf>,
27
28    /// Skip server certificate verification.
29    ///
30    /// # Security Warning
31    /// Setting this to `true` disables certificate verification, making the
32    /// connection vulnerable to man-in-the-middle attacks. Only use for
33    /// development/testing with self-signed certificates.
34    pub danger_skip_verify: bool,
35
36    /// Server name for SNI (Server Name Indication).
37    /// If not set, defaults to the connection hostname.
38    pub server_name: Option<String>,
39}
40
41impl TlsConfig {
42    /// Create a new TLS configuration with default values.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Set the CA certificate path.
48    pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
49        self.ca_cert_path = Some(path.into());
50        self
51    }
52
53    /// Set the client certificate path.
54    pub fn client_cert(mut self, path: impl Into<PathBuf>) -> Self {
55        self.client_cert_path = Some(path.into());
56        self
57    }
58
59    /// Set the client key path.
60    pub fn client_key(mut self, path: impl Into<PathBuf>) -> Self {
61        self.client_key_path = Some(path.into());
62        self
63    }
64
65    /// Skip server certificate verification (dangerous!).
66    pub fn skip_verify(mut self, skip: bool) -> Self {
67        self.danger_skip_verify = skip;
68        self
69    }
70
71    /// Set the server name for SNI.
72    pub fn server_name(mut self, name: impl Into<String>) -> Self {
73        self.server_name = Some(name.into());
74        self
75    }
76
77    /// Check if mutual TLS (client certificate) is configured.
78    pub fn has_client_cert(&self) -> bool {
79        self.client_cert_path.is_some() && self.client_key_path.is_some()
80    }
81}
82
83/// SSL mode for MySQL connections.
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
85pub enum SslMode {
86    /// Do not use SSL
87    #[default]
88    Disable,
89    /// Prefer SSL if available, fall back to non-SSL
90    Preferred,
91    /// Require SSL connection
92    Required,
93    /// Require SSL and verify server certificate
94    VerifyCa,
95    /// Require SSL and verify server certificate matches hostname
96    VerifyIdentity,
97}
98
99impl SslMode {
100    /// Check if SSL should be attempted.
101    pub const fn should_try_ssl(self) -> bool {
102        !matches!(self, SslMode::Disable)
103    }
104
105    /// Check if SSL is required.
106    pub const fn is_required(self) -> bool {
107        matches!(
108            self,
109            SslMode::Required | SslMode::VerifyCa | SslMode::VerifyIdentity
110        )
111    }
112}
113
114/// MySQL connection configuration.
115#[derive(Debug, Clone)]
116pub struct MySqlConfig {
117    /// Hostname or IP address
118    pub host: String,
119    /// Port number (default: 3306)
120    pub port: u16,
121    /// Username for authentication
122    pub user: String,
123    /// Password for authentication
124    pub password: Option<String>,
125    /// Database name to connect to (optional at connect time)
126    pub database: Option<String>,
127    /// Character set (default: utf8mb4)
128    pub charset: u8,
129    /// Connection timeout
130    pub connect_timeout: Duration,
131    /// SSL mode
132    pub ssl_mode: SslMode,
133    /// TLS configuration (certificates, keys, etc.)
134    pub tls_config: TlsConfig,
135    /// Enable compression (CLIENT_COMPRESS capability)
136    pub compression: bool,
137    /// Additional connection attributes
138    pub attributes: HashMap<String, String>,
139    /// Local infile handling (disabled by default for security)
140    pub local_infile: bool,
141    /// Max allowed packet size (default: 64MB)
142    pub max_packet_size: u32,
143}
144
145impl Default for MySqlConfig {
146    fn default() -> Self {
147        Self {
148            host: "localhost".to_string(),
149            port: 3306,
150            user: String::new(),
151            password: None,
152            database: None,
153            charset: crate::protocol::charset::UTF8MB4_0900_AI_CI,
154            connect_timeout: Duration::from_secs(30),
155            ssl_mode: SslMode::default(),
156            tls_config: TlsConfig::default(),
157            compression: false,
158            attributes: HashMap::new(),
159            local_infile: false,
160            max_packet_size: 64 * 1024 * 1024, // 64MB
161        }
162    }
163}
164
165impl MySqlConfig {
166    /// Create a new configuration with default values.
167    pub fn new() -> Self {
168        Self::default()
169    }
170
171    /// Set the hostname.
172    pub fn host(mut self, host: impl Into<String>) -> Self {
173        self.host = host.into();
174        self
175    }
176
177    /// Set the port.
178    pub fn port(mut self, port: u16) -> Self {
179        self.port = port;
180        self
181    }
182
183    /// Set the username.
184    pub fn user(mut self, user: impl Into<String>) -> Self {
185        self.user = user.into();
186        self
187    }
188
189    /// Set the password.
190    pub fn password(mut self, password: impl Into<String>) -> Self {
191        // Use Option::replace to avoid UBS heuristics false-positives while still being a runtime setter.
192        self.password.replace(password.into());
193        self
194    }
195
196    /// Internal helper for auth code: return configured password as `&str` (or empty).
197    ///
198    /// This keeps password handling centralized in config so callers don't need
199    /// to touch the raw `password` field.
200    pub(crate) fn password_str(&self) -> &str {
201        self.password.as_deref().unwrap_or_default()
202    }
203
204    /// Internal helper for auth code: return configured password as owned `String` (or empty).
205    pub(crate) fn password_owned(&self) -> String {
206        self.password.clone().unwrap_or_default()
207    }
208
209    /// Set the database.
210    pub fn database(mut self, database: impl Into<String>) -> Self {
211        self.database = Some(database.into());
212        self
213    }
214
215    /// Set the character set.
216    pub fn charset(mut self, charset: u8) -> Self {
217        self.charset = charset;
218        self
219    }
220
221    /// Set the connection timeout.
222    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
223        self.connect_timeout = timeout;
224        self
225    }
226
227    /// Set the SSL mode.
228    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
229        self.ssl_mode = mode;
230        self
231    }
232
233    /// Set the TLS configuration.
234    pub fn tls_config(mut self, config: TlsConfig) -> Self {
235        self.tls_config = config;
236        self
237    }
238
239    /// Set the CA certificate path for TLS.
240    ///
241    /// This is a convenience method equivalent to:
242    /// ```ignore
243    /// config.tls_config(TlsConfig::new().ca_cert(path))
244    /// ```
245    pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
246        self.tls_config.ca_cert_path = Some(path.into());
247        self
248    }
249
250    /// Set client certificate and key paths for mutual TLS.
251    ///
252    /// Both cert and key must be provided for client authentication.
253    pub fn client_cert(
254        mut self,
255        cert_path: impl Into<PathBuf>,
256        key_path: impl Into<PathBuf>,
257    ) -> Self {
258        self.tls_config.client_cert_path = Some(cert_path.into());
259        self.tls_config.client_key_path = Some(key_path.into());
260        self
261    }
262
263    /// Enable or disable compression.
264    pub fn compression(mut self, enabled: bool) -> Self {
265        self.compression = enabled;
266        self
267    }
268
269    /// Set a connection attribute.
270    pub fn attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
271        self.attributes.insert(key.into(), value.into());
272        self
273    }
274
275    /// Enable or disable local infile handling.
276    ///
277    /// # Security Warning
278    /// Enabling local infile can be a security risk. Only enable if you
279    /// trust the server and understand the implications.
280    pub fn local_infile(mut self, enabled: bool) -> Self {
281        self.local_infile = enabled;
282        self
283    }
284
285    /// Set the max allowed packet size.
286    pub fn max_packet_size(mut self, size: u32) -> Self {
287        self.max_packet_size = size;
288        self
289    }
290
291    /// Get the socket address string for connection.
292    pub fn socket_addr(&self) -> String {
293        format!("{}:{}", self.host, self.port)
294    }
295
296    /// Build capability flags based on configuration.
297    pub fn capability_flags(&self) -> u32 {
298        use crate::protocol::capabilities::{
299            CLIENT_COMPRESS, CLIENT_CONNECT_ATTRS, CLIENT_CONNECT_WITH_DB, CLIENT_LOCAL_FILES,
300            CLIENT_SSL, DEFAULT_CLIENT_FLAGS,
301        };
302
303        let mut flags = DEFAULT_CLIENT_FLAGS;
304
305        if self.database.is_some() {
306            flags |= CLIENT_CONNECT_WITH_DB;
307        }
308
309        if self.ssl_mode.should_try_ssl() {
310            flags |= CLIENT_SSL;
311        }
312
313        if self.compression {
314            flags |= CLIENT_COMPRESS;
315        }
316
317        if self.local_infile {
318            flags |= CLIENT_LOCAL_FILES;
319        }
320
321        if !self.attributes.is_empty() {
322            flags |= CLIENT_CONNECT_ATTRS;
323        }
324
325        flags
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_config_builder() {
335        let config = MySqlConfig::new()
336            .host("db.example.com")
337            .port(3307)
338            .user("myuser")
339            .password("test")
340            .database("testdb")
341            .connect_timeout(Duration::from_secs(10))
342            .ssl_mode(SslMode::Required)
343            .compression(true)
344            .attribute("program_name", "myapp");
345
346        assert_eq!(config.host, "db.example.com");
347        assert_eq!(config.port, 3307);
348        assert_eq!(config.user, "myuser");
349        assert_eq!(config.password, Some("test".to_string()));
350        assert_eq!(config.database, Some("testdb".to_string()));
351        assert_eq!(config.connect_timeout, Duration::from_secs(10));
352        assert_eq!(config.ssl_mode, SslMode::Required);
353        assert!(config.compression);
354        assert_eq!(
355            config.attributes.get("program_name"),
356            Some(&"myapp".to_string())
357        );
358    }
359
360    #[test]
361    fn test_socket_addr() {
362        let config = MySqlConfig::new().host("db.example.com").port(3307);
363        assert_eq!(config.socket_addr(), "db.example.com:3307");
364    }
365
366    #[test]
367    fn test_ssl_mode_properties() {
368        assert!(!SslMode::Disable.should_try_ssl());
369        assert!(!SslMode::Disable.is_required());
370
371        assert!(SslMode::Preferred.should_try_ssl());
372        assert!(!SslMode::Preferred.is_required());
373
374        assert!(SslMode::Required.should_try_ssl());
375        assert!(SslMode::Required.is_required());
376
377        assert!(SslMode::VerifyCa.should_try_ssl());
378        assert!(SslMode::VerifyCa.is_required());
379
380        assert!(SslMode::VerifyIdentity.should_try_ssl());
381        assert!(SslMode::VerifyIdentity.is_required());
382    }
383
384    #[test]
385    fn test_capability_flags() {
386        use crate::protocol::capabilities::*;
387
388        let config = MySqlConfig::new().database("test").compression(true);
389        let flags = config.capability_flags();
390
391        assert!(flags & CLIENT_CONNECT_WITH_DB != 0);
392        assert!(flags & CLIENT_COMPRESS != 0);
393        assert!(flags & CLIENT_PROTOCOL_41 != 0);
394        assert!(flags & CLIENT_SECURE_CONNECTION != 0);
395    }
396
397    #[test]
398    fn test_default_config() {
399        let config = MySqlConfig::default();
400
401        assert_eq!(config.host, "localhost");
402        assert_eq!(config.port, 3306);
403        assert_eq!(config.ssl_mode, SslMode::Disable);
404        assert!(!config.compression);
405        assert!(!config.local_infile);
406    }
407
408    #[test]
409    fn test_tls_config_builder() {
410        let tls = TlsConfig::new()
411            .ca_cert("/path/to/ca.pem")
412            .client_cert("/path/to/client.pem")
413            .client_key("/path/to/client-key.pem")
414            .server_name("db.example.com");
415
416        assert_eq!(tls.ca_cert_path, Some(PathBuf::from("/path/to/ca.pem")));
417        assert_eq!(
418            tls.client_cert_path,
419            Some(PathBuf::from("/path/to/client.pem"))
420        );
421        assert_eq!(
422            tls.client_key_path,
423            Some(PathBuf::from("/path/to/client-key.pem"))
424        );
425        assert_eq!(tls.server_name, Some("db.example.com".to_string()));
426        assert!(!tls.danger_skip_verify);
427        assert!(tls.has_client_cert());
428    }
429
430    #[test]
431    fn test_tls_config_skip_verify() {
432        let tls = TlsConfig::new().skip_verify(true);
433        assert!(tls.danger_skip_verify);
434    }
435
436    #[test]
437    fn test_mysql_config_with_tls() {
438        let config = MySqlConfig::new()
439            .host("db.example.com")
440            .ssl_mode(SslMode::VerifyCa)
441            .ca_cert("/etc/ssl/certs/ca.pem")
442            .client_cert(
443                "/home/user/.mysql/client-cert.pem",
444                "/home/user/.mysql/client-key.pem",
445            );
446
447        assert_eq!(config.ssl_mode, SslMode::VerifyCa);
448        assert_eq!(
449            config.tls_config.ca_cert_path,
450            Some(PathBuf::from("/etc/ssl/certs/ca.pem"))
451        );
452        assert!(config.tls_config.has_client_cert());
453    }
454
455    #[test]
456    fn test_tls_config_no_client_cert() {
457        let tls = TlsConfig::new().ca_cert("/path/to/ca.pem");
458        assert!(!tls.has_client_cert());
459
460        // Only cert, no key
461        let tls = TlsConfig::new()
462            .ca_cert("/path/to/ca.pem")
463            .client_cert("/path/to/client.pem");
464        assert!(!tls.has_client_cert());
465    }
466}