watermelon_mini/proto/connection/
security.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8use tokio_rustls::{TlsConnector, client::TlsStream, rustls::pki_types::ServerName};
9
10#[derive(Debug)]
11#[expect(
12    clippy::large_enum_variant,
13    reason = "using TLS is the recommended thing, we do not want to affect it"
14)]
15pub enum ConnectionSecurity<S> {
16    Plain(S),
17    Tls(TlsStream<S>),
18}
19
20impl<S> ConnectionSecurity<S>
21where
22    S: AsyncRead + AsyncWrite + Unpin,
23{
24    pub(crate) async fn upgrade_tls(
25        self,
26        connector: &TlsConnector,
27        domain: ServerName<'static>,
28    ) -> io::Result<Self> {
29        let conn = match self {
30            Self::Plain(conn) => conn,
31            Self::Tls(_) => unreachable!("trying to upgrade to Tls a Tls connection"),
32        };
33
34        let conn = connector.connect(domain, conn).await?;
35        Ok(Self::Tls(conn))
36    }
37}
38
39impl<S> AsyncRead for ConnectionSecurity<S>
40where
41    S: AsyncRead + AsyncWrite + Unpin,
42{
43    fn poll_read(
44        self: Pin<&mut Self>,
45        cx: &mut Context<'_>,
46        buf: &mut ReadBuf<'_>,
47    ) -> Poll<io::Result<()>> {
48        match self.get_mut() {
49            Self::Plain(conn) => Pin::new(conn).poll_read(cx, buf),
50            Self::Tls(conn) => Pin::new(conn).poll_read(cx, buf),
51        }
52    }
53}
54
55impl<S> AsyncWrite for ConnectionSecurity<S>
56where
57    S: AsyncRead + AsyncWrite + Unpin,
58{
59    fn poll_write(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62        buf: &[u8],
63    ) -> Poll<io::Result<usize>> {
64        match self.get_mut() {
65            Self::Plain(conn) => Pin::new(conn).poll_write(cx, buf),
66            Self::Tls(conn) => Pin::new(conn).poll_write(cx, buf),
67        }
68    }
69
70    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71        match self.get_mut() {
72            Self::Plain(conn) => Pin::new(conn).poll_flush(cx),
73            Self::Tls(conn) => Pin::new(conn).poll_flush(cx),
74        }
75    }
76
77    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
78        match self.get_mut() {
79            Self::Plain(conn) => Pin::new(conn).poll_shutdown(cx),
80            Self::Tls(conn) => Pin::new(conn).poll_shutdown(cx),
81        }
82    }
83
84    fn poll_write_vectored(
85        self: Pin<&mut Self>,
86        cx: &mut Context<'_>,
87        bufs: &[io::IoSlice<'_>],
88    ) -> Poll<io::Result<usize>> {
89        match self.get_mut() {
90            Self::Plain(conn) => Pin::new(conn).poll_write_vectored(cx, bufs),
91            Self::Tls(conn) => Pin::new(conn).poll_write_vectored(cx, bufs),
92        }
93    }
94
95    fn is_write_vectored(&self) -> bool {
96        match self {
97            Self::Plain(conn) => conn.is_write_vectored(),
98            Self::Tls(conn) => conn.is_write_vectored(),
99        }
100    }
101}