Skip to main content

sqlmodel_mysql/
tls.rs

1//! TLS/SSL support for MySQL connections.
2//!
3//! This module implements the TLS handshake for MySQL connections using rustls.
4//!
5//! # MySQL TLS Handshake Flow
6//!
7//! 1. Server sends initial handshake with `CLIENT_SSL` capability
8//! 2. If SSL is requested, client sends short SSL request packet:
9//!    - 4 bytes: capability flags (with `CLIENT_SSL`)
10//!    - 4 bytes: max packet size
11//!    - 1 byte: character set
12//!    - 23 bytes: reserved (zeros)
13//! 3. Client performs TLS handshake
14//! 4. Client sends full handshake response over TLS
15//! 5. Server sends auth result over TLS
16//!
17//! # Feature Flag
18//!
19//! TLS support requires the `tls` feature to be enabled:
20//!
21//! ```toml
22//! [dependencies]
23//! sqlmodel-mysql = { version = "0.1", features = ["tls"] }
24//! ```
25//!
26//! # Example
27//!
28//! ```rust,ignore
29//! use sqlmodel_mysql::{MySqlConfig, SslMode, TlsConfig};
30//!
31//! let config = MySqlConfig::new()
32//!     .host("db.example.com")
33//!     .ssl_mode(SslMode::VerifyCa)
34//!     .tls_config(TlsConfig::new()
35//!         .ca_cert("/etc/ssl/certs/ca.pem"));
36//!
37//! // Connection will use TLS after initial handshake
38//! let conn = MySqlConnection::connect(config)?;
39//! ```
40
41#![allow(clippy::cast_possible_truncation)]
42// The Error type is intentionally large to carry full context
43#![allow(clippy::result_large_err)]
44// Placeholder function takes stream by value since it just returns error
45#![allow(clippy::needless_pass_by_value)]
46
47use crate::config::{SslMode, TlsConfig};
48use crate::protocol::{PacketWriter, capabilities};
49use sqlmodel_core::Error;
50use sqlmodel_core::error::{ConnectionError, ConnectionErrorKind};
51
52#[cfg(feature = "tls")]
53use std::io::{Read, Write};
54#[cfg(feature = "tls")]
55use std::sync::Arc;
56
57/// Build an SSL request packet.
58///
59/// This packet is sent after receiving the server handshake and before
60/// performing the TLS handshake. It tells the server that we want to
61/// upgrade to TLS.
62///
63/// # Format
64///
65/// - capability_flags (4 bytes): Client capabilities with CLIENT_SSL set
66/// - max_packet_size (4 bytes): Maximum packet size
67/// - character_set (1 byte): Character set code
68/// - reserved (23 bytes): All zeros
69///
70/// Total: 32 bytes
71pub fn build_ssl_request_packet(
72    client_caps: u32,
73    max_packet_size: u32,
74    character_set: u8,
75    sequence_id: u8,
76) -> Vec<u8> {
77    let mut writer = PacketWriter::with_capacity(32);
78
79    // Capability flags with CLIENT_SSL
80    let caps_with_ssl = client_caps | capabilities::CLIENT_SSL;
81    writer.write_u32_le(caps_with_ssl);
82
83    // Max packet size
84    writer.write_u32_le(max_packet_size);
85
86    // Character set
87    writer.write_u8(character_set);
88
89    // Reserved (23 bytes of zeros)
90    writer.write_zeros(23);
91
92    writer.build_packet(sequence_id)
93}
94
95/// Check if the server supports SSL/TLS.
96///
97/// # Arguments
98///
99/// * `server_caps` - Server capability flags from handshake
100///
101/// # Returns
102///
103/// `true` if the server has the CLIENT_SSL capability flag set.
104pub const fn server_supports_ssl(server_caps: u32) -> bool {
105    server_caps & capabilities::CLIENT_SSL != 0
106}
107
108/// Validate SSL mode against server capabilities.
109///
110/// # Returns
111///
112/// - `Ok(true)` if SSL should be used
113/// - `Ok(false)` if SSL should not be used
114/// - `Err(_)` if SSL is required but not supported by server
115pub fn validate_ssl_mode(ssl_mode: SslMode, server_caps: u32) -> Result<bool, Error> {
116    let server_supports = server_supports_ssl(server_caps);
117
118    match ssl_mode {
119        SslMode::Disable => Ok(false),
120        SslMode::Preferred => Ok(server_supports),
121        SslMode::Required | SslMode::VerifyCa | SslMode::VerifyIdentity => {
122            if server_supports {
123                Ok(true)
124            } else {
125                Err(tls_error("SSL required but server does not support it"))
126            }
127        }
128    }
129}
130
131/// Validate TLS configuration for the given SSL mode.
132///
133/// # Arguments
134///
135/// * `ssl_mode` - The requested SSL mode
136/// * `tls_config` - The TLS configuration
137///
138/// # Returns
139///
140/// `Ok(())` if configuration is valid, `Err(_)` with details if not.
141pub fn validate_tls_config(ssl_mode: SslMode, tls_config: &TlsConfig) -> Result<(), Error> {
142    match ssl_mode {
143        SslMode::Disable | SslMode::Preferred | SslMode::Required => {
144            // No certificate validation required for these modes
145            Ok(())
146        }
147        SslMode::VerifyCa | SslMode::VerifyIdentity => {
148            // Need CA certificate for server verification
149            if tls_config.ca_cert_path.is_none() && !tls_config.danger_skip_verify {
150                return Err(tls_error(
151                    "CA certificate required for VerifyCa/VerifyIdentity mode. \
152                     Set ca_cert_path or danger_skip_verify.",
153                ));
154            }
155
156            // If client cert is provided, key must also be provided
157            if tls_config.client_cert_path.is_some() && tls_config.client_key_path.is_none() {
158                return Err(tls_error(
159                    "Client certificate provided without client key. \
160                     Both must be set for mutual TLS.",
161                ));
162            }
163
164            Ok(())
165        }
166    }
167}
168
169/// Create a TLS-related connection error.
170fn tls_error(message: impl Into<String>) -> Error {
171    Error::Connection(ConnectionError {
172        kind: ConnectionErrorKind::Ssl,
173        message: message.into(),
174        source: None,
175    })
176}
177
178// ============================================================================
179// TLS Stream Implementation (feature-gated)
180// ============================================================================
181
182/// TLS connection wrapper using rustls.
183///
184/// This struct wraps a TCP stream with TLS encryption using the rustls library.
185/// It implements `Read` and `Write` traits to provide transparent encryption.
186///
187/// # SSL Modes
188///
189/// The implementation supports all MySQL SSL modes:
190/// - `Disable`: No TLS (TlsStream is not used)
191/// - `Preferred`: TLS if available, no cert verification
192/// - `Required`: TLS required, no cert verification
193/// - `VerifyCa`: Verify server certificate with CA
194/// - `VerifyIdentity`: Verify server cert + hostname
195#[cfg(feature = "tls")]
196pub struct TlsStream<S: Read + Write> {
197    /// The rustls connection state
198    conn: rustls::ClientConnection,
199    /// The underlying TCP stream
200    stream: S,
201}
202
203#[cfg(feature = "tls")]
204impl<S: Read + Write> std::fmt::Debug for TlsStream<S> {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        f.debug_struct("TlsStream")
207            .field("protocol_version", &self.conn.protocol_version())
208            .field("is_handshaking", &self.conn.is_handshaking())
209            .finish_non_exhaustive()
210    }
211}
212
213#[cfg(feature = "tls")]
214impl<S: Read + Write> TlsStream<S> {
215    /// Create a new TLS stream and perform the handshake.
216    ///
217    /// # Arguments
218    ///
219    /// * `stream` - The underlying TCP stream (already connected)
220    /// * `tls_config` - TLS configuration (certificates, verification options)
221    /// * `server_name` - Server hostname for SNI and certificate verification
222    /// * `ssl_mode` - The SSL mode to use for verification
223    ///
224    /// # Returns
225    ///
226    /// A new `TlsStream` with encryption enabled, or an error if the handshake fails.
227    pub fn new(
228        mut stream: S,
229        tls_config: &TlsConfig,
230        server_name: &str,
231        ssl_mode: SslMode,
232    ) -> Result<Self, Error> {
233        // Build the rustls ClientConfig based on SSL mode
234        let config = build_client_config(tls_config, ssl_mode)?;
235
236        // Parse the server name for SNI
237        let sni_name = tls_config.server_name.as_deref().unwrap_or(server_name);
238
239        let server_name = sni_name
240            .to_string()
241            .try_into()
242            .map_err(|e| tls_error(format!("Invalid server name '{}': {}", sni_name, e)))?;
243
244        // Create the rustls client connection
245        let mut conn = rustls::ClientConnection::new(Arc::new(config), server_name)
246            .map_err(|e| tls_error(format!("Failed to create TLS connection: {}", e)))?;
247
248        // Perform the TLS handshake synchronously
249        // This writes/reads until the handshake completes
250        while conn.is_handshaking() {
251            // Write any pending TLS data to the stream
252            while conn.wants_write() {
253                conn.write_tls(&mut stream)
254                    .map_err(|e| tls_error(format!("TLS handshake write error: {}", e)))?;
255            }
256
257            // Read TLS data from the stream if needed
258            if conn.wants_read() {
259                conn.read_tls(&mut stream)
260                    .map_err(|e| tls_error(format!("TLS handshake read error: {}", e)))?;
261
262                // Process the TLS data
263                conn.process_new_packets()
264                    .map_err(|e| tls_error(format!("TLS handshake error: {}", e)))?;
265            }
266        }
267
268        Ok(TlsStream { conn, stream })
269    }
270
271    /// Get the negotiated protocol version.
272    pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
273        self.conn.protocol_version()
274    }
275
276    /// Get the negotiated cipher suite.
277    pub fn negotiated_cipher_suite(&self) -> Option<rustls::SupportedCipherSuite> {
278        self.conn.negotiated_cipher_suite()
279    }
280
281    /// Check if the connection is using TLS 1.3.
282    pub fn is_tls13(&self) -> bool {
283        self.conn.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
284    }
285}
286
287#[cfg(feature = "tls")]
288impl<S: Read + Write> Read for TlsStream<S> {
289    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
290        // Try to read decrypted data from the rustls buffer
291        loop {
292            // First, try to read from the plaintext buffer
293            match self.conn.reader().read(buf) {
294                Ok(n) if n > 0 => return Ok(n),
295                Ok(_) => {}
296                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
297                Err(e) => return Err(e),
298            }
299
300            // If no data available, read more TLS records
301            if self.conn.wants_read() {
302                let n = self.conn.read_tls(&mut self.stream)?;
303                if n == 0 {
304                    return Ok(0); // EOF
305                }
306
307                // Process the TLS records
308                self.conn
309                    .process_new_packets()
310                    .map_err(|e| std::io::Error::other(format!("TLS error: {}", e)))?;
311            } else {
312                return Ok(0);
313            }
314        }
315    }
316}
317
318#[cfg(feature = "tls")]
319impl<S: Read + Write> Write for TlsStream<S> {
320    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
321        // Write plaintext to the rustls buffer
322        let n = self.conn.writer().write(buf)?;
323
324        // Flush TLS data to the underlying stream
325        while self.conn.wants_write() {
326            self.conn.write_tls(&mut self.stream)?;
327        }
328
329        Ok(n)
330    }
331
332    fn flush(&mut self) -> std::io::Result<()> {
333        self.conn.writer().flush()?;
334        while self.conn.wants_write() {
335            self.conn.write_tls(&mut self.stream)?;
336        }
337        self.stream.flush()
338    }
339}
340
341/// Build a rustls ClientConfig based on TLS configuration and SSL mode.
342#[cfg(feature = "tls")]
343pub(crate) fn build_client_config(
344    tls_config: &TlsConfig,
345    ssl_mode: SslMode,
346) -> Result<rustls::ClientConfig, Error> {
347    // Get the default crypto provider (ring when that feature is enabled)
348    let provider = Arc::new(rustls::crypto::ring::default_provider());
349
350    match ssl_mode {
351        SslMode::Disable => {
352            // This shouldn't happen - TlsStream shouldn't be created for Disable mode
353            Err(tls_error("TlsStream created with SslMode::Disable"))
354        }
355
356        SslMode::Preferred | SslMode::Required => {
357            // No certificate verification - accept any server certificate
358            // This is common for MySQL deployments with self-signed certs
359            if tls_config.danger_skip_verify {
360                build_no_verify_config(&provider)
361            } else {
362                // Use webpki-roots for standard CA verification
363                build_webpki_config(&provider, tls_config)
364            }
365        }
366
367        SslMode::VerifyCa | SslMode::VerifyIdentity => {
368            if tls_config.danger_skip_verify {
369                // User explicitly wants to skip verification (dangerous!)
370                build_no_verify_config(&provider)
371            } else if let Some(ca_path) = &tls_config.ca_cert_path {
372                // Use custom CA certificate
373                build_custom_ca_config(&provider, tls_config, ca_path)
374            } else {
375                // Use webpki-roots (standard CA bundle)
376                build_webpki_config(&provider, tls_config)
377            }
378        }
379    }
380}
381
382/// Build a ClientConfig that skips certificate verification (dangerous!).
383#[cfg(feature = "tls")]
384fn build_no_verify_config(
385    provider: &Arc<rustls::crypto::CryptoProvider>,
386) -> Result<rustls::ClientConfig, Error> {
387    use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
388    use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
389    use rustls::{DigitallySignedStruct, Error as RustlsError, SignatureScheme};
390
391    /// A certificate verifier that accepts any certificate (insecure!).
392    #[derive(Debug)]
393    struct NoVerifier;
394
395    impl ServerCertVerifier for NoVerifier {
396        fn verify_server_cert(
397            &self,
398            _end_entity: &CertificateDer<'_>,
399            _intermediates: &[CertificateDer<'_>],
400            _server_name: &ServerName<'_>,
401            _ocsp_response: &[u8],
402            _now: UnixTime,
403        ) -> Result<ServerCertVerified, RustlsError> {
404            Ok(ServerCertVerified::assertion())
405        }
406
407        fn verify_tls12_signature(
408            &self,
409            _message: &[u8],
410            _cert: &CertificateDer<'_>,
411            _dss: &DigitallySignedStruct,
412        ) -> Result<HandshakeSignatureValid, RustlsError> {
413            Ok(HandshakeSignatureValid::assertion())
414        }
415
416        fn verify_tls13_signature(
417            &self,
418            _message: &[u8],
419            _cert: &CertificateDer<'_>,
420            _dss: &DigitallySignedStruct,
421        ) -> Result<HandshakeSignatureValid, RustlsError> {
422            Ok(HandshakeSignatureValid::assertion())
423        }
424
425        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
426            vec![
427                SignatureScheme::RSA_PKCS1_SHA256,
428                SignatureScheme::RSA_PKCS1_SHA384,
429                SignatureScheme::RSA_PKCS1_SHA512,
430                SignatureScheme::ECDSA_NISTP256_SHA256,
431                SignatureScheme::ECDSA_NISTP384_SHA384,
432                SignatureScheme::ECDSA_NISTP521_SHA512,
433                SignatureScheme::RSA_PSS_SHA256,
434                SignatureScheme::RSA_PSS_SHA384,
435                SignatureScheme::RSA_PSS_SHA512,
436                SignatureScheme::ED25519,
437            ]
438        }
439    }
440
441    let config = rustls::ClientConfig::builder_with_provider(provider.clone())
442        .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
443        .map_err(|e| tls_error(format!("Failed to set TLS versions: {}", e)))?
444        .dangerous()
445        .with_custom_certificate_verifier(Arc::new(NoVerifier))
446        .with_no_client_auth();
447
448    Ok(config)
449}
450
451/// Build a ClientConfig using webpki-roots CA bundle.
452#[cfg(feature = "tls")]
453fn build_webpki_config(
454    provider: &Arc<rustls::crypto::CryptoProvider>,
455    tls_config: &TlsConfig,
456) -> Result<rustls::ClientConfig, Error> {
457    use rustls::RootCertStore;
458
459    let mut root_store = RootCertStore::empty();
460    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
461
462    let builder = rustls::ClientConfig::builder_with_provider(provider.clone())
463        .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
464        .map_err(|e| tls_error(format!("Failed to set TLS versions: {}", e)))?
465        .with_root_certificates(root_store);
466
467    // Add client certificate if configured
468    let config = add_client_auth(builder, tls_config)?;
469
470    Ok(config)
471}
472
473/// Build a ClientConfig using a custom CA certificate.
474#[cfg(feature = "tls")]
475fn build_custom_ca_config(
476    provider: &Arc<rustls::crypto::CryptoProvider>,
477    tls_config: &TlsConfig,
478    ca_path: &std::path::Path,
479) -> Result<rustls::ClientConfig, Error> {
480    use rustls::RootCertStore;
481    use std::fs::File;
482    use std::io::BufReader;
483
484    // Load CA certificate(s)
485    let ca_file = File::open(ca_path).map_err(|e| {
486        tls_error(format!(
487            "Failed to open CA certificate '{}': {}",
488            ca_path.display(),
489            e
490        ))
491    })?;
492    let mut reader = BufReader::new(ca_file);
493
494    let certs = rustls_pemfile::certs(&mut reader)
495        .collect::<Result<Vec<_>, _>>()
496        .map_err(|e| tls_error(format!("Failed to parse CA certificate: {}", e)))?;
497
498    if certs.is_empty() {
499        return Err(tls_error(format!(
500            "No certificates found in CA file '{}'",
501            ca_path.display()
502        )));
503    }
504
505    let mut root_store = RootCertStore::empty();
506    for cert in certs {
507        root_store
508            .add(cert)
509            .map_err(|e| tls_error(format!("Failed to add CA certificate: {}", e)))?;
510    }
511
512    let builder = rustls::ClientConfig::builder_with_provider(provider.clone())
513        .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
514        .map_err(|e| tls_error(format!("Failed to set TLS versions: {}", e)))?
515        .with_root_certificates(root_store);
516
517    // Add client certificate if configured
518    let config = add_client_auth(builder, tls_config)?;
519
520    Ok(config)
521}
522
523/// Add client authentication if configured.
524#[cfg(feature = "tls")]
525fn add_client_auth(
526    builder: rustls::ConfigBuilder<rustls::ClientConfig, rustls::client::WantsClientCert>,
527    tls_config: &TlsConfig,
528) -> Result<rustls::ClientConfig, Error> {
529    use std::fs::File;
530    use std::io::BufReader;
531
532    if let (Some(cert_path), Some(key_path)) =
533        (&tls_config.client_cert_path, &tls_config.client_key_path)
534    {
535        // Load client certificate
536        let cert_file = File::open(cert_path).map_err(|e| {
537            tls_error(format!(
538                "Failed to open client cert '{}': {}",
539                cert_path.display(),
540                e
541            ))
542        })?;
543        let mut cert_reader = BufReader::new(cert_file);
544
545        let certs = rustls_pemfile::certs(&mut cert_reader)
546            .collect::<Result<Vec<_>, _>>()
547            .map_err(|e| tls_error(format!("Failed to parse client certificate: {}", e)))?;
548
549        if certs.is_empty() {
550            return Err(tls_error(format!(
551                "No certificates found in client cert file '{}'",
552                cert_path.display()
553            )));
554        }
555
556        // Load client private key
557        let key_file = File::open(key_path).map_err(|e| {
558            tls_error(format!(
559                "Failed to open client key '{}': {}",
560                key_path.display(),
561                e
562            ))
563        })?;
564        let mut key_reader = BufReader::new(key_file);
565
566        let key = rustls_pemfile::private_key(&mut key_reader)
567            .map_err(|e| tls_error(format!("Failed to parse client key: {}", e)))?
568            .ok_or_else(|| {
569                tls_error(format!("No private key found in '{}'", key_path.display()))
570            })?;
571
572        builder
573            .with_client_auth_cert(certs, key)
574            .map_err(|e| tls_error(format!("Failed to configure client auth: {}", e)))
575    } else {
576        Ok(builder.with_no_client_auth())
577    }
578}
579
580// ============================================================================
581// Stand-in types when TLS feature is disabled
582// ============================================================================
583
584/// TLS connection wrapper when `tls` feature is disabled.
585#[cfg(not(feature = "tls"))]
586#[derive(Debug)]
587pub struct TlsStream<S> {
588    /// The underlying stream
589    #[allow(dead_code)]
590    inner: S,
591}
592
593#[cfg(not(feature = "tls"))]
594impl<S> TlsStream<S> {
595    /// Create a new TLS stream.
596    ///
597    /// # Note
598    ///
599    /// This always returns an error when the `tls` feature is disabled.
600    /// Enable the `tls` feature in Cargo.toml to use TLS connections.
601    #[allow(unused_variables)]
602    pub fn new(
603        stream: S,
604        tls_config: &TlsConfig,
605        server_name: &str,
606        ssl_mode: SslMode,
607    ) -> Result<Self, Error> {
608        Err(tls_error(
609            "TLS support requires the 'tls' feature. \
610             Add `sqlmodel-mysql = { features = [\"tls\"] }` to your Cargo.toml.",
611        ))
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use crate::protocol::charset;
619
620    #[test]
621    fn test_build_ssl_request_packet() {
622        let packet = build_ssl_request_packet(
623            capabilities::DEFAULT_CLIENT_FLAGS,
624            16 * 1024 * 1024, // 16MB
625            charset::UTF8MB4_0900_AI_CI,
626            1,
627        );
628
629        // Header (4) + payload (32) = 36 bytes
630        assert_eq!(packet.len(), 36);
631
632        // Check header
633        assert_eq!(packet[0], 32); // payload length low byte
634        assert_eq!(packet[1], 0); // payload length mid byte
635        assert_eq!(packet[2], 0); // payload length high byte
636        assert_eq!(packet[3], 1); // sequence id
637
638        // Check that CLIENT_SSL is set in the capability flags
639        let caps = u32::from_le_bytes([packet[4], packet[5], packet[6], packet[7]]);
640        assert!(caps & capabilities::CLIENT_SSL != 0);
641    }
642
643    #[test]
644    fn test_server_supports_ssl() {
645        assert!(server_supports_ssl(capabilities::CLIENT_SSL));
646        assert!(server_supports_ssl(
647            capabilities::CLIENT_SSL | capabilities::CLIENT_PROTOCOL_41
648        ));
649        assert!(!server_supports_ssl(0));
650        assert!(!server_supports_ssl(capabilities::CLIENT_PROTOCOL_41));
651    }
652
653    #[test]
654    fn test_validate_ssl_mode_disable() {
655        assert!(!validate_ssl_mode(SslMode::Disable, 0).unwrap());
656        assert!(!validate_ssl_mode(SslMode::Disable, capabilities::CLIENT_SSL).unwrap());
657    }
658
659    #[test]
660    fn test_validate_ssl_mode_preferred() {
661        // Preferred without SSL support -> no SSL
662        assert!(!validate_ssl_mode(SslMode::Preferred, 0).unwrap());
663        // Preferred with SSL support -> use SSL
664        assert!(validate_ssl_mode(SslMode::Preferred, capabilities::CLIENT_SSL).unwrap());
665    }
666
667    #[test]
668    fn test_validate_ssl_mode_required() {
669        // Required without SSL support -> error
670        assert!(validate_ssl_mode(SslMode::Required, 0).is_err());
671        // Required with SSL support -> use SSL
672        assert!(validate_ssl_mode(SslMode::Required, capabilities::CLIENT_SSL).unwrap());
673    }
674
675    #[test]
676    fn test_validate_ssl_mode_verify() {
677        // VerifyCa/VerifyIdentity without SSL support -> error
678        assert!(validate_ssl_mode(SslMode::VerifyCa, 0).is_err());
679        assert!(validate_ssl_mode(SslMode::VerifyIdentity, 0).is_err());
680
681        // With SSL support -> use SSL
682        assert!(validate_ssl_mode(SslMode::VerifyCa, capabilities::CLIENT_SSL).unwrap());
683        assert!(validate_ssl_mode(SslMode::VerifyIdentity, capabilities::CLIENT_SSL).unwrap());
684    }
685
686    #[test]
687    fn test_validate_tls_config_basic_modes() {
688        let config = TlsConfig::new();
689
690        // Basic modes don't require CA cert
691        assert!(validate_tls_config(SslMode::Disable, &config).is_ok());
692        assert!(validate_tls_config(SslMode::Preferred, &config).is_ok());
693        assert!(validate_tls_config(SslMode::Required, &config).is_ok());
694    }
695
696    #[test]
697    fn test_validate_tls_config_verify_modes() {
698        // VerifyCa without CA cert -> error
699        let config = TlsConfig::new();
700        assert!(validate_tls_config(SslMode::VerifyCa, &config).is_err());
701        assert!(validate_tls_config(SslMode::VerifyIdentity, &config).is_err());
702
703        // With CA cert -> ok
704        let config = TlsConfig::new().ca_cert("/path/to/ca.pem");
705        assert!(validate_tls_config(SslMode::VerifyCa, &config).is_ok());
706        assert!(validate_tls_config(SslMode::VerifyIdentity, &config).is_ok());
707
708        // With skip_verify -> ok (dangerous but valid config)
709        let config = TlsConfig::new().skip_verify(true);
710        assert!(validate_tls_config(SslMode::VerifyCa, &config).is_ok());
711    }
712
713    #[test]
714    fn test_validate_tls_config_client_cert() {
715        // Client cert without key -> error
716        let config = TlsConfig::new()
717            .ca_cert("/path/to/ca.pem")
718            .client_cert("/path/to/client.pem");
719        assert!(validate_tls_config(SslMode::VerifyCa, &config).is_err());
720
721        // Client cert with key -> ok
722        let config = TlsConfig::new()
723            .ca_cert("/path/to/ca.pem")
724            .client_cert("/path/to/client.pem")
725            .client_key("/path/to/client-key.pem");
726        assert!(validate_tls_config(SslMode::VerifyCa, &config).is_ok());
727    }
728}