watermelon_mini/proto/connection/
security.rs1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8use tokio_rustls::{TlsConnector, client::TlsStream, rustls::pki_types::ServerName};
9
10#[derive(Debug)]
11#[expect(
12 clippy::large_enum_variant,
13 reason = "using TLS is the recommended thing, we do not want to affect it"
14)]
15pub enum ConnectionSecurity<S> {
16 Plain(S),
17 Tls(TlsStream<S>),
18}
19
20impl<S> ConnectionSecurity<S>
21where
22 S: AsyncRead + AsyncWrite + Unpin,
23{
24 pub(crate) async fn upgrade_tls(
25 self,
26 connector: &TlsConnector,
27 domain: ServerName<'static>,
28 ) -> io::Result<Self> {
29 let conn = match self {
30 Self::Plain(conn) => conn,
31 Self::Tls(_) => unreachable!("trying to upgrade to Tls a Tls connection"),
32 };
33
34 let conn = connector.connect(domain, conn).await?;
35 Ok(Self::Tls(conn))
36 }
37}
38
39impl<S> AsyncRead for ConnectionSecurity<S>
40where
41 S: AsyncRead + AsyncWrite + Unpin,
42{
43 fn poll_read(
44 self: Pin<&mut Self>,
45 cx: &mut Context<'_>,
46 buf: &mut ReadBuf<'_>,
47 ) -> Poll<io::Result<()>> {
48 match self.get_mut() {
49 Self::Plain(conn) => Pin::new(conn).poll_read(cx, buf),
50 Self::Tls(conn) => Pin::new(conn).poll_read(cx, buf),
51 }
52 }
53}
54
55impl<S> AsyncWrite for ConnectionSecurity<S>
56where
57 S: AsyncRead + AsyncWrite + Unpin,
58{
59 fn poll_write(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 buf: &[u8],
63 ) -> Poll<io::Result<usize>> {
64 match self.get_mut() {
65 Self::Plain(conn) => Pin::new(conn).poll_write(cx, buf),
66 Self::Tls(conn) => Pin::new(conn).poll_write(cx, buf),
67 }
68 }
69
70 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71 match self.get_mut() {
72 Self::Plain(conn) => Pin::new(conn).poll_flush(cx),
73 Self::Tls(conn) => Pin::new(conn).poll_flush(cx),
74 }
75 }
76
77 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
78 match self.get_mut() {
79 Self::Plain(conn) => Pin::new(conn).poll_shutdown(cx),
80 Self::Tls(conn) => Pin::new(conn).poll_shutdown(cx),
81 }
82 }
83
84 fn poll_write_vectored(
85 self: Pin<&mut Self>,
86 cx: &mut Context<'_>,
87 bufs: &[io::IoSlice<'_>],
88 ) -> Poll<io::Result<usize>> {
89 match self.get_mut() {
90 Self::Plain(conn) => Pin::new(conn).poll_write_vectored(cx, bufs),
91 Self::Tls(conn) => Pin::new(conn).poll_write_vectored(cx, bufs),
92 }
93 }
94
95 fn is_write_vectored(&self) -> bool {
96 match self {
97 Self::Plain(conn) => conn.is_write_vectored(),
98 Self::Tls(conn) => conn.is_write_vectored(),
99 }
100 }
101}