1use std::io::{Read, Seek, SeekFrom, Write};
43use std::pin::Pin;
44use std::task::{Context, Poll};
45use std::time::Duration;
46use std::{io, mem};
47
48use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
49
50#[cfg(test)]
51mod tests;
52
53#[derive(Debug)]
55pub struct AsyncIoCompat<T> {
56 inner: T,
57 last_seek: io::Result<u64>,
58 last_seek_position: SeekFrom,
59 wake_delay: Duration,
60}
61
62impl<T> AsyncIoCompat<T> {
63 pub const fn new(inner: T) -> Self {
65 Self {
66 inner,
67 last_seek: Ok(0),
68 last_seek_position: SeekFrom::Start(0),
69 wake_delay: Duration::ZERO,
70 }
71 }
72 pub const fn new_with_delay(inner: T, delay: Duration) -> Self {
74 Self {
75 inner,
76 last_seek: Ok(0),
77 last_seek_position: SeekFrom::Start(0),
78 wake_delay: delay,
79 }
80 }
81 #[allow(clippy::missing_const_for_fn)] pub fn into_inner(self) -> T {
84 self.inner
85 }
86
87 fn schedule_wake(&self, ctx: &Context<'_>) {
88 if self.wake_delay.is_zero() {
89 ctx.waker().wake_by_ref();
90 } else {
91 let waker = ctx.waker().clone();
92 let delay = self.wake_delay;
93 tokio::spawn(async move {
94 tokio::time::sleep(delay).await;
95 waker.wake();
96 });
97 }
98 }
99
100 fn no_blocking<F, O>(&mut self, ctx: &Context<'_>, f: F) -> Poll<io::Result<O>>
101 where
102 F: for<'a> FnOnce(&'a mut Self) -> io::Result<O>,
103 {
104 match f(self) {
105 Ok(t) => Poll::Ready(Ok(t)),
106 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
107 self.schedule_wake(ctx);
108 Poll::Pending
109 }
110 Err(e) => Poll::Ready(Err(e)),
111 }
112 }
113}
114
115impl<T: Read + Unpin> AsyncRead for AsyncIoCompat<T> {
116 fn poll_read(
117 mut self: Pin<&mut Self>,
118 cx: &mut Context<'_>,
119 buf: &mut ReadBuf<'_>,
120 ) -> Poll<io::Result<()>> {
121 self.no_blocking(cx, |this| {
122 this.inner.read(buf.initialize_unfilled()).map(|filled| {
123 buf.advance(filled);
124 })
125 })
126 }
127}
128
129impl<T: Write + Unpin> AsyncWrite for AsyncIoCompat<T> {
130 fn poll_write(
131 mut self: Pin<&mut Self>,
132 cx: &mut Context<'_>,
133 buf: &[u8],
134 ) -> Poll<io::Result<usize>> {
135 self.no_blocking(cx, |this| this.inner.write(buf))
136 }
137
138 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
139 self.no_blocking(cx, |this| this.inner.flush())
140 }
141
142 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
143 Poll::Ready(Ok(()))
144 }
145}
146
147impl<T: Seek + Unpin> AsyncSeek for AsyncIoCompat<T> {
148 fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
149 self.last_seek_position = position;
150 self.last_seek = self.inner.seek(position);
151 Ok(())
152 }
153
154 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
155 match self.last_seek {
156 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
157 let position = self.last_seek_position;
158 let res = self.inner.seek(position);
159 match res {
160 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
161 self.last_seek = res;
162 self.schedule_wake(cx);
163 Poll::Pending
164 }
165 _ => {
166 self.last_seek = Ok(0);
167 Poll::Ready(res)
168 }
169 }
170 }
171 _ => Poll::Ready(mem::replace(&mut self.last_seek, Ok(0))),
172 }
173 }
174}
175
176pub trait CompatHelperTrait {
178 fn tokio_io(self) -> AsyncIoCompat<Self>
180 where
181 Self: Sized;
182 fn tokio_io_mut(&mut self) -> AsyncIoCompat<&mut Self>;
184}
185
186impl<T> CompatHelperTrait for T {
187 fn tokio_io(self) -> AsyncIoCompat<Self>
188 where
189 Self: Sized,
190 {
191 AsyncIoCompat::new(self)
192 }
193
194 fn tokio_io_mut(&mut self) -> AsyncIoCompat<&mut Self> {
195 AsyncIoCompat::new(self)
196 }
197}