rustapi_core/middleware/
body_limit.rs1use super::{BoxedNext, MiddlewareLayer};
20use crate::error::ApiError;
21use crate::request::Request;
22use crate::response::{IntoResponse, Response};
23use http::StatusCode;
24use std::future::Future;
25use std::pin::Pin;
26
27pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024;
29
30#[derive(Clone)]
35pub struct BodyLimitLayer {
36 limit: usize,
37}
38
39impl BodyLimitLayer {
40 pub fn new(limit: usize) -> Self {
53 Self { limit }
54 }
55
56 pub fn default_limit() -> Self {
58 Self::new(DEFAULT_BODY_LIMIT)
59 }
60
61 pub fn limit(&self) -> usize {
63 self.limit
64 }
65}
66
67impl Default for BodyLimitLayer {
68 fn default() -> Self {
69 Self::default_limit()
70 }
71}
72
73impl MiddlewareLayer for BodyLimitLayer {
74 fn call(
75 &self,
76 req: Request,
77 next: BoxedNext,
78 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
79 let limit = self.limit;
80
81 Box::pin(async move {
82 if let Some(content_length) = req.headers().get(http::header::CONTENT_LENGTH) {
84 if let Ok(length_str) = content_length.to_str() {
85 if let Ok(length) = length_str.parse::<usize>() {
86 if length > limit {
87 return ApiError::new(
88 StatusCode::PAYLOAD_TOO_LARGE,
89 "payload_too_large",
90 format!("Request body exceeds limit of {} bytes", limit),
91 )
92 .into_response();
93 }
94 }
95 }
96 }
97
98 if let Some(body) = &req.body {
101 if body.len() > limit {
102 return ApiError::new(
103 StatusCode::PAYLOAD_TOO_LARGE,
104 "payload_too_large",
105 format!("Request body exceeds limit of {} bytes", limit),
106 )
107 .into_response();
108 }
109 }
110
111 next(req).await
113 })
114 }
115
116 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
117 Box::new(self.clone())
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::request::Request;
125 use bytes::Bytes;
126 use http::{Extensions, Method};
127 use proptest::prelude::*;
128 use std::collections::HashMap;
129 use std::sync::Arc;
130
131 fn create_test_request_with_body(body: Bytes) -> Request {
133 let uri: http::Uri = "/test".parse().unwrap();
134 let mut builder = http::Request::builder().method(Method::POST).uri(uri);
135
136 builder = builder.header(http::header::CONTENT_LENGTH, body.len().to_string());
138
139 let req = builder.body(()).unwrap();
140 let (parts, _) = req.into_parts();
141
142 Request::new(parts, body, Arc::new(Extensions::new()), HashMap::new())
143 }
144
145 fn create_test_request_without_content_length(body: Bytes) -> Request {
147 let uri: http::Uri = "/test".parse().unwrap();
148 let builder = http::Request::builder().method(Method::POST).uri(uri);
149
150 let req = builder.body(()).unwrap();
151 let (parts, _) = req.into_parts();
152
153 Request::new(parts, body, Arc::new(Extensions::new()), HashMap::new())
154 }
155
156 fn ok_handler() -> BoxedNext {
158 Arc::new(|_req: Request| {
159 Box::pin(async {
160 http::Response::builder()
161 .status(StatusCode::OK)
162 .body(http_body_util::Full::new(Bytes::from("ok")))
163 .unwrap()
164 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
165 })
166 }
167
168 proptest! {
175 #![proptest_config(ProptestConfig::with_cases(100))]
176
177 #[test]
178 fn prop_body_size_limit_enforcement(
179 limit in 1usize..10240usize,
181 body_size_factor in 0.5f64..2.0f64,
183 ) {
184 let rt = tokio::runtime::Runtime::new().unwrap();
185 rt.block_on(async {
186 let body_size = ((limit as f64) * body_size_factor) as usize;
187 let body = Bytes::from(vec![b'x'; body_size]);
188 let request = create_test_request_with_body(body.clone());
189
190 let layer = BodyLimitLayer::new(limit);
191 let handler = ok_handler();
192
193 let response = layer.call(request, handler).await;
194
195 if body_size > limit {
196 prop_assert_eq!(
198 response.status(),
199 StatusCode::PAYLOAD_TOO_LARGE,
200 "Expected 413 for body size {} > limit {}",
201 body_size,
202 limit
203 );
204 } else {
205 prop_assert_eq!(
207 response.status(),
208 StatusCode::OK,
209 "Expected 200 for body size {} <= limit {}",
210 body_size,
211 limit
212 );
213 }
214
215 Ok(())
216 })?;
217 }
218
219 #[test]
220 fn prop_body_limit_without_content_length_header(
221 limit in 1usize..10240usize,
222 body_size_factor in 0.5f64..2.0f64,
223 ) {
224 let rt = tokio::runtime::Runtime::new().unwrap();
225 rt.block_on(async {
226 let body_size = ((limit as f64) * body_size_factor) as usize;
227 let body = Bytes::from(vec![b'x'; body_size]);
228 let request = create_test_request_without_content_length(body.clone());
230
231 let layer = BodyLimitLayer::new(limit);
232 let handler = ok_handler();
233
234 let response = layer.call(request, handler).await;
235
236 if body_size > limit {
237 prop_assert_eq!(
239 response.status(),
240 StatusCode::PAYLOAD_TOO_LARGE,
241 "Expected 413 for body size {} > limit {} (no Content-Length)",
242 body_size,
243 limit
244 );
245 } else {
246 prop_assert_eq!(
248 response.status(),
249 StatusCode::OK,
250 "Expected 200 for body size {} <= limit {} (no Content-Length)",
251 body_size,
252 limit
253 );
254 }
255
256 Ok(())
257 })?;
258 }
259 }
260
261 #[tokio::test]
262 async fn test_body_at_exact_limit() {
263 let limit = 100;
264 let body = Bytes::from(vec![b'x'; limit]);
265 let request = create_test_request_with_body(body);
266
267 let layer = BodyLimitLayer::new(limit);
268 let handler = ok_handler();
269
270 let response = layer.call(request, handler).await;
271 assert_eq!(response.status(), StatusCode::OK);
272 }
273
274 #[tokio::test]
275 async fn test_body_one_byte_over_limit() {
276 let limit = 100;
277 let body = Bytes::from(vec![b'x'; limit + 1]);
278 let request = create_test_request_with_body(body);
279
280 let layer = BodyLimitLayer::new(limit);
281 let handler = ok_handler();
282
283 let response = layer.call(request, handler).await;
284 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
285 }
286
287 #[tokio::test]
288 async fn test_body_one_byte_under_limit() {
289 let limit = 100;
290 let body = Bytes::from(vec![b'x'; limit - 1]);
291 let request = create_test_request_with_body(body);
292
293 let layer = BodyLimitLayer::new(limit);
294 let handler = ok_handler();
295
296 let response = layer.call(request, handler).await;
297 assert_eq!(response.status(), StatusCode::OK);
298 }
299
300 #[tokio::test]
301 async fn test_empty_body() {
302 let limit = 100;
303 let body = Bytes::new();
304 let request = create_test_request_with_body(body);
305
306 let layer = BodyLimitLayer::new(limit);
307 let handler = ok_handler();
308
309 let response = layer.call(request, handler).await;
310 assert_eq!(response.status(), StatusCode::OK);
311 }
312
313 #[tokio::test]
314 async fn test_default_limit() {
315 let layer = BodyLimitLayer::default();
316 assert_eq!(layer.limit(), DEFAULT_BODY_LIMIT);
317 }
318
319 #[test]
320 fn test_clone() {
321 let layer = BodyLimitLayer::new(1024);
322 let cloned = layer.clone();
323 assert_eq!(layer.limit(), cloned.limit());
324 }
325}