seamless/handler/
request.rs

1/*!
2This module provides some functionality for creating a valid request body to send to Seamless as
3part of an [`http::Request`]. Essentially, anything implementing `AsyncRead + Send + Unpin` can be
4provided as a request body, but this module provides a `Bytes` struct which can easily convert a
5[`Vec<u8>`] or [`futures::Stream`] into such a body.
6*/
7use futures::{ Stream, TryStreamExt, io::{ AsyncRead, Cursor } };
8use std::pin::Pin;
9use std::task::{ Poll, Context };
10
11/// Any valid body will implement this trait
12pub trait AsyncReadBody: AsyncRead + Send + Unpin {}
13impl <T: AsyncRead + Send + Unpin> AsyncReadBody for T {}
14
15/// A collection of bytes that can be read from asynchronously via the
16/// [`futures::AsyncRead`] trait.
17pub struct Bytes {
18    variant: BytesVariant
19}
20
21/// To avoid an allocation around a vector of bytes, we use an enum which
22/// we'll match on instead, and only box things when we need to.
23enum BytesVariant {
24    FromVec(Cursor<Vec<u8>>),
25    FromReader(Box<dyn AsyncReadBody>)
26}
27
28impl std::fmt::Debug for Bytes {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_tuple("ByteStream").finish()
31    }
32}
33
34impl AsyncRead for Bytes {
35    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
36        // If Bytes is pinned, the enum variant is also pinned, so here
37        // we just want to convert a Pin<&mut Bytes> into the relevant variant
38        // to poll it. Very much the same sort of idea as:
39        // 
40        // https://github.com/rust-lang/futures-rs/blob/0.3.0-alpha.12/futures-core/src/stream/mod.rs#L88-L102
41        unsafe {
42            match &mut Pin::get_unchecked_mut(self).variant {
43                BytesVariant::FromVec(v) => {
44                    Pin::new_unchecked(v).poll_read(cx, buf)
45                },
46                BytesVariant::FromReader(r) => {
47                    Pin::new_unchecked(r).poll_read(cx, buf)
48                } 
49            }
50        }
51    }
52}
53
54// For simple cases, it's easy to create this struct via `bytes.into()` so that one
55// doesn't have to import anything from this file. Prefer to stream, though.
56impl From<Vec<u8>> for Bytes {
57    fn from(bytes: Vec<u8>) -> Self {
58        Bytes::from_vec(bytes)
59    } 
60}
61
62impl Bytes {
63    /// Turn a vector of bytes into a [`Bytes`]. Prefer to stream where possible.
64    pub fn from_vec(bytes: Vec<u8>) -> Bytes {
65        Bytes { 
66            variant: BytesVariant::FromVec(Cursor::new(bytes)) 
67        }
68    }
69    /// Turn a thing implementing [`futures::AsyncRead`] into [`Bytes`].
70    pub fn from_reader<S: AsyncReadBody + 'static>(reader: S) -> Bytes {
71        Bytes { 
72            variant: BytesVariant::FromReader(Box::new(reader)) 
73        }
74    }
75    /// Turn a thing implementing [`futures::Stream`] into [`Bytes`].
76    pub fn from_stream<S: Stream<Item = std::io::Result<Vec<u8>>> + 'static + Send + Unpin>(stream: S) -> Bytes {
77        Bytes { 
78            variant: BytesVariant::FromReader(Box::new(stream.into_async_read())) 
79        }
80    }
81}
82
83#[cfg(test)]
84mod test_bytes {
85    use super::*;
86    use futures::AsyncReadExt;
87
88    #[tokio::test]
89    async fn can_read_from_vec() {
90        let mut bytes = Bytes::from_vec(vec![1,2,3,4,5]);
91
92        let mut output = vec![];
93        let n = bytes.read_to_end(&mut output).await.expect("No error should occur reading back the bytes");
94
95        assert_eq!(n, 5);
96        assert_eq!(output, vec![1,2,3,4,5]);
97    }
98
99    #[tokio::test]
100    async fn can_read_from_reader() {
101        // We use another instance of Bytes as our reader:
102        let mut bytes = Bytes::from_reader(Bytes::from_vec(vec![1,2,3,4,5]));
103
104        let mut output = vec![];
105        let n = bytes.read_to_end(&mut output).await.expect("No error should occur reading back the bytes");
106
107        assert_eq!(n, 5);
108        assert_eq!(output, vec![1,2,3,4,5]);
109    }
110
111    #[tokio::test]
112    async fn can_read_from_stream() {
113        let mut bytes = Bytes::from_stream(futures::stream::iter(vec![
114            Ok(vec![1]),
115            Ok(vec![2]),
116            Ok(vec![3]),
117            Ok(vec![4]),
118            Ok(vec![5]),
119        ]));
120
121        let mut output = vec![];
122        let n = bytes.read_to_end(&mut output).await.expect("No error should occur reading back the bytes");
123
124        assert_eq!(n, 5);
125        assert_eq!(output, vec![1,2,3,4,5]);
126    }
127}
128
129/// This wraps other `AsyncRead` impls and caps how many bytes can be
130/// read from the underlying reader before an `UnexpectedEof` error is
131/// returned. The cap exists at the type level via `const MAX`.
132pub (crate) struct CappedAsyncRead<T: AsyncRead, const MAX: usize> {
133    inner: T,
134    bytes_read: usize
135}
136
137impl <T: AsyncRead, const MAX: usize> AsyncRead for CappedAsyncRead<T, MAX> {
138    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {        
139        // Structural projection; Pin<CappedAsyncRead> to Pin<T>. Must not access the field in any other way.
140        let inner = unsafe { 
141            self.as_mut().map_unchecked_mut(|lr| &mut lr.inner) 
142        };
143
144        // Read some bytes into the provided buffer:
145        let new_bytes_read = match inner.poll_read(cx, buf) {
146            Poll::Ready(Ok(n)) => {
147                n
148            },
149            Poll::Ready(Err(e)) => {
150                return Poll::Ready(Err(e))
151            },
152            Poll::Pending => {
153                return Poll::Pending
154            }
155        };
156
157        // Bail if we've read more bytes than our limit allows. Non-structural projection here;
158        // Pin<CappedAsyncRead> to &mut usize.
159        let bytes_read = unsafe { &mut self.as_mut().get_unchecked_mut().bytes_read };
160        *bytes_read += new_bytes_read;
161        if *bytes_read > MAX {
162            return Poll::Ready(
163                Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "Size limit exceeded"))
164            )
165        }
166
167        // Return the number of bytes written on this run:
168        Poll::Ready(Ok(new_bytes_read))
169    }
170}
171
172impl <T: AsyncRead, const MAX: usize> CappedAsyncRead<T, MAX> {
173    pub fn new(read: T) -> CappedAsyncRead<T, MAX> {
174        CappedAsyncRead {
175            inner: read,
176            bytes_read: 0
177        }
178    }
179}
180
181#[cfg(test)]
182mod test_capped_reader {
183    use super::*;
184    use futures::AsyncReadExt;
185
186    #[tokio::test]
187    async fn capped_reader_ok_with_0_bytes() {
188        // no bytes to read:
189        let input = vec![];
190        let mut capped_reader = CappedAsyncRead::<_, 5>::new(&*input);
191
192        let mut output = vec![];
193        let n = capped_reader.read_to_end(&mut output).await.expect("No error should occur reading no bytes");
194        assert_eq!(n, 0);
195        assert_eq!(output, Vec::<u8>::new());
196    }
197
198    #[tokio::test]
199    async fn capped_reader_errors_if_limit_exceeded() {
200        // 6 bytes to read:
201        let input = vec![1,2,3,4,5,6];
202        // 5 byte limit though:
203        let mut limit_to_5_bytes = CappedAsyncRead::<_, 5>::new(&*input);
204
205        let mut output = vec![];
206        let err = limit_to_5_bytes.read_to_end(&mut output).await.expect_err("Exceeded limit: error expected");
207        assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
208    }
209
210    #[tokio::test]
211    async fn capped_reader_ok_if_limit_not_exceeded() {
212        // 5 bytes to read:
213        let input = vec![1,2,3,4,5];
214        // 5 byte limit, so all OK:
215        let mut limit_to_5_bytes = CappedAsyncRead::<_, 5>::new(&*input);
216
217        let mut output = vec![];
218        let n = limit_to_5_bytes.read_to_end(&mut output).await.expect("Should successfully read all bytes");
219        assert_eq!(n, 5);
220        assert_eq!(output, vec![1,2,3,4,5]);
221    }
222}