Skip to main content

veilid_tools/
async_peek_stream.rs

1use super::*;
2
3use std::io;
4use task::{Context, Poll};
5
6////////
7trait SendStream: AsyncRead + AsyncWrite + Send + Unpin {}
8impl<S> SendStream for S where S: AsyncRead + AsyncWrite + Send + Unpin + 'static {}
9
10////////
11
12pub struct Peek<'a> {
13    aps: AsyncPeekStream,
14    buf: &'a mut [u8],
15}
16
17impl Unpin for Peek<'_> {}
18
19impl Future for Peek<'_> {
20    type Output = std::io::Result<usize>;
21
22    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
23        let this = &mut *self;
24
25        let mut inner = this.aps.inner.lock();
26        let inner = &mut *inner;
27        //
28        let buf_len = this.buf.len();
29        let mut copy_len = buf_len;
30        if buf_len > inner.peekbuf_len {
31            //
32            inner.peekbuf.resize(buf_len, 0u8);
33            let read_len = match Pin::new(&mut inner.stream).poll_read(
34                cx,
35                &mut inner.peekbuf.as_mut_slice()[inner.peekbuf_len..buf_len],
36            ) {
37                Poll::Pending => {
38                    inner.peekbuf.resize(inner.peekbuf_len, 0u8);
39                    return Poll::Pending;
40                }
41                Poll::Ready(Err(e)) => {
42                    return Poll::Ready(Err(e));
43                }
44                Poll::Ready(Ok(v)) => v,
45            };
46            inner.peekbuf_len += read_len;
47            inner.peekbuf.resize(inner.peekbuf_len, 0u8);
48            copy_len = inner.peekbuf_len;
49        }
50        this.buf[..copy_len].copy_from_slice(&inner.peekbuf[..copy_len]);
51        Poll::Ready(Ok(copy_len))
52    }
53}
54
55////////
56
57pub struct PeekExact<'a> {
58    aps: AsyncPeekStream,
59    buf: &'a mut [u8],
60    cur_read: usize,
61}
62
63impl Unpin for PeekExact<'_> {}
64
65impl Future for PeekExact<'_> {
66    type Output = std::io::Result<usize>;
67
68    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69        let this = &mut *self;
70
71        let mut inner = this.aps.inner.lock();
72        let inner = &mut *inner;
73        //
74        let buf_len = this.buf.len();
75        let mut copy_len = buf_len;
76        if buf_len > inner.peekbuf_len {
77            //
78            inner.peekbuf.resize(buf_len, 0u8);
79            let read_len = match Pin::new(&mut inner.stream).poll_read(
80                cx,
81                &mut inner.peekbuf.as_mut_slice()[inner.peekbuf_len..buf_len],
82            ) {
83                Poll::Pending => {
84                    inner.peekbuf.resize(inner.peekbuf_len, 0u8);
85                    return Poll::Pending;
86                }
87                Poll::Ready(Err(e)) => {
88                    return Poll::Ready(Err(e));
89                }
90                Poll::Ready(Ok(v)) => v,
91            };
92            inner.peekbuf_len += read_len;
93            inner.peekbuf.resize(inner.peekbuf_len, 0u8);
94            copy_len = inner.peekbuf_len;
95        }
96        this.buf[this.cur_read..copy_len].copy_from_slice(&inner.peekbuf[this.cur_read..copy_len]);
97        this.cur_read = copy_len;
98        if this.cur_read == buf_len {
99            Poll::Ready(Ok(buf_len))
100        } else {
101            Poll::Pending
102        }
103    }
104}
105/////////
106struct AsyncPeekStreamInner {
107    stream: Box<dyn SendStream>,
108    peekbuf: Vec<u8>,
109    peekbuf_len: usize,
110}
111
112#[derive(Clone)]
113pub struct AsyncPeekStream
114where
115    Self: AsyncRead + AsyncWrite + Send + Unpin,
116{
117    inner: Arc<Mutex<AsyncPeekStreamInner>>,
118}
119
120impl AsyncPeekStream {
121    pub fn new<S>(stream: S) -> Self
122    where
123        S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
124    {
125        Self {
126            inner: Arc::new(Mutex::new(AsyncPeekStreamInner {
127                stream: Box::new(stream),
128                peekbuf: Vec::new(),
129                peekbuf_len: 0,
130            })),
131        }
132    }
133
134    pub fn peek<'a>(&'a self, buf: &'a mut [u8]) -> Peek<'a> {
135        Peek::<'a> {
136            aps: self.clone(),
137            buf,
138        }
139    }
140
141    pub fn peek_exact<'a>(&'a self, buf: &'a mut [u8]) -> PeekExact<'a> {
142        PeekExact::<'a> {
143            aps: self.clone(),
144            buf,
145            cur_read: 0,
146        }
147    }
148}
149
150impl AsyncRead for AsyncPeekStream {
151    fn poll_read(
152        self: Pin<&mut Self>,
153        cx: &mut Context<'_>,
154        buf: &mut [u8],
155    ) -> Poll<io::Result<usize>> {
156        let mut inner = self.inner.lock();
157        //
158        let buflen = buf.len();
159        let bufcopylen = core::cmp::min(buflen, inner.peekbuf_len);
160        let bufreadlen = buflen.saturating_sub(inner.peekbuf_len);
161
162        if bufreadlen > 0 {
163            match Pin::new(&mut inner.stream).poll_read(cx, &mut buf[bufcopylen..buflen]) {
164                Poll::Ready(res) => {
165                    let readlen = res?;
166                    buf[0..bufcopylen].copy_from_slice(&inner.peekbuf[0..bufcopylen]);
167                    inner.peekbuf_len = 0;
168                    Poll::Ready(Ok(bufcopylen + readlen))
169                }
170                Poll::Pending => {
171                    if bufcopylen > 0 {
172                        buf[0..bufcopylen].copy_from_slice(&inner.peekbuf[0..bufcopylen]);
173                        inner.peekbuf_len = 0;
174                        Poll::Ready(Ok(bufcopylen))
175                    } else {
176                        Poll::Pending
177                    }
178                }
179            }
180        } else {
181            buf[0..bufcopylen].copy_from_slice(&inner.peekbuf[0..bufcopylen]);
182            if bufcopylen == inner.peekbuf_len {
183                inner.peekbuf_len = 0;
184            } else if bufcopylen != 0 {
185                // slide buffer over by bufcopylen
186                let tail = inner.peekbuf.split_off(bufcopylen);
187                inner.peekbuf = tail;
188                inner.peekbuf_len -= bufcopylen;
189            }
190            Poll::Ready(Ok(bufcopylen))
191        }
192    }
193}
194
195impl AsyncWrite for AsyncPeekStream {
196    fn poll_write(
197        self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        buf: &[u8],
200    ) -> Poll<io::Result<usize>> {
201        let mut inner = self.inner.lock();
202        Pin::new(&mut inner.stream).poll_write(cx, buf)
203    }
204    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
205        let mut inner = self.inner.lock();
206        Pin::new(&mut inner.stream).poll_flush(cx)
207    }
208    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209        let mut inner = self.inner.lock();
210        Pin::new(&mut inner.stream).poll_close(cx)
211    }
212}
213
214impl core::marker::Unpin for AsyncPeekStream {}