1use futures_util::ready;
12use hyper::server::accept::Accept;
13use hyper::server::conn::{AddrIncoming, AddrStream};
14use std::fs::File;
15use std::future::Future;
16use std::io::{self, BufReader, Cursor, Read};
17use std::net::SocketAddr;
18use std::path::{Path, PathBuf};
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
23use tokio_rustls::rustls::{pki_types::PrivateKeyDer, Error as TlsError, ServerConfig};
24
25use crate::transport::Transport;
26
27#[derive(Debug)]
29pub enum TlsConfigError {
30 Io(io::Error),
32 CertParseError,
34 InvalidIdentityPem,
36 EmptyKey,
38 UnknownPrivateKeyFormat,
40 InvalidKey(TlsError),
42}
43
44impl std::fmt::Display for TlsConfigError {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 TlsConfigError::Io(err) => err.fmt(f),
48 TlsConfigError::CertParseError => write!(f, "certificate parse error"),
49 TlsConfigError::InvalidIdentityPem => write!(f, "identity PEM is invalid"),
50 TlsConfigError::UnknownPrivateKeyFormat => write!(f, "unknown private key format"),
51 TlsConfigError::EmptyKey => write!(f, "key contains no private key"),
52 TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {err}"),
53 }
54 }
55}
56
57impl std::error::Error for TlsConfigError {}
58
59pub struct TlsConfigBuilder {
61 cert: Box<dyn Read + Send + Sync>,
62 key: Box<dyn Read + Send + Sync>,
63}
64
65impl std::fmt::Debug for TlsConfigBuilder {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
67 f.debug_struct("TlsConfigBuilder").finish()
68 }
69}
70
71impl TlsConfigBuilder {
72 pub fn new() -> TlsConfigBuilder {
74 TlsConfigBuilder {
75 key: Box::new(io::empty()),
76 cert: Box::new(io::empty()),
77 }
78 }
79
80 pub fn key_path(mut self, path: impl AsRef<Path>) -> Self {
82 self.key = Box::new(LazyFile {
83 path: path.as_ref().into(),
84 file: None,
85 });
86 self
87 }
88
89 pub fn key(mut self, key: &[u8]) -> Self {
91 self.key = Box::new(Cursor::new(Vec::from(key)));
92 self
93 }
94
95 pub fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
97 self.cert = Box::new(LazyFile {
98 path: path.as_ref().into(),
99 file: None,
100 });
101 self
102 }
103
104 pub fn cert(mut self, cert: &[u8]) -> Self {
106 self.cert = Box::new(Cursor::new(Vec::from(cert)));
107 self
108 }
109
110 pub fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
112 let mut cert_rdr = BufReader::new(self.cert);
113 let cert = rustls_pemfile::certs(&mut cert_rdr)
114 .collect::<Result<Vec<_>, _>>()
115 .map_err(|_e| TlsConfigError::CertParseError)?;
116
117 let mut key_buf = Vec::new();
119 self.key
120 .read_to_end(&mut key_buf)
121 .map_err(TlsConfigError::Io)?;
122
123 if key_buf.is_empty() {
124 return Err(TlsConfigError::EmptyKey);
125 }
126
127 let mut key: Option<PrivateKeyDer<'_>> = None;
128 let mut reader = Cursor::new(key_buf);
129 for item in std::iter::from_fn(|| rustls_pemfile::read_one(&mut reader).transpose()) {
130 match item.map_err(|_e| TlsConfigError::InvalidIdentityPem)? {
131 rustls_pemfile::Item::Pkcs1Key(k) => key = Some(k.into()),
133 rustls_pemfile::Item::Pkcs8Key(k) => key = Some(k.into()),
135 rustls_pemfile::Item::Sec1Key(k) => key = Some(k.into()),
137 _ => return Err(TlsConfigError::UnknownPrivateKeyFormat),
139 }
140 }
141
142 let key = match key {
143 Some(k) => k,
144 _ => return Err(TlsConfigError::EmptyKey),
145 };
146
147 let mut config = ServerConfig::builder()
148 .with_no_client_auth()
149 .with_single_cert(cert, key)
150 .map_err(TlsConfigError::InvalidKey)?;
151 config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
152 Ok(config)
153 }
154}
155
156impl Default for TlsConfigBuilder {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162struct LazyFile {
163 path: PathBuf,
164 file: Option<File>,
165}
166
167impl LazyFile {
168 fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
169 if self.file.is_none() {
170 self.file = Some(File::open(&self.path)?);
171 }
172
173 self.file.as_mut().unwrap().read(buf)
174 }
175}
176
177impl Read for LazyFile {
178 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
179 self.lazy_read(buf).map_err(|err| {
180 let kind = err.kind();
181 io::Error::new(
182 kind,
183 format!("error reading file ({:?}): {}", self.path.display(), err),
184 )
185 })
186 }
187}
188
189impl Transport for TlsStream {
190 fn remote_addr(&self) -> Option<SocketAddr> {
191 Some(self.remote_addr)
192 }
193}
194
195enum State {
196 Handshaking(tokio_rustls::Accept<AddrStream>),
197 Streaming(tokio_rustls::server::TlsStream<AddrStream>),
198}
199
200pub struct TlsStream {
205 state: State,
206 remote_addr: SocketAddr,
207}
208
209impl TlsStream {
210 fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
211 let remote_addr = stream.remote_addr();
212 let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
213 TlsStream {
214 state: State::Handshaking(accept),
215 remote_addr,
216 }
217 }
218}
219
220impl AsyncRead for TlsStream {
221 fn poll_read(
222 self: Pin<&mut Self>,
223 cx: &mut Context<'_>,
224 buf: &mut ReadBuf<'_>,
225 ) -> Poll<io::Result<()>> {
226 let pin = self.get_mut();
227 match pin.state {
228 State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
229 Ok(mut stream) => {
230 let result = Pin::new(&mut stream).poll_read(cx, buf);
231 pin.state = State::Streaming(stream);
232 result
233 }
234 Err(err) => Poll::Ready(Err(err)),
235 },
236 State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
237 }
238 }
239}
240
241impl AsyncWrite for TlsStream {
242 fn poll_write(
243 self: Pin<&mut Self>,
244 cx: &mut Context<'_>,
245 buf: &[u8],
246 ) -> Poll<io::Result<usize>> {
247 let pin = self.get_mut();
248 match pin.state {
249 State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
250 Ok(mut stream) => {
251 let result = Pin::new(&mut stream).poll_write(cx, buf);
252 pin.state = State::Streaming(stream);
253 result
254 }
255 Err(err) => Poll::Ready(Err(err)),
256 },
257 State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
258 }
259 }
260
261 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
262 match self.state {
263 State::Handshaking(_) => Poll::Ready(Ok(())),
264 State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
265 }
266 }
267
268 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
269 match self.state {
270 State::Handshaking(_) => Poll::Ready(Ok(())),
271 State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
272 }
273 }
274}
275
276pub struct TlsAcceptor {
278 config: Arc<ServerConfig>,
279 incoming: AddrIncoming,
280}
281
282impl TlsAcceptor {
283 pub fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
285 TlsAcceptor {
286 config: Arc::new(config),
287 incoming,
288 }
289 }
290}
291
292impl Accept for TlsAcceptor {
293 type Conn = TlsStream;
294 type Error = io::Error;
295
296 fn poll_accept(
297 self: Pin<&mut Self>,
298 cx: &mut Context<'_>,
299 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
300 let pin = self.get_mut();
301 match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
302 Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
303 Some(Err(e)) => Poll::Ready(Some(Err(e))),
304 None => Poll::Ready(None),
305 }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn file_cert_key_rsa_pkcs1() {
315 TlsConfigBuilder::new()
316 .cert_path("tests/tls/local.dev_cert.rsa_pkcs1.pem")
317 .key_path("tests/tls/local.dev_key.rsa_pkcs1.pem")
318 .build()
319 .unwrap();
320 }
321
322 #[test]
323 fn bytes_cert_key_rsa_pkcs1() {
324 let cert = include_str!("../tests/tls/local.dev_cert.rsa_pkcs1.pem");
325 let key = include_str!("../tests/tls/local.dev_key.rsa_pkcs1.pem");
326
327 TlsConfigBuilder::new()
328 .key(key.as_bytes())
329 .cert(cert.as_bytes())
330 .build()
331 .unwrap();
332 }
333
334 #[test]
335 fn file_cert_key_pkcs8() {
336 TlsConfigBuilder::new()
337 .cert_path("tests/tls/local.dev_cert.pkcs8.pem")
338 .key_path("tests/tls/local.dev_key.pkcs8.pem")
339 .build()
340 .unwrap();
341 }
342
343 #[test]
344 fn bytes_cert_key_pkcs8() {
345 let cert = include_str!("../tests/tls/local.dev_cert.pkcs8.pem");
346 let key = include_str!("../tests/tls/local.dev_key.pkcs8.pem");
347
348 TlsConfigBuilder::new()
349 .key(key.as_bytes())
350 .cert(cert.as_bytes())
351 .build()
352 .unwrap();
353 }
354
355 #[test]
356 fn file_cert_key_sec1_ec() {
357 TlsConfigBuilder::new()
358 .cert_path("tests/tls/local.dev_cert.sec1_ec.pem")
359 .key_path("tests/tls/local.dev_key.sec1_ec.pem")
360 .build()
361 .unwrap();
362 }
363
364 #[test]
365 fn bytes_cert_key_sec1_ec() {
366 let cert = include_str!("../tests/tls/local.dev_cert.sec1_ec.pem");
367 let key = include_str!("../tests/tls/local.dev_key.sec1_ec.pem");
368
369 TlsConfigBuilder::new()
370 .key(key.as_bytes())
371 .cert(cert.as_bytes())
372 .build()
373 .unwrap();
374 }
375}