1use futures_util::ready;
12use hyper::server::accept::Accept;
13use hyper::server::conn::{AddrIncoming, AddrStream};
14use rustls_pki_types::pem::PemObject;
15use rustls_pki_types::{CertificateDer, PrivateKeyDer};
16use std::fs::File;
17use std::future::Future;
18use std::io::{self, BufReader, Cursor, Read};
19use std::net::SocketAddr;
20use std::path::{Path, PathBuf};
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
25use tokio_rustls::rustls::{Error as TlsError, ServerConfig};
26
27use crate::transport::Transport;
28
29#[derive(Debug)]
31pub enum TlsConfigError {
32 Io(io::Error),
34 CertParseError,
36 InvalidIdentityPem,
38 EmptyKey,
40 UnknownPrivateKeyFormat,
42 InvalidKey(TlsError),
44 IllegalSectionStart(Vec<u8>),
46 IllegalSectionEnd(Vec<u8>),
48}
49
50impl std::fmt::Display for TlsConfigError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 TlsConfigError::Io(err) => err.fmt(f),
54 TlsConfigError::CertParseError => write!(f, "failed to parse certificate"),
55 TlsConfigError::InvalidIdentityPem => write!(f, "the identity PEM provided is invalid"),
56 TlsConfigError::UnknownPrivateKeyFormat => {
57 write!(f, "the private key format is unknown")
58 }
59 TlsConfigError::EmptyKey => write!(f, "the key provided is probably missing or empty"),
60 TlsConfigError::InvalidKey(err) => write!(f, "the key provided is invalid, {err}"),
61 TlsConfigError::IllegalSectionStart(line) => {
62 let line = String::from_utf8(line.clone()).unwrap_or_default();
63 write!(f, "illegal section start in PEM at '{line}'")
64 }
65 TlsConfigError::IllegalSectionEnd(end_marker) => {
66 let end_marker = String::from_utf8(end_marker.clone()).unwrap_or_default();
67 write!(f, "illegal section end in PEM at '{end_marker}'")
68 }
69 }
70 }
71}
72
73impl std::error::Error for TlsConfigError {}
74
75pub struct TlsConfigBuilder {
77 cert: Box<dyn Read + Send + Sync>,
78 key: Box<dyn Read + Send + Sync>,
79}
80
81impl std::fmt::Debug for TlsConfigBuilder {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
83 f.debug_struct("TlsConfigBuilder").finish()
84 }
85}
86
87impl TlsConfigBuilder {
88 pub fn new() -> TlsConfigBuilder {
90 TlsConfigBuilder {
91 key: Box::new(io::empty()),
92 cert: Box::new(io::empty()),
93 }
94 }
95
96 pub fn key_path(mut self, path: impl AsRef<Path>) -> Self {
98 self.key = Box::new(LazyFile {
99 path: path.as_ref().into(),
100 file: None,
101 });
102 self
103 }
104
105 pub fn key(mut self, key: &[u8]) -> Self {
107 self.key = Box::new(Cursor::new(Vec::from(key)));
108 self
109 }
110
111 pub fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
113 self.cert = Box::new(LazyFile {
114 path: path.as_ref().into(),
115 file: None,
116 });
117 self
118 }
119
120 pub fn cert(mut self, cert: &[u8]) -> Self {
122 self.cert = Box::new(Cursor::new(Vec::from(cert)));
123 self
124 }
125
126 pub fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
128 let mut cert_rdr = BufReader::new(self.cert);
129 let cert = CertificateDer::pem_reader_iter(&mut cert_rdr)
130 .collect::<Result<Vec<_>, _>>()
131 .map_err(|_e| TlsConfigError::CertParseError)?;
132
133 let mut key_buf = Vec::new();
135 self.key
136 .read_to_end(&mut key_buf)
137 .map_err(TlsConfigError::Io)?;
138
139 if key_buf.is_empty() {
140 return Err(TlsConfigError::EmptyKey);
141 }
142
143 let reader = Cursor::new(key_buf);
144 let key = PrivateKeyDer::from_pem_reader(reader).map_err(|err| match err {
145 rustls_pki_types::pem::Error::Base64Decode(_) => TlsConfigError::InvalidIdentityPem,
146 rustls_pki_types::pem::Error::NoItemsFound => TlsConfigError::EmptyKey,
147 rustls_pki_types::pem::Error::IllegalSectionStart { line } => {
148 TlsConfigError::IllegalSectionStart(line)
149 }
150 rustls_pki_types::pem::Error::MissingSectionEnd { end_marker } => {
151 TlsConfigError::IllegalSectionEnd(end_marker)
152 }
153 rustls_pki_types::pem::Error::Io(err) => {
154 TlsConfigError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
155 }
156 _ => TlsConfigError::InvalidIdentityPem,
157 })?;
158
159 let mut config = ServerConfig::builder()
160 .with_no_client_auth()
161 .with_single_cert(cert, key)
162 .map_err(TlsConfigError::InvalidKey)?;
163 config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
164 Ok(config)
165 }
166}
167
168impl Default for TlsConfigBuilder {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174struct LazyFile {
175 path: PathBuf,
176 file: Option<File>,
177}
178
179impl LazyFile {
180 fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
181 if self.file.is_none() {
182 self.file = Some(File::open(&self.path)?);
183 }
184
185 self.file.as_mut().unwrap().read(buf)
186 }
187}
188
189impl Read for LazyFile {
190 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
191 self.lazy_read(buf).map_err(|err| {
192 let kind = err.kind();
193 io::Error::new(
194 kind,
195 format!("error reading file ({:?}): {}", self.path.display(), err),
196 )
197 })
198 }
199}
200
201impl Transport for TlsStream {
202 fn remote_addr(&self) -> Option<SocketAddr> {
203 Some(self.remote_addr)
204 }
205}
206
207enum State {
208 Handshaking(tokio_rustls::Accept<AddrStream>),
209 Streaming(tokio_rustls::server::TlsStream<AddrStream>),
210}
211
212pub struct TlsStream {
217 state: State,
218 remote_addr: SocketAddr,
219}
220
221impl TlsStream {
222 fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
223 let remote_addr = stream.remote_addr();
224 let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
225 TlsStream {
226 state: State::Handshaking(accept),
227 remote_addr,
228 }
229 }
230}
231
232impl AsyncRead for TlsStream {
233 fn poll_read(
234 self: Pin<&mut Self>,
235 cx: &mut Context<'_>,
236 buf: &mut ReadBuf<'_>,
237 ) -> Poll<io::Result<()>> {
238 let pin = self.get_mut();
239 match pin.state {
240 State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
241 Ok(mut stream) => {
242 let result = Pin::new(&mut stream).poll_read(cx, buf);
243 pin.state = State::Streaming(stream);
244 result
245 }
246 Err(err) => Poll::Ready(Err(err)),
247 },
248 State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
249 }
250 }
251}
252
253impl AsyncWrite for TlsStream {
254 fn poll_write(
255 self: Pin<&mut Self>,
256 cx: &mut Context<'_>,
257 buf: &[u8],
258 ) -> Poll<io::Result<usize>> {
259 let pin = self.get_mut();
260 match pin.state {
261 State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
262 Ok(mut stream) => {
263 let result = Pin::new(&mut stream).poll_write(cx, buf);
264 pin.state = State::Streaming(stream);
265 result
266 }
267 Err(err) => Poll::Ready(Err(err)),
268 },
269 State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
270 }
271 }
272
273 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
274 match self.state {
275 State::Handshaking(_) => Poll::Ready(Ok(())),
276 State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
277 }
278 }
279
280 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
281 match self.state {
282 State::Handshaking(_) => Poll::Ready(Ok(())),
283 State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
284 }
285 }
286}
287
288pub struct TlsAcceptor {
290 config: Arc<ServerConfig>,
291 incoming: AddrIncoming,
292}
293
294impl TlsAcceptor {
295 pub fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
297 TlsAcceptor {
298 config: Arc::new(config),
299 incoming,
300 }
301 }
302}
303
304impl Accept for TlsAcceptor {
305 type Conn = TlsStream;
306 type Error = io::Error;
307
308 fn poll_accept(
309 self: Pin<&mut Self>,
310 cx: &mut Context<'_>,
311 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
312 let pin = self.get_mut();
313 match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
314 Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
315 Some(Err(e)) => Poll::Ready(Some(Err(e))),
316 None => Poll::Ready(None),
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn file_cert_key_rsa_pkcs1() {
327 TlsConfigBuilder::new()
328 .cert_path("tests/tls/local.dev_cert.rsa_pkcs1.pem")
329 .key_path("tests/tls/local.dev_key.rsa_pkcs1.pem")
330 .build()
331 .unwrap();
332 }
333
334 #[test]
335 fn bytes_cert_key_rsa_pkcs1() {
336 let cert = include_str!("../tests/tls/local.dev_cert.rsa_pkcs1.pem");
337 let key = include_str!("../tests/tls/local.dev_key.rsa_pkcs1.pem");
338
339 TlsConfigBuilder::new()
340 .key(key.as_bytes())
341 .cert(cert.as_bytes())
342 .build()
343 .unwrap();
344 }
345
346 #[test]
347 fn file_cert_key_pkcs8() {
348 TlsConfigBuilder::new()
349 .cert_path("tests/tls/local.dev_cert.pkcs8.pem")
350 .key_path("tests/tls/local.dev_key.pkcs8.pem")
351 .build()
352 .unwrap();
353 }
354
355 #[test]
356 fn bytes_cert_key_pkcs8() {
357 let cert = include_str!("../tests/tls/local.dev_cert.pkcs8.pem");
358 let key = include_str!("../tests/tls/local.dev_key.pkcs8.pem");
359
360 TlsConfigBuilder::new()
361 .key(key.as_bytes())
362 .cert(cert.as_bytes())
363 .build()
364 .unwrap();
365 }
366
367 #[test]
368 fn file_cert_key_sec1_ec() {
369 TlsConfigBuilder::new()
370 .cert_path("tests/tls/local.dev_cert.sec1_ec.pem")
371 .key_path("tests/tls/local.dev_key.sec1_ec.pem")
372 .build()
373 .unwrap();
374 }
375
376 #[test]
377 fn bytes_cert_key_sec1_ec() {
378 let cert = include_str!("../tests/tls/local.dev_cert.sec1_ec.pem");
379 let key = include_str!("../tests/tls/local.dev_key.sec1_ec.pem");
380
381 TlsConfigBuilder::new()
382 .key(key.as_bytes())
383 .cert(cert.as_bytes())
384 .build()
385 .unwrap();
386 }
387}