tokio_snappy/
lib.rs

1//!  Wrap rust-snappy for `tokio::io::AsyncRead` and `tokio::io::AsyncWrite`.
2//!
3use std::pin::Pin;
4use std::io::{Read, Write, Result};
5use std::task::{Context, Poll};
6use std::mem::MaybeUninit;
7
8use pin_project::pin_project;
9use tokio::io::{
10    AsyncRead,
11    AsyncWrite,
12    ReadBuf,
13};
14use bytes::{
15    buf::{
16        Reader,
17        Writer,
18    },
19    Buf,
20    BufMut,
21    BytesMut,
22};
23use futures::ready;
24
25// max uncompressed content size
26const MAX_BLOCK_SIZE: usize = 1 << 16;
27
28// max_compress_len(MAX_BLOCK_SIZE)
29const MAX_COMPRESSED_SIZE: usize = 76490;
30
31/// Async implementation for the snappy framed encoder/decoder.
32#[pin_project]
33#[derive(Debug)]
34pub struct SnappyIO<T> {
35    #[pin] inner: T,
36    read_buf: BytesMut,
37    decoder: snap::read::FrameDecoder<Reader<BytesMut>>,
38    encoder: snap::write::FrameEncoder<Writer<BytesMut>>,
39}
40
41impl<T> SnappyIO<T> {
42
43    /// Create a new SnappyIO, wrapped the given io object.
44    pub fn new(io: T) -> Self {
45        let encoder_writer = BytesMut::with_capacity(MAX_BLOCK_SIZE);
46        let decoder_reader = BytesMut::with_capacity(MAX_COMPRESSED_SIZE);
47        Self {
48            inner: io,
49            read_buf: BytesMut::with_capacity(MAX_BLOCK_SIZE),
50            decoder: snap::read::FrameDecoder::new(decoder_reader.reader()),
51            encoder: snap::write::FrameEncoder::new(encoder_writer.writer()),
52        }
53    }
54
55    /// Consume this `SnappyIO`, returning the underlying value.
56    pub fn into_inner(self) -> T {
57        self.inner
58    }
59}
60
61impl<T: AsyncRead + Unpin> AsyncRead for SnappyIO<T> {
62    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
63        let mut this = self.project();
64
65        loop {
66            if this.read_buf.remaining() > 0 {
67                let amt = std::cmp::min(this.read_buf.remaining(), buf.remaining());
68                let slice = this.read_buf.split_to(amt);
69                buf.put_slice(&slice);
70                return Poll::Ready(Ok(()));
71            }
72
73            let decoder_reader = this.decoder.get_mut();
74            let decoder_buf: &mut BytesMut = decoder_reader.get_mut();
75            let buf_len = decoder_buf.len();
76            if buf_len < 4 {
77                decoder_buf.reserve(4 - buf_len);
78
79                let n = {
80                    let dst = decoder_buf.chunk_mut();
81                    let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
82                    let mut buf = ReadBuf::uninit(&mut dst[..4 - buf_len]);
83                    let ptr = buf.filled().as_ptr();
84                    let inner = this.inner.as_mut();
85                    ready!(inner.poll_read(cx, &mut buf)?);
86
87                    // Ensure the pointer does not change from under us
88                    assert_eq!(ptr, buf.filled().as_ptr());
89                    buf.filled().len()
90                };
91                if n == 0 {
92                    return Poll::Ready(Ok(()));
93                }
94                // Safety: This is guaranteed to be the number of initialized (and read)
95                // bytes due to the invariants provided by `ReadBuf::filled`.
96                unsafe {
97                    decoder_buf.advance_mut(n);
98                }
99
100                continue;
101            }
102
103            let mut chunk_len_buf = &decoder_buf.as_ref()[1..];
104            let chunk_len = chunk_len_buf.get_uint_le(3) as usize;
105
106            let buf_len = decoder_buf.len();
107            // Read the entire chunk into buf
108            if buf_len < chunk_len + 4 {
109                decoder_buf.reserve(chunk_len + 4 - buf_len);
110                let n = {
111                    let dst = decoder_buf.chunk_mut();
112                    let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
113                    let mut buf = ReadBuf::uninit(&mut dst[..chunk_len + 4 - buf_len]);
114                    let ptr = buf.filled().as_ptr();
115                    ready!(this.inner.as_mut().poll_read(cx, &mut buf)?);
116
117                    assert_eq!(ptr, buf.filled().as_ptr());
118                    buf.filled().len()
119                };
120                if n == 0 {
121                    return Poll::Ready(Ok(()));
122                }
123
124                unsafe {
125                    decoder_buf.advance_mut(n);
126                }
127
128                continue;
129            }
130
131            if decoder_buf.len() == chunk_len + 4 {
132                let dst = this.read_buf.chunk_mut();
133                let mut dst = unsafe { &mut *(dst as *mut _ as *mut [u8]) };
134                let _decoded = this.decoder.read(&mut dst)?;
135                unsafe {
136                    this.read_buf.advance_mut(_decoded);
137                }
138            }
139        }
140    }
141}
142
143impl<T: AsyncWrite + Unpin> AsyncWrite for SnappyIO<T> {
144    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
145        if buf.is_empty() {
146            return Poll::Ready(Ok(0));
147        }
148
149        let len = std::cmp::min(buf.len(), MAX_BLOCK_SIZE);
150
151        let mut this = self.project();
152        loop {
153            let output_buf = this.encoder.get_mut().get_mut();
154            if output_buf.has_remaining() {
155                let n = ready!(this.inner.as_mut().poll_write(cx, output_buf.chunk())?);
156                output_buf.advance(n);
157                return Poll::Ready(Ok(len));
158            }
159
160            let _ = this.encoder.write(&buf[..len])?;
161            let output_buf = this.encoder.get_mut().get_mut();
162
163            if output_buf.is_empty() {
164                this.encoder.flush()?;
165            }
166        }
167    }
168
169    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
170        let mut this = self.project();
171
172        this.encoder.flush()?;
173        let output_buf = this.encoder.get_mut().get_mut();
174        while output_buf.has_remaining() {
175            let n = ready!(this.inner.as_mut().poll_write(cx, output_buf.chunk())?);
176            output_buf.advance(n);
177        }
178        this.inner.poll_flush(cx)
179    }
180
181    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
182        self.project().inner.poll_shutdown(cx)
183    }
184
185}