1use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Default)]
15pub struct TlsConfig {
16 pub ca_cert_path: Option<PathBuf>,
19
20 pub client_cert_path: Option<PathBuf>,
23
24 pub client_key_path: Option<PathBuf>,
27
28 pub danger_skip_verify: bool,
35
36 pub server_name: Option<String>,
39}
40
41impl TlsConfig {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
49 self.ca_cert_path = Some(path.into());
50 self
51 }
52
53 pub fn client_cert(mut self, path: impl Into<PathBuf>) -> Self {
55 self.client_cert_path = Some(path.into());
56 self
57 }
58
59 pub fn client_key(mut self, path: impl Into<PathBuf>) -> Self {
61 self.client_key_path = Some(path.into());
62 self
63 }
64
65 pub fn skip_verify(mut self, skip: bool) -> Self {
67 self.danger_skip_verify = skip;
68 self
69 }
70
71 pub fn server_name(mut self, name: impl Into<String>) -> Self {
73 self.server_name = Some(name.into());
74 self
75 }
76
77 pub fn has_client_cert(&self) -> bool {
79 self.client_cert_path.is_some() && self.client_key_path.is_some()
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
85pub enum SslMode {
86 #[default]
88 Disable,
89 Preferred,
91 Required,
93 VerifyCa,
95 VerifyIdentity,
97}
98
99impl SslMode {
100 pub const fn should_try_ssl(self) -> bool {
102 !matches!(self, SslMode::Disable)
103 }
104
105 pub const fn is_required(self) -> bool {
107 matches!(
108 self,
109 SslMode::Required | SslMode::VerifyCa | SslMode::VerifyIdentity
110 )
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct MySqlConfig {
117 pub host: String,
119 pub port: u16,
121 pub user: String,
123 pub password: Option<String>,
125 pub database: Option<String>,
127 pub charset: u8,
129 pub connect_timeout: Duration,
131 pub ssl_mode: SslMode,
133 pub tls_config: TlsConfig,
135 pub compression: bool,
137 pub attributes: HashMap<String, String>,
139 pub local_infile: bool,
141 pub max_packet_size: u32,
143}
144
145impl Default for MySqlConfig {
146 fn default() -> Self {
147 Self {
148 host: "localhost".to_string(),
149 port: 3306,
150 user: String::new(),
151 password: None,
152 database: None,
153 charset: crate::protocol::charset::UTF8MB4_0900_AI_CI,
154 connect_timeout: Duration::from_secs(30),
155 ssl_mode: SslMode::default(),
156 tls_config: TlsConfig::default(),
157 compression: false,
158 attributes: HashMap::new(),
159 local_infile: false,
160 max_packet_size: 64 * 1024 * 1024, }
162 }
163}
164
165impl MySqlConfig {
166 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn host(mut self, host: impl Into<String>) -> Self {
173 self.host = host.into();
174 self
175 }
176
177 pub fn port(mut self, port: u16) -> Self {
179 self.port = port;
180 self
181 }
182
183 pub fn user(mut self, user: impl Into<String>) -> Self {
185 self.user = user.into();
186 self
187 }
188
189 pub fn password(mut self, password: impl Into<String>) -> Self {
191 self.password = Some(password.into());
192 self
193 }
194
195 pub fn database(mut self, database: impl Into<String>) -> Self {
197 self.database = Some(database.into());
198 self
199 }
200
201 pub fn charset(mut self, charset: u8) -> Self {
203 self.charset = charset;
204 self
205 }
206
207 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
209 self.connect_timeout = timeout;
210 self
211 }
212
213 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
215 self.ssl_mode = mode;
216 self
217 }
218
219 pub fn tls_config(mut self, config: TlsConfig) -> Self {
221 self.tls_config = config;
222 self
223 }
224
225 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
232 self.tls_config.ca_cert_path = Some(path.into());
233 self
234 }
235
236 pub fn client_cert(
240 mut self,
241 cert_path: impl Into<PathBuf>,
242 key_path: impl Into<PathBuf>,
243 ) -> Self {
244 self.tls_config.client_cert_path = Some(cert_path.into());
245 self.tls_config.client_key_path = Some(key_path.into());
246 self
247 }
248
249 pub fn compression(mut self, enabled: bool) -> Self {
251 self.compression = enabled;
252 self
253 }
254
255 pub fn attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
257 self.attributes.insert(key.into(), value.into());
258 self
259 }
260
261 pub fn local_infile(mut self, enabled: bool) -> Self {
267 self.local_infile = enabled;
268 self
269 }
270
271 pub fn max_packet_size(mut self, size: u32) -> Self {
273 self.max_packet_size = size;
274 self
275 }
276
277 pub fn socket_addr(&self) -> String {
279 format!("{}:{}", self.host, self.port)
280 }
281
282 pub fn capability_flags(&self) -> u32 {
284 use crate::protocol::capabilities::{
285 CLIENT_COMPRESS, CLIENT_CONNECT_ATTRS, CLIENT_CONNECT_WITH_DB, CLIENT_LOCAL_FILES,
286 CLIENT_SSL, DEFAULT_CLIENT_FLAGS,
287 };
288
289 let mut flags = DEFAULT_CLIENT_FLAGS;
290
291 if self.database.is_some() {
292 flags |= CLIENT_CONNECT_WITH_DB;
293 }
294
295 if self.ssl_mode.should_try_ssl() {
296 flags |= CLIENT_SSL;
297 }
298
299 if self.compression {
300 flags |= CLIENT_COMPRESS;
301 }
302
303 if self.local_infile {
304 flags |= CLIENT_LOCAL_FILES;
305 }
306
307 if !self.attributes.is_empty() {
308 flags |= CLIENT_CONNECT_ATTRS;
309 }
310
311 flags
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_config_builder() {
321 let config = MySqlConfig::new()
322 .host("db.example.com")
323 .port(3307)
324 .user("myuser")
325 .password("secret")
326 .database("testdb")
327 .connect_timeout(Duration::from_secs(10))
328 .ssl_mode(SslMode::Required)
329 .compression(true)
330 .attribute("program_name", "myapp");
331
332 assert_eq!(config.host, "db.example.com");
333 assert_eq!(config.port, 3307);
334 assert_eq!(config.user, "myuser");
335 assert_eq!(config.password, Some("secret".to_string()));
336 assert_eq!(config.database, Some("testdb".to_string()));
337 assert_eq!(config.connect_timeout, Duration::from_secs(10));
338 assert_eq!(config.ssl_mode, SslMode::Required);
339 assert!(config.compression);
340 assert_eq!(
341 config.attributes.get("program_name"),
342 Some(&"myapp".to_string())
343 );
344 }
345
346 #[test]
347 fn test_socket_addr() {
348 let config = MySqlConfig::new().host("db.example.com").port(3307);
349 assert_eq!(config.socket_addr(), "db.example.com:3307");
350 }
351
352 #[test]
353 fn test_ssl_mode_properties() {
354 assert!(!SslMode::Disable.should_try_ssl());
355 assert!(!SslMode::Disable.is_required());
356
357 assert!(SslMode::Preferred.should_try_ssl());
358 assert!(!SslMode::Preferred.is_required());
359
360 assert!(SslMode::Required.should_try_ssl());
361 assert!(SslMode::Required.is_required());
362
363 assert!(SslMode::VerifyCa.should_try_ssl());
364 assert!(SslMode::VerifyCa.is_required());
365
366 assert!(SslMode::VerifyIdentity.should_try_ssl());
367 assert!(SslMode::VerifyIdentity.is_required());
368 }
369
370 #[test]
371 fn test_capability_flags() {
372 use crate::protocol::capabilities::*;
373
374 let config = MySqlConfig::new().database("test").compression(true);
375 let flags = config.capability_flags();
376
377 assert!(flags & CLIENT_CONNECT_WITH_DB != 0);
378 assert!(flags & CLIENT_COMPRESS != 0);
379 assert!(flags & CLIENT_PROTOCOL_41 != 0);
380 assert!(flags & CLIENT_SECURE_CONNECTION != 0);
381 }
382
383 #[test]
384 fn test_default_config() {
385 let config = MySqlConfig::default();
386
387 assert_eq!(config.host, "localhost");
388 assert_eq!(config.port, 3306);
389 assert_eq!(config.ssl_mode, SslMode::Disable);
390 assert!(!config.compression);
391 assert!(!config.local_infile);
392 }
393
394 #[test]
395 fn test_tls_config_builder() {
396 let tls = TlsConfig::new()
397 .ca_cert("/path/to/ca.pem")
398 .client_cert("/path/to/client.pem")
399 .client_key("/path/to/client-key.pem")
400 .server_name("db.example.com");
401
402 assert_eq!(tls.ca_cert_path, Some(PathBuf::from("/path/to/ca.pem")));
403 assert_eq!(
404 tls.client_cert_path,
405 Some(PathBuf::from("/path/to/client.pem"))
406 );
407 assert_eq!(
408 tls.client_key_path,
409 Some(PathBuf::from("/path/to/client-key.pem"))
410 );
411 assert_eq!(tls.server_name, Some("db.example.com".to_string()));
412 assert!(!tls.danger_skip_verify);
413 assert!(tls.has_client_cert());
414 }
415
416 #[test]
417 fn test_tls_config_skip_verify() {
418 let tls = TlsConfig::new().skip_verify(true);
419 assert!(tls.danger_skip_verify);
420 }
421
422 #[test]
423 fn test_mysql_config_with_tls() {
424 let config = MySqlConfig::new()
425 .host("db.example.com")
426 .ssl_mode(SslMode::VerifyCa)
427 .ca_cert("/etc/ssl/certs/ca.pem")
428 .client_cert(
429 "/home/user/.mysql/client-cert.pem",
430 "/home/user/.mysql/client-key.pem",
431 );
432
433 assert_eq!(config.ssl_mode, SslMode::VerifyCa);
434 assert_eq!(
435 config.tls_config.ca_cert_path,
436 Some(PathBuf::from("/etc/ssl/certs/ca.pem"))
437 );
438 assert!(config.tls_config.has_client_cert());
439 }
440
441 #[test]
442 fn test_tls_config_no_client_cert() {
443 let tls = TlsConfig::new().ca_cert("/path/to/ca.pem");
444 assert!(!tls.has_client_cert());
445
446 let tls = TlsConfig::new()
448 .ca_cert("/path/to/ca.pem")
449 .client_cert("/path/to/client.pem");
450 assert!(!tls.has_client_cert());
451 }
452}