Skip to main content

static_web_server/
tls.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! The module handles requests over TLS via [Rustls](tokio_rustls::rustls).
7//!
8
9// Ensure exactly one TLS crypto provider is enabled.
10#[cfg(all(feature = "http2-ring", feature = "http2-fips"))]
11compile_error!(
12    "Only one TLS crypto provider may be enabled. Choose one of: http2-ring, http2-fips"
13);
14
15#[cfg(not(any(feature = "http2-ring", feature = "http2-fips")))]
16compile_error!("http2 requires a crypto provider: enable http2-ring or http2-fips");
17
18// Most of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/tls.rs
19
20use futures_util::ready;
21use hyper::server::accept::Accept;
22use hyper::server::conn::{AddrIncoming, AddrStream};
23use rustls_pki_types::pem::PemObject;
24use rustls_pki_types::{CertificateDer, PrivateKeyDer};
25use std::fs::File;
26use std::future::Future;
27use std::io::{self, BufReader, Cursor, Read};
28use std::net::SocketAddr;
29use std::path::{Path, PathBuf};
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
34use tokio_rustls::rustls::{Error as TlsError, ServerConfig};
35
36use crate::transport::Transport;
37
38/// Represents errors that can occur building the TlsConfig
39#[derive(Debug)]
40pub enum TlsConfigError {
41    /// Error type for I/O operations
42    Io(io::Error),
43    /// An Error parsing the Certificate
44    CertParseError,
45    /// Identity PEM is invalid
46    InvalidIdentityPem,
47    /// An error from an empty key
48    EmptyKey,
49    /// Unknown private key format
50    UnknownPrivateKeyFormat,
51    /// An error from an invalid key
52    InvalidKey(TlsError),
53    /// Illegal section start in PEM
54    IllegalSectionStart(Vec<u8>),
55    /// Illegal section end in PEM
56    IllegalSectionEnd(Vec<u8>),
57}
58
59impl std::fmt::Display for TlsConfigError {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            TlsConfigError::Io(err) => err.fmt(f),
63            TlsConfigError::CertParseError => write!(f, "failed to parse certificate"),
64            TlsConfigError::InvalidIdentityPem => write!(f, "the identity PEM provided is invalid"),
65            TlsConfigError::UnknownPrivateKeyFormat => {
66                write!(f, "the private key format is unknown")
67            }
68            TlsConfigError::EmptyKey => write!(f, "the key provided is probably missing or empty"),
69            TlsConfigError::InvalidKey(err) => write!(f, "the key provided is invalid, {err}"),
70            TlsConfigError::IllegalSectionStart(line) => {
71                let line = String::from_utf8(line.clone()).unwrap_or_default();
72                write!(f, "illegal section start in PEM at '{line}'")
73            }
74            TlsConfigError::IllegalSectionEnd(end_marker) => {
75                let end_marker = String::from_utf8(end_marker.clone()).unwrap_or_default();
76                write!(f, "illegal section end in PEM at '{end_marker}'")
77            }
78        }
79    }
80}
81
82impl std::error::Error for TlsConfigError {}
83
84/// Builder to set the configuration for the Tls server.
85pub struct TlsConfigBuilder {
86    cert: Box<dyn Read + Send + Sync>,
87    key: Box<dyn Read + Send + Sync>,
88}
89
90impl std::fmt::Debug for TlsConfigBuilder {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
92        f.debug_struct("TlsConfigBuilder").finish()
93    }
94}
95
96impl TlsConfigBuilder {
97    /// Create a new TlsConfigBuilder
98    pub fn new() -> TlsConfigBuilder {
99        TlsConfigBuilder {
100            key: Box::new(io::empty()),
101            cert: Box::new(io::empty()),
102        }
103    }
104
105    /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open
106    pub fn key_path(mut self, path: impl AsRef<Path>) -> Self {
107        self.key = Box::new(LazyFile {
108            path: path.as_ref().into(),
109            file: None,
110        });
111        self
112    }
113
114    /// sets the Tls key via bytes slice
115    pub fn key(mut self, key: &[u8]) -> Self {
116        self.key = Box::new(Cursor::new(Vec::from(key)));
117        self
118    }
119
120    /// Specify the file path for the TLS certificate to use.
121    pub fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
122        self.cert = Box::new(LazyFile {
123            path: path.as_ref().into(),
124            file: None,
125        });
126        self
127    }
128
129    /// sets the Tls certificate via bytes slice
130    pub fn cert(mut self, cert: &[u8]) -> Self {
131        self.cert = Box::new(Cursor::new(Vec::from(cert)));
132        self
133    }
134
135    /// Builds TLS configuration.
136    pub fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
137        let mut cert_rdr = BufReader::new(self.cert);
138        let cert = CertificateDer::pem_reader_iter(&mut cert_rdr)
139            .collect::<Result<Vec<_>, _>>()
140            .map_err(|_e| TlsConfigError::CertParseError)?;
141
142        // convert it to Vec<u8> to allow reading it again if key is RSA
143        let mut key_buf = Vec::new();
144        self.key
145            .read_to_end(&mut key_buf)
146            .map_err(TlsConfigError::Io)?;
147
148        if key_buf.is_empty() {
149            return Err(TlsConfigError::EmptyKey);
150        }
151
152        let reader = Cursor::new(key_buf);
153        let key = PrivateKeyDer::from_pem_reader(reader).map_err(|err| match err {
154            rustls_pki_types::pem::Error::Base64Decode(_) => TlsConfigError::InvalidIdentityPem,
155            rustls_pki_types::pem::Error::NoItemsFound => TlsConfigError::EmptyKey,
156            rustls_pki_types::pem::Error::IllegalSectionStart { line } => {
157                TlsConfigError::IllegalSectionStart(line)
158            }
159            rustls_pki_types::pem::Error::MissingSectionEnd { end_marker } => {
160                TlsConfigError::IllegalSectionEnd(end_marker)
161            }
162            rustls_pki_types::pem::Error::Io(err) => {
163                TlsConfigError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
164            }
165            _ => TlsConfigError::InvalidIdentityPem,
166        })?;
167
168        let mut config = ServerConfig::builder()
169            .with_no_client_auth()
170            .with_single_cert(cert, key)
171            .map_err(TlsConfigError::InvalidKey)?;
172        config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
173        Ok(config)
174    }
175}
176
177impl Default for TlsConfigBuilder {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183struct LazyFile {
184    path: PathBuf,
185    file: Option<File>,
186}
187
188impl LazyFile {
189    fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
190        if self.file.is_none() {
191            self.file = Some(File::open(&self.path)?);
192        }
193
194        match self.file.as_mut() {
195            Some(file) => file.read(buf),
196            None => Err(io::Error::other("file handle unavailable after open")),
197        }
198    }
199}
200
201impl Read for LazyFile {
202    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
203        self.lazy_read(buf).map_err(|err| {
204            let kind = err.kind();
205            io::Error::new(
206                kind,
207                format!("error reading file ({:?}): {}", self.path.display(), err),
208            )
209        })
210    }
211}
212
213impl Transport for TlsStream {
214    fn remote_addr(&self) -> Option<SocketAddr> {
215        Some(self.remote_addr)
216    }
217}
218
219enum State {
220    Handshaking(tokio_rustls::Accept<AddrStream>),
221    Streaming(tokio_rustls::server::TlsStream<AddrStream>),
222}
223
224/// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first.
225///
226/// tokio_rustls::server::TlsStream doesn't expose constructor methods,
227/// so we have to TlsAcceptor::accept and handshake to have access to it.
228pub struct TlsStream {
229    state: State,
230    remote_addr: SocketAddr,
231}
232
233impl TlsStream {
234    fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
235        let remote_addr = stream.remote_addr();
236        let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
237        TlsStream {
238            state: State::Handshaking(accept),
239            remote_addr,
240        }
241    }
242}
243
244impl AsyncRead for TlsStream {
245    fn poll_read(
246        self: Pin<&mut Self>,
247        cx: &mut Context<'_>,
248        buf: &mut ReadBuf<'_>,
249    ) -> Poll<io::Result<()>> {
250        let pin = self.get_mut();
251        match pin.state {
252            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
253                Ok(mut stream) => {
254                    let result = Pin::new(&mut stream).poll_read(cx, buf);
255                    pin.state = State::Streaming(stream);
256                    result
257                }
258                Err(err) => Poll::Ready(Err(err)),
259            },
260            State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
261        }
262    }
263}
264
265impl AsyncWrite for TlsStream {
266    fn poll_write(
267        self: Pin<&mut Self>,
268        cx: &mut Context<'_>,
269        buf: &[u8],
270    ) -> Poll<io::Result<usize>> {
271        let pin = self.get_mut();
272        match pin.state {
273            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
274                Ok(mut stream) => {
275                    let result = Pin::new(&mut stream).poll_write(cx, buf);
276                    pin.state = State::Streaming(stream);
277                    result
278                }
279                Err(err) => Poll::Ready(Err(err)),
280            },
281            State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
282        }
283    }
284
285    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
286        match self.state {
287            State::Handshaking(_) => Poll::Ready(Ok(())),
288            State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
289        }
290    }
291
292    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
293        match self.state {
294            State::Handshaking(_) => Poll::Ready(Ok(())),
295            State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
296        }
297    }
298}
299
300/// Type to intercept Tls incoming connections.
301pub struct TlsAcceptor {
302    config: Arc<ServerConfig>,
303    incoming: AddrIncoming,
304}
305
306impl TlsAcceptor {
307    /// Creates a new Tls interceptor.
308    pub fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
309        TlsAcceptor {
310            config: Arc::new(config),
311            incoming,
312        }
313    }
314}
315
316impl Accept for TlsAcceptor {
317    type Conn = TlsStream;
318    type Error = io::Error;
319
320    fn poll_accept(
321        self: Pin<&mut Self>,
322        cx: &mut Context<'_>,
323    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
324        let pin = self.get_mut();
325        match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
326            Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
327            Some(Err(e)) => Poll::Ready(Some(Err(e))),
328            None => Poll::Ready(None),
329        }
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn file_cert_key_rsa_pkcs1() {
339        TlsConfigBuilder::new()
340            .cert_path("tests/tls/local.dev_cert.rsa_pkcs1.pem")
341            .key_path("tests/tls/local.dev_key.rsa_pkcs1.pem")
342            .build()
343            .unwrap();
344    }
345
346    #[test]
347    fn bytes_cert_key_rsa_pkcs1() {
348        let cert = include_str!("../tests/tls/local.dev_cert.rsa_pkcs1.pem");
349        let key = include_str!("../tests/tls/local.dev_key.rsa_pkcs1.pem");
350
351        TlsConfigBuilder::new()
352            .key(key.as_bytes())
353            .cert(cert.as_bytes())
354            .build()
355            .unwrap();
356    }
357
358    #[test]
359    fn file_cert_key_pkcs8() {
360        TlsConfigBuilder::new()
361            .cert_path("tests/tls/local.dev_cert.pkcs8.pem")
362            .key_path("tests/tls/local.dev_key.pkcs8.pem")
363            .build()
364            .unwrap();
365    }
366
367    #[test]
368    fn bytes_cert_key_pkcs8() {
369        let cert = include_str!("../tests/tls/local.dev_cert.pkcs8.pem");
370        let key = include_str!("../tests/tls/local.dev_key.pkcs8.pem");
371
372        TlsConfigBuilder::new()
373            .key(key.as_bytes())
374            .cert(cert.as_bytes())
375            .build()
376            .unwrap();
377    }
378
379    #[test]
380    fn file_cert_key_sec1_ec() {
381        TlsConfigBuilder::new()
382            .cert_path("tests/tls/local.dev_cert.sec1_ec.pem")
383            .key_path("tests/tls/local.dev_key.sec1_ec.pem")
384            .build()
385            .unwrap();
386    }
387
388    #[test]
389    fn bytes_cert_key_sec1_ec() {
390        let cert = include_str!("../tests/tls/local.dev_cert.sec1_ec.pem");
391        let key = include_str!("../tests/tls/local.dev_key.sec1_ec.pem");
392
393        TlsConfigBuilder::new()
394            .key(key.as_bytes())
395            .cert(cert.as_bytes())
396            .build()
397            .unwrap();
398    }
399
400    #[cfg(feature = "http2-fips")]
401    #[test]
402    fn fips_mode_is_active() {
403        aws_lc_rs::try_fips_mode()
404            .expect("FIPS mode should be active when http2-fips feature is enabled");
405    }
406}