Skip to main content

solti_tls/
client.rs

1//! Client-side TLS configuration.
2
3use std::path::PathBuf;
4
5use crate::{PemSource, TlsError};
6
7/// Client-side TLS configuration.
8///
9/// _Construct via [`ClientTlsConfig::builder`]_.
10#[derive(Debug, Clone)]
11pub struct ClientTlsConfig {
12    /// Trusted CA bundle for verifying the server's certificate.
13    pub ca: PemSource,
14    /// Client certificate chain (`None` = no client cert).
15    pub client_cert: Option<PemSource>,
16    /// Client private key.
17    pub client_key: Option<PemSource>,
18    /// ALPN protocol list, in preference order.
19    pub alpn: Vec<Vec<u8>>,
20}
21
22impl ClientTlsConfig {
23    /// Start a new builder.
24    pub fn builder() -> ClientTlsConfigBuilder {
25        ClientTlsConfigBuilder::default()
26    }
27
28    /// Build a [`rustls::ClientConfig`].
29    ///
30    /// Reads PEM sources, builds a `RootCertStore` from CA, optionally adds the client cert+key for mTLS, and applies ALPN.
31    /// All I/O surfaces here.
32    ///
33    /// Auto-installs the `ring` `CryptoProvider` if no provider is set process-wide.
34    pub fn into_rustls_config(self) -> Result<rustls::ClientConfig, TlsError> {
35        crate::ensure_default_provider();
36
37        let ca_bytes = self.ca.read()?;
38        let ca_certs = crate::load_certs_from_pem(ca_bytes.as_slice())?;
39        let mut roots = rustls::RootCertStore::empty();
40        for ca in ca_certs {
41            roots.add(ca)?;
42        }
43
44        let builder = rustls::ClientConfig::builder().with_root_certificates(roots);
45
46        let mut config = match (self.client_cert, self.client_key) {
47            (Some(cert_src), Some(key_src)) => {
48                let cert_bytes = cert_src.read()?;
49                let key_bytes = key_src.read()?;
50                let certs = crate::load_certs_from_pem(cert_bytes.as_slice())?;
51                let key = crate::load_key_from_pem(key_bytes.as_slice())?;
52                builder.with_client_auth_cert(certs, key)?
53            }
54            _ => builder.with_no_client_auth(),
55        };
56
57        config.alpn_protocols = self.alpn;
58        Ok(config)
59    }
60}
61
62/// Incremental builder for [`ClientTlsConfig`].
63#[derive(Debug, Default, Clone)]
64pub struct ClientTlsConfigBuilder {
65    client_cert: Option<PemSource>,
66    client_key: Option<PemSource>,
67    ca: Option<PemSource>,
68    alpn: Vec<Vec<u8>>,
69}
70
71impl ClientTlsConfigBuilder {
72    /// Set the trusted CA bundle (verifies the server's certificate).
73    pub fn ca(mut self, src: PemSource) -> Self {
74        self.ca = Some(src);
75        self
76    }
77
78    /// Set the client certificate chain.
79    pub fn client_cert(mut self, src: PemSource) -> Self {
80        self.client_cert = Some(src);
81        self
82    }
83
84    /// Set the client private key.
85    pub fn client_key(mut self, src: PemSource) -> Self {
86        self.client_key = Some(src);
87        self
88    }
89
90    /// Set the ALPN protocol list, in preference order.
91    ///
92    /// Pass `["h2"]` for gRPC-only, `["h2", "http/1.1"]` for HTTP (default is empty).
93    pub fn with_alpn<I, S>(mut self, protocols: I) -> Self
94    where
95        I: IntoIterator<Item = S>,
96        S: Into<Vec<u8>>,
97    {
98        self.alpn = protocols.into_iter().map(Into::into).collect();
99        self
100    }
101
102    /// Convenience: trusted CA bundle from a file path.
103    pub fn ca_pem_file(self, path: impl Into<PathBuf>) -> Self {
104        self.ca(PemSource::Path(path.into()))
105    }
106
107    /// Convenience: trusted CA bundle from in-memory bytes.
108    pub fn ca_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
109        self.ca(PemSource::Bytes(bytes.into()))
110    }
111
112    /// Convenience: client cert chain from a file path.
113    pub fn client_cert_pem_file(self, path: impl Into<PathBuf>) -> Self {
114        self.client_cert(PemSource::Path(path.into()))
115    }
116
117    /// Convenience: client cert chain from in-memory bytes.
118    pub fn client_cert_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
119        self.client_cert(PemSource::Bytes(bytes.into()))
120    }
121
122    /// Convenience: client private key from a file path.
123    pub fn client_key_pem_file(self, path: impl Into<PathBuf>) -> Self {
124        self.client_key(PemSource::Path(path.into()))
125    }
126
127    /// Convenience: client private key from in-memory bytes.
128    pub fn client_key_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
129        self.client_key(PemSource::Bytes(bytes.into()))
130    }
131
132    /// Build.
133    pub fn build(self) -> Result<ClientTlsConfig, TlsError> {
134        let ca = self.ca.ok_or(TlsError::MissingField("ca"))?;
135        match (&self.client_cert, &self.client_key) {
136            (Some(_), None) => return Err(TlsError::MissingField("client_key")),
137            (None, Some(_)) => return Err(TlsError::MissingField("client_cert")),
138            _ => {}
139        }
140        Ok(ClientTlsConfig {
141            ca,
142            client_cert: self.client_cert,
143            client_key: self.client_key,
144            alpn: self.alpn,
145        })
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::PemSource;
153
154    #[test]
155    fn builder_returns_config_with_ca() {
156        let cfg = ClientTlsConfig::builder()
157            .ca_pem_bytes(b"--FAKE CA--".to_vec())
158            .build()
159            .unwrap();
160        assert!(matches!(cfg.ca, PemSource::Bytes(_)));
161        assert!(cfg.client_cert.is_none());
162        assert!(cfg.client_key.is_none());
163        assert!(cfg.alpn.is_empty());
164    }
165
166    #[test]
167    fn builder_errors_when_ca_is_missing() {
168        let err = ClientTlsConfig::builder().build().unwrap_err();
169        assert!(matches!(err, TlsError::MissingField("ca")));
170    }
171
172    #[test]
173    fn with_client_cert_pair_enables_mtls() {
174        let cfg = ClientTlsConfig::builder()
175            .ca_pem_bytes(vec![1])
176            .client_cert_pem_bytes(b"cert".to_vec())
177            .client_key_pem_bytes(b"key".to_vec())
178            .build()
179            .unwrap();
180        assert!(matches!(cfg.client_cert, Some(PemSource::Bytes(_))));
181        assert!(matches!(cfg.client_key, Some(PemSource::Bytes(_))));
182    }
183
184    #[test]
185    fn builder_errors_when_client_cert_without_key() {
186        let err = ClientTlsConfig::builder()
187            .ca_pem_bytes(vec![1])
188            .client_cert_pem_bytes(b"cert".to_vec())
189            .build()
190            .unwrap_err();
191        assert!(matches!(err, TlsError::MissingField("client_key")));
192    }
193
194    #[test]
195    fn builder_errors_when_client_key_without_cert() {
196        let err = ClientTlsConfig::builder()
197            .ca_pem_bytes(vec![1])
198            .client_key_pem_bytes(b"key".to_vec())
199            .build()
200            .unwrap_err();
201        assert!(matches!(err, TlsError::MissingField("client_cert")));
202    }
203
204    #[test]
205    fn with_alpn_sets_protocols() {
206        let cfg = ClientTlsConfig::builder()
207            .ca_pem_bytes(vec![1])
208            .with_alpn(["h2", "http/1.1"])
209            .build()
210            .unwrap();
211        assert_eq!(cfg.alpn, vec![b"h2".to_vec(), b"http/1.1".to_vec()]);
212    }
213
214    fn rcgen_self_signed() -> (Vec<u8>, Vec<u8>) {
215        let b = rcgen::generate_simple_self_signed(vec!["example.com".into()]).unwrap();
216        (
217            b.cert.pem().into_bytes(),
218            b.signing_key.serialize_pem().into_bytes(),
219        )
220    }
221
222    #[test]
223    fn into_rustls_config_succeeds_with_ca_only() {
224        let (ca, _) = rcgen_self_signed();
225        let cfg = ClientTlsConfig::builder().ca_pem_bytes(ca).build().unwrap();
226        let _rustls = cfg.into_rustls_config().unwrap();
227    }
228
229    #[test]
230    fn into_rustls_config_succeeds_with_mtls_client_cert() {
231        let (ca, _) = rcgen_self_signed();
232        let (cert, key) = rcgen_self_signed();
233        let cfg = ClientTlsConfig::builder()
234            .ca_pem_bytes(ca)
235            .client_cert_pem_bytes(cert)
236            .client_key_pem_bytes(key)
237            .build()
238            .unwrap();
239        let _rustls = cfg.into_rustls_config().unwrap();
240    }
241
242    #[test]
243    fn into_rustls_config_propagates_alpn_to_rustls() {
244        let (ca, _) = rcgen_self_signed();
245        let cfg = ClientTlsConfig::builder()
246            .ca_pem_bytes(ca)
247            .with_alpn(["h2"])
248            .build()
249            .unwrap();
250        let rustls = cfg.into_rustls_config().unwrap();
251        assert_eq!(rustls.alpn_protocols, vec![b"h2".to_vec()]);
252    }
253}