Skip to main content

tork_core/
body.rs

1//! Body type aliases shared across the runtime.
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use http_body::{Body, Frame, SizeHint};
8use http_body_util::combinators::UnsyncBoxBody;
9use http_body_util::BodyExt;
10use http_body_util::Full;
11
12/// Boxed, thread-safe error type carried by an erased request body.
13pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
14
15/// The inbound request body type.
16///
17/// The concrete body produced by Hyper is erased into this boxed body so the
18/// runtime is agnostic to where a request body comes from. This also makes the
19/// request context easy to construct in tests.
20pub type ReqBody = UnsyncBoxBody<Bytes, BoxError>;
21
22/// A boxed streaming response body.
23type BoxStreamBody = Pin<Box<dyn Body<Data = Bytes, Error = BoxError> + Send>>;
24
25/// The outbound response body.
26///
27/// A response body is either fully buffered (the common case: a JSON payload, an
28/// error body, a static asset) or streaming (Server-Sent Events, and other
29/// frame-at-a-time responses). Both share this type, so handler and middleware
30/// signatures do not change between them.
31pub struct RespBody {
32    kind: BodyKind,
33}
34
35/// The two shapes a [`RespBody`] can take.
36enum BodyKind {
37    /// A single, contiguous, already-available buffer.
38    Full(Full<Bytes>),
39    /// A body that yields frames over time.
40    Stream(BoxStreamBody),
41}
42
43impl RespBody {
44    /// Builds a fully-buffered body from a contiguous buffer.
45    pub fn new(body: Bytes) -> Self {
46        Self {
47            kind: BodyKind::Full(Full::new(body)),
48        }
49    }
50
51    /// Builds a streaming body that yields frames over time.
52    ///
53    /// Used by streaming responses such as Server-Sent Events, and available for
54    /// returning a custom frame-at-a-time body.
55    pub fn stream<B>(body: B) -> Self
56    where
57        B: Body<Data = Bytes, Error = BoxError> + Send + 'static,
58    {
59        Self {
60            kind: BodyKind::Stream(Box::pin(body)),
61        }
62    }
63
64    /// Builds a streaming body that fails once it has emitted more than
65    /// `max_bytes`.
66    ///
67    /// Unbounded streaming responses (a file download backed by a generator, say)
68    /// have no inherent size limit; a runaway or buggy producer can stream without
69    /// end. Wrapping the body caps the total bytes it may emit, erroring the
70    /// response stream past the limit so it cannot run forever. Server-Sent Events
71    /// are intentionally open-ended and use [`stream`](RespBody::stream) instead.
72    pub fn stream_capped<B>(body: B, max_bytes: u64) -> Self
73    where
74        B: Body<Data = Bytes, Error = BoxError> + Send + 'static,
75    {
76        Self {
77            kind: BodyKind::Stream(Box::pin(CappedBody {
78                inner: Box::pin(body),
79                emitted: 0,
80                limit: max_bytes,
81            })),
82        }
83    }
84}
85
86/// A streaming body that errors once it has emitted more than its byte limit.
87struct CappedBody {
88    inner: BoxStreamBody,
89    emitted: u64,
90    limit: u64,
91}
92
93impl Body for CappedBody {
94    type Data = Bytes;
95    type Error = BoxError;
96
97    fn poll_frame(
98        self: Pin<&mut Self>,
99        cx: &mut Context<'_>,
100    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
101        let this = self.get_mut();
102        match this.inner.as_mut().poll_frame(cx) {
103            Poll::Ready(Some(Ok(frame))) => {
104                if let Some(data) = frame.data_ref() {
105                    this.emitted = this.emitted.saturating_add(data.len() as u64);
106                    if this.emitted > this.limit {
107                        return Poll::Ready(Some(Err(format!(
108                            "response body exceeded the {}-byte limit",
109                            this.limit
110                        )
111                        .into())));
112                    }
113                }
114                Poll::Ready(Some(Ok(frame)))
115            }
116            other => other,
117        }
118    }
119
120    fn is_end_stream(&self) -> bool {
121        self.inner.is_end_stream()
122    }
123
124    fn size_hint(&self) -> SizeHint {
125        self.inner.size_hint()
126    }
127}
128
129impl Body for RespBody {
130    type Data = Bytes;
131    type Error = BoxError;
132
133    fn poll_frame(
134        self: Pin<&mut Self>,
135        cx: &mut Context<'_>,
136    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
137        // `RespBody` is `Unpin` (its variants are), so projecting through `get_mut`
138        // is sound and keeps the delegation simple.
139        match &mut self.get_mut().kind {
140            // `Full`'s error type is `Infallible`, so it never yields an error.
141            BodyKind::Full(full) => Pin::new(full)
142                .poll_frame(cx)
143                .map_err(|never| match never {}),
144            BodyKind::Stream(stream) => stream.as_mut().poll_frame(cx),
145        }
146    }
147
148    fn is_end_stream(&self) -> bool {
149        match &self.kind {
150            BodyKind::Full(full) => full.is_end_stream(),
151            BodyKind::Stream(stream) => stream.is_end_stream(),
152        }
153    }
154
155    fn size_hint(&self) -> SizeHint {
156        match &self.kind {
157            BodyKind::Full(full) => full.size_hint(),
158            BodyKind::Stream(stream) => stream.size_hint(),
159        }
160    }
161}
162
163/// Erases any compatible HTTP body into the runtime's [`ReqBody`] type.
164pub fn box_body<B>(body: B) -> ReqBody
165where
166    B: hyper::body::Body<Data = Bytes> + Send + 'static,
167    B::Error: Into<BoxError>,
168{
169    body.map_err(Into::into).boxed_unsync()
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use http_body_util::StreamBody;
176
177    async fn collect_chunks(body: RespBody) -> Vec<Bytes> {
178        let collected = body.collect().await.expect("body collects");
179        // Re-stream the aggregated bytes as a single chunk for a simple assertion.
180        vec![collected.to_bytes()]
181    }
182
183    #[tokio::test]
184    async fn full_body_yields_its_buffer() {
185        let body = RespBody::new(Bytes::from_static(b"hello"));
186        let chunks = collect_chunks(body).await;
187        assert_eq!(chunks, vec![Bytes::from_static(b"hello")]);
188    }
189
190    #[tokio::test]
191    async fn streaming_body_yields_each_frame() {
192        // A stream of three data frames, erased into a streaming RespBody.
193        let frames = futures_util::stream::iter(vec![
194            Ok::<_, BoxError>(Frame::data(Bytes::from_static(b"a"))),
195            Ok(Frame::data(Bytes::from_static(b"b"))),
196            Ok(Frame::data(Bytes::from_static(b"c"))),
197        ]);
198        let body = RespBody::stream(StreamBody::new(frames));
199
200        let mut out = Vec::new();
201        let mut body = body;
202        loop {
203            let frame = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await;
204            match frame {
205                Some(Ok(frame)) => {
206                    if let Ok(data) = frame.into_data() {
207                        out.push(data);
208                    }
209                }
210                Some(Err(error)) => panic!("unexpected body error: {error}"),
211                None => break,
212            }
213        }
214
215        assert_eq!(
216            out,
217            vec![
218                Bytes::from_static(b"a"),
219                Bytes::from_static(b"b"),
220                Bytes::from_static(b"c"),
221            ]
222        );
223    }
224
225    #[tokio::test]
226    async fn capped_stream_errors_once_it_exceeds_the_limit() {
227        // Three 4-byte frames (12 bytes total) under a 10-byte cap: the first two
228        // pass, and the frame that pushes the total over the limit errors.
229        let frames = futures_util::stream::iter(vec![
230            Ok::<_, BoxError>(Frame::data(Bytes::from_static(b"aaaa"))),
231            Ok(Frame::data(Bytes::from_static(b"bbbb"))),
232            Ok(Frame::data(Bytes::from_static(b"cccc"))),
233        ]);
234        let mut body = RespBody::stream_capped(StreamBody::new(frames), 10);
235
236        let mut delivered = 0usize;
237        let mut errored = false;
238        loop {
239            let frame = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await;
240            match frame {
241                Some(Ok(frame)) => {
242                    if let Ok(data) = frame.into_data() {
243                        delivered += data.len();
244                    }
245                }
246                Some(Err(_)) => {
247                    errored = true;
248                    break;
249                }
250                None => break,
251            }
252        }
253
254        assert!(errored, "the body should error once it exceeds the cap");
255        assert_eq!(delivered, 8, "only the frames within the cap are delivered");
256    }
257}