veilid_tools/
async_peek_stream.rs1use super::*;
2
3use std::io;
4use task::{Context, Poll};
5
6trait SendStream: AsyncRead + AsyncWrite + Send + Unpin {}
8impl<S> SendStream for S where S: AsyncRead + AsyncWrite + Send + Unpin + 'static {}
9
10pub struct Peek<'a> {
13 aps: AsyncPeekStream,
14 buf: &'a mut [u8],
15}
16
17impl Unpin for Peek<'_> {}
18
19impl Future for Peek<'_> {
20 type Output = std::io::Result<usize>;
21
22 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
23 let this = &mut *self;
24
25 let mut inner = this.aps.inner.lock();
26 let inner = &mut *inner;
27 let buf_len = this.buf.len();
29 let mut copy_len = buf_len;
30 if buf_len > inner.peekbuf_len {
31 inner.peekbuf.resize(buf_len, 0u8);
33 let read_len = match Pin::new(&mut inner.stream).poll_read(
34 cx,
35 &mut inner.peekbuf.as_mut_slice()[inner.peekbuf_len..buf_len],
36 ) {
37 Poll::Pending => {
38 inner.peekbuf.resize(inner.peekbuf_len, 0u8);
39 return Poll::Pending;
40 }
41 Poll::Ready(Err(e)) => {
42 return Poll::Ready(Err(e));
43 }
44 Poll::Ready(Ok(v)) => v,
45 };
46 inner.peekbuf_len += read_len;
47 inner.peekbuf.resize(inner.peekbuf_len, 0u8);
48 copy_len = inner.peekbuf_len;
49 }
50 this.buf[..copy_len].copy_from_slice(&inner.peekbuf[..copy_len]);
51 Poll::Ready(Ok(copy_len))
52 }
53}
54
55pub struct PeekExact<'a> {
58 aps: AsyncPeekStream,
59 buf: &'a mut [u8],
60 cur_read: usize,
61}
62
63impl Unpin for PeekExact<'_> {}
64
65impl Future for PeekExact<'_> {
66 type Output = std::io::Result<usize>;
67
68 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69 let this = &mut *self;
70
71 let mut inner = this.aps.inner.lock();
72 let inner = &mut *inner;
73 let buf_len = this.buf.len();
75 let mut copy_len = buf_len;
76 if buf_len > inner.peekbuf_len {
77 inner.peekbuf.resize(buf_len, 0u8);
79 let read_len = match Pin::new(&mut inner.stream).poll_read(
80 cx,
81 &mut inner.peekbuf.as_mut_slice()[inner.peekbuf_len..buf_len],
82 ) {
83 Poll::Pending => {
84 inner.peekbuf.resize(inner.peekbuf_len, 0u8);
85 return Poll::Pending;
86 }
87 Poll::Ready(Err(e)) => {
88 return Poll::Ready(Err(e));
89 }
90 Poll::Ready(Ok(v)) => v,
91 };
92 inner.peekbuf_len += read_len;
93 inner.peekbuf.resize(inner.peekbuf_len, 0u8);
94 copy_len = inner.peekbuf_len;
95 }
96 this.buf[this.cur_read..copy_len].copy_from_slice(&inner.peekbuf[this.cur_read..copy_len]);
97 this.cur_read = copy_len;
98 if this.cur_read == buf_len {
99 Poll::Ready(Ok(buf_len))
100 } else {
101 Poll::Pending
102 }
103 }
104}
105struct AsyncPeekStreamInner {
107 stream: Box<dyn SendStream>,
108 peekbuf: Vec<u8>,
109 peekbuf_len: usize,
110}
111
112#[derive(Clone)]
113pub struct AsyncPeekStream
114where
115 Self: AsyncRead + AsyncWrite + Send + Unpin,
116{
117 inner: Arc<Mutex<AsyncPeekStreamInner>>,
118}
119
120impl AsyncPeekStream {
121 pub fn new<S>(stream: S) -> Self
122 where
123 S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
124 {
125 Self {
126 inner: Arc::new(Mutex::new(AsyncPeekStreamInner {
127 stream: Box::new(stream),
128 peekbuf: Vec::new(),
129 peekbuf_len: 0,
130 })),
131 }
132 }
133
134 pub fn peek<'a>(&'a self, buf: &'a mut [u8]) -> Peek<'a> {
135 Peek::<'a> {
136 aps: self.clone(),
137 buf,
138 }
139 }
140
141 pub fn peek_exact<'a>(&'a self, buf: &'a mut [u8]) -> PeekExact<'a> {
142 PeekExact::<'a> {
143 aps: self.clone(),
144 buf,
145 cur_read: 0,
146 }
147 }
148}
149
150impl AsyncRead for AsyncPeekStream {
151 fn poll_read(
152 self: Pin<&mut Self>,
153 cx: &mut Context<'_>,
154 buf: &mut [u8],
155 ) -> Poll<io::Result<usize>> {
156 let mut inner = self.inner.lock();
157 let buflen = buf.len();
159 let bufcopylen = core::cmp::min(buflen, inner.peekbuf_len);
160 let bufreadlen = buflen.saturating_sub(inner.peekbuf_len);
161
162 if bufreadlen > 0 {
163 match Pin::new(&mut inner.stream).poll_read(cx, &mut buf[bufcopylen..buflen]) {
164 Poll::Ready(res) => {
165 let readlen = res?;
166 buf[0..bufcopylen].copy_from_slice(&inner.peekbuf[0..bufcopylen]);
167 inner.peekbuf_len = 0;
168 Poll::Ready(Ok(bufcopylen + readlen))
169 }
170 Poll::Pending => {
171 if bufcopylen > 0 {
172 buf[0..bufcopylen].copy_from_slice(&inner.peekbuf[0..bufcopylen]);
173 inner.peekbuf_len = 0;
174 Poll::Ready(Ok(bufcopylen))
175 } else {
176 Poll::Pending
177 }
178 }
179 }
180 } else {
181 buf[0..bufcopylen].copy_from_slice(&inner.peekbuf[0..bufcopylen]);
182 if bufcopylen == inner.peekbuf_len {
183 inner.peekbuf_len = 0;
184 } else if bufcopylen != 0 {
185 let tail = inner.peekbuf.split_off(bufcopylen);
187 inner.peekbuf = tail;
188 inner.peekbuf_len -= bufcopylen;
189 }
190 Poll::Ready(Ok(bufcopylen))
191 }
192 }
193}
194
195impl AsyncWrite for AsyncPeekStream {
196 fn poll_write(
197 self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 buf: &[u8],
200 ) -> Poll<io::Result<usize>> {
201 let mut inner = self.inner.lock();
202 Pin::new(&mut inner.stream).poll_write(cx, buf)
203 }
204 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
205 let mut inner = self.inner.lock();
206 Pin::new(&mut inner.stream).poll_flush(cx)
207 }
208 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209 let mut inner = self.inner.lock();
210 Pin::new(&mut inner.stream).poll_close(cx)
211 }
212}
213
214impl core::marker::Unpin for AsyncPeekStream {}