1use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use http_body::{Body, Frame, SizeHint};
8use http_body_util::combinators::UnsyncBoxBody;
9use http_body_util::BodyExt;
10use http_body_util::Full;
11
12pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
14
15pub type ReqBody = UnsyncBoxBody<Bytes, BoxError>;
21
22type BoxStreamBody = Pin<Box<dyn Body<Data = Bytes, Error = BoxError> + Send>>;
24
25pub struct RespBody {
32 kind: BodyKind,
33}
34
35enum BodyKind {
37 Full(Full<Bytes>),
39 Stream(BoxStreamBody),
41}
42
43impl RespBody {
44 pub fn new(body: Bytes) -> Self {
46 Self {
47 kind: BodyKind::Full(Full::new(body)),
48 }
49 }
50
51 pub fn stream<B>(body: B) -> Self
56 where
57 B: Body<Data = Bytes, Error = BoxError> + Send + 'static,
58 {
59 Self {
60 kind: BodyKind::Stream(Box::pin(body)),
61 }
62 }
63
64 pub fn stream_capped<B>(body: B, max_bytes: u64) -> Self
73 where
74 B: Body<Data = Bytes, Error = BoxError> + Send + 'static,
75 {
76 Self {
77 kind: BodyKind::Stream(Box::pin(CappedBody {
78 inner: Box::pin(body),
79 emitted: 0,
80 limit: max_bytes,
81 })),
82 }
83 }
84}
85
86struct CappedBody {
88 inner: BoxStreamBody,
89 emitted: u64,
90 limit: u64,
91}
92
93impl Body for CappedBody {
94 type Data = Bytes;
95 type Error = BoxError;
96
97 fn poll_frame(
98 self: Pin<&mut Self>,
99 cx: &mut Context<'_>,
100 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
101 let this = self.get_mut();
102 match this.inner.as_mut().poll_frame(cx) {
103 Poll::Ready(Some(Ok(frame))) => {
104 if let Some(data) = frame.data_ref() {
105 this.emitted = this.emitted.saturating_add(data.len() as u64);
106 if this.emitted > this.limit {
107 return Poll::Ready(Some(Err(format!(
108 "response body exceeded the {}-byte limit",
109 this.limit
110 )
111 .into())));
112 }
113 }
114 Poll::Ready(Some(Ok(frame)))
115 }
116 other => other,
117 }
118 }
119
120 fn is_end_stream(&self) -> bool {
121 self.inner.is_end_stream()
122 }
123
124 fn size_hint(&self) -> SizeHint {
125 self.inner.size_hint()
126 }
127}
128
129impl Body for RespBody {
130 type Data = Bytes;
131 type Error = BoxError;
132
133 fn poll_frame(
134 self: Pin<&mut Self>,
135 cx: &mut Context<'_>,
136 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
137 match &mut self.get_mut().kind {
140 BodyKind::Full(full) => Pin::new(full)
142 .poll_frame(cx)
143 .map_err(|never| match never {}),
144 BodyKind::Stream(stream) => stream.as_mut().poll_frame(cx),
145 }
146 }
147
148 fn is_end_stream(&self) -> bool {
149 match &self.kind {
150 BodyKind::Full(full) => full.is_end_stream(),
151 BodyKind::Stream(stream) => stream.is_end_stream(),
152 }
153 }
154
155 fn size_hint(&self) -> SizeHint {
156 match &self.kind {
157 BodyKind::Full(full) => full.size_hint(),
158 BodyKind::Stream(stream) => stream.size_hint(),
159 }
160 }
161}
162
163pub fn box_body<B>(body: B) -> ReqBody
165where
166 B: hyper::body::Body<Data = Bytes> + Send + 'static,
167 B::Error: Into<BoxError>,
168{
169 body.map_err(Into::into).boxed_unsync()
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use http_body_util::StreamBody;
176
177 async fn collect_chunks(body: RespBody) -> Vec<Bytes> {
178 let collected = body.collect().await.expect("body collects");
179 vec![collected.to_bytes()]
181 }
182
183 #[tokio::test]
184 async fn full_body_yields_its_buffer() {
185 let body = RespBody::new(Bytes::from_static(b"hello"));
186 let chunks = collect_chunks(body).await;
187 assert_eq!(chunks, vec![Bytes::from_static(b"hello")]);
188 }
189
190 #[tokio::test]
191 async fn streaming_body_yields_each_frame() {
192 let frames = futures_util::stream::iter(vec![
194 Ok::<_, BoxError>(Frame::data(Bytes::from_static(b"a"))),
195 Ok(Frame::data(Bytes::from_static(b"b"))),
196 Ok(Frame::data(Bytes::from_static(b"c"))),
197 ]);
198 let body = RespBody::stream(StreamBody::new(frames));
199
200 let mut out = Vec::new();
201 let mut body = body;
202 loop {
203 let frame = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await;
204 match frame {
205 Some(Ok(frame)) => {
206 if let Ok(data) = frame.into_data() {
207 out.push(data);
208 }
209 }
210 Some(Err(error)) => panic!("unexpected body error: {error}"),
211 None => break,
212 }
213 }
214
215 assert_eq!(
216 out,
217 vec![
218 Bytes::from_static(b"a"),
219 Bytes::from_static(b"b"),
220 Bytes::from_static(b"c"),
221 ]
222 );
223 }
224
225 #[tokio::test]
226 async fn capped_stream_errors_once_it_exceeds_the_limit() {
227 let frames = futures_util::stream::iter(vec![
230 Ok::<_, BoxError>(Frame::data(Bytes::from_static(b"aaaa"))),
231 Ok(Frame::data(Bytes::from_static(b"bbbb"))),
232 Ok(Frame::data(Bytes::from_static(b"cccc"))),
233 ]);
234 let mut body = RespBody::stream_capped(StreamBody::new(frames), 10);
235
236 let mut delivered = 0usize;
237 let mut errored = false;
238 loop {
239 let frame = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await;
240 match frame {
241 Some(Ok(frame)) => {
242 if let Ok(data) = frame.into_data() {
243 delivered += data.len();
244 }
245 }
246 Some(Err(_)) => {
247 errored = true;
248 break;
249 }
250 None => break,
251 }
252 }
253
254 assert!(errored, "the body should error once it exceeds the cap");
255 assert_eq!(delivered, 8, "only the frames within the cap are delivered");
256 }
257}