tokio_io_rewind/lib.rs
1//! This crate provides a `Rewind` struct that wraps any type that implements
2//! the `AsyncRead` and/or `AsyncWrite` traits, allowing for the prepending of
3//! bytes before the actual read happens. This is particularly useful in scenarios
4//! where you've read bytes from a stream but need to "put them back" to be read
5//! again, effectively allowing you to rewind the stream.
6//!
7//! # Examples
8//!
9//! Basic usage of `Rewind` with an `AsyncRead` implementor (`tokio::io::Cursor`):
10//!
11//! ```
12//! use bytes::Bytes;
13//! use tokio::io::AsyncReadExt;
14//! use tokio_io_rewind::Rewind;
15//! use std::io::Cursor;
16//!
17//! #[tokio::main]
18//! async fn main() {
19//! // Create a new Rewind instance with a Cursor wrapped inside.
20//! let mut rw = Rewind::new(Cursor::new(b"world".to_vec()));
21//!
22//! // Prepend "hello " to the stream.
23//! rw.rewind(Bytes::from_static(b"hello "));
24//!
25//! // Read all bytes from the rewinded stream.
26//! let mut buf = Vec::new();
27//! rw.read_to_end(&mut buf).await.unwrap();
28//!
29//! assert_eq!(buf, b"hello world");
30//! }
31//! ```
32//!
33//! This module also supports asynchronous write operations if the underlying type
34//! implements `AsyncWrite`.
35//!
36//! # Features
37//!
38//! - `Rewind::new(inner)`: Create a new `Rewind` instance with no pre-buffered bytes.
39//! - `Rewind::new_buffered(inner, pre)`: Create a new `Rewind` instance with an initial buffer of bytes to prepend.
40//! - `Rewind::rewind(pre)`: Prepend bytes to the current buffer. If there's already a buffer, the new bytes are prepended before the existing ones.
41//! - `Rewind::into_inner()`: Consumes the `Rewind`, returning the inner type and any un-read pre-buffered bytes.
42//!
43//! `Rewind` can be especially useful in protocols or situations where a piece of data is read to determine what comes next,
44//! but where that data also needs to be part of the eventual input stream.
45
46use bytes::{Buf, Bytes, BytesMut};
47use tokio::io::{AsyncRead, AsyncWrite};
48
49/// Wraps an `AsyncRead` and/or `AsyncWrite` implementor, allowing bytes to be prepended to the stream.
50///
51/// This is useful for situations where bytes are read from a stream to make decisions and then need to be
52/// "unread", making them available for future read operations.
53pub struct Rewind<T> {
54 inner: T,
55 pre: Option<Bytes>,
56}
57
58impl<T> Rewind<T> {
59 /// Creates a new `Rewind` instance without any pre-buffered bytes.
60 ///
61 /// # Arguments
62 ///
63 /// * `inner` - The inner type that implements `AsyncRead` and/or `AsyncWrite`.
64 pub fn new(inner: T) -> Self {
65 Rewind { inner, pre: None }
66 }
67
68 /// Creates a new `Rewind` instance with pre-buffered bytes.
69 ///
70 /// # Arguments
71 ///
72 /// * `inner` - The inner type that implements `AsyncRead` and/or `AsyncWrite`.
73 /// * `pre` - Initial bytes to prepend to the stream.
74 pub fn new_buffered(inner: T, pre: Bytes) -> Self {
75 Rewind {
76 inner,
77 pre: Some(pre),
78 }
79 }
80
81 /// Prepends bytes to the stream. If there are already pre-buffered bytes,
82 /// the new bytes are added before the existing ones.
83 ///
84 /// # Arguments
85 ///
86 /// * `pre` - Bytes to prepend.
87 pub fn rewind(&mut self, pre: Bytes) {
88 match self.pre {
89 Some(ref mut old_pre) => {
90 let mut new_pre = BytesMut::with_capacity(old_pre.len() + pre.len());
91 new_pre.extend_from_slice(&pre);
92 new_pre.extend_from_slice(old_pre);
93 self.pre = Some(new_pre.freeze());
94 }
95 None => {
96 self.pre = Some(pre);
97 }
98 }
99 }
100
101 /// Consumes the `Rewind`, returning the inner type and any un-read pre-buffered bytes.
102 pub fn into_inner(self) -> (T, Bytes) {
103 (self.inner, self.pre.unwrap_or_default())
104 }
105}
106
107impl<T> AsRef<T> for Rewind<T> {
108 fn as_ref(&self) -> &T {
109 &self.inner
110 }
111}
112
113impl<T> AsMut<T> for Rewind<T> {
114 fn as_mut(&mut self) -> &mut T {
115 &mut self.inner
116 }
117}
118
119impl<T> AsyncRead for Rewind<T>
120where
121 T: AsyncRead + Unpin,
122{
123 fn poll_read(
124 mut self: std::pin::Pin<&mut Self>,
125 cx: &mut std::task::Context<'_>,
126 buf: &mut tokio::io::ReadBuf<'_>,
127 ) -> std::task::Poll<std::io::Result<()>> {
128 if let Some(mut pre) = self.pre.take() {
129 let copy_len = std::cmp::min(pre.len(), buf.remaining());
130 buf.put_slice(&pre[..copy_len]);
131 pre.advance(copy_len);
132 if !pre.is_empty() {
133 self.pre = Some(pre);
134 }
135 return std::task::Poll::Ready(Ok(()));
136 }
137 std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
138 }
139}
140
141impl<T> AsyncWrite for Rewind<T>
142where
143 T: AsyncWrite + Unpin,
144{
145 fn poll_write(
146 self: std::pin::Pin<&mut Self>,
147 cx: &mut std::task::Context<'_>,
148 buf: &[u8],
149 ) -> std::task::Poll<std::io::Result<usize>> {
150 std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
151 }
152
153 fn poll_flush(
154 self: std::pin::Pin<&mut Self>,
155 cx: &mut std::task::Context<'_>,
156 ) -> std::task::Poll<std::io::Result<()>> {
157 std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
158 }
159
160 fn poll_shutdown(
161 self: std::pin::Pin<&mut Self>,
162 cx: &mut std::task::Context<'_>,
163 ) -> std::task::Poll<std::io::Result<()>> {
164 std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use std::io::Cursor;
172 use tokio::io::AsyncReadExt;
173
174 #[tokio::test]
175 async fn test_rewind() {
176 let mut rw = Rewind::new(Cursor::new(b"world".to_vec()));
177 rw.rewind(Bytes::from_static(b"hello "));
178 let mut buf = Vec::new();
179 rw.read_to_end(&mut buf).await.unwrap();
180 assert_eq!(buf, b"hello world");
181 let mut buf = Vec::new();
182 rw.read_to_end(&mut buf).await.unwrap();
183 assert_eq!(buf, b"");
184 }
185}