tokio_async_write_utility/
lib.rs

1use core::cell::UnsafeCell;
2use core::future::Future;
3use core::marker::Sized;
4use core::pin::Pin;
5use core::slice;
6use core::task::{Context, Poll};
7
8use std::io::{IoSlice, Result};
9use tokio::io::AsyncWrite;
10
11#[cfg(test)]
12extern crate tokio_pipe;
13
14pub struct WriteVectorizedAll<'a, 'b, 'c, T: AsyncWriteUtility + ?Sized>(
15    UnsafeCell<&'a mut T>,
16    &'b mut [IoSlice<'c>],
17);
18
19impl<T: AsyncWriteUtility + ?Sized> Future for WriteVectorizedAll<'_, '_, '_, T> {
20    type Output = Result<()>;
21
22    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
23        AsyncWriteUtility::poll_write_vectored_all(
24            unsafe { Pin::new_unchecked(*(self.0.get())) },
25            cx,
26            self.1,
27        )
28    }
29}
30
31pub trait AsyncWriteUtility: AsyncWrite {
32    fn poll_write_vectored_all(
33        mut self: Pin<&mut Self>,
34        cx: &mut Context<'_>,
35        mut bufs: &mut [IoSlice<'_>],
36    ) -> Poll<Result<()>> {
37        if bufs.is_empty() {
38            return Poll::Ready(Ok(()));
39        }
40
41        // Loop Invariant: bufs must not be empty
42        'outer: loop {
43            // bytes must be greater than 0
44            let mut bytes = match self.as_mut().poll_write_vectored(cx, bufs) {
45                Poll::Ready(res) => res?,
46                Poll::Pending => return Poll::Pending,
47            };
48
49            while bufs[0].len() <= bytes {
50                bytes -= bufs[0].len();
51                bufs = &mut bufs[1..];
52
53                if bufs.is_empty() {
54                    return Poll::Ready(Ok(()));
55                }
56
57                if bytes == 0 {
58                    continue 'outer;
59                }
60            }
61
62            let buf = &bufs[0][bytes..];
63            bufs[0] = IoSlice::new(unsafe { slice::from_raw_parts(buf.as_ptr(), buf.len()) });
64        }
65    }
66
67    /// Equivalent to:
68    ///
69    /// ```ignore
70    /// async fn write_vectored_all(&mut self, bufs: &mut [IoSlice<'_>]) -> Result<()>;
71    /// ```
72    fn write_vectored_all<'a, 'b, 'c>(
73        &'a mut self,
74        bufs: &'b mut [IoSlice<'c>],
75    ) -> WriteVectorizedAll<'a, 'b, 'c, Self> {
76        WriteVectorizedAll(UnsafeCell::new(self), bufs)
77    }
78}
79
80impl<T: AsyncWrite + ?Sized> AsyncWriteUtility for T {}
81
82#[cfg(test)]
83mod tests {
84    use super::AsyncWriteUtility;
85
86    use std::io::IoSlice;
87    use std::slice::from_raw_parts;
88    use tokio::io::AsyncReadExt;
89
90    fn as_ioslice<T>(v: &[T]) -> IoSlice<'_> {
91        IoSlice::new(unsafe {
92            from_raw_parts(v.as_ptr() as *const u8, v.len() * std::mem::size_of::<T>())
93        })
94    }
95
96    #[tokio::test]
97    async fn test() {
98        let (mut r, mut w) = tokio_pipe::pipe().unwrap();
99
100        let w_task = tokio::spawn(async move {
101            let buffer: Vec<u32> = (0..1024).collect();
102            w.write_vectored_all(&mut [as_ioslice(&buffer), as_ioslice(&buffer)])
103                .await
104                .unwrap();
105        });
106
107        let r_task = tokio::spawn(async move {
108            let mut n = 0u32;
109            let mut buf = [0; 4 * 128];
110            while n < 1024 {
111                r.read_exact(&mut buf).await.unwrap();
112                for x in buf.chunks(4) {
113                    assert_eq!(x, n.to_ne_bytes());
114                    n += 1;
115                }
116            }
117
118            n = 0;
119            while n < 1024 {
120                r.read_exact(&mut buf).await.unwrap();
121                for x in buf.chunks(4) {
122                    assert_eq!(x, n.to_ne_bytes());
123                    n += 1;
124                }
125            }
126        });
127        tokio::try_join!(w_task, r_task).unwrap();
128    }
129}