static_web_server/
tls.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! The module handles requests over TLS via [Rustls](tokio_rustls::rustls).
7//!
8
9// Most of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/tls.rs
10
11use futures_util::ready;
12use hyper::server::accept::Accept;
13use hyper::server::conn::{AddrIncoming, AddrStream};
14use std::fs::File;
15use std::future::Future;
16use std::io::{self, BufReader, Cursor, Read};
17use std::net::SocketAddr;
18use std::path::{Path, PathBuf};
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
23use tokio_rustls::rustls::{pki_types::PrivateKeyDer, Error as TlsError, ServerConfig};
24
25use crate::transport::Transport;
26
27/// Represents errors that can occur building the TlsConfig
28#[derive(Debug)]
29pub enum TlsConfigError {
30    /// Error type for I/O operations
31    Io(io::Error),
32    /// An Error parsing the Certificate
33    CertParseError,
34    /// Identity PEM is invalid
35    InvalidIdentityPem,
36    /// An error from an empty key
37    EmptyKey,
38    /// Unknown private key format
39    UnknownPrivateKeyFormat,
40    /// An error from an invalid key
41    InvalidKey(TlsError),
42}
43
44impl std::fmt::Display for TlsConfigError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            TlsConfigError::Io(err) => err.fmt(f),
48            TlsConfigError::CertParseError => write!(f, "certificate parse error"),
49            TlsConfigError::InvalidIdentityPem => write!(f, "identity PEM is invalid"),
50            TlsConfigError::UnknownPrivateKeyFormat => write!(f, "unknown private key format"),
51            TlsConfigError::EmptyKey => write!(f, "key contains no private key"),
52            TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {err}"),
53        }
54    }
55}
56
57impl std::error::Error for TlsConfigError {}
58
59/// Builder to set the configuration for the Tls server.
60pub struct TlsConfigBuilder {
61    cert: Box<dyn Read + Send + Sync>,
62    key: Box<dyn Read + Send + Sync>,
63}
64
65impl std::fmt::Debug for TlsConfigBuilder {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
67        f.debug_struct("TlsConfigBuilder").finish()
68    }
69}
70
71impl TlsConfigBuilder {
72    /// Create a new TlsConfigBuilder
73    pub fn new() -> TlsConfigBuilder {
74        TlsConfigBuilder {
75            key: Box::new(io::empty()),
76            cert: Box::new(io::empty()),
77        }
78    }
79
80    /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open
81    pub fn key_path(mut self, path: impl AsRef<Path>) -> Self {
82        self.key = Box::new(LazyFile {
83            path: path.as_ref().into(),
84            file: None,
85        });
86        self
87    }
88
89    /// sets the Tls key via bytes slice
90    pub fn key(mut self, key: &[u8]) -> Self {
91        self.key = Box::new(Cursor::new(Vec::from(key)));
92        self
93    }
94
95    /// Specify the file path for the TLS certificate to use.
96    pub fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
97        self.cert = Box::new(LazyFile {
98            path: path.as_ref().into(),
99            file: None,
100        });
101        self
102    }
103
104    /// sets the Tls certificate via bytes slice
105    pub fn cert(mut self, cert: &[u8]) -> Self {
106        self.cert = Box::new(Cursor::new(Vec::from(cert)));
107        self
108    }
109
110    /// Builds TLS configuration.
111    pub fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
112        let mut cert_rdr = BufReader::new(self.cert);
113        let cert = rustls_pemfile::certs(&mut cert_rdr)
114            .collect::<Result<Vec<_>, _>>()
115            .map_err(|_e| TlsConfigError::CertParseError)?;
116
117        // convert it to Vec<u8> to allow reading it again if key is RSA
118        let mut key_buf = Vec::new();
119        self.key
120            .read_to_end(&mut key_buf)
121            .map_err(TlsConfigError::Io)?;
122
123        if key_buf.is_empty() {
124            return Err(TlsConfigError::EmptyKey);
125        }
126
127        let mut key: Option<PrivateKeyDer<'_>> = None;
128        let mut reader = Cursor::new(key_buf);
129        for item in std::iter::from_fn(|| rustls_pemfile::read_one(&mut reader).transpose()) {
130            match item.map_err(|_e| TlsConfigError::InvalidIdentityPem)? {
131                // rsa pkcs1 key
132                rustls_pemfile::Item::Pkcs1Key(k) => key = Some(k.into()),
133                // pkcs8 key
134                rustls_pemfile::Item::Pkcs8Key(k) => key = Some(k.into()),
135                // sec1 ec key
136                rustls_pemfile::Item::Sec1Key(k) => key = Some(k.into()),
137                // unknown format
138                _ => return Err(TlsConfigError::UnknownPrivateKeyFormat),
139            }
140        }
141
142        let key = match key {
143            Some(k) => k,
144            _ => return Err(TlsConfigError::EmptyKey),
145        };
146
147        let mut config = ServerConfig::builder()
148            .with_no_client_auth()
149            .with_single_cert(cert, key)
150            .map_err(TlsConfigError::InvalidKey)?;
151        config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
152        Ok(config)
153    }
154}
155
156impl Default for TlsConfigBuilder {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162struct LazyFile {
163    path: PathBuf,
164    file: Option<File>,
165}
166
167impl LazyFile {
168    fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
169        if self.file.is_none() {
170            self.file = Some(File::open(&self.path)?);
171        }
172
173        self.file.as_mut().unwrap().read(buf)
174    }
175}
176
177impl Read for LazyFile {
178    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
179        self.lazy_read(buf).map_err(|err| {
180            let kind = err.kind();
181            io::Error::new(
182                kind,
183                format!("error reading file ({:?}): {}", self.path.display(), err),
184            )
185        })
186    }
187}
188
189impl Transport for TlsStream {
190    fn remote_addr(&self) -> Option<SocketAddr> {
191        Some(self.remote_addr)
192    }
193}
194
195enum State {
196    Handshaking(tokio_rustls::Accept<AddrStream>),
197    Streaming(tokio_rustls::server::TlsStream<AddrStream>),
198}
199
200/// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first.
201///
202/// tokio_rustls::server::TlsStream doesn't expose constructor methods,
203/// so we have to TlsAcceptor::accept and handshake to have access to it.
204pub struct TlsStream {
205    state: State,
206    remote_addr: SocketAddr,
207}
208
209impl TlsStream {
210    fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
211        let remote_addr = stream.remote_addr();
212        let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
213        TlsStream {
214            state: State::Handshaking(accept),
215            remote_addr,
216        }
217    }
218}
219
220impl AsyncRead for TlsStream {
221    fn poll_read(
222        self: Pin<&mut Self>,
223        cx: &mut Context<'_>,
224        buf: &mut ReadBuf<'_>,
225    ) -> Poll<io::Result<()>> {
226        let pin = self.get_mut();
227        match pin.state {
228            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
229                Ok(mut stream) => {
230                    let result = Pin::new(&mut stream).poll_read(cx, buf);
231                    pin.state = State::Streaming(stream);
232                    result
233                }
234                Err(err) => Poll::Ready(Err(err)),
235            },
236            State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
237        }
238    }
239}
240
241impl AsyncWrite for TlsStream {
242    fn poll_write(
243        self: Pin<&mut Self>,
244        cx: &mut Context<'_>,
245        buf: &[u8],
246    ) -> Poll<io::Result<usize>> {
247        let pin = self.get_mut();
248        match pin.state {
249            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
250                Ok(mut stream) => {
251                    let result = Pin::new(&mut stream).poll_write(cx, buf);
252                    pin.state = State::Streaming(stream);
253                    result
254                }
255                Err(err) => Poll::Ready(Err(err)),
256            },
257            State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
258        }
259    }
260
261    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
262        match self.state {
263            State::Handshaking(_) => Poll::Ready(Ok(())),
264            State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
265        }
266    }
267
268    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
269        match self.state {
270            State::Handshaking(_) => Poll::Ready(Ok(())),
271            State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
272        }
273    }
274}
275
276/// Type to intercept Tls incoming connections.
277pub struct TlsAcceptor {
278    config: Arc<ServerConfig>,
279    incoming: AddrIncoming,
280}
281
282impl TlsAcceptor {
283    /// Creates a new Tls interceptor.
284    pub fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
285        TlsAcceptor {
286            config: Arc::new(config),
287            incoming,
288        }
289    }
290}
291
292impl Accept for TlsAcceptor {
293    type Conn = TlsStream;
294    type Error = io::Error;
295
296    fn poll_accept(
297        self: Pin<&mut Self>,
298        cx: &mut Context<'_>,
299    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
300        let pin = self.get_mut();
301        match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
302            Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
303            Some(Err(e)) => Poll::Ready(Some(Err(e))),
304            None => Poll::Ready(None),
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn file_cert_key_rsa_pkcs1() {
315        TlsConfigBuilder::new()
316            .cert_path("tests/tls/local.dev_cert.rsa_pkcs1.pem")
317            .key_path("tests/tls/local.dev_key.rsa_pkcs1.pem")
318            .build()
319            .unwrap();
320    }
321
322    #[test]
323    fn bytes_cert_key_rsa_pkcs1() {
324        let cert = include_str!("../tests/tls/local.dev_cert.rsa_pkcs1.pem");
325        let key = include_str!("../tests/tls/local.dev_key.rsa_pkcs1.pem");
326
327        TlsConfigBuilder::new()
328            .key(key.as_bytes())
329            .cert(cert.as_bytes())
330            .build()
331            .unwrap();
332    }
333
334    #[test]
335    fn file_cert_key_pkcs8() {
336        TlsConfigBuilder::new()
337            .cert_path("tests/tls/local.dev_cert.pkcs8.pem")
338            .key_path("tests/tls/local.dev_key.pkcs8.pem")
339            .build()
340            .unwrap();
341    }
342
343    #[test]
344    fn bytes_cert_key_pkcs8() {
345        let cert = include_str!("../tests/tls/local.dev_cert.pkcs8.pem");
346        let key = include_str!("../tests/tls/local.dev_key.pkcs8.pem");
347
348        TlsConfigBuilder::new()
349            .key(key.as_bytes())
350            .cert(cert.as_bytes())
351            .build()
352            .unwrap();
353    }
354
355    #[test]
356    fn file_cert_key_sec1_ec() {
357        TlsConfigBuilder::new()
358            .cert_path("tests/tls/local.dev_cert.sec1_ec.pem")
359            .key_path("tests/tls/local.dev_key.sec1_ec.pem")
360            .build()
361            .unwrap();
362    }
363
364    #[test]
365    fn bytes_cert_key_sec1_ec() {
366        let cert = include_str!("../tests/tls/local.dev_cert.sec1_ec.pem");
367        let key = include_str!("../tests/tls/local.dev_key.sec1_ec.pem");
368
369        TlsConfigBuilder::new()
370            .key(key.as_bytes())
371            .cert(cert.as_bytes())
372            .build()
373            .unwrap();
374    }
375}