trillium_testing/
test_transport.rs1use async_dup::Arc;
2use futures_lite::{AsyncRead, AsyncWrite};
3use std::{
4 fmt::{Debug, Display},
5 future::Future,
6 io,
7 pin::Pin,
8 sync::RwLock,
9 task::{Context, Poll, Waker},
10};
11use trillium_macros::{AsyncRead, AsyncWrite};
12
13#[derive(Default, Clone, Debug, AsyncRead, AsyncWrite)]
15pub struct TestTransport {
16 #[async_read]
18 pub read: Arc<CloseableCursor>,
19
20 #[async_write]
22 pub write: Arc<CloseableCursor>,
23}
24
25impl trillium_http::transport::Transport for TestTransport {}
26
27impl TestTransport {
28 pub fn new() -> (TestTransport, TestTransport) {
33 let a = Arc::new(CloseableCursor::default());
34 let b = Arc::new(CloseableCursor::default());
35
36 (
37 TestTransport {
38 read: a.clone(),
39 write: b.clone(),
40 },
41 TestTransport { read: b, write: a },
42 )
43 }
44
45 pub fn close(&mut self) {
51 self.read.close();
52 self.write.close();
53 }
54
55 pub fn snapshot(&self) -> Vec<u8> {
57 self.read.snapshot()
58 }
59
60 pub fn write_all(&self, bytes: impl AsRef<[u8]>) {
63 io::Write::write_all(&mut &*self.write, bytes.as_ref()).unwrap();
64 }
65
66 pub async fn read_available(&self) -> Vec<u8> {
69 self.read.read_available().await
70 }
71
72 pub async fn read_available_string(&self) -> String {
75 self.read.read_available_string().await
76 }
77}
78
79impl Drop for TestTransport {
80 fn drop(&mut self) {
81 self.close();
82 }
83}
84
85#[derive(Default)]
86struct CloseableCursorInner {
87 data: Vec<u8>,
88 cursor: usize,
89 waker: Option<Waker>,
90 closed: bool,
91}
92
93#[derive(Default)]
94pub struct CloseableCursor(RwLock<CloseableCursorInner>);
95
96impl CloseableCursor {
97 pub fn len(&self) -> usize {
101 self.0.read().unwrap().data.len()
102 }
103
104 pub fn cursor(&self) -> usize {
108 self.0.read().unwrap().cursor
109 }
110
111 pub fn is_empty(&self) -> bool {
115 self.len() == 0
116 }
117
118 pub fn snapshot(&self) -> Vec<u8> {
120 self.0.read().unwrap().data.clone()
121 }
122
123 pub fn current(&self) -> bool {
127 let inner = self.0.read().unwrap();
128 inner.data.len() == inner.cursor
129 }
130
131 pub fn close(&self) {
135 let mut inner = self.0.write().unwrap();
136 inner.closed = true;
137 if let Some(waker) = inner.waker.take() {
138 waker.wake();
139 }
140 }
141
142 pub async fn read_available(&self) -> Vec<u8> {
144 ReadAvailable(self).await.unwrap()
145 }
146
147 pub async fn read_available_string(&self) -> String {
149 String::from_utf8(self.read_available().await).unwrap()
150 }
151}
152
153struct ReadAvailable<T>(T);
154
155impl<T: AsyncRead + Unpin> Future for ReadAvailable<T> {
156 type Output = io::Result<Vec<u8>>;
157
158 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159 let mut buf = vec![];
160 let mut bytes_read = 0;
161 loop {
162 if buf.len() == bytes_read {
163 buf.reserve(32);
164 buf.resize(buf.capacity(), 0);
165 }
166 match Pin::new(&mut self.0).poll_read(cx, &mut buf[bytes_read..]) {
167 Poll::Ready(Ok(0)) => break,
168 Poll::Ready(Ok(new_bytes)) => {
169 bytes_read += new_bytes;
170 }
171 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
172 Poll::Pending if bytes_read == 0 => return Poll::Pending,
173 Poll::Pending => break,
174 }
175 }
176
177 buf.truncate(bytes_read);
178 Poll::Ready(Ok(buf))
179 }
180}
181
182impl Display for CloseableCursor {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 let inner = self.0.read().unwrap();
185 write!(f, "{}", String::from_utf8_lossy(&inner.data))
186 }
187}
188
189impl Debug for CloseableCursor {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 let inner = self.0.read().unwrap();
192 f.debug_struct("CloseableCursor")
193 .field(
194 "data",
195 &std::str::from_utf8(&inner.data).unwrap_or("not utf8"),
196 )
197 .field("closed", &inner.closed)
198 .field("cursor", &inner.cursor)
199 .finish()
200 }
201}
202
203impl AsyncRead for CloseableCursor {
204 fn poll_read(
205 self: Pin<&mut Self>,
206 cx: &mut Context<'_>,
207 buf: &mut [u8],
208 ) -> Poll<io::Result<usize>> {
209 Pin::new(&mut &*self).poll_read(cx, buf)
210 }
211}
212
213impl AsyncRead for &CloseableCursor {
214 fn poll_read(
215 self: Pin<&mut Self>,
216 cx: &mut Context<'_>,
217 buf: &mut [u8],
218 ) -> Poll<io::Result<usize>> {
219 let mut inner = self.0.write().unwrap();
220 if inner.cursor < inner.data.len() {
221 let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor);
222 buf[..bytes_to_copy]
223 .copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]);
224 inner.cursor += bytes_to_copy;
225 Poll::Ready(Ok(bytes_to_copy))
226 } else if inner.closed {
227 Poll::Ready(Ok(0))
228 } else {
229 inner.waker = Some(cx.waker().clone());
230 Poll::Pending
231 }
232 }
233}
234
235impl AsyncWrite for &CloseableCursor {
236 fn poll_write(
237 self: Pin<&mut Self>,
238 _cx: &mut Context<'_>,
239 buf: &[u8],
240 ) -> Poll<io::Result<usize>> {
241 let mut inner = self.0.write().unwrap();
242 if inner.closed {
243 Poll::Ready(Ok(0))
244 } else {
245 inner.data.extend_from_slice(buf);
246 if let Some(waker) = inner.waker.take() {
247 waker.wake();
248 }
249 Poll::Ready(Ok(buf.len()))
250 }
251 }
252
253 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
254 Poll::Ready(Ok(()))
255 }
256
257 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
258 self.close();
259 Poll::Ready(Ok(()))
260 }
261}
262
263impl io::Write for CloseableCursor {
264 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
265 io::Write::write(&mut &*self, buf)
266 }
267
268 fn flush(&mut self) -> io::Result<()> {
269 Ok(())
270 }
271}
272
273impl io::Write for &CloseableCursor {
274 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
275 let mut inner = self.0.write().unwrap();
276 if inner.closed {
277 Ok(0)
278 } else {
279 inner.data.extend_from_slice(buf);
280 if let Some(waker) = inner.waker.take() {
281 waker.wake();
282 }
283 Ok(buf.len())
284 }
285 }
286
287 fn flush(&mut self) -> io::Result<()> {
288 Ok(())
289 }
290}