1use std::pin::Pin;
4use std::io::{Read, Write, Result};
5use std::task::{Context, Poll};
6use std::mem::MaybeUninit;
7
8use pin_project::pin_project;
9use tokio::io::{
10 AsyncRead,
11 AsyncWrite,
12 ReadBuf,
13};
14use bytes::{
15 buf::{
16 Reader,
17 Writer,
18 },
19 Buf,
20 BufMut,
21 BytesMut,
22};
23use futures::ready;
24
25const MAX_BLOCK_SIZE: usize = 1 << 16;
27
28const MAX_COMPRESSED_SIZE: usize = 76490;
30
31#[pin_project]
33#[derive(Debug)]
34pub struct SnappyIO<T> {
35 #[pin] inner: T,
36 read_buf: BytesMut,
37 decoder: snap::read::FrameDecoder<Reader<BytesMut>>,
38 encoder: snap::write::FrameEncoder<Writer<BytesMut>>,
39}
40
41impl<T> SnappyIO<T> {
42
43 pub fn new(io: T) -> Self {
45 let encoder_writer = BytesMut::with_capacity(MAX_BLOCK_SIZE);
46 let decoder_reader = BytesMut::with_capacity(MAX_COMPRESSED_SIZE);
47 Self {
48 inner: io,
49 read_buf: BytesMut::with_capacity(MAX_BLOCK_SIZE),
50 decoder: snap::read::FrameDecoder::new(decoder_reader.reader()),
51 encoder: snap::write::FrameEncoder::new(encoder_writer.writer()),
52 }
53 }
54
55 pub fn into_inner(self) -> T {
57 self.inner
58 }
59}
60
61impl<T: AsyncRead + Unpin> AsyncRead for SnappyIO<T> {
62 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
63 let mut this = self.project();
64
65 loop {
66 if this.read_buf.remaining() > 0 {
67 let amt = std::cmp::min(this.read_buf.remaining(), buf.remaining());
68 let slice = this.read_buf.split_to(amt);
69 buf.put_slice(&slice);
70 return Poll::Ready(Ok(()));
71 }
72
73 let decoder_reader = this.decoder.get_mut();
74 let decoder_buf: &mut BytesMut = decoder_reader.get_mut();
75 let buf_len = decoder_buf.len();
76 if buf_len < 4 {
77 decoder_buf.reserve(4 - buf_len);
78
79 let n = {
80 let dst = decoder_buf.chunk_mut();
81 let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
82 let mut buf = ReadBuf::uninit(&mut dst[..4 - buf_len]);
83 let ptr = buf.filled().as_ptr();
84 let inner = this.inner.as_mut();
85 ready!(inner.poll_read(cx, &mut buf)?);
86
87 assert_eq!(ptr, buf.filled().as_ptr());
89 buf.filled().len()
90 };
91 if n == 0 {
92 return Poll::Ready(Ok(()));
93 }
94 unsafe {
97 decoder_buf.advance_mut(n);
98 }
99
100 continue;
101 }
102
103 let mut chunk_len_buf = &decoder_buf.as_ref()[1..];
104 let chunk_len = chunk_len_buf.get_uint_le(3) as usize;
105
106 let buf_len = decoder_buf.len();
107 if buf_len < chunk_len + 4 {
109 decoder_buf.reserve(chunk_len + 4 - buf_len);
110 let n = {
111 let dst = decoder_buf.chunk_mut();
112 let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
113 let mut buf = ReadBuf::uninit(&mut dst[..chunk_len + 4 - buf_len]);
114 let ptr = buf.filled().as_ptr();
115 ready!(this.inner.as_mut().poll_read(cx, &mut buf)?);
116
117 assert_eq!(ptr, buf.filled().as_ptr());
118 buf.filled().len()
119 };
120 if n == 0 {
121 return Poll::Ready(Ok(()));
122 }
123
124 unsafe {
125 decoder_buf.advance_mut(n);
126 }
127
128 continue;
129 }
130
131 if decoder_buf.len() == chunk_len + 4 {
132 let dst = this.read_buf.chunk_mut();
133 let mut dst = unsafe { &mut *(dst as *mut _ as *mut [u8]) };
134 let _decoded = this.decoder.read(&mut dst)?;
135 unsafe {
136 this.read_buf.advance_mut(_decoded);
137 }
138 }
139 }
140 }
141}
142
143impl<T: AsyncWrite + Unpin> AsyncWrite for SnappyIO<T> {
144 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
145 if buf.is_empty() {
146 return Poll::Ready(Ok(0));
147 }
148
149 let len = std::cmp::min(buf.len(), MAX_BLOCK_SIZE);
150
151 let mut this = self.project();
152 loop {
153 let output_buf = this.encoder.get_mut().get_mut();
154 if output_buf.has_remaining() {
155 let n = ready!(this.inner.as_mut().poll_write(cx, output_buf.chunk())?);
156 output_buf.advance(n);
157 return Poll::Ready(Ok(len));
158 }
159
160 let _ = this.encoder.write(&buf[..len])?;
161 let output_buf = this.encoder.get_mut().get_mut();
162
163 if output_buf.is_empty() {
164 this.encoder.flush()?;
165 }
166 }
167 }
168
169 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
170 let mut this = self.project();
171
172 this.encoder.flush()?;
173 let output_buf = this.encoder.get_mut().get_mut();
174 while output_buf.has_remaining() {
175 let n = ready!(this.inner.as_mut().poll_write(cx, output_buf.chunk())?);
176 output_buf.advance(n);
177 }
178 this.inner.poll_flush(cx)
179 }
180
181 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
182 self.project().inner.poll_shutdown(cx)
183 }
184
185}