tokio_async_write_utility/
lib.rs1use core::cell::UnsafeCell;
2use core::future::Future;
3use core::marker::Sized;
4use core::pin::Pin;
5use core::slice;
6use core::task::{Context, Poll};
7
8use std::io::{IoSlice, Result};
9use tokio::io::AsyncWrite;
10
11#[cfg(test)]
12extern crate tokio_pipe;
13
14pub struct WriteVectorizedAll<'a, 'b, 'c, T: AsyncWriteUtility + ?Sized>(
15 UnsafeCell<&'a mut T>,
16 &'b mut [IoSlice<'c>],
17);
18
19impl<T: AsyncWriteUtility + ?Sized> Future for WriteVectorizedAll<'_, '_, '_, T> {
20 type Output = Result<()>;
21
22 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
23 AsyncWriteUtility::poll_write_vectored_all(
24 unsafe { Pin::new_unchecked(*(self.0.get())) },
25 cx,
26 self.1,
27 )
28 }
29}
30
31pub trait AsyncWriteUtility: AsyncWrite {
32 fn poll_write_vectored_all(
33 mut self: Pin<&mut Self>,
34 cx: &mut Context<'_>,
35 mut bufs: &mut [IoSlice<'_>],
36 ) -> Poll<Result<()>> {
37 if bufs.is_empty() {
38 return Poll::Ready(Ok(()));
39 }
40
41 'outer: loop {
43 let mut bytes = match self.as_mut().poll_write_vectored(cx, bufs) {
45 Poll::Ready(res) => res?,
46 Poll::Pending => return Poll::Pending,
47 };
48
49 while bufs[0].len() <= bytes {
50 bytes -= bufs[0].len();
51 bufs = &mut bufs[1..];
52
53 if bufs.is_empty() {
54 return Poll::Ready(Ok(()));
55 }
56
57 if bytes == 0 {
58 continue 'outer;
59 }
60 }
61
62 let buf = &bufs[0][bytes..];
63 bufs[0] = IoSlice::new(unsafe { slice::from_raw_parts(buf.as_ptr(), buf.len()) });
64 }
65 }
66
67 fn write_vectored_all<'a, 'b, 'c>(
73 &'a mut self,
74 bufs: &'b mut [IoSlice<'c>],
75 ) -> WriteVectorizedAll<'a, 'b, 'c, Self> {
76 WriteVectorizedAll(UnsafeCell::new(self), bufs)
77 }
78}
79
80impl<T: AsyncWrite + ?Sized> AsyncWriteUtility for T {}
81
82#[cfg(test)]
83mod tests {
84 use super::AsyncWriteUtility;
85
86 use std::io::IoSlice;
87 use std::slice::from_raw_parts;
88 use tokio::io::AsyncReadExt;
89
90 fn as_ioslice<T>(v: &[T]) -> IoSlice<'_> {
91 IoSlice::new(unsafe {
92 from_raw_parts(v.as_ptr() as *const u8, v.len() * std::mem::size_of::<T>())
93 })
94 }
95
96 #[tokio::test]
97 async fn test() {
98 let (mut r, mut w) = tokio_pipe::pipe().unwrap();
99
100 let w_task = tokio::spawn(async move {
101 let buffer: Vec<u32> = (0..1024).collect();
102 w.write_vectored_all(&mut [as_ioslice(&buffer), as_ioslice(&buffer)])
103 .await
104 .unwrap();
105 });
106
107 let r_task = tokio::spawn(async move {
108 let mut n = 0u32;
109 let mut buf = [0; 4 * 128];
110 while n < 1024 {
111 r.read_exact(&mut buf).await.unwrap();
112 for x in buf.chunks(4) {
113 assert_eq!(x, n.to_ne_bytes());
114 n += 1;
115 }
116 }
117
118 n = 0;
119 while n < 1024 {
120 r.read_exact(&mut buf).await.unwrap();
121 for x in buf.chunks(4) {
122 assert_eq!(x, n.to_ne_bytes());
123 n += 1;
124 }
125 }
126 });
127 tokio::try_join!(w_task, r_task).unwrap();
128 }
129}