tide_rustls/
tls_listener.rs

1use crate::custom_tls_acceptor::StandardTlsAcceptor;
2use crate::{
3    CustomTlsAcceptor, TcpConnection, TlsListenerBuilder, TlsListenerConfig, TlsStreamWrapper,
4};
5
6use tide::listener::ListenInfo;
7use tide::listener::{Listener, ToListener};
8use tide::Server;
9
10use async_std::net::{TcpListener, TcpStream};
11use async_std::prelude::*;
12use async_std::{io, task};
13
14use async_rustls::TlsAcceptor;
15use rustls::internal::pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
16use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
17
18use std::fmt::{self, Debug, Display, Formatter};
19use std::fs::File;
20use std::io::{BufReader, Seek, SeekFrom};
21use std::path::Path;
22use std::sync::Arc;
23use std::time::Duration;
24
25/// The primary type for this crate
26pub struct TlsListener<State> {
27    connection: TcpConnection,
28    config: TlsListenerConfig,
29    server: Option<Server<State>>,
30}
31
32impl<State> Debug for TlsListener<State> {
33    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
34        f.debug_struct("TlsListener")
35            .field(&"connection", &self.connection)
36            .field(&"config", &self.config)
37            .field(
38                &"server",
39                if self.server.is_some() {
40                    &"Some(Server<State>)"
41                } else {
42                    &"None"
43                },
44            )
45            .finish()
46    }
47}
48
49impl<State> TlsListener<State> {
50    pub(crate) fn new(connection: TcpConnection, config: TlsListenerConfig) -> Self {
51        Self {
52            connection,
53            config,
54            server: None,
55        }
56    }
57    /// The primary entrypoint to create a TlsListener. See
58    /// [TlsListenerBuilder](crate::TlsListenerBuilder) for more
59    /// configuration options.
60    ///
61    /// # Example
62    ///
63    /// ```rust
64    /// # use tide_rustls::TlsListener;
65    /// let listener = TlsListener::<()>::build()
66    ///     .addrs("localhost:4433")
67    ///     .cert("./tls/localhost-4433.cert")
68    ///     .key("./tls/localhost-4433.key")
69    ///     .finish();
70    /// ```
71    pub fn build() -> TlsListenerBuilder<State> {
72        TlsListenerBuilder::new()
73    }
74
75    async fn configure(&mut self) -> io::Result<()> {
76        self.config = match std::mem::take(&mut self.config) {
77            TlsListenerConfig::Paths { cert, key } => {
78                let certs = load_certs(&cert)?;
79                let mut keys = load_keys(&key)?;
80                let mut config = ServerConfig::new(NoClientAuth::new());
81                config
82                    .set_single_cert(certs, keys.remove(0))
83                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
84
85                TlsListenerConfig::Acceptor(Arc::new(StandardTlsAcceptor(TlsAcceptor::from(
86                    Arc::new(config),
87                ))))
88            }
89
90            TlsListenerConfig::ServerConfig(config) => TlsListenerConfig::Acceptor(Arc::new(
91                StandardTlsAcceptor(TlsAcceptor::from(Arc::new(config))),
92            )),
93
94            other @ TlsListenerConfig::Acceptor(_) => other,
95
96            TlsListenerConfig::Unconfigured => {
97                return Err(io::Error::new(
98                    io::ErrorKind::Other,
99                    "could not configure tlslistener",
100                ));
101            }
102        };
103
104        Ok(())
105    }
106
107    fn acceptor(&self) -> Option<&Arc<dyn CustomTlsAcceptor>> {
108        match self.config {
109            TlsListenerConfig::Acceptor(ref a) => Some(a),
110            _ => None,
111        }
112    }
113
114    fn tcp(&self) -> Option<&TcpListener> {
115        match self.connection {
116            TcpConnection::Connected(ref t) => Some(t),
117            _ => None,
118        }
119    }
120
121    async fn connect(&mut self) -> io::Result<()> {
122        if let TcpConnection::Addrs(addrs) = &self.connection {
123            let tcp = TcpListener::bind(&addrs[..]).await?;
124            self.connection = TcpConnection::Connected(tcp);
125        }
126        Ok(())
127    }
128}
129
130fn handle_tls<State: Clone + Send + Sync + 'static>(
131    app: Server<State>,
132    stream: TcpStream,
133    acceptor: Arc<dyn CustomTlsAcceptor>,
134) {
135    task::spawn(async move {
136        let local_addr = stream.local_addr().ok();
137        let peer_addr = stream.peer_addr().ok();
138
139        match acceptor.accept(stream).await {
140            Ok(None) => {}
141
142            Ok(Some(tls_stream)) => {
143                let stream = TlsStreamWrapper::new(tls_stream);
144                let fut = async_h1::accept(stream, |mut req| async {
145                    if req.url_mut().set_scheme("https").is_err() {
146                        tide::log::error!("unable to set https scheme on url", { url: req.url().to_string() });
147                    }
148
149                    req.set_local_addr(local_addr);
150                    req.set_peer_addr(peer_addr);
151                    app.respond(req).await
152                });
153
154                if let Err(error) = fut.await {
155                    tide::log::error!("async-h1 error", { error: error.to_string() });
156                }
157            }
158
159            Err(tls_error) => {
160                tide::log::error!("tls error", { error: tls_error.to_string() });
161            }
162        }
163    });
164}
165
166impl<State: Clone + Send + Sync + 'static> ToListener<State> for TlsListener<State> {
167    type Listener = Self;
168    fn to_listener(self) -> io::Result<Self::Listener> {
169        Ok(self)
170    }
171}
172
173impl<State: Clone + Send + Sync + 'static> ToListener<State> for TlsListenerBuilder<State> {
174    type Listener = TlsListener<State>;
175    fn to_listener(self) -> io::Result<Self::Listener> {
176        self.finish()
177    }
178}
179
180#[tide::utils::async_trait]
181impl<State: Clone + Send + Sync + 'static> Listener<State> for TlsListener<State> {
182    async fn bind(&mut self, server: Server<State>) -> io::Result<()> {
183        self.configure().await?;
184        self.connect().await?;
185        self.server = Some(server);
186        Ok(())
187    }
188
189    async fn accept(&mut self) -> io::Result<()> {
190        let listener = self.tcp().unwrap();
191        let mut incoming = listener.incoming();
192        let acceptor = self.acceptor().unwrap();
193        let server = self.server.as_ref().unwrap();
194
195        while let Some(stream) = incoming.next().await {
196            match stream {
197                Err(ref e) if is_transient_error(e) => continue,
198
199                Err(error) => {
200                    let delay = Duration::from_millis(500);
201                    tide::log::error!("Error: {}. Pausing for {:?}.", error, delay);
202                    task::sleep(delay).await;
203                    continue;
204                }
205
206                Ok(stream) => handle_tls(server.clone(), stream, acceptor.clone()),
207            };
208        }
209        Ok(())
210    }
211
212    fn info(&self) -> Vec<ListenInfo> {
213        vec![ListenInfo::new(
214            self.connection.to_string(),
215            String::from("tcp"),
216            true,
217        )]
218    }
219}
220
221fn is_transient_error(e: &io::Error) -> bool {
222    use io::ErrorKind::*;
223    matches!(
224        e.kind(),
225        ConnectionRefused | ConnectionAborted | ConnectionReset
226    )
227}
228
229impl<State> Display for TlsListener<State> {
230    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
231        write!(f, "{}", self.connection)
232    }
233}
234
235fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
236    certs(&mut BufReader::new(File::open(path)?))
237        .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
238}
239
240fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
241    let mut bufreader = BufReader::new(File::open(path)?);
242    if let Ok(pkcs8) = pkcs8_private_keys(&mut bufreader) {
243        if !pkcs8.is_empty() {
244            return Ok(pkcs8);
245        }
246    }
247
248    bufreader.seek(SeekFrom::Start(0))?;
249
250    if let Ok(rsa) = rsa_private_keys(&mut bufreader) {
251        if !rsa.is_empty() {
252            return Ok(rsa);
253        }
254    }
255
256    Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
257}