xitca_web/middleware/
limit.rs1use std::{
4 cell::RefCell,
5 pin::Pin,
6 task::{Context, Poll, ready},
7};
8
9use futures_core::stream::Stream;
10use pin_project_lite::pin_project;
11use xitca_http::Request;
12
13use crate::{
14 body::BodyStream,
15 context::WebContext,
16 error::{BodyError, BodyOverFlow},
17 service::{Service, ready::ReadyService},
18};
19
20#[derive(Copy, Clone)]
29pub struct Limit {
30 request_body_size: usize,
31}
32
33impl Default for Limit {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl Limit {
40 pub const fn new() -> Self {
41 Self {
42 request_body_size: usize::MAX,
43 }
44 }
45
46 pub fn set_request_body_max_size(mut self, size: usize) -> Self {
48 self.request_body_size = size;
49 self
50 }
51}
52
53impl<S, E> Service<Result<S, E>> for Limit {
54 type Response = LimitService<S>;
55 type Error = E;
56
57 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
58 res.map(|service| LimitService { service, limit: *self })
59 }
60}
61
62pub struct LimitService<S> {
63 service: S,
64 limit: Limit,
65}
66
67impl<'r, S, C, B, Res, Err> Service<WebContext<'r, C, B>> for LimitService<S>
68where
69 B: BodyStream + Default,
70 S: for<'r2> Service<WebContext<'r2, C, LimitBody<B>>, Response = Res, Error = Err>,
71{
72 type Response = Res;
73 type Error = Err;
74
75 async fn call(&self, mut ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
76 let (parts, ext) = ctx.take_request().into_parts();
77 let state = ctx.ctx;
78 let (ext, body) = ext.replace_body(());
79 let mut body = RefCell::new(LimitBody::new(body, self.limit.request_body_size));
80 let mut req = Request::from_parts(parts, ext);
81
82 self.service
83 .call(WebContext::new(&mut req, &mut body, state))
84 .await
85 .inspect_err(|_| {
86 let body = body.into_inner().into_inner();
87 *ctx.body_borrow_mut() = body;
88 })
89 }
90}
91
92impl<S> ReadyService for LimitService<S>
93where
94 S: ReadyService,
95{
96 type Ready = S::Ready;
97
98 #[inline]
99 async fn ready(&self) -> Self::Ready {
100 self.service.ready().await
101 }
102}
103
104pin_project! {
105 pub struct LimitBody<B> {
106 limit: usize,
107 record: usize,
108 #[pin]
109 body: B
110 }
111}
112
113impl<B: Default> Default for LimitBody<B> {
114 fn default() -> Self {
115 Self {
116 limit: 0,
117 record: 0,
118 body: B::default(),
119 }
120 }
121}
122
123impl<B> LimitBody<B> {
124 const fn new(body: B, limit: usize) -> Self {
125 Self { limit, record: 0, body }
126 }
127
128 fn into_inner(self) -> B {
129 self.body
130 }
131}
132
133impl<B> Stream for LimitBody<B>
134where
135 B: BodyStream,
136{
137 type Item = Result<B::Chunk, BodyError>;
138
139 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 let this = self.project();
141
142 if *this.record >= *this.limit {
143 return Poll::Ready(Some(Err(BodyError::from(BodyOverFlow { limit: *this.limit }))));
146 }
147
148 match ready!(this.body.poll_next(cx)) {
149 Some(res) => {
150 let chunk = res.map_err(Into::into)?;
151 *this.record += chunk.as_ref().len();
152 Poll::Ready(Some(Ok(chunk)))
154 }
155 None => Poll::Ready(None),
156 }
157 }
158}
159
160#[cfg(test)]
161mod test {
162 use core::{future::poll_fn, pin::pin};
163
164 use xitca_unsafe_collection::futures::NowOrPanic;
165
166 use crate::{
167 App,
168 body::BoxBody,
169 bytes::Bytes,
170 handler::{body::Body, handler_service},
171 http::{StatusCode, WebRequest},
172 test::collect_body,
173 };
174
175 use super::*;
176
177 const CHUNK: &[u8] = b"hello,world!";
178
179 async fn handler<B: BodyStream>(Body(body): Body<B>) -> String {
180 let mut body = pin!(body);
181
182 let chunk = poll_fn(|cx| body.as_mut().poll_next(cx)).await.unwrap().ok().unwrap();
183
184 let err = poll_fn(|cx| body.as_mut().poll_next(cx)).await.unwrap().err().unwrap();
185 let err = crate::error::Error::from(err.into());
186 assert_eq!(
187 err.to_string(),
188 format!("body size reached limit: {} bytes", CHUNK.len())
189 );
190
191 let mut ctx = WebContext::new_test(());
192 let res = err.call(ctx.as_web_ctx()).await.unwrap();
193 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
194
195 std::str::from_utf8(chunk.as_ref()).unwrap().to_string()
196 }
197
198 #[test]
199 fn request_body_over_limit() {
200 use futures_util::stream::{self, StreamExt};
201
202 let item = || async { Ok::<_, BodyError>(Bytes::from_static(CHUNK)) };
203
204 let body = stream::once(item()).chain(stream::once(item()));
205 let req = WebRequest::default().map(|ext| ext.map_body(|_: ()| BoxBody::new(body).into()));
206
207 let body = App::new()
208 .at("/", handler_service(handler))
209 .enclosed(Limit::new().set_request_body_max_size(CHUNK.len()))
210 .finish()
211 .call(())
212 .now_or_panic()
213 .unwrap()
214 .call(req)
215 .now_or_panic()
216 .ok()
217 .unwrap()
218 .into_body();
219
220 let body = collect_body(body).now_or_panic().unwrap();
221
222 assert_eq!(body, CHUNK);
223 }
224}