1use std::sync::{Arc, RwLock};
2
3use derive_where::derive_where;
4use ignore_result::Ignore;
5use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer};
6use rustls::RootCertStore;
7
8use super::TlsClient;
9use crate::client::Result;
10use crate::Error;
11
12type PemItem = rustls_pemfile::Item;
13
14#[derive(Clone, Debug)]
16pub struct TlsCa {
17 pub(super) roots: RootCertStore,
18 pub(super) crls: Vec<CertificateRevocationListDer<'static>>,
19}
20
21impl TlsCa {
22 pub fn from_pem(pem: &str) -> Result<Self> {
24 let mut ca = Self { roots: RootCertStore::empty(), crls: Vec::new() };
25 for r in rustls_pemfile::read_all(&mut pem.as_bytes()) {
26 match r {
27 Ok(PemItem::X509Certificate(cert)) => ca.roots.add(cert).ignore(),
28 Ok(PemItem::Crl(crl)) => ca.crls.push(crl),
29 Ok(_) => continue,
30 Err(err) => return Err(Error::with_other("fail to read ca", err)),
31 }
32 }
33 if ca.roots.is_empty() {
34 return Err(Error::BadArguments(&"no valid tls trust anchor in pem"));
35 }
36 Ok(ca)
37 }
38
39 fn merge(&mut self, ca: TlsCa) {
40 self.roots.roots.extend(ca.roots.roots);
41 self.crls.extend(ca.crls);
42 }
43}
44
45#[derive_where(Debug)]
47pub struct TlsIdentity {
48 pub(super) cert: Vec<CertificateDer<'static>>,
50
51 #[derive_where(skip)]
53 pub(super) key: PrivateKeyDer<'static>,
54}
55
56impl TlsIdentity {
57 pub fn from_pem(cert: &str, key: &str) -> Result<Self> {
59 let r: Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect();
60 let cert = match r {
61 Err(err) => return Err(Error::with_other("fail to read cert", err)),
62 Ok(cert) => cert,
63 };
64 let key = match rustls_pemfile::private_key(&mut key.as_bytes()) {
65 Err(err) => return Err(Error::with_other("fail to read client private key", err)),
66 Ok(None) => return Err(Error::BadArguments(&"no client private key")),
67 Ok(Some(key)) => key,
68 };
69 Ok(Self { cert, key })
70 }
71}
72
73impl Clone for TlsIdentity {
74 fn clone(&self) -> Self {
75 Self { cert: self.cert.clone(), key: self.key.clone_key() }
76 }
77}
78
79#[derive(Clone, Debug)]
81pub struct TlsCerts {
82 pub(super) ca: TlsCa,
84 pub(super) identity: Option<TlsIdentity>,
86}
87
88impl TlsCerts {
89 pub fn builder() -> TlsCertsBuilder {
91 TlsCertsBuilder::new()
92 }
93}
94
95#[derive(Clone, Debug)]
97pub struct TlsCertsBuilder {
98 ca: Option<TlsCa>,
99 identity: Option<TlsIdentity>,
101}
102
103impl TlsCertsBuilder {
104 fn new() -> Self {
106 Self { ca: None, identity: None }
107 }
108
109 pub fn with_ca(mut self, ca: TlsCa) -> Self {
111 self.ca = Some(ca);
112 self
113 }
114
115 pub fn with_identity(mut self, identity: TlsIdentity) -> Self {
117 self.identity = Some(identity);
118 self
119 }
120
121 pub fn build(self) -> Result<TlsCerts> {
123 let ca = match self.ca {
124 None => return Err(Error::BadArguments(&"no tls ca")),
125 Some(ca) => ca,
126 };
127 Ok(TlsCerts { ca, identity: self.identity })
128 }
129}
130
131#[derive(Clone, Debug)]
133pub struct TlsCertsOptions {
134 certs: TlsInnerCerts,
135}
136
137#[derive(Clone, Debug)]
138pub(super) enum TlsInnerCerts {
139 Static(TlsCerts),
140 Dynamic(TlsDynamicCerts),
141}
142
143impl From<TlsCertsOptions> for TlsInnerCerts {
144 fn from(options: TlsCertsOptions) -> Self {
145 options.certs
146 }
147}
148
149impl From<TlsInnerCerts> for TlsCertsOptions {
150 fn from(certs: TlsInnerCerts) -> Self {
151 Self { certs }
152 }
153}
154
155impl From<TlsCerts> for TlsCertsOptions {
156 fn from(certs: TlsCerts) -> Self {
157 TlsInnerCerts::Static(certs).into()
158 }
159}
160
161impl From<TlsDynamicCerts> for TlsCertsOptions {
162 fn from(certs: TlsDynamicCerts) -> Self {
163 TlsInnerCerts::Dynamic(certs).into()
164 }
165}
166
167#[derive(Clone, Debug)]
172pub struct TlsDynamicCerts {
173 certs: Arc<RwLock<(u64, Arc<TlsCerts>)>>,
174}
175
176impl TlsDynamicCerts {
177 pub fn new(certs: TlsCerts) -> Self {
179 let certs = certs.into();
180 Self { certs: Arc::new(RwLock::new((1, certs))) }
181 }
182
183 pub fn update(&self, certs: TlsCerts) {
185 let certs = certs.into();
186 let mut writer = self.certs.write().unwrap();
187 writer.0 += 1;
188 let old = std::mem::replace(&mut writer.1, certs);
189 drop(writer);
190 drop(old);
191 }
192
193 pub fn update_ca(&self, ca: TlsCa) {
195 self.update_partially(|certs| certs.ca = ca.clone())
196 }
197
198 pub fn update_identity(&self, identity: Option<TlsIdentity>) {
200 self.update_partially(|certs| certs.identity = identity.clone())
201 }
202
203 fn update_versioned(&self, version: u64, certs: TlsCerts) -> bool {
204 let certs = certs.into();
205 let mut writer = self.certs.write().unwrap();
206 if writer.0 != version {
207 return false;
208 }
209 writer.0 += 1;
210 let old = std::mem::replace(&mut writer.1, certs);
211 drop(writer);
212 drop(old);
213 true
214 }
215
216 fn update_partially(&self, update: impl Fn(&mut TlsCerts)) {
217 loop {
218 let (version, certs) = self.get_versioned();
219 let mut certs = (*certs).clone();
220 update(&mut certs);
221 if self.update_versioned(version, certs) {
222 break;
223 }
224 }
225 }
226
227 pub(crate) fn get_versioned(&self) -> (u64, Arc<TlsCerts>) {
228 self.certs.read().unwrap().clone()
229 }
230
231 pub(crate) fn get_updated(&self, version: u64) -> Option<(u64, Arc<TlsCerts>)> {
232 let locked = self.certs.read().unwrap();
233 if version >= locked.0 {
234 return None;
235 }
236 Some(locked.clone())
237 }
238}
239
240#[derive(Clone, Debug)]
242pub struct TlsOptions {
243 ca: Option<TlsCa>,
244 identity: Option<TlsIdentity>,
245 certs: Option<TlsCertsOptions>,
246 hostname_verification: bool,
247 #[cfg(all(feature = "fips", not(feature = "fips-only")))]
248 fips: bool,
249}
250
251impl Default for TlsOptions {
252 fn default() -> Self {
254 Self::new()
255 }
256}
257
258impl TlsOptions {
259 #[deprecated(since = "0.10.0", note = "use TlsOptions::new instead")]
261 pub fn no_ca() -> Self {
262 Self::new()
263 }
264
265 pub fn new() -> Self {
267 Self {
268 ca: None,
269 identity: None,
270 certs: None,
271 hostname_verification: true,
272 #[cfg(all(feature = "fips", not(feature = "fips-only")))]
273 fips: false,
274 }
275 }
276
277 pub unsafe fn with_no_hostname_verification(mut self) -> Self {
282 self.hostname_verification = false;
283 self
284 }
285
286 #[cfg(feature = "fips")]
290 #[cfg_attr(docsrs, doc(cfg(any(feature = "fips", feature = "fips-only"))))]
291 pub fn with_fips(self) -> Self {
292 self.with_fips_internal()
293 }
294
295 #[cfg(all(feature = "fips", not(feature = "fips-only")))]
296 fn with_fips_internal(mut self) -> Self {
297 self.fips = true;
298 self
299 }
300
301 #[cfg(feature = "fips-only")]
302 fn with_fips_internal(self) -> Self {
303 self
304 }
305
306 #[deprecated(since = "0.10.0", note = "use TlsOptions::with_pem_ca instead")]
312 pub fn with_pem_ca_certs(mut self, certs: &str) -> Result<Self> {
313 let mut ca = TlsCa::from_pem(certs)?;
314 ca.crls.clear();
315 match self.ca.as_mut() {
316 None => self.ca = Some(ca),
317 Some(existing_ca) => existing_ca.merge(ca),
318 };
319 Ok(self)
320 }
321
322 pub fn with_pem_ca(mut self, ca: &str) -> Result<Self> {
326 self.ca = Some(TlsCa::from_pem(ca)?);
327 Ok(self)
328 }
329
330 pub fn with_pem_identity(mut self, cert: &str, key: &str) -> Result<Self> {
334 self.identity = Some(TlsIdentity::from_pem(cert, key)?);
335 Ok(self)
336 }
337
338 pub fn with_certs(mut self, certs: impl Into<TlsCertsOptions>) -> Self {
341 self.certs = Some(certs.into());
342 self
343 }
344
345 pub(super) fn into_client_options(self) -> Result<TlsClientOptions<TlsInnerCerts>> {
346 let certs = match self.certs.map(TlsInnerCerts::from) {
347 None => {
348 let certs = TlsCertsBuilder { ca: self.ca, identity: self.identity }.build()?;
349 TlsInnerCerts::Static(certs)
350 },
351 Some(certs) => certs,
352 };
353 Ok(TlsClientOptions {
354 certs,
355 hostname_verification: self.hostname_verification,
356 #[cfg(all(feature = "fips", not(feature = "fips-only")))]
357 fips: self.fips,
358 })
359 }
360
361 pub(crate) fn into_client(self) -> Result<TlsClient> {
362 let options = self.into_client_options()?;
363 TlsClient::new(options)
364 }
365}
366
367pub(super) struct TlsClientOptions<Certs> {
368 pub certs: Certs,
369 pub hostname_verification: bool,
370 #[cfg(all(feature = "fips", not(feature = "fips-only")))]
371 pub fips: bool,
372}