seamless/handler/
request.rs1use futures::{ Stream, TryStreamExt, io::{ AsyncRead, Cursor } };
8use std::pin::Pin;
9use std::task::{ Poll, Context };
10
11pub trait AsyncReadBody: AsyncRead + Send + Unpin {}
13impl <T: AsyncRead + Send + Unpin> AsyncReadBody for T {}
14
15pub struct Bytes {
18 variant: BytesVariant
19}
20
21enum BytesVariant {
24 FromVec(Cursor<Vec<u8>>),
25 FromReader(Box<dyn AsyncReadBody>)
26}
27
28impl std::fmt::Debug for Bytes {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_tuple("ByteStream").finish()
31 }
32}
33
34impl AsyncRead for Bytes {
35 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
36 unsafe {
42 match &mut Pin::get_unchecked_mut(self).variant {
43 BytesVariant::FromVec(v) => {
44 Pin::new_unchecked(v).poll_read(cx, buf)
45 },
46 BytesVariant::FromReader(r) => {
47 Pin::new_unchecked(r).poll_read(cx, buf)
48 }
49 }
50 }
51 }
52}
53
54impl From<Vec<u8>> for Bytes {
57 fn from(bytes: Vec<u8>) -> Self {
58 Bytes::from_vec(bytes)
59 }
60}
61
62impl Bytes {
63 pub fn from_vec(bytes: Vec<u8>) -> Bytes {
65 Bytes {
66 variant: BytesVariant::FromVec(Cursor::new(bytes))
67 }
68 }
69 pub fn from_reader<S: AsyncReadBody + 'static>(reader: S) -> Bytes {
71 Bytes {
72 variant: BytesVariant::FromReader(Box::new(reader))
73 }
74 }
75 pub fn from_stream<S: Stream<Item = std::io::Result<Vec<u8>>> + 'static + Send + Unpin>(stream: S) -> Bytes {
77 Bytes {
78 variant: BytesVariant::FromReader(Box::new(stream.into_async_read()))
79 }
80 }
81}
82
83#[cfg(test)]
84mod test_bytes {
85 use super::*;
86 use futures::AsyncReadExt;
87
88 #[tokio::test]
89 async fn can_read_from_vec() {
90 let mut bytes = Bytes::from_vec(vec![1,2,3,4,5]);
91
92 let mut output = vec![];
93 let n = bytes.read_to_end(&mut output).await.expect("No error should occur reading back the bytes");
94
95 assert_eq!(n, 5);
96 assert_eq!(output, vec![1,2,3,4,5]);
97 }
98
99 #[tokio::test]
100 async fn can_read_from_reader() {
101 let mut bytes = Bytes::from_reader(Bytes::from_vec(vec![1,2,3,4,5]));
103
104 let mut output = vec![];
105 let n = bytes.read_to_end(&mut output).await.expect("No error should occur reading back the bytes");
106
107 assert_eq!(n, 5);
108 assert_eq!(output, vec![1,2,3,4,5]);
109 }
110
111 #[tokio::test]
112 async fn can_read_from_stream() {
113 let mut bytes = Bytes::from_stream(futures::stream::iter(vec![
114 Ok(vec![1]),
115 Ok(vec![2]),
116 Ok(vec![3]),
117 Ok(vec![4]),
118 Ok(vec![5]),
119 ]));
120
121 let mut output = vec![];
122 let n = bytes.read_to_end(&mut output).await.expect("No error should occur reading back the bytes");
123
124 assert_eq!(n, 5);
125 assert_eq!(output, vec![1,2,3,4,5]);
126 }
127}
128
129pub (crate) struct CappedAsyncRead<T: AsyncRead, const MAX: usize> {
133 inner: T,
134 bytes_read: usize
135}
136
137impl <T: AsyncRead, const MAX: usize> AsyncRead for CappedAsyncRead<T, MAX> {
138 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
139 let inner = unsafe {
141 self.as_mut().map_unchecked_mut(|lr| &mut lr.inner)
142 };
143
144 let new_bytes_read = match inner.poll_read(cx, buf) {
146 Poll::Ready(Ok(n)) => {
147 n
148 },
149 Poll::Ready(Err(e)) => {
150 return Poll::Ready(Err(e))
151 },
152 Poll::Pending => {
153 return Poll::Pending
154 }
155 };
156
157 let bytes_read = unsafe { &mut self.as_mut().get_unchecked_mut().bytes_read };
160 *bytes_read += new_bytes_read;
161 if *bytes_read > MAX {
162 return Poll::Ready(
163 Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "Size limit exceeded"))
164 )
165 }
166
167 Poll::Ready(Ok(new_bytes_read))
169 }
170}
171
172impl <T: AsyncRead, const MAX: usize> CappedAsyncRead<T, MAX> {
173 pub fn new(read: T) -> CappedAsyncRead<T, MAX> {
174 CappedAsyncRead {
175 inner: read,
176 bytes_read: 0
177 }
178 }
179}
180
181#[cfg(test)]
182mod test_capped_reader {
183 use super::*;
184 use futures::AsyncReadExt;
185
186 #[tokio::test]
187 async fn capped_reader_ok_with_0_bytes() {
188 let input = vec![];
190 let mut capped_reader = CappedAsyncRead::<_, 5>::new(&*input);
191
192 let mut output = vec![];
193 let n = capped_reader.read_to_end(&mut output).await.expect("No error should occur reading no bytes");
194 assert_eq!(n, 0);
195 assert_eq!(output, Vec::<u8>::new());
196 }
197
198 #[tokio::test]
199 async fn capped_reader_errors_if_limit_exceeded() {
200 let input = vec![1,2,3,4,5,6];
202 let mut limit_to_5_bytes = CappedAsyncRead::<_, 5>::new(&*input);
204
205 let mut output = vec![];
206 let err = limit_to_5_bytes.read_to_end(&mut output).await.expect_err("Exceeded limit: error expected");
207 assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
208 }
209
210 #[tokio::test]
211 async fn capped_reader_ok_if_limit_not_exceeded() {
212 let input = vec![1,2,3,4,5];
214 let mut limit_to_5_bytes = CappedAsyncRead::<_, 5>::new(&*input);
216
217 let mut output = vec![];
218 let n = limit_to_5_bytes.read_to_end(&mut output).await.expect("Should successfully read all bytes");
219 assert_eq!(n, 5);
220 assert_eq!(output, vec![1,2,3,4,5]);
221 }
222}