tokio_io_compat/
lib.rs

1//! [![GitHub Workflow Status](https://img.shields.io/github/workflow/status/PhotonQuantum/tokio-io-compat/Test?style=flat-square)](https://github.com/PhotonQuantum/tokio-io-compat/actions/workflows/test.yml)
2//! [![crates.io](https://img.shields.io/crates/v/tokio-io-compat?style=flat-square)](https://crates.io/crates/tokio-io-compat)
3//! [![Documentation](https://img.shields.io/docsrs/tokio-io-compat?style=flat-square)](https://docs.rs/tokio-io-compat)
4//!
5//! Compatibility wrapper around `std::io::{Read, Write, Seek}` traits that implements `tokio::io::{AsyncRead, AsyncWrite, AsyncSeek}`.
6//!
7//! Beware that this won't magically make your IO operations asynchronous.
8//! You should still consider asyncify your code or move the IO operations to blocking thread if the cost is high.
9//!
10//! ## Deal with `WouldBlock`
11//!
12//! If you are trying to wrap a non-blocking IO, it may yield [`WouldBlock`](std::io::ErrorKind::WouldBlock) errors when data
13//! is not ready.
14//! This wrapper will automatically convert [`WouldBlock`](std::io::ErrorKind::WouldBlock) into `Poll::Pending`.
15//!
16//! However, the waker must be waken later to avoid blocking the future.
17//! By default, it is waken immediately. This may waste excessive CPU cycles, especially when the operation
18//! is slow.
19//!
20//! You may add a delay before each wake by creating the wrapper with [`AsyncIoCompat::new_with_delay`](AsyncIoCompat::new_with_delay).
21//! If your underlying non-blocking IO has a native poll complete notification mechanism, consider
22//! writing your own wrapper instead of using this crate.
23//!
24//! For reference please see [tokio-tls](https://github.com/tokio-rs/tls/blob/master/tokio-native-tls/src/lib.rs).
25//!
26//! ## Example
27//!
28//! ```rust
29//! use std::io::Cursor;
30//! use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom};
31//! use tokio_io_compat::CompatHelperTrait;
32//!
33//! # #[tokio::test]
34//! # async fn test() {
35//! let mut data = Cursor::new(vec![]);
36//! data.tokio_io_mut().write_all(&vec![0, 1, 2, 3, 4]).await.unwrap();
37//! data.tokio_io_mut().seek(SeekFrom::Start(2)).await.unwrap();
38//! assert_eq!(data.tokio_io_mut().read_u8().await.unwrap(), 2);
39//! # }
40//! ```
41
42use std::io::{Read, Seek, SeekFrom, Write};
43use std::pin::Pin;
44use std::task::{Context, Poll};
45use std::time::Duration;
46use std::{io, mem};
47
48use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
49
50#[cfg(test)]
51mod tests;
52
53/// The wrapper type.
54#[derive(Debug)]
55pub struct AsyncIoCompat<T> {
56    inner: T,
57    last_seek: io::Result<u64>,
58    last_seek_position: SeekFrom,
59    wake_delay: Duration,
60}
61
62impl<T> AsyncIoCompat<T> {
63    /// Create a new wrapper.
64    pub const fn new(inner: T) -> Self {
65        Self {
66            inner,
67            last_seek: Ok(0),
68            last_seek_position: SeekFrom::Start(0),
69            wake_delay: Duration::ZERO,
70        }
71    }
72    /// Create a new wrapper with given [`WouldBlock`](std::io::ErrorKind::WouldBlock) wake delay.
73    pub const fn new_with_delay(inner: T, delay: Duration) -> Self {
74        Self {
75            inner,
76            last_seek: Ok(0),
77            last_seek_position: SeekFrom::Start(0),
78            wake_delay: delay,
79        }
80    }
81    /// Get the inner type.
82    #[allow(clippy::missing_const_for_fn)] // false positive
83    pub fn into_inner(self) -> T {
84        self.inner
85    }
86
87    fn schedule_wake(&self, ctx: &Context<'_>) {
88        if self.wake_delay.is_zero() {
89            ctx.waker().wake_by_ref();
90        } else {
91            let waker = ctx.waker().clone();
92            let delay = self.wake_delay;
93            tokio::spawn(async move {
94                tokio::time::sleep(delay).await;
95                waker.wake();
96            });
97        }
98    }
99
100    fn no_blocking<F, O>(&mut self, ctx: &Context<'_>, f: F) -> Poll<io::Result<O>>
101    where
102        F: for<'a> FnOnce(&'a mut Self) -> io::Result<O>,
103    {
104        match f(self) {
105            Ok(t) => Poll::Ready(Ok(t)),
106            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
107                self.schedule_wake(ctx);
108                Poll::Pending
109            }
110            Err(e) => Poll::Ready(Err(e)),
111        }
112    }
113}
114
115impl<T: Read + Unpin> AsyncRead for AsyncIoCompat<T> {
116    fn poll_read(
117        mut self: Pin<&mut Self>,
118        cx: &mut Context<'_>,
119        buf: &mut ReadBuf<'_>,
120    ) -> Poll<io::Result<()>> {
121        self.no_blocking(cx, |this| {
122            this.inner.read(buf.initialize_unfilled()).map(|filled| {
123                buf.advance(filled);
124            })
125        })
126    }
127}
128
129impl<T: Write + Unpin> AsyncWrite for AsyncIoCompat<T> {
130    fn poll_write(
131        mut self: Pin<&mut Self>,
132        cx: &mut Context<'_>,
133        buf: &[u8],
134    ) -> Poll<io::Result<usize>> {
135        self.no_blocking(cx, |this| this.inner.write(buf))
136    }
137
138    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
139        self.no_blocking(cx, |this| this.inner.flush())
140    }
141
142    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
143        Poll::Ready(Ok(()))
144    }
145}
146
147impl<T: Seek + Unpin> AsyncSeek for AsyncIoCompat<T> {
148    fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
149        self.last_seek_position = position;
150        self.last_seek = self.inner.seek(position);
151        Ok(())
152    }
153
154    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
155        match self.last_seek {
156            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
157                let position = self.last_seek_position;
158                let res = self.inner.seek(position);
159                match res {
160                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
161                        self.last_seek = res;
162                        self.schedule_wake(cx);
163                        Poll::Pending
164                    }
165                    _ => {
166                        self.last_seek = Ok(0);
167                        Poll::Ready(res)
168                    }
169                }
170            }
171            _ => Poll::Ready(mem::replace(&mut self.last_seek, Ok(0))),
172        }
173    }
174}
175
176/// Helper trait that applies [`AsyncIoCompat`](AsyncIoCompat) wrapper to std types.
177pub trait CompatHelperTrait {
178    /// Applies the [`AsyncIoCompat`](AsyncIoCompat) wrapper by value.
179    fn tokio_io(self) -> AsyncIoCompat<Self>
180    where
181        Self: Sized;
182    /// Applies the [`AsyncIoCompat`](AsyncIoCompat) wrapper by mutable reference.
183    fn tokio_io_mut(&mut self) -> AsyncIoCompat<&mut Self>;
184}
185
186impl<T> CompatHelperTrait for T {
187    fn tokio_io(self) -> AsyncIoCompat<Self>
188    where
189        Self: Sized,
190    {
191        AsyncIoCompat::new(self)
192    }
193
194    fn tokio_io_mut(&mut self) -> AsyncIoCompat<&mut Self> {
195        AsyncIoCompat::new(self)
196    }
197}