watermelon_mini/proto/connection/
compression.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9#[cfg(feature = "non-standard-zstd")]
10use crate::non_standard_zstd::ZstdStream;
11
12#[derive(Debug)]
13pub enum ConnectionCompression<S> {
14    Plain(S),
15    #[cfg(feature = "non-standard-zstd")]
16    Zstd(ZstdStream<S>),
17}
18
19impl<S> ConnectionCompression<S>
20where
21    S: AsyncRead + AsyncWrite + Unpin,
22{
23    #[cfg(feature = "non-standard-zstd")]
24    pub(crate) fn upgrade_zstd(self, compression_level: u8) -> Self {
25        let Self::Plain(socket) = self else {
26            unreachable!()
27        };
28
29        Self::Zstd(ZstdStream::new(socket, compression_level))
30    }
31
32    #[cfg(feature = "non-standard-zstd")]
33    pub fn is_zstd_compressed(&self) -> bool {
34        matches!(self, Self::Zstd(_))
35    }
36}
37
38impl<S> AsyncRead for ConnectionCompression<S>
39where
40    S: AsyncRead + AsyncWrite + Unpin,
41{
42    fn poll_read(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45        buf: &mut ReadBuf<'_>,
46    ) -> Poll<io::Result<()>> {
47        match self.get_mut() {
48            Self::Plain(conn) => Pin::new(conn).poll_read(cx, buf),
49            #[cfg(feature = "non-standard-zstd")]
50            Self::Zstd(conn) => Pin::new(conn).poll_read(cx, buf),
51        }
52    }
53}
54
55impl<S> AsyncWrite for ConnectionCompression<S>
56where
57    S: AsyncRead + AsyncWrite + Unpin,
58{
59    fn poll_write(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62        buf: &[u8],
63    ) -> Poll<io::Result<usize>> {
64        match self.get_mut() {
65            Self::Plain(conn) => Pin::new(conn).poll_write(cx, buf),
66            #[cfg(feature = "non-standard-zstd")]
67            Self::Zstd(conn) => Pin::new(conn).poll_write(cx, buf),
68        }
69    }
70
71    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72        match self.get_mut() {
73            Self::Plain(conn) => Pin::new(conn).poll_flush(cx),
74            #[cfg(feature = "non-standard-zstd")]
75            Self::Zstd(conn) => Pin::new(conn).poll_flush(cx),
76        }
77    }
78
79    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
80        match self.get_mut() {
81            Self::Plain(conn) => Pin::new(conn).poll_shutdown(cx),
82            #[cfg(feature = "non-standard-zstd")]
83            Self::Zstd(conn) => Pin::new(conn).poll_shutdown(cx),
84        }
85    }
86
87    fn poll_write_vectored(
88        self: Pin<&mut Self>,
89        cx: &mut Context<'_>,
90        bufs: &[io::IoSlice<'_>],
91    ) -> Poll<io::Result<usize>> {
92        match self.get_mut() {
93            Self::Plain(conn) => Pin::new(conn).poll_write_vectored(cx, bufs),
94            #[cfg(feature = "non-standard-zstd")]
95            Self::Zstd(conn) => Pin::new(conn).poll_write_vectored(cx, bufs),
96        }
97    }
98
99    fn is_write_vectored(&self) -> bool {
100        match self {
101            Self::Plain(conn) => conn.is_write_vectored(),
102            #[cfg(feature = "non-standard-zstd")]
103            Self::Zstd(conn) => conn.is_write_vectored(),
104        }
105    }
106}