sbd_server/
maybe_tls.rs

1//! taken and altered from tokio_tungstenite
2
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9/// RustTLS config plus cert and pk paths.
10pub struct TlsConfig {
11    cert: std::path::PathBuf,
12    pk: std::path::PathBuf,
13    config: Arc<Mutex<Arc<rustls::server::ServerConfig>>>,
14}
15
16impl TlsConfig {
17    /// Load a new TlsConfig from a cert and pk path.
18    pub async fn new(
19        cert: &std::path::Path,
20        pk: &std::path::Path,
21    ) -> std::io::Result<Self> {
22        let cert = cert.to_owned();
23        let pk = pk.to_owned();
24        let config = Self::load(&cert, &pk).await?;
25        Ok(Self {
26            cert,
27            pk,
28            config: Arc::new(Mutex::new(config)),
29        })
30    }
31
32    /// Get the current rustls::server::ServerConfig.
33    pub fn config(&self) -> Arc<rustls::server::ServerConfig> {
34        self.config.lock().unwrap().clone()
35    }
36
37    /// Reload the cert and pk.
38    #[allow(dead_code)] // watch reload tls
39    pub async fn reload(&self) -> std::io::Result<()> {
40        let new_config = Self::load(&self.cert, &self.pk).await?;
41        *self.config.lock().unwrap() = new_config;
42        Ok(())
43    }
44
45    async fn load(
46        cert: &std::path::Path,
47        pk: &std::path::Path,
48    ) -> std::io::Result<Arc<rustls::server::ServerConfig>> {
49        let cert = tokio::fs::read(cert).await?;
50        let pk = tokio::fs::read(pk).await?;
51
52        let mut certs = Vec::new();
53        for cert in rustls_pemfile::certs(&mut std::io::Cursor::new(&cert)) {
54            certs.push(cert?);
55        }
56
57        let pk = rustls_pemfile::private_key(&mut std::io::Cursor::new(&pk))?
58            .ok_or_else(|| {
59            std::io::Error::other("error reading priv key")
60        })?;
61
62        Ok(Arc::new(
63            rustls::server::ServerConfig::builder()
64                .with_no_client_auth()
65                .with_single_cert(certs, pk)
66                .map_err(std::io::Error::other)?,
67        ))
68    }
69}
70
71/// A stream that might be protected with TLS.
72#[non_exhaustive]
73#[derive(Debug)]
74pub enum MaybeTlsStream {
75    /// Tcp.
76    Tcp(tokio::net::TcpStream),
77
78    /// Tls.
79    Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
80}
81
82impl MaybeTlsStream {
83    /// Wrap a TcpStream in a MaybeTlsStream, configuring TLS
84    pub async fn tls(
85        tls_config: &TlsConfig,
86        tcp: tokio::net::TcpStream,
87    ) -> std::io::Result<Self> {
88        let config = tls_config.config();
89
90        let tls = tokio_rustls::TlsAcceptor::from(config).accept(tcp).await?;
91
92        Ok(Self::Tls(tls))
93    }
94}
95
96impl AsyncRead for MaybeTlsStream {
97    fn poll_read(
98        self: Pin<&mut Self>,
99        cx: &mut Context<'_>,
100        buf: &mut ReadBuf<'_>,
101    ) -> Poll<std::io::Result<()>> {
102        match self.get_mut() {
103            MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_read(cx, buf),
104            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
105        }
106    }
107}
108
109impl AsyncWrite for MaybeTlsStream {
110    fn poll_write(
111        self: Pin<&mut Self>,
112        cx: &mut Context<'_>,
113        buf: &[u8],
114    ) -> Poll<Result<usize, std::io::Error>> {
115        match self.get_mut() {
116            MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_write(cx, buf),
117            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
118        }
119    }
120
121    fn poll_flush(
122        self: Pin<&mut Self>,
123        cx: &mut Context<'_>,
124    ) -> Poll<Result<(), std::io::Error>> {
125        match self.get_mut() {
126            MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_flush(cx),
127            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
128        }
129    }
130
131    fn poll_shutdown(
132        self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134    ) -> Poll<Result<(), std::io::Error>> {
135        match self.get_mut() {
136            MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_shutdown(cx),
137            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_shutdown(cx),
138        }
139    }
140}