1use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9pub struct TlsConfig {
11 cert: std::path::PathBuf,
12 pk: std::path::PathBuf,
13 config: Arc<Mutex<Arc<rustls::server::ServerConfig>>>,
14}
15
16impl TlsConfig {
17 pub async fn new(
19 cert: &std::path::Path,
20 pk: &std::path::Path,
21 ) -> std::io::Result<Self> {
22 let cert = cert.to_owned();
23 let pk = pk.to_owned();
24 let config = Self::load(&cert, &pk).await?;
25 Ok(Self {
26 cert,
27 pk,
28 config: Arc::new(Mutex::new(config)),
29 })
30 }
31
32 pub fn config(&self) -> Arc<rustls::server::ServerConfig> {
34 self.config.lock().unwrap().clone()
35 }
36
37 #[allow(dead_code)] pub async fn reload(&self) -> std::io::Result<()> {
40 let new_config = Self::load(&self.cert, &self.pk).await?;
41 *self.config.lock().unwrap() = new_config;
42 Ok(())
43 }
44
45 async fn load(
46 cert: &std::path::Path,
47 pk: &std::path::Path,
48 ) -> std::io::Result<Arc<rustls::server::ServerConfig>> {
49 let cert = tokio::fs::read(cert).await?;
50 let pk = tokio::fs::read(pk).await?;
51
52 let mut certs = Vec::new();
53 for cert in rustls_pemfile::certs(&mut std::io::Cursor::new(&cert)) {
54 certs.push(cert?);
55 }
56
57 let pk = rustls_pemfile::private_key(&mut std::io::Cursor::new(&pk))?
58 .ok_or_else(|| {
59 std::io::Error::other("error reading priv key")
60 })?;
61
62 Ok(Arc::new(
63 rustls::server::ServerConfig::builder()
64 .with_no_client_auth()
65 .with_single_cert(certs, pk)
66 .map_err(std::io::Error::other)?,
67 ))
68 }
69}
70
71#[non_exhaustive]
73#[derive(Debug)]
74pub enum MaybeTlsStream {
75 Tcp(tokio::net::TcpStream),
77
78 Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
80}
81
82impl MaybeTlsStream {
83 pub async fn tls(
85 tls_config: &TlsConfig,
86 tcp: tokio::net::TcpStream,
87 ) -> std::io::Result<Self> {
88 let config = tls_config.config();
89
90 let tls = tokio_rustls::TlsAcceptor::from(config).accept(tcp).await?;
91
92 Ok(Self::Tls(tls))
93 }
94}
95
96impl AsyncRead for MaybeTlsStream {
97 fn poll_read(
98 self: Pin<&mut Self>,
99 cx: &mut Context<'_>,
100 buf: &mut ReadBuf<'_>,
101 ) -> Poll<std::io::Result<()>> {
102 match self.get_mut() {
103 MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_read(cx, buf),
104 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
105 }
106 }
107}
108
109impl AsyncWrite for MaybeTlsStream {
110 fn poll_write(
111 self: Pin<&mut Self>,
112 cx: &mut Context<'_>,
113 buf: &[u8],
114 ) -> Poll<Result<usize, std::io::Error>> {
115 match self.get_mut() {
116 MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_write(cx, buf),
117 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
118 }
119 }
120
121 fn poll_flush(
122 self: Pin<&mut Self>,
123 cx: &mut Context<'_>,
124 ) -> Poll<Result<(), std::io::Error>> {
125 match self.get_mut() {
126 MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_flush(cx),
127 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
128 }
129 }
130
131 fn poll_shutdown(
132 self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 ) -> Poll<Result<(), std::io::Error>> {
135 match self.get_mut() {
136 MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_shutdown(cx),
137 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_shutdown(cx),
138 }
139 }
140}