1use bytes::Bytes;
2use http::{HeaderValue, Response, StatusCode};
3use http_body::{Body, SizeHint};
4use http_body_util::Full;
5use pin_project_lite::pin_project;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10 pub struct ResponseBody<B> {
14 #[pin]
15 inner: ResponseBodyInner<B>
16 }
17}
18
19impl<B> ResponseBody<B> {
20 fn payload_too_large() -> Self {
21 Self {
22 inner: ResponseBodyInner::PayloadTooLarge {
23 body: Full::from(BODY),
24 },
25 }
26 }
27
28 pub(crate) fn new(body: B) -> Self {
29 Self {
30 inner: ResponseBodyInner::Body { body },
31 }
32 }
33}
34
35impl<B> Default for ResponseBody<B>
36where
37 B: Default,
38{
39 fn default() -> Self {
40 Self {
41 inner: ResponseBodyInner::Body { body: B::default() },
42 }
43 }
44}
45
46pin_project! {
47 #[project = BodyProj]
48 enum ResponseBodyInner<B> {
49 PayloadTooLarge {
50 #[pin]
51 body: Full<Bytes>,
52 },
53 Body {
54 #[pin]
55 body: B
56 }
57 }
58}
59
60impl<B> Body for ResponseBody<B>
61where
62 B: Body<Data = Bytes>,
63{
64 type Data = Bytes;
65 type Error = B::Error;
66
67 fn poll_frame(
68 self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
71 match self.project().inner.project() {
72 BodyProj::PayloadTooLarge { body } => body.poll_frame(cx).map_err(|err| match err {}),
73 BodyProj::Body { body } => body.poll_frame(cx),
74 }
75 }
76
77 fn is_end_stream(&self) -> bool {
78 match &self.inner {
79 ResponseBodyInner::PayloadTooLarge { body } => body.is_end_stream(),
80 ResponseBodyInner::Body { body } => body.is_end_stream(),
81 }
82 }
83
84 fn size_hint(&self) -> SizeHint {
85 match &self.inner {
86 ResponseBodyInner::PayloadTooLarge { body } => body.size_hint(),
87 ResponseBodyInner::Body { body } => body.size_hint(),
88 }
89 }
90}
91
92const BODY: &[u8] = b"length limit exceeded";
93
94pub(crate) fn create_error_response<B>() -> Response<ResponseBody<B>>
95where
96 B: Body,
97{
98 let mut res = Response::new(ResponseBody::payload_too_large());
99 *res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
100
101 #[allow(clippy::declare_interior_mutable_const)]
102 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
103 res.headers_mut()
104 .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
105
106 res
107}