tokio_io_rewind/
lib.rs

1//! This crate provides a `Rewind` struct that wraps any type that implements
2//! the `AsyncRead` and/or `AsyncWrite` traits, allowing for the prepending of
3//! bytes before the actual read happens. This is particularly useful in scenarios
4//! where you've read bytes from a stream but need to "put them back" to be read
5//! again, effectively allowing you to rewind the stream.
6//!
7//! # Examples
8//!
9//! Basic usage of `Rewind` with an `AsyncRead` implementor (`tokio::io::Cursor`):
10//!
11//! ```
12//! use bytes::Bytes;
13//! use tokio::io::AsyncReadExt;
14//! use tokio_io_rewind::Rewind;
15//! use std::io::Cursor;
16//!
17//! #[tokio::main]
18//! async fn main() {
19//!     // Create a new Rewind instance with a Cursor wrapped inside.
20//!     let mut rw = Rewind::new(Cursor::new(b"world".to_vec()));
21//!
22//!     // Prepend "hello " to the stream.
23//!     rw.rewind(Bytes::from_static(b"hello "));
24//!
25//!     // Read all bytes from the rewinded stream.
26//!     let mut buf = Vec::new();
27//!     rw.read_to_end(&mut buf).await.unwrap();
28//!
29//!     assert_eq!(buf, b"hello world");
30//! }
31//! ```
32//!
33//! This module also supports asynchronous write operations if the underlying type
34//! implements `AsyncWrite`.
35//!
36//! # Features
37//!
38//! - `Rewind::new(inner)`: Create a new `Rewind` instance with no pre-buffered bytes.
39//! - `Rewind::new_buffered(inner, pre)`: Create a new `Rewind` instance with an initial buffer of bytes to prepend.
40//! - `Rewind::rewind(pre)`: Prepend bytes to the current buffer. If there's already a buffer, the new bytes are prepended before the existing ones.
41//! - `Rewind::into_inner()`: Consumes the `Rewind`, returning the inner type and any un-read pre-buffered bytes.
42//!
43//! `Rewind` can be especially useful in protocols or situations where a piece of data is read to determine what comes next,
44//! but where that data also needs to be part of the eventual input stream.
45
46use bytes::{Buf, Bytes, BytesMut};
47use tokio::io::{AsyncRead, AsyncWrite};
48
49/// Wraps an `AsyncRead` and/or `AsyncWrite` implementor, allowing bytes to be prepended to the stream.
50///
51/// This is useful for situations where bytes are read from a stream to make decisions and then need to be
52/// "unread", making them available for future read operations.
53pub struct Rewind<T> {
54    inner: T,
55    pre: Option<Bytes>,
56}
57
58impl<T> Rewind<T> {
59    /// Creates a new `Rewind` instance without any pre-buffered bytes.
60    ///
61    /// # Arguments
62    ///
63    /// * `inner` - The inner type that implements `AsyncRead` and/or `AsyncWrite`.
64    pub fn new(inner: T) -> Self {
65        Rewind { inner, pre: None }
66    }
67
68    /// Creates a new `Rewind` instance with pre-buffered bytes.
69    ///
70    /// # Arguments
71    ///
72    /// * `inner` - The inner type that implements `AsyncRead` and/or `AsyncWrite`.
73    /// * `pre` - Initial bytes to prepend to the stream.
74    pub fn new_buffered(inner: T, pre: Bytes) -> Self {
75        Rewind {
76            inner,
77            pre: Some(pre),
78        }
79    }
80
81    /// Prepends bytes to the stream. If there are already pre-buffered bytes,
82    /// the new bytes are added before the existing ones.
83    ///
84    /// # Arguments
85    ///
86    /// * `pre` - Bytes to prepend.
87    pub fn rewind(&mut self, pre: Bytes) {
88        match self.pre {
89            Some(ref mut old_pre) => {
90                let mut new_pre = BytesMut::with_capacity(old_pre.len() + pre.len());
91                new_pre.extend_from_slice(&pre);
92                new_pre.extend_from_slice(old_pre);
93                self.pre = Some(new_pre.freeze());
94            }
95            None => {
96                self.pre = Some(pre);
97            }
98        }
99    }
100
101    /// Consumes the `Rewind`, returning the inner type and any un-read pre-buffered bytes.
102    pub fn into_inner(self) -> (T, Bytes) {
103        (self.inner, self.pre.unwrap_or_default())
104    }
105}
106
107impl<T> AsRef<T> for Rewind<T> {
108    fn as_ref(&self) -> &T {
109        &self.inner
110    }
111}
112
113impl<T> AsMut<T> for Rewind<T> {
114    fn as_mut(&mut self) -> &mut T {
115        &mut self.inner
116    }
117}
118
119impl<T> AsyncRead for Rewind<T>
120where
121    T: AsyncRead + Unpin,
122{
123    fn poll_read(
124        mut self: std::pin::Pin<&mut Self>,
125        cx: &mut std::task::Context<'_>,
126        buf: &mut tokio::io::ReadBuf<'_>,
127    ) -> std::task::Poll<std::io::Result<()>> {
128        if let Some(mut pre) = self.pre.take() {
129            let copy_len = std::cmp::min(pre.len(), buf.remaining());
130            buf.put_slice(&pre[..copy_len]);
131            pre.advance(copy_len);
132            if !pre.is_empty() {
133                self.pre = Some(pre);
134            }
135            return std::task::Poll::Ready(Ok(()));
136        }
137        std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
138    }
139}
140
141impl<T> AsyncWrite for Rewind<T>
142where
143    T: AsyncWrite + Unpin,
144{
145    fn poll_write(
146        self: std::pin::Pin<&mut Self>,
147        cx: &mut std::task::Context<'_>,
148        buf: &[u8],
149    ) -> std::task::Poll<std::io::Result<usize>> {
150        std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
151    }
152
153    fn poll_flush(
154        self: std::pin::Pin<&mut Self>,
155        cx: &mut std::task::Context<'_>,
156    ) -> std::task::Poll<std::io::Result<()>> {
157        std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
158    }
159
160    fn poll_shutdown(
161        self: std::pin::Pin<&mut Self>,
162        cx: &mut std::task::Context<'_>,
163    ) -> std::task::Poll<std::io::Result<()>> {
164        std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use std::io::Cursor;
172    use tokio::io::AsyncReadExt;
173
174    #[tokio::test]
175    async fn test_rewind() {
176        let mut rw = Rewind::new(Cursor::new(b"world".to_vec()));
177        rw.rewind(Bytes::from_static(b"hello "));
178        let mut buf = Vec::new();
179        rw.read_to_end(&mut buf).await.unwrap();
180        assert_eq!(buf, b"hello world");
181        let mut buf = Vec::new();
182        rw.read_to_end(&mut buf).await.unwrap();
183        assert_eq!(buf, b"");
184    }
185}