trillium_http/
received_body.rs

1use crate::{copy, http_config::DEFAULT_CONFIG, Body, Buffer, HttpConfig, MutCow};
2use encoding_rs::Encoding;
3use futures_lite::{ready, AsyncRead, AsyncReadExt, AsyncWrite, Stream};
4use httparse::{InvalidChunkSize, Status};
5use std::{
6    fmt::{self, Debug, Formatter},
7    future::{Future, IntoFuture},
8    io::{self, ErrorKind},
9    iter,
10    pin::Pin,
11    task::{Context, Poll},
12};
13use Poll::{Pending, Ready};
14use ReceivedBodyState::{Chunked, End, FixedLength, PartialChunkSize, Start};
15
16mod chunked;
17mod fixed_length;
18
19/** A received http body
20
21This type represents a body that will be read from the underlying
22transport, which it may either borrow from a [`Conn`](crate::Conn) or
23own.
24
25```rust
26# trillium_testing::block_on(async {
27# use trillium_http::{Method, Conn};
28let mut conn = Conn::new_synthetic(Method::Get, "/", "hello");
29let body = conn.request_body().await;
30assert_eq!(body.read_string().await?, "hello");
31# trillium_http::Result::Ok(()) }).unwrap();
32```
33
34## Bounds checking
35
36Every `ReceivedBody` has a maximum length beyond which it will return an error, expressed as a
37u64. To override this on the specific `ReceivedBody`, use [`ReceivedBody::with_max_len`] or
38[`ReceivedBody::set_max_len`]
39
40The default maximum length is currently set to 500mb. In the next semver-minor release, this value
41will decrease substantially.
42
43## Large chunks, small read buffers
44
45Attempting to read a chunked body with a buffer that is shorter than the chunk size in hex will
46result in an error. This limitation is temporary.
47*/
48
49pub struct ReceivedBody<'conn, Transport> {
50    content_length: Option<u64>,
51    buffer: MutCow<'conn, Buffer>,
52    transport: Option<MutCow<'conn, Transport>>,
53    state: MutCow<'conn, ReceivedBodyState>,
54    on_completion: Option<Box<dyn Fn(Transport) + Send + Sync + 'static>>,
55    encoding: &'static Encoding,
56    max_len: u64,
57    initial_len: usize,
58    copy_loops_per_yield: usize,
59    max_preallocate: usize,
60}
61
62fn slice_from(min: u64, buf: &[u8]) -> Option<&[u8]> {
63    buf.get(usize::try_from(min).unwrap_or(usize::MAX)..)
64        .filter(|buf| !buf.is_empty())
65}
66
67impl<'conn, Transport> ReceivedBody<'conn, Transport>
68where
69    Transport: AsyncRead + Unpin + Send + Sync + 'static,
70{
71    #[allow(missing_docs)]
72    #[doc(hidden)]
73    pub fn new(
74        content_length: Option<u64>,
75        buffer: impl Into<MutCow<'conn, Buffer>>,
76        transport: impl Into<MutCow<'conn, Transport>>,
77        state: impl Into<MutCow<'conn, ReceivedBodyState>>,
78        on_completion: Option<Box<dyn Fn(Transport) + Send + Sync + 'static>>,
79        encoding: &'static Encoding,
80    ) -> Self {
81        Self::new_with_config(
82            content_length,
83            buffer,
84            transport,
85            state,
86            on_completion,
87            encoding,
88            &DEFAULT_CONFIG,
89        )
90    }
91
92    #[allow(missing_docs)]
93    #[doc(hidden)]
94    pub(crate) fn new_with_config(
95        content_length: Option<u64>,
96        buffer: impl Into<MutCow<'conn, Buffer>>,
97        transport: impl Into<MutCow<'conn, Transport>>,
98        state: impl Into<MutCow<'conn, ReceivedBodyState>>,
99        on_completion: Option<Box<dyn Fn(Transport) + Send + Sync + 'static>>,
100        encoding: &'static Encoding,
101        config: &HttpConfig,
102    ) -> Self {
103        Self {
104            content_length,
105            buffer: buffer.into(),
106            transport: Some(transport.into()),
107            state: state.into(),
108            on_completion,
109            encoding,
110            max_len: config.received_body_max_len,
111            initial_len: config.received_body_initial_len,
112            copy_loops_per_yield: config.copy_loops_per_yield,
113            max_preallocate: config.received_body_max_preallocate,
114        }
115    }
116
117    /**
118    Returns the content-length of this body, if available. This
119    usually is derived from the content-length header. If the http
120    request or response that this body is attached to uses
121    transfer-encoding chunked, this will be None.
122
123    ```rust
124    # trillium_testing::block_on(async {
125    # use trillium_http::{Method, Conn};
126    let mut conn = Conn::new_synthetic(Method::Get, "/", "hello");
127    let body = conn.request_body().await;
128    assert_eq!(body.content_length(), Some(5));
129    # trillium_http::Result::Ok(()) }).unwrap();
130    ```
131    */
132    pub fn content_length(&self) -> Option<u64> {
133        self.content_length
134    }
135
136    /// # Reads entire body to String.
137    ///
138    /// This uses the encoding determined by the content-type (mime)
139    /// charset. If an encoding problem is encountered, the String
140    /// returned by [`ReceivedBody::read_string`] will contain utf8
141    /// replacement characters.
142    ///
143    /// Note that this can only be performed once per Conn, as the
144    /// underlying data is not cached anywhere. This is the only copy of
145    /// the body contents.
146    ///
147    /// # Errors
148    ///
149    /// This will return an error if there is an IO error on the
150    /// underlying transport such as a disconnect
151    ///
152    /// This will also return an error if the length exceeds the maximum length. To override this
153    /// value on this specific body, use [`ReceivedBody::with_max_len`] or
154    /// [`ReceivedBody::set_max_len`]
155    pub async fn read_string(self) -> crate::Result<String> {
156        let encoding = self.encoding();
157        let bytes = self.read_bytes().await?;
158        let (s, _, _) = encoding.decode(&bytes);
159        Ok(s.to_string())
160    }
161
162    fn owns_transport(&self) -> bool {
163        self.transport.as_ref().is_some_and(MutCow::is_owned)
164    }
165
166    /// Set the maximum length that can be read from this body before error
167    ///
168    /// See also [`HttpConfig::received_body_max_len`][HttpConfig#received_body_max_len]
169    pub fn set_max_len(&mut self, max_len: u64) {
170        self.max_len = max_len;
171    }
172
173    /// chainable setter for the maximum length that can be read from this body before error
174    ///
175    /// See also [`HttpConfig::received_body_max_len`][HttpConfig#received_body_max_len]
176    #[must_use]
177    pub fn with_max_len(mut self, max_len: u64) -> Self {
178        self.set_max_len(max_len);
179        self
180    }
181
182    /// Similar to [`ReceivedBody::read_string`], but returns the raw bytes. This is useful for
183    /// bodies that are not text.
184    ///
185    /// You can use this in conjunction with `encoding` if you need different handling of malformed
186    /// character encoding than the lossy conversion provided by [`ReceivedBody::read_string`].
187    ///
188    /// # Errors
189    ///
190    /// This will return an error if there is an IO error on the underlying transport such as a
191    /// disconnect
192    ///
193    /// This will also return an error if the length exceeds
194    /// [`received_body_max_len`][HttpConfig::with_received_body_max_len]. To override this value on
195    /// this specific body, use [`ReceivedBody::with_max_len`] or [`ReceivedBody::set_max_len`]
196    pub async fn read_bytes(mut self) -> crate::Result<Vec<u8>> {
197        let mut vec = if let Some(len) = self.content_length {
198            if len > self.max_len {
199                return Err(crate::Error::ReceivedBodyTooLong(self.max_len));
200            }
201
202            let len = usize::try_from(len)
203                .map_err(|_| crate::Error::ReceivedBodyTooLong(self.max_len))?;
204
205            Vec::with_capacity(len.min(self.max_preallocate))
206        } else {
207            Vec::with_capacity(self.initial_len)
208        };
209
210        self.read_to_end(&mut vec).await?;
211        Ok(vec)
212    }
213
214    /**
215    returns the character encoding of this body, usually determined from the content type
216    (mime-type) of the associated Conn.
217    */
218    pub fn encoding(&self) -> &'static Encoding {
219        self.encoding
220    }
221
222    fn read_raw(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
223        if let Some(transport) = self.transport.as_deref_mut() {
224            read_raw(&mut self.buffer, transport, cx, buf)
225        } else {
226            Ready(Err(ErrorKind::NotConnected.into()))
227        }
228    }
229
230    /**
231    Consumes the remainder of this body from the underlying transport by reading it to the end and
232    discarding the contents. This is important for http1.1 keepalive, but most of the time you do
233    not need to directly call this. It returns the number of bytes consumed.
234
235    # Errors
236
237    This will return an [`std::io::Result::Err`] if there is an io error on the underlying
238    transport, such as a disconnect
239    */
240    #[allow(clippy::missing_errors_doc)] // false positive
241    pub async fn drain(self) -> io::Result<u64> {
242        let copy_loops_per_yield = self.copy_loops_per_yield;
243        copy(self, futures_lite::io::sink(), copy_loops_per_yield).await
244    }
245}
246
247impl<'a, Transport> IntoFuture for ReceivedBody<'a, Transport>
248where
249    Transport: AsyncRead + Unpin + Send + Sync + 'static,
250{
251    type Output = crate::Result<String>;
252
253    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
254
255    fn into_future(self) -> Self::IntoFuture {
256        Box::pin(async move { self.read_string().await })
257    }
258}
259
260impl<T> ReceivedBody<'static, T> {
261    /// takes the static transport from this received body
262    pub fn take_transport(&mut self) -> Option<T> {
263        self.transport.take().map(MutCow::unwrap_owned)
264    }
265}
266
267fn read_raw<Transport>(
268    self_buffer: &mut Buffer,
269    transport: &mut Transport,
270    cx: &mut Context<'_>,
271    buf: &mut [u8],
272) -> Poll<io::Result<usize>>
273where
274    Transport: AsyncRead + Unpin + Send + Sync + 'static,
275{
276    if self_buffer.is_empty() {
277        Pin::new(transport).poll_read(cx, buf)
278    } else if self_buffer.len() >= buf.len() {
279        let len = buf.len();
280        buf.copy_from_slice(&self_buffer[..len]);
281        self_buffer.ignore_front(len);
282        Ready(Ok(len))
283    } else {
284        let self_buffer_len = self_buffer.len();
285        buf[..self_buffer_len].copy_from_slice(self_buffer);
286        self_buffer.truncate(0);
287        match Pin::new(transport).poll_read(cx, &mut buf[self_buffer_len..]) {
288            Ready(Ok(additional)) => Ready(Ok(additional + self_buffer_len)),
289            Pending => Ready(Ok(self_buffer_len)),
290            other @ Ready(_) => other,
291        }
292    }
293}
294
295const STREAM_READ_BUF_LENGTH: usize = 128;
296impl<'conn, Transport> Stream for ReceivedBody<'conn, Transport>
297where
298    Transport: AsyncRead + Unpin + Send + Sync + 'static,
299{
300    type Item = Vec<u8>;
301
302    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303        let mut bytes = 0;
304        let mut vec = vec![0; STREAM_READ_BUF_LENGTH];
305        loop {
306            match Pin::new(&mut *self).poll_read(cx, &mut vec[bytes..]) {
307                Pending if bytes == 0 => return Pending,
308                Ready(Ok(0)) if bytes == 0 => return Ready(None),
309                Pending | Ready(Ok(0)) => {
310                    vec.truncate(bytes);
311                    return Ready(Some(vec));
312                }
313                Ready(Ok(new_bytes)) => {
314                    bytes += new_bytes;
315                    vec.extend(iter::repeat(0).take(bytes + STREAM_READ_BUF_LENGTH - vec.len()));
316                }
317                Ready(Err(error)) => {
318                    log::error!("got {error:?} in ReceivedBody stream");
319                    return Ready(None);
320                }
321            }
322        }
323    }
324}
325
326type StateOutput = Poll<io::Result<(ReceivedBodyState, usize)>>;
327
328impl<'conn, Transport> ReceivedBody<'conn, Transport>
329where
330    Transport: AsyncRead + Unpin + Send + Sync + 'static,
331{
332    #[inline]
333    fn handle_start(&mut self) -> StateOutput {
334        Ready(Ok((
335            match self.content_length {
336                Some(0) => End,
337
338                Some(total_length) if total_length < self.max_len => FixedLength {
339                    current_index: 0,
340                    total: total_length,
341                },
342
343                Some(_) => {
344                    return Ready(Err(io::Error::new(
345                        ErrorKind::Unsupported,
346                        "content too long",
347                    )))
348                }
349
350                None => Chunked {
351                    remaining: 0,
352                    total: 0,
353                },
354            },
355            0,
356        )))
357    }
358}
359
360impl<'conn, Transport> AsyncRead for ReceivedBody<'conn, Transport>
361where
362    Transport: AsyncRead + Unpin + Send + Sync + 'static,
363{
364    fn poll_read(
365        mut self: Pin<&mut Self>,
366        cx: &mut Context<'_>,
367        buf: &mut [u8],
368    ) -> Poll<io::Result<usize>> {
369        for _ in 0..self.copy_loops_per_yield {
370            let (new_body_state, bytes) = ready!(match *self.state {
371                Start => self.handle_start(),
372                Chunked { remaining, total } => self.handle_chunked(cx, buf, remaining, total),
373                PartialChunkSize { total } => self.handle_partial(cx, buf, total),
374                FixedLength {
375                    current_index,
376                    total,
377                } => self.handle_fixed_length(cx, buf, current_index, total),
378                End => Ready(Ok((End, 0))),
379            })?;
380
381            *self.state = new_body_state;
382
383            if *self.state == End {
384                if self.on_completion.is_some() && self.owns_transport() {
385                    let transport = self.transport.take().unwrap().unwrap_owned();
386                    let on_completion = self.on_completion.take().unwrap();
387                    on_completion(transport);
388                }
389                return Ready(Ok(bytes));
390            } else if bytes != 0 {
391                return Ready(Ok(bytes));
392            }
393        }
394
395        cx.waker().wake_by_ref();
396        Pending
397    }
398}
399
400impl<'conn, Transport> Debug for ReceivedBody<'conn, Transport> {
401    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
402        f.debug_struct("RequestBody")
403            .field("state", &*self.state)
404            .field("content_length", &self.content_length)
405            .field("buffer", &"..")
406            .field("on_completion", &self.on_completion.is_some())
407            .finish()
408    }
409}
410
411#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
412#[allow(missing_docs)]
413#[doc(hidden)]
414pub enum ReceivedBodyState {
415    /// initial state
416    #[default]
417    Start,
418
419    /// read state for a chunked-encoded body. the number of bytes that have been read from the
420    /// current chunk is the difference between remaining and total.
421    Chunked {
422        /// remaining indicates the bytes left _in the current
423        /// chunk_. initial state is zero.
424        remaining: u64,
425
426        /// total indicates the absolute number of bytes read from all chunks
427        total: u64,
428    },
429
430    PartialChunkSize {
431        total: u64,
432    },
433
434    /// read state for a fixed-length body.
435    FixedLength {
436        /// current index represents the bytes that have already been
437        /// read. initial state is zero
438        current_index: u64,
439
440        /// total length indicates the claimed length, usually
441        /// determined by the content-length header
442        total: u64,
443    },
444
445    /// the terminal read state
446    End,
447}
448
449impl<Transport> From<ReceivedBody<'static, Transport>> for Body
450where
451    Transport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
452{
453    fn from(rb: ReceivedBody<'static, Transport>) -> Self {
454        let len = rb.content_length;
455        Body::new_streaming(rb, len)
456    }
457}