tide_rustls/
tls_listener.rs1use crate::custom_tls_acceptor::StandardTlsAcceptor;
2use crate::{
3 CustomTlsAcceptor, TcpConnection, TlsListenerBuilder, TlsListenerConfig, TlsStreamWrapper,
4};
5
6use tide::listener::ListenInfo;
7use tide::listener::{Listener, ToListener};
8use tide::Server;
9
10use async_std::net::{TcpListener, TcpStream};
11use async_std::prelude::*;
12use async_std::{io, task};
13
14use async_rustls::TlsAcceptor;
15use rustls::internal::pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
16use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
17
18use std::fmt::{self, Debug, Display, Formatter};
19use std::fs::File;
20use std::io::{BufReader, Seek, SeekFrom};
21use std::path::Path;
22use std::sync::Arc;
23use std::time::Duration;
24
25pub struct TlsListener<State> {
27 connection: TcpConnection,
28 config: TlsListenerConfig,
29 server: Option<Server<State>>,
30}
31
32impl<State> Debug for TlsListener<State> {
33 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
34 f.debug_struct("TlsListener")
35 .field(&"connection", &self.connection)
36 .field(&"config", &self.config)
37 .field(
38 &"server",
39 if self.server.is_some() {
40 &"Some(Server<State>)"
41 } else {
42 &"None"
43 },
44 )
45 .finish()
46 }
47}
48
49impl<State> TlsListener<State> {
50 pub(crate) fn new(connection: TcpConnection, config: TlsListenerConfig) -> Self {
51 Self {
52 connection,
53 config,
54 server: None,
55 }
56 }
57 pub fn build() -> TlsListenerBuilder<State> {
72 TlsListenerBuilder::new()
73 }
74
75 async fn configure(&mut self) -> io::Result<()> {
76 self.config = match std::mem::take(&mut self.config) {
77 TlsListenerConfig::Paths { cert, key } => {
78 let certs = load_certs(&cert)?;
79 let mut keys = load_keys(&key)?;
80 let mut config = ServerConfig::new(NoClientAuth::new());
81 config
82 .set_single_cert(certs, keys.remove(0))
83 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
84
85 TlsListenerConfig::Acceptor(Arc::new(StandardTlsAcceptor(TlsAcceptor::from(
86 Arc::new(config),
87 ))))
88 }
89
90 TlsListenerConfig::ServerConfig(config) => TlsListenerConfig::Acceptor(Arc::new(
91 StandardTlsAcceptor(TlsAcceptor::from(Arc::new(config))),
92 )),
93
94 other @ TlsListenerConfig::Acceptor(_) => other,
95
96 TlsListenerConfig::Unconfigured => {
97 return Err(io::Error::new(
98 io::ErrorKind::Other,
99 "could not configure tlslistener",
100 ));
101 }
102 };
103
104 Ok(())
105 }
106
107 fn acceptor(&self) -> Option<&Arc<dyn CustomTlsAcceptor>> {
108 match self.config {
109 TlsListenerConfig::Acceptor(ref a) => Some(a),
110 _ => None,
111 }
112 }
113
114 fn tcp(&self) -> Option<&TcpListener> {
115 match self.connection {
116 TcpConnection::Connected(ref t) => Some(t),
117 _ => None,
118 }
119 }
120
121 async fn connect(&mut self) -> io::Result<()> {
122 if let TcpConnection::Addrs(addrs) = &self.connection {
123 let tcp = TcpListener::bind(&addrs[..]).await?;
124 self.connection = TcpConnection::Connected(tcp);
125 }
126 Ok(())
127 }
128}
129
130fn handle_tls<State: Clone + Send + Sync + 'static>(
131 app: Server<State>,
132 stream: TcpStream,
133 acceptor: Arc<dyn CustomTlsAcceptor>,
134) {
135 task::spawn(async move {
136 let local_addr = stream.local_addr().ok();
137 let peer_addr = stream.peer_addr().ok();
138
139 match acceptor.accept(stream).await {
140 Ok(None) => {}
141
142 Ok(Some(tls_stream)) => {
143 let stream = TlsStreamWrapper::new(tls_stream);
144 let fut = async_h1::accept(stream, |mut req| async {
145 if req.url_mut().set_scheme("https").is_err() {
146 tide::log::error!("unable to set https scheme on url", { url: req.url().to_string() });
147 }
148
149 req.set_local_addr(local_addr);
150 req.set_peer_addr(peer_addr);
151 app.respond(req).await
152 });
153
154 if let Err(error) = fut.await {
155 tide::log::error!("async-h1 error", { error: error.to_string() });
156 }
157 }
158
159 Err(tls_error) => {
160 tide::log::error!("tls error", { error: tls_error.to_string() });
161 }
162 }
163 });
164}
165
166impl<State: Clone + Send + Sync + 'static> ToListener<State> for TlsListener<State> {
167 type Listener = Self;
168 fn to_listener(self) -> io::Result<Self::Listener> {
169 Ok(self)
170 }
171}
172
173impl<State: Clone + Send + Sync + 'static> ToListener<State> for TlsListenerBuilder<State> {
174 type Listener = TlsListener<State>;
175 fn to_listener(self) -> io::Result<Self::Listener> {
176 self.finish()
177 }
178}
179
180#[tide::utils::async_trait]
181impl<State: Clone + Send + Sync + 'static> Listener<State> for TlsListener<State> {
182 async fn bind(&mut self, server: Server<State>) -> io::Result<()> {
183 self.configure().await?;
184 self.connect().await?;
185 self.server = Some(server);
186 Ok(())
187 }
188
189 async fn accept(&mut self) -> io::Result<()> {
190 let listener = self.tcp().unwrap();
191 let mut incoming = listener.incoming();
192 let acceptor = self.acceptor().unwrap();
193 let server = self.server.as_ref().unwrap();
194
195 while let Some(stream) = incoming.next().await {
196 match stream {
197 Err(ref e) if is_transient_error(e) => continue,
198
199 Err(error) => {
200 let delay = Duration::from_millis(500);
201 tide::log::error!("Error: {}. Pausing for {:?}.", error, delay);
202 task::sleep(delay).await;
203 continue;
204 }
205
206 Ok(stream) => handle_tls(server.clone(), stream, acceptor.clone()),
207 };
208 }
209 Ok(())
210 }
211
212 fn info(&self) -> Vec<ListenInfo> {
213 vec![ListenInfo::new(
214 self.connection.to_string(),
215 String::from("tcp"),
216 true,
217 )]
218 }
219}
220
221fn is_transient_error(e: &io::Error) -> bool {
222 use io::ErrorKind::*;
223 matches!(
224 e.kind(),
225 ConnectionRefused | ConnectionAborted | ConnectionReset
226 )
227}
228
229impl<State> Display for TlsListener<State> {
230 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
231 write!(f, "{}", self.connection)
232 }
233}
234
235fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
236 certs(&mut BufReader::new(File::open(path)?))
237 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
238}
239
240fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
241 let mut bufreader = BufReader::new(File::open(path)?);
242 if let Ok(pkcs8) = pkcs8_private_keys(&mut bufreader) {
243 if !pkcs8.is_empty() {
244 return Ok(pkcs8);
245 }
246 }
247
248 bufreader.seek(SeekFrom::Start(0))?;
249
250 if let Ok(rsa) = rsa_private_keys(&mut bufreader) {
251 if !rsa.is_empty() {
252 return Ok(rsa);
253 }
254 }
255
256 Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
257}