Skip to main content

xitca_web/middleware/
limit.rs

1//! limitation middleware.
2
3use std::{
4    cell::RefCell,
5    pin::Pin,
6    task::{Context, Poll, ready},
7};
8
9use pin_project_lite::pin_project;
10use xitca_http::Request;
11
12use crate::{
13    body::{Body, BodyStream, Frame},
14    context::WebContext,
15    error::{BodyError, BodyOverFlow},
16    service::{Service, ready::ReadyService},
17};
18
19/// General purposed limitation middleware. Limiting request/response body size etc.
20///
21/// # Type mutation
22/// [`Limit`] would mutate request body type from `B` to [`Limit<B>`]. Service enclosed by it must be
23/// able to handle it's mutation or utilize [`TypeEraser`] to erase the mutation.
24/// For more explanation please reference [`type mutation`](crate::middleware#type-mutation).
25///
26/// [`TypeEraser`]: crate::middleware::eraser::TypeEraser
27#[derive(Copy, Clone)]
28pub struct Limit {
29    request_body_size: usize,
30}
31
32impl Default for Limit {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl Limit {
39    pub const fn new() -> Self {
40        Self {
41            request_body_size: usize::MAX,
42        }
43    }
44
45    /// Set max size in byte unit the request body can be.
46    pub fn set_request_body_max_size(mut self, size: usize) -> Self {
47        self.request_body_size = size;
48        self
49    }
50}
51
52impl<S, E> Service<Result<S, E>> for Limit {
53    type Response = LimitService<S>;
54    type Error = E;
55
56    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
57        res.map(|service| LimitService { service, limit: *self })
58    }
59}
60
61pub struct LimitService<S> {
62    service: S,
63    limit: Limit,
64}
65
66impl<'r, S, C, B, Res, Err> Service<WebContext<'r, C, B>> for LimitService<S>
67where
68    B: BodyStream + Default,
69    S: for<'r2> Service<WebContext<'r2, C, LimitBody<B>>, Response = Res, Error = Err>,
70{
71    type Response = Res;
72    type Error = Err;
73
74    async fn call(&self, mut ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
75        let (parts, ext) = ctx.take_request().into_parts();
76        let state = ctx.ctx;
77        let (ext, body) = ext.replace_body(());
78        let mut body = RefCell::new(LimitBody::new(body, self.limit.request_body_size));
79        let mut req = Request::from_parts(parts, ext);
80
81        self.service
82            .call(WebContext::new(&mut req, &mut body, state))
83            .await
84            .inspect_err(|_| {
85                let body = body.into_inner().into_inner();
86                *ctx.body_borrow_mut() = body;
87            })
88    }
89}
90
91impl<S> ReadyService for LimitService<S>
92where
93    S: ReadyService,
94{
95    type Ready = S::Ready;
96
97    #[inline]
98    async fn ready(&self) -> Self::Ready {
99        self.service.ready().await
100    }
101}
102
103pin_project! {
104    pub struct LimitBody<B> {
105        limit: usize,
106        record: usize,
107        #[pin]
108        body: B
109    }
110}
111
112impl<B: Default> Default for LimitBody<B> {
113    fn default() -> Self {
114        Self {
115            limit: 0,
116            record: 0,
117            body: B::default(),
118        }
119    }
120}
121
122impl<B> LimitBody<B> {
123    const fn new(body: B, limit: usize) -> Self {
124        Self { limit, record: 0, body }
125    }
126
127    fn into_inner(self) -> B {
128        self.body
129    }
130}
131
132impl<B> Body for LimitBody<B>
133where
134    B: BodyStream,
135{
136    type Data = B::Data;
137    type Error = BodyError;
138
139    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
140        let this = self.project();
141
142        if *this.record >= *this.limit {
143            // search error module for downcast_ref::<BodyOverFlow>() before considering change the
144            // error type.
145            return Poll::Ready(Some(Err(BodyError::from(BodyOverFlow { limit: *this.limit }))));
146        }
147
148        match ready!(Body::poll_frame(this.body, cx)) {
149            Some(res) => {
150                let frame = res.map_err(Into::into)?;
151
152                if let Some(data) = frame.data_ref() {
153                    *this.record += data.as_ref().len();
154                }
155                // TODO: for now there is no way to split a frame if it goes beyond body limit.
156                Poll::Ready(Some(Ok(frame)))
157            }
158            None => Poll::Ready(None),
159        }
160    }
161}
162
163#[cfg(test)]
164mod test {
165    use core::pin::pin;
166
167    use xitca_unsafe_collection::futures::NowOrPanic;
168
169    use crate::{
170        App,
171        body::{BodyExt, BoxBody, Full},
172        bytes::Bytes,
173        handler::{body::Body, handler_service},
174        http::{StatusCode, WebRequest},
175        test::collect_body,
176    };
177
178    use super::*;
179
180    const CHUNK: &[u8] = b"hello,world!";
181
182    async fn handler<B: BodyStream>(Body(body): Body<B>) -> String {
183        let mut body = pin!(body);
184
185        let chunk = body
186            .as_mut()
187            .frame()
188            .await
189            .unwrap()
190            .ok()
191            .unwrap()
192            .into_data()
193            .ok()
194            .unwrap();
195
196        let err = body.as_mut().frame().await.unwrap().err().unwrap();
197        let err = crate::error::Error::from(err.into());
198        assert_eq!(
199            err.to_string(),
200            format!("body size reached limit: {} bytes", CHUNK.len())
201        );
202
203        let mut ctx = WebContext::new_test(());
204        let res = err.call(ctx.as_web_ctx()).await.unwrap();
205        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
206
207        std::str::from_utf8(chunk.as_ref()).unwrap().to_string()
208    }
209
210    #[test]
211    fn request_body_over_limit() {
212        let body = Full::new(Bytes::from_static(CHUNK)).chain(Full::new(Bytes::from_static(CHUNK)));
213        let req = WebRequest::default().map(|ext| ext.map_body(|_: ()| BoxBody::new(body).into()));
214
215        let body = App::new()
216            .at("/", handler_service(handler))
217            .enclosed(Limit::new().set_request_body_max_size(CHUNK.len()))
218            .finish()
219            .call(())
220            .now_or_panic()
221            .unwrap()
222            .call(req)
223            .now_or_panic()
224            .ok()
225            .unwrap()
226            .into_body();
227
228        let body = collect_body(body).now_or_panic().unwrap();
229
230        assert_eq!(body, CHUNK);
231    }
232}