1use std::path::PathBuf;
4
5use crate::{PemSource, TlsError};
6
7#[derive(Debug, Clone)]
11pub struct ClientTlsConfig {
12 pub ca: PemSource,
14 pub client_cert: Option<PemSource>,
16 pub client_key: Option<PemSource>,
18 pub alpn: Vec<Vec<u8>>,
20}
21
22impl ClientTlsConfig {
23 pub fn builder() -> ClientTlsConfigBuilder {
25 ClientTlsConfigBuilder::default()
26 }
27
28 pub fn into_rustls_config(self) -> Result<rustls::ClientConfig, TlsError> {
35 crate::ensure_default_provider();
36
37 let ca_bytes = self.ca.read()?;
38 let ca_certs = crate::load_certs_from_pem(ca_bytes.as_slice())?;
39 let mut roots = rustls::RootCertStore::empty();
40 for ca in ca_certs {
41 roots.add(ca)?;
42 }
43
44 let builder = rustls::ClientConfig::builder().with_root_certificates(roots);
45
46 let mut config = match (self.client_cert, self.client_key) {
47 (Some(cert_src), Some(key_src)) => {
48 let cert_bytes = cert_src.read()?;
49 let key_bytes = key_src.read()?;
50 let certs = crate::load_certs_from_pem(cert_bytes.as_slice())?;
51 let key = crate::load_key_from_pem(key_bytes.as_slice())?;
52 builder.with_client_auth_cert(certs, key)?
53 }
54 _ => builder.with_no_client_auth(),
55 };
56
57 config.alpn_protocols = self.alpn;
58 Ok(config)
59 }
60}
61
62#[derive(Debug, Default, Clone)]
64pub struct ClientTlsConfigBuilder {
65 client_cert: Option<PemSource>,
66 client_key: Option<PemSource>,
67 ca: Option<PemSource>,
68 alpn: Vec<Vec<u8>>,
69}
70
71impl ClientTlsConfigBuilder {
72 pub fn ca(mut self, src: PemSource) -> Self {
74 self.ca = Some(src);
75 self
76 }
77
78 pub fn client_cert(mut self, src: PemSource) -> Self {
80 self.client_cert = Some(src);
81 self
82 }
83
84 pub fn client_key(mut self, src: PemSource) -> Self {
86 self.client_key = Some(src);
87 self
88 }
89
90 pub fn with_alpn<I, S>(mut self, protocols: I) -> Self
94 where
95 I: IntoIterator<Item = S>,
96 S: Into<Vec<u8>>,
97 {
98 self.alpn = protocols.into_iter().map(Into::into).collect();
99 self
100 }
101
102 pub fn ca_pem_file(self, path: impl Into<PathBuf>) -> Self {
104 self.ca(PemSource::Path(path.into()))
105 }
106
107 pub fn ca_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
109 self.ca(PemSource::Bytes(bytes.into()))
110 }
111
112 pub fn client_cert_pem_file(self, path: impl Into<PathBuf>) -> Self {
114 self.client_cert(PemSource::Path(path.into()))
115 }
116
117 pub fn client_cert_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
119 self.client_cert(PemSource::Bytes(bytes.into()))
120 }
121
122 pub fn client_key_pem_file(self, path: impl Into<PathBuf>) -> Self {
124 self.client_key(PemSource::Path(path.into()))
125 }
126
127 pub fn client_key_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
129 self.client_key(PemSource::Bytes(bytes.into()))
130 }
131
132 pub fn build(self) -> Result<ClientTlsConfig, TlsError> {
134 let ca = self.ca.ok_or(TlsError::MissingField("ca"))?;
135 match (&self.client_cert, &self.client_key) {
136 (Some(_), None) => return Err(TlsError::MissingField("client_key")),
137 (None, Some(_)) => return Err(TlsError::MissingField("client_cert")),
138 _ => {}
139 }
140 Ok(ClientTlsConfig {
141 ca,
142 client_cert: self.client_cert,
143 client_key: self.client_key,
144 alpn: self.alpn,
145 })
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use crate::PemSource;
153
154 #[test]
155 fn builder_returns_config_with_ca() {
156 let cfg = ClientTlsConfig::builder()
157 .ca_pem_bytes(b"--FAKE CA--".to_vec())
158 .build()
159 .unwrap();
160 assert!(matches!(cfg.ca, PemSource::Bytes(_)));
161 assert!(cfg.client_cert.is_none());
162 assert!(cfg.client_key.is_none());
163 assert!(cfg.alpn.is_empty());
164 }
165
166 #[test]
167 fn builder_errors_when_ca_is_missing() {
168 let err = ClientTlsConfig::builder().build().unwrap_err();
169 assert!(matches!(err, TlsError::MissingField("ca")));
170 }
171
172 #[test]
173 fn with_client_cert_pair_enables_mtls() {
174 let cfg = ClientTlsConfig::builder()
175 .ca_pem_bytes(vec![1])
176 .client_cert_pem_bytes(b"cert".to_vec())
177 .client_key_pem_bytes(b"key".to_vec())
178 .build()
179 .unwrap();
180 assert!(matches!(cfg.client_cert, Some(PemSource::Bytes(_))));
181 assert!(matches!(cfg.client_key, Some(PemSource::Bytes(_))));
182 }
183
184 #[test]
185 fn builder_errors_when_client_cert_without_key() {
186 let err = ClientTlsConfig::builder()
187 .ca_pem_bytes(vec![1])
188 .client_cert_pem_bytes(b"cert".to_vec())
189 .build()
190 .unwrap_err();
191 assert!(matches!(err, TlsError::MissingField("client_key")));
192 }
193
194 #[test]
195 fn builder_errors_when_client_key_without_cert() {
196 let err = ClientTlsConfig::builder()
197 .ca_pem_bytes(vec![1])
198 .client_key_pem_bytes(b"key".to_vec())
199 .build()
200 .unwrap_err();
201 assert!(matches!(err, TlsError::MissingField("client_cert")));
202 }
203
204 #[test]
205 fn with_alpn_sets_protocols() {
206 let cfg = ClientTlsConfig::builder()
207 .ca_pem_bytes(vec![1])
208 .with_alpn(["h2", "http/1.1"])
209 .build()
210 .unwrap();
211 assert_eq!(cfg.alpn, vec![b"h2".to_vec(), b"http/1.1".to_vec()]);
212 }
213
214 fn rcgen_self_signed() -> (Vec<u8>, Vec<u8>) {
215 let b = rcgen::generate_simple_self_signed(vec!["example.com".into()]).unwrap();
216 (
217 b.cert.pem().into_bytes(),
218 b.signing_key.serialize_pem().into_bytes(),
219 )
220 }
221
222 #[test]
223 fn into_rustls_config_succeeds_with_ca_only() {
224 let (ca, _) = rcgen_self_signed();
225 let cfg = ClientTlsConfig::builder().ca_pem_bytes(ca).build().unwrap();
226 let _rustls = cfg.into_rustls_config().unwrap();
227 }
228
229 #[test]
230 fn into_rustls_config_succeeds_with_mtls_client_cert() {
231 let (ca, _) = rcgen_self_signed();
232 let (cert, key) = rcgen_self_signed();
233 let cfg = ClientTlsConfig::builder()
234 .ca_pem_bytes(ca)
235 .client_cert_pem_bytes(cert)
236 .client_key_pem_bytes(key)
237 .build()
238 .unwrap();
239 let _rustls = cfg.into_rustls_config().unwrap();
240 }
241
242 #[test]
243 fn into_rustls_config_propagates_alpn_to_rustls() {
244 let (ca, _) = rcgen_self_signed();
245 let cfg = ClientTlsConfig::builder()
246 .ca_pem_bytes(ca)
247 .with_alpn(["h2"])
248 .build()
249 .unwrap();
250 let rustls = cfg.into_rustls_config().unwrap();
251 assert_eq!(rustls.alpn_protocols, vec![b"h2".to_vec()]);
252 }
253}