xitca_web/middleware/
limit.rs

1//! limitation middleware.
2
3use std::{
4    cell::RefCell,
5    pin::Pin,
6    task::{Context, Poll, ready},
7};
8
9use futures_core::stream::Stream;
10use pin_project_lite::pin_project;
11use xitca_http::Request;
12
13use crate::{
14    body::BodyStream,
15    context::WebContext,
16    error::{BodyError, BodyOverFlow},
17    service::{Service, ready::ReadyService},
18};
19
20/// General purposed limitation middleware. Limiting request/response body size etc.
21///
22/// # Type mutation
23/// [`Limit`] would mutate request body type from `B` to [`Limit<B>`]. Service enclosed by it must be
24/// able to handle it's mutation or utilize [`TypeEraser`] to erase the mutation.
25/// For more explanation please reference [`type mutation`](crate::middleware#type-mutation).
26///
27/// [`TypeEraser`]: crate::middleware::eraser::TypeEraser
28#[derive(Copy, Clone)]
29pub struct Limit {
30    request_body_size: usize,
31}
32
33impl Default for Limit {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl Limit {
40    pub const fn new() -> Self {
41        Self {
42            request_body_size: usize::MAX,
43        }
44    }
45
46    /// Set max size in byte unit the request body can be.
47    pub fn set_request_body_max_size(mut self, size: usize) -> Self {
48        self.request_body_size = size;
49        self
50    }
51}
52
53impl<S, E> Service<Result<S, E>> for Limit {
54    type Response = LimitService<S>;
55    type Error = E;
56
57    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
58        res.map(|service| LimitService { service, limit: *self })
59    }
60}
61
62pub struct LimitService<S> {
63    service: S,
64    limit: Limit,
65}
66
67impl<'r, S, C, B, Res, Err> Service<WebContext<'r, C, B>> for LimitService<S>
68where
69    B: BodyStream + Default,
70    S: for<'r2> Service<WebContext<'r2, C, LimitBody<B>>, Response = Res, Error = Err>,
71{
72    type Response = Res;
73    type Error = Err;
74
75    async fn call(&self, mut ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
76        let (parts, ext) = ctx.take_request().into_parts();
77        let state = ctx.ctx;
78        let (ext, body) = ext.replace_body(());
79        let mut body = RefCell::new(LimitBody::new(body, self.limit.request_body_size));
80        let mut req = Request::from_parts(parts, ext);
81
82        self.service
83            .call(WebContext::new(&mut req, &mut body, state))
84            .await
85            .inspect_err(|_| {
86                let body = body.into_inner().into_inner();
87                *ctx.body_borrow_mut() = body;
88            })
89    }
90}
91
92impl<S> ReadyService for LimitService<S>
93where
94    S: ReadyService,
95{
96    type Ready = S::Ready;
97
98    #[inline]
99    async fn ready(&self) -> Self::Ready {
100        self.service.ready().await
101    }
102}
103
104pin_project! {
105    pub struct LimitBody<B> {
106        limit: usize,
107        record: usize,
108        #[pin]
109        body: B
110    }
111}
112
113impl<B: Default> Default for LimitBody<B> {
114    fn default() -> Self {
115        Self {
116            limit: 0,
117            record: 0,
118            body: B::default(),
119        }
120    }
121}
122
123impl<B> LimitBody<B> {
124    const fn new(body: B, limit: usize) -> Self {
125        Self { limit, record: 0, body }
126    }
127
128    fn into_inner(self) -> B {
129        self.body
130    }
131}
132
133impl<B> Stream for LimitBody<B>
134where
135    B: BodyStream,
136{
137    type Item = Result<B::Chunk, BodyError>;
138
139    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140        let this = self.project();
141
142        if *this.record >= *this.limit {
143            // search error module for downcast_ref::<BodyOverFlow>() before considering change the
144            // error type.
145            return Poll::Ready(Some(Err(BodyError::from(BodyOverFlow { limit: *this.limit }))));
146        }
147
148        match ready!(this.body.poll_next(cx)) {
149            Some(res) => {
150                let chunk = res.map_err(Into::into)?;
151                *this.record += chunk.as_ref().len();
152                // TODO: for now there is no way to split a chunk if it goes beyond body limit.
153                Poll::Ready(Some(Ok(chunk)))
154            }
155            None => Poll::Ready(None),
156        }
157    }
158}
159
160#[cfg(test)]
161mod test {
162    use core::{future::poll_fn, pin::pin};
163
164    use xitca_unsafe_collection::futures::NowOrPanic;
165
166    use crate::{
167        App,
168        body::BoxBody,
169        bytes::Bytes,
170        handler::{body::Body, handler_service},
171        http::{StatusCode, WebRequest},
172        test::collect_body,
173    };
174
175    use super::*;
176
177    const CHUNK: &[u8] = b"hello,world!";
178
179    async fn handler<B: BodyStream>(Body(body): Body<B>) -> String {
180        let mut body = pin!(body);
181
182        let chunk = poll_fn(|cx| body.as_mut().poll_next(cx)).await.unwrap().ok().unwrap();
183
184        let err = poll_fn(|cx| body.as_mut().poll_next(cx)).await.unwrap().err().unwrap();
185        let err = crate::error::Error::from(err.into());
186        assert_eq!(
187            err.to_string(),
188            format!("body size reached limit: {} bytes", CHUNK.len())
189        );
190
191        let mut ctx = WebContext::new_test(());
192        let res = err.call(ctx.as_web_ctx()).await.unwrap();
193        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
194
195        std::str::from_utf8(chunk.as_ref()).unwrap().to_string()
196    }
197
198    #[test]
199    fn request_body_over_limit() {
200        use futures_util::stream::{self, StreamExt};
201
202        let item = || async { Ok::<_, BodyError>(Bytes::from_static(CHUNK)) };
203
204        let body = stream::once(item()).chain(stream::once(item()));
205        let req = WebRequest::default().map(|ext| ext.map_body(|_: ()| BoxBody::new(body).into()));
206
207        let body = App::new()
208            .at("/", handler_service(handler))
209            .enclosed(Limit::new().set_request_body_max_size(CHUNK.len()))
210            .finish()
211            .call(())
212            .now_or_panic()
213            .unwrap()
214            .call(req)
215            .now_or_panic()
216            .ok()
217            .unwrap()
218            .into_body();
219
220        let body = collect_body(body).now_or_panic().unwrap();
221
222        assert_eq!(body, CHUNK);
223    }
224}