watermelon_mini/
non_standard_zstd.rs

1use std::{
2    fmt::{self, Debug, Formatter},
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use async_compression::{
9    Level,
10    tokio::{bufread::ZstdDecoder, write::ZstdEncoder},
11};
12use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
13
14use crate::util::MaybeConnection;
15
16pub struct ZstdStream<S> {
17    decoder: ZstdDecoder<BufReader<MaybeConnection<S>>>,
18    encoder: ZstdEncoder<MaybeConnection<S>>,
19}
20
21impl<S> ZstdStream<S>
22where
23    S: AsyncRead + AsyncWrite + Unpin,
24{
25    #[must_use]
26    pub fn new(stream: S, compression_level: u8) -> Self {
27        Self {
28            decoder: ZstdDecoder::new(BufReader::new(MaybeConnection(Some(stream)))),
29            encoder: ZstdEncoder::with_quality(
30                MaybeConnection(None),
31                Level::Precise(compression_level.into()),
32            ),
33        }
34    }
35}
36
37impl<S> AsyncRead for ZstdStream<S>
38where
39    S: AsyncRead + AsyncWrite + Unpin,
40{
41    fn poll_read(
42        mut self: Pin<&mut Self>,
43        cx: &mut Context<'_>,
44        buf: &mut ReadBuf<'_>,
45    ) -> Poll<io::Result<()>> {
46        if let Some(stream) = self.encoder.get_mut().0.take() {
47            self.decoder.get_mut().get_mut().0 = Some(stream);
48        }
49
50        Pin::new(&mut self.decoder).poll_read(cx, buf)
51    }
52}
53
54impl<S> AsyncWrite for ZstdStream<S>
55where
56    S: AsyncRead + AsyncWrite + Unpin,
57{
58    fn poll_write(
59        mut self: Pin<&mut Self>,
60        cx: &mut Context<'_>,
61        buf: &[u8],
62    ) -> Poll<io::Result<usize>> {
63        if let Some(stream) = self.decoder.get_mut().get_mut().0.take() {
64            self.encoder.get_mut().0 = Some(stream);
65        }
66
67        Pin::new(&mut self.encoder).poll_write(cx, buf)
68    }
69
70    fn poll_write_vectored(
71        mut self: Pin<&mut Self>,
72        cx: &mut Context<'_>,
73        bufs: &[io::IoSlice<'_>],
74    ) -> Poll<io::Result<usize>> {
75        if let Some(stream) = self.decoder.get_mut().get_mut().0.take() {
76            self.encoder.get_mut().0 = Some(stream);
77        }
78
79        Pin::new(&mut self.encoder).poll_write_vectored(cx, bufs)
80    }
81
82    fn is_write_vectored(&self) -> bool {
83        if let Some(stream) = &self.encoder.get_ref().0 {
84            stream.is_write_vectored()
85        } else if let Some(stream) = &self.decoder.get_ref().get_ref().0 {
86            stream.is_write_vectored()
87        } else {
88            unreachable!()
89        }
90    }
91
92    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
93        if let Some(stream) = self.decoder.get_mut().get_mut().0.take() {
94            self.encoder.get_mut().0 = Some(stream);
95        }
96
97        Pin::new(&mut self.encoder).poll_flush(cx)
98    }
99
100    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101        if let Some(stream) = self.decoder.get_mut().get_mut().0.take() {
102            self.encoder.get_mut().0 = Some(stream);
103        }
104
105        Pin::new(&mut self.encoder).poll_shutdown(cx)
106    }
107}
108
109impl<S> Debug for ZstdStream<S> {
110    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
111        f.debug_struct("ZstdStream").finish_non_exhaustive()
112    }
113}