watermelon_mini/
non_standard_zstd.rs1use std::{
2 fmt::{self, Debug, Formatter},
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use async_compression::{
9 Level,
10 tokio::{bufread::ZstdDecoder, write::ZstdEncoder},
11};
12use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
13
14use crate::util::MaybeConnection;
15
16pub struct ZstdStream<S> {
17 decoder: ZstdDecoder<BufReader<MaybeConnection<S>>>,
18 encoder: ZstdEncoder<MaybeConnection<S>>,
19}
20
21impl<S> ZstdStream<S>
22where
23 S: AsyncRead + AsyncWrite + Unpin,
24{
25 #[must_use]
26 pub fn new(stream: S, compression_level: u8) -> Self {
27 Self {
28 decoder: ZstdDecoder::new(BufReader::new(MaybeConnection(Some(stream)))),
29 encoder: ZstdEncoder::with_quality(
30 MaybeConnection(None),
31 Level::Precise(compression_level.into()),
32 ),
33 }
34 }
35}
36
37impl<S> AsyncRead for ZstdStream<S>
38where
39 S: AsyncRead + AsyncWrite + Unpin,
40{
41 fn poll_read(
42 mut self: Pin<&mut Self>,
43 cx: &mut Context<'_>,
44 buf: &mut ReadBuf<'_>,
45 ) -> Poll<io::Result<()>> {
46 if let Some(stream) = self.encoder.get_mut().0.take() {
47 self.decoder.get_mut().get_mut().0 = Some(stream);
48 }
49
50 Pin::new(&mut self.decoder).poll_read(cx, buf)
51 }
52}
53
54impl<S> AsyncWrite for ZstdStream<S>
55where
56 S: AsyncRead + AsyncWrite + Unpin,
57{
58 fn poll_write(
59 mut self: Pin<&mut Self>,
60 cx: &mut Context<'_>,
61 buf: &[u8],
62 ) -> Poll<io::Result<usize>> {
63 if let Some(stream) = self.decoder.get_mut().get_mut().0.take() {
64 self.encoder.get_mut().0 = Some(stream);
65 }
66
67 Pin::new(&mut self.encoder).poll_write(cx, buf)
68 }
69
70 fn poll_write_vectored(
71 mut self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 bufs: &[io::IoSlice<'_>],
74 ) -> Poll<io::Result<usize>> {
75 if let Some(stream) = self.decoder.get_mut().get_mut().0.take() {
76 self.encoder.get_mut().0 = Some(stream);
77 }
78
79 Pin::new(&mut self.encoder).poll_write_vectored(cx, bufs)
80 }
81
82 fn is_write_vectored(&self) -> bool {
83 if let Some(stream) = &self.encoder.get_ref().0 {
84 stream.is_write_vectored()
85 } else if let Some(stream) = &self.decoder.get_ref().get_ref().0 {
86 stream.is_write_vectored()
87 } else {
88 unreachable!()
89 }
90 }
91
92 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
93 if let Some(stream) = self.decoder.get_mut().get_mut().0.take() {
94 self.encoder.get_mut().0 = Some(stream);
95 }
96
97 Pin::new(&mut self.encoder).poll_flush(cx)
98 }
99
100 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101 if let Some(stream) = self.decoder.get_mut().get_mut().0.take() {
102 self.encoder.get_mut().0 = Some(stream);
103 }
104
105 Pin::new(&mut self.encoder).poll_shutdown(cx)
106 }
107}
108
109impl<S> Debug for ZstdStream<S> {
110 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
111 f.debug_struct("ZstdStream").finish_non_exhaustive()
112 }
113}