xitca_web/middleware/
limit.rs1use std::{
4 cell::RefCell,
5 pin::Pin,
6 task::{Context, Poll, ready},
7};
8
9use pin_project_lite::pin_project;
10use xitca_http::Request;
11
12use crate::{
13 body::{Body, BodyStream, Frame},
14 context::WebContext,
15 error::{BodyError, BodyOverFlow},
16 service::{Service, ready::ReadyService},
17};
18
19#[derive(Copy, Clone)]
28pub struct Limit {
29 request_body_size: usize,
30}
31
32impl Default for Limit {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl Limit {
39 pub const fn new() -> Self {
40 Self {
41 request_body_size: usize::MAX,
42 }
43 }
44
45 pub fn set_request_body_max_size(mut self, size: usize) -> Self {
47 self.request_body_size = size;
48 self
49 }
50}
51
52impl<S, E> Service<Result<S, E>> for Limit {
53 type Response = LimitService<S>;
54 type Error = E;
55
56 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
57 res.map(|service| LimitService { service, limit: *self })
58 }
59}
60
61pub struct LimitService<S> {
62 service: S,
63 limit: Limit,
64}
65
66impl<'r, S, C, B, Res, Err> Service<WebContext<'r, C, B>> for LimitService<S>
67where
68 B: BodyStream + Default,
69 S: for<'r2> Service<WebContext<'r2, C, LimitBody<B>>, Response = Res, Error = Err>,
70{
71 type Response = Res;
72 type Error = Err;
73
74 async fn call(&self, mut ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
75 let (parts, ext) = ctx.take_request().into_parts();
76 let state = ctx.ctx;
77 let (ext, body) = ext.replace_body(());
78 let mut body = RefCell::new(LimitBody::new(body, self.limit.request_body_size));
79 let mut req = Request::from_parts(parts, ext);
80
81 self.service
82 .call(WebContext::new(&mut req, &mut body, state))
83 .await
84 .inspect_err(|_| {
85 let body = body.into_inner().into_inner();
86 *ctx.body_borrow_mut() = body;
87 })
88 }
89}
90
91impl<S> ReadyService for LimitService<S>
92where
93 S: ReadyService,
94{
95 type Ready = S::Ready;
96
97 #[inline]
98 async fn ready(&self) -> Self::Ready {
99 self.service.ready().await
100 }
101}
102
103pin_project! {
104 pub struct LimitBody<B> {
105 limit: usize,
106 record: usize,
107 #[pin]
108 body: B
109 }
110}
111
112impl<B: Default> Default for LimitBody<B> {
113 fn default() -> Self {
114 Self {
115 limit: 0,
116 record: 0,
117 body: B::default(),
118 }
119 }
120}
121
122impl<B> LimitBody<B> {
123 const fn new(body: B, limit: usize) -> Self {
124 Self { limit, record: 0, body }
125 }
126
127 fn into_inner(self) -> B {
128 self.body
129 }
130}
131
132impl<B> Body for LimitBody<B>
133where
134 B: BodyStream,
135{
136 type Data = B::Data;
137 type Error = BodyError;
138
139 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
140 let this = self.project();
141
142 if *this.record >= *this.limit {
143 return Poll::Ready(Some(Err(BodyError::from(BodyOverFlow { limit: *this.limit }))));
146 }
147
148 match ready!(Body::poll_frame(this.body, cx)) {
149 Some(res) => {
150 let frame = res.map_err(Into::into)?;
151
152 if let Some(data) = frame.data_ref() {
153 *this.record += data.as_ref().len();
154 }
155 Poll::Ready(Some(Ok(frame)))
157 }
158 None => Poll::Ready(None),
159 }
160 }
161}
162
163#[cfg(test)]
164mod test {
165 use core::pin::pin;
166
167 use xitca_unsafe_collection::futures::NowOrPanic;
168
169 use crate::{
170 App,
171 body::{BodyExt, BoxBody, Full},
172 bytes::Bytes,
173 handler::{body::Body, handler_service},
174 http::{StatusCode, WebRequest},
175 test::collect_body,
176 };
177
178 use super::*;
179
180 const CHUNK: &[u8] = b"hello,world!";
181
182 async fn handler<B: BodyStream>(Body(body): Body<B>) -> String {
183 let mut body = pin!(body);
184
185 let chunk = body
186 .as_mut()
187 .frame()
188 .await
189 .unwrap()
190 .ok()
191 .unwrap()
192 .into_data()
193 .ok()
194 .unwrap();
195
196 let err = body.as_mut().frame().await.unwrap().err().unwrap();
197 let err = crate::error::Error::from(err.into());
198 assert_eq!(
199 err.to_string(),
200 format!("body size reached limit: {} bytes", CHUNK.len())
201 );
202
203 let mut ctx = WebContext::new_test(());
204 let res = err.call(ctx.as_web_ctx()).await.unwrap();
205 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
206
207 std::str::from_utf8(chunk.as_ref()).unwrap().to_string()
208 }
209
210 #[test]
211 fn request_body_over_limit() {
212 let body = Full::new(Bytes::from_static(CHUNK)).chain(Full::new(Bytes::from_static(CHUNK)));
213 let req = WebRequest::default().map(|ext| ext.map_body(|_: ()| BoxBody::new(body).into()));
214
215 let body = App::new()
216 .at("/", handler_service(handler))
217 .enclosed(Limit::new().set_request_body_max_size(CHUNK.len()))
218 .finish()
219 .call(())
220 .now_or_panic()
221 .unwrap()
222 .call(req)
223 .now_or_panic()
224 .ok()
225 .unwrap()
226 .into_body();
227
228 let body = collect_body(body).now_or_panic().unwrap();
229
230 assert_eq!(body, CHUNK);
231 }
232}