viz/server/tls/
rustls.rs

1use std::{
2    io::{Error as IoError, ErrorKind, Result as IoResult},
3    net::SocketAddr,
4};
5
6use tokio::net::{TcpListener, TcpStream};
7use tokio_rustls::{
8    rustls::{pki_types::PrivateKeyDer, server::WebPkiClientVerifier, RootCertStore, ServerConfig},
9    server::TlsStream,
10};
11
12use crate::{Error, Result};
13
14pub use tokio_rustls::TlsAcceptor;
15
16/// Tls client authentication configuration.
17#[derive(Debug)]
18pub(crate) enum ClientAuth {
19    /// No client auth.
20    Off,
21    /// Allow any anonymous or authenticated client.
22    Optional(Vec<u8>),
23    /// Allow any authenticated client.
24    Required(Vec<u8>),
25}
26
27/// `rustls`'s config.
28#[derive(Debug)]
29pub struct Config {
30    cert: Vec<u8>,
31    key: Vec<u8>,
32    ocsp_resp: Vec<u8>,
33    client_auth: ClientAuth,
34}
35
36impl Default for Config {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl Config {
43    /// Create a new Tls config
44    #[must_use]
45    pub fn new() -> Self {
46        Self {
47            cert: Vec::new(),
48            key: Vec::new(),
49            client_auth: ClientAuth::Off,
50            ocsp_resp: Vec::new(),
51        }
52    }
53
54    /// sets the Tls certificate
55    #[must_use]
56    pub fn cert(mut self, cert: impl Into<Vec<u8>>) -> Self {
57        self.cert = cert.into();
58        self
59    }
60
61    /// sets the Tls key
62    #[must_use]
63    pub fn key(mut self, key: impl Into<Vec<u8>>) -> Self {
64        self.key = key.into();
65        self
66    }
67
68    /// Sets the trust anchor for optional Tls client authentication
69    #[must_use]
70    pub fn client_auth_optional(mut self, trust_anchor: impl Into<Vec<u8>>) -> Self {
71        self.client_auth = ClientAuth::Optional(trust_anchor.into());
72        self
73    }
74
75    /// Sets the trust anchor for required Tls client authentication
76    #[must_use]
77    pub fn client_auth_required(mut self, trust_anchor: impl Into<Vec<u8>>) -> Self {
78        self.client_auth = ClientAuth::Required(trust_anchor.into());
79        self
80    }
81
82    /// sets the DER-encoded OCSP response
83    #[must_use]
84    pub fn ocsp_resp(mut self, ocsp_resp: impl Into<Vec<u8>>) -> Self {
85        self.ocsp_resp = ocsp_resp.into();
86        self
87    }
88
89    /// builds the Tls `ServerConfig`
90    ///
91    /// # Errors
92    pub fn build(self) -> Result<ServerConfig> {
93        fn read_trust_anchor(mut trust_anchor: &[u8]) -> Result<RootCertStore> {
94            let certs = rustls_pemfile::certs(&mut trust_anchor)
95                .collect::<IoResult<Vec<_>>>()
96                .map_err(Error::boxed)?;
97            let mut store = RootCertStore::empty();
98            for cert in certs {
99                store.add(cert).map_err(Error::boxed)?;
100            }
101            Ok(store)
102        }
103
104        let certs = rustls_pemfile::certs(&mut self.cert.as_slice())
105            .collect::<Result<Vec<_>, _>>()
106            .map_err(Error::boxed)?;
107
108        let keys = {
109            let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut self.key.as_slice())
110                .collect::<Result<Vec<_>, _>>()
111                .map_err(Error::boxed)?;
112            if pkcs8.is_empty() {
113                let mut rsa = rustls_pemfile::rsa_private_keys(&mut self.key.as_slice())
114                    .collect::<Result<Vec<_>, _>>()
115                    .map_err(Error::boxed)?;
116
117                if rsa.is_empty() {
118                    return Err(Error::boxed(IoError::new(
119                        ErrorKind::InvalidData,
120                        "failed to parse tls private keys",
121                    )));
122                }
123                PrivateKeyDer::Pkcs1(rsa.remove(0))
124            } else {
125                PrivateKeyDer::Pkcs8(pkcs8.remove(0))
126            }
127        };
128
129        let client_auth = match self.client_auth {
130            ClientAuth::Off => WebPkiClientVerifier::no_client_auth(),
131            ClientAuth::Optional(trust_anchor) => {
132                WebPkiClientVerifier::builder(read_trust_anchor(&trust_anchor)?.into())
133                    .allow_unauthenticated()
134                    .build()
135                    .map_err(Error::boxed)?
136            }
137            ClientAuth::Required(trust_anchor) => {
138                WebPkiClientVerifier::builder(read_trust_anchor(&trust_anchor)?.into())
139                    .build()
140                    .map_err(Error::boxed)?
141            }
142        };
143
144        ServerConfig::builder()
145            .with_client_cert_verifier(client_auth)
146            .with_single_cert_with_ocsp(certs, keys, self.ocsp_resp)
147            .map_err(Error::boxed)
148    }
149}
150
151impl crate::Listener for crate::tls::TlsListener<TcpListener, TlsAcceptor> {
152    type Io = TlsStream<TcpStream>;
153    type Addr = SocketAddr;
154
155    async fn accept(&self) -> IoResult<(Self::Io, Self::Addr)> {
156        let (stream, addr) = self.inner.accept().await?;
157        let stream = self.acceptor.accept(stream).await?;
158        Ok((stream, addr))
159    }
160
161    fn local_addr(&self) -> IoResult<Self::Addr> {
162        self.inner.local_addr()
163    }
164}