1use std::{
2 io::{Error as IoError, ErrorKind, Result as IoResult},
3 net::SocketAddr,
4};
5
6use tokio::net::{TcpListener, TcpStream};
7use tokio_rustls::{
8 rustls::{pki_types::PrivateKeyDer, server::WebPkiClientVerifier, RootCertStore, ServerConfig},
9 server::TlsStream,
10};
11
12use crate::{Error, Result};
13
14pub use tokio_rustls::TlsAcceptor;
15
16#[derive(Debug)]
18pub(crate) enum ClientAuth {
19 Off,
21 Optional(Vec<u8>),
23 Required(Vec<u8>),
25}
26
27#[derive(Debug)]
29pub struct Config {
30 cert: Vec<u8>,
31 key: Vec<u8>,
32 ocsp_resp: Vec<u8>,
33 client_auth: ClientAuth,
34}
35
36impl Default for Config {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl Config {
43 #[must_use]
45 pub fn new() -> Self {
46 Self {
47 cert: Vec::new(),
48 key: Vec::new(),
49 client_auth: ClientAuth::Off,
50 ocsp_resp: Vec::new(),
51 }
52 }
53
54 #[must_use]
56 pub fn cert(mut self, cert: impl Into<Vec<u8>>) -> Self {
57 self.cert = cert.into();
58 self
59 }
60
61 #[must_use]
63 pub fn key(mut self, key: impl Into<Vec<u8>>) -> Self {
64 self.key = key.into();
65 self
66 }
67
68 #[must_use]
70 pub fn client_auth_optional(mut self, trust_anchor: impl Into<Vec<u8>>) -> Self {
71 self.client_auth = ClientAuth::Optional(trust_anchor.into());
72 self
73 }
74
75 #[must_use]
77 pub fn client_auth_required(mut self, trust_anchor: impl Into<Vec<u8>>) -> Self {
78 self.client_auth = ClientAuth::Required(trust_anchor.into());
79 self
80 }
81
82 #[must_use]
84 pub fn ocsp_resp(mut self, ocsp_resp: impl Into<Vec<u8>>) -> Self {
85 self.ocsp_resp = ocsp_resp.into();
86 self
87 }
88
89 pub fn build(self) -> Result<ServerConfig> {
93 fn read_trust_anchor(mut trust_anchor: &[u8]) -> Result<RootCertStore> {
94 let certs = rustls_pemfile::certs(&mut trust_anchor)
95 .collect::<IoResult<Vec<_>>>()
96 .map_err(Error::boxed)?;
97 let mut store = RootCertStore::empty();
98 for cert in certs {
99 store.add(cert).map_err(Error::boxed)?;
100 }
101 Ok(store)
102 }
103
104 let certs = rustls_pemfile::certs(&mut self.cert.as_slice())
105 .collect::<Result<Vec<_>, _>>()
106 .map_err(Error::boxed)?;
107
108 let keys = {
109 let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut self.key.as_slice())
110 .collect::<Result<Vec<_>, _>>()
111 .map_err(Error::boxed)?;
112 if pkcs8.is_empty() {
113 let mut rsa = rustls_pemfile::rsa_private_keys(&mut self.key.as_slice())
114 .collect::<Result<Vec<_>, _>>()
115 .map_err(Error::boxed)?;
116
117 if rsa.is_empty() {
118 return Err(Error::boxed(IoError::new(
119 ErrorKind::InvalidData,
120 "failed to parse tls private keys",
121 )));
122 }
123 PrivateKeyDer::Pkcs1(rsa.remove(0))
124 } else {
125 PrivateKeyDer::Pkcs8(pkcs8.remove(0))
126 }
127 };
128
129 let client_auth = match self.client_auth {
130 ClientAuth::Off => WebPkiClientVerifier::no_client_auth(),
131 ClientAuth::Optional(trust_anchor) => {
132 WebPkiClientVerifier::builder(read_trust_anchor(&trust_anchor)?.into())
133 .allow_unauthenticated()
134 .build()
135 .map_err(Error::boxed)?
136 }
137 ClientAuth::Required(trust_anchor) => {
138 WebPkiClientVerifier::builder(read_trust_anchor(&trust_anchor)?.into())
139 .build()
140 .map_err(Error::boxed)?
141 }
142 };
143
144 ServerConfig::builder()
145 .with_client_cert_verifier(client_auth)
146 .with_single_cert_with_ocsp(certs, keys, self.ocsp_resp)
147 .map_err(Error::boxed)
148 }
149}
150
151impl crate::Listener for crate::tls::TlsListener<TcpListener, TlsAcceptor> {
152 type Io = TlsStream<TcpStream>;
153 type Addr = SocketAddr;
154
155 async fn accept(&self) -> IoResult<(Self::Io, Self::Addr)> {
156 let (stream, addr) = self.inner.accept().await?;
157 let stream = self.acceptor.accept(stream).await?;
158 Ok((stream, addr))
159 }
160
161 fn local_addr(&self) -> IoResult<Self::Addr> {
162 self.inner.local_addr()
163 }
164}