rustapi_core/middleware/
body_limit.rs1use super::{BoxedNext, MiddlewareLayer};
20use crate::error::ApiError;
21use crate::request::Request;
22use crate::response::{IntoResponse, Response};
23use http::StatusCode;
24use std::future::Future;
25use std::pin::Pin;
26
27pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024;
29
30#[derive(Clone)]
35pub struct BodyLimitLayer {
36 limit: usize,
37}
38
39impl BodyLimitLayer {
40 pub fn new(limit: usize) -> Self {
53 Self { limit }
54 }
55
56 pub fn default_limit() -> Self {
58 Self::new(DEFAULT_BODY_LIMIT)
59 }
60
61 pub fn limit(&self) -> usize {
63 self.limit
64 }
65}
66
67impl Default for BodyLimitLayer {
68 fn default() -> Self {
69 Self::default_limit()
70 }
71}
72
73impl MiddlewareLayer for BodyLimitLayer {
74 fn call(
75 &self,
76 req: Request,
77 next: BoxedNext,
78 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
79 let limit = self.limit;
80
81 Box::pin(async move {
82 if let Some(content_length) = req.headers().get(http::header::CONTENT_LENGTH) {
84 if let Ok(length_str) = content_length.to_str() {
85 if let Ok(length) = length_str.parse::<usize>() {
86 if length > limit {
87 return ApiError::new(
88 StatusCode::PAYLOAD_TOO_LARGE,
89 "payload_too_large",
90 format!("Request body exceeds limit of {} bytes", limit),
91 )
92 .into_response();
93 }
94 }
95 }
96 }
97
98 if let crate::request::BodyVariant::Buffered(bytes) = &req.body {
101 if bytes.len() > limit {
102 return ApiError::new(
103 StatusCode::PAYLOAD_TOO_LARGE,
104 "payload_too_large",
105 format!("Request body exceeds limit of {} bytes", limit),
106 )
107 .into_response();
108 }
109 }
110
111 next(req).await
113 })
114 }
115
116 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
117 Box::new(self.clone())
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::path_params::PathParams;
125 use crate::request::Request;
126 use bytes::Bytes;
127 use http::{Extensions, Method};
128 use proptest::prelude::*;
129 use std::sync::Arc;
130
131 fn create_test_request_with_body(body: Bytes) -> Request {
133 let uri: http::Uri = "/test".parse().unwrap();
134 let mut builder = http::Request::builder().method(Method::POST).uri(uri);
135
136 builder = builder.header(http::header::CONTENT_LENGTH, body.len().to_string());
138
139 let req = builder.body(()).unwrap();
140 let (parts, _) = req.into_parts();
141
142 Request::new(
143 parts,
144 crate::request::BodyVariant::Buffered(body),
145 Arc::new(Extensions::new()),
146 PathParams::new(),
147 )
148 }
149
150 fn create_test_request_without_content_length(body: Bytes) -> Request {
152 let uri: http::Uri = "/test".parse().unwrap();
153 let builder = http::Request::builder().method(Method::POST).uri(uri);
154
155 let req = builder.body(()).unwrap();
156 let (parts, _) = req.into_parts();
157
158 Request::new(
159 parts,
160 crate::request::BodyVariant::Buffered(body),
161 Arc::new(Extensions::new()),
162 PathParams::new(),
163 )
164 }
165
166 fn ok_handler() -> BoxedNext {
168 Arc::new(|_req: Request| {
169 Box::pin(async {
170 http::Response::builder()
171 .status(StatusCode::OK)
172 .body(http_body_util::Full::new(Bytes::from("ok")))
173 .unwrap()
174 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
175 })
176 }
177
178 proptest! {
185 #![proptest_config(ProptestConfig::with_cases(100))]
186
187 #[test]
188 fn prop_body_size_limit_enforcement(
189 limit in 1usize..10240usize,
191 body_size_factor in 0.5f64..2.0f64,
193 ) {
194 let rt = tokio::runtime::Runtime::new().unwrap();
195 rt.block_on(async {
196 let body_size = ((limit as f64) * body_size_factor) as usize;
197 let body = Bytes::from(vec![b'x'; body_size]);
198 let request = create_test_request_with_body(body.clone());
199
200 let layer = BodyLimitLayer::new(limit);
201 let handler = ok_handler();
202
203 let response = layer.call(request, handler).await;
204
205 if body_size > limit {
206 prop_assert_eq!(
208 response.status(),
209 StatusCode::PAYLOAD_TOO_LARGE,
210 "Expected 413 for body size {} > limit {}",
211 body_size,
212 limit
213 );
214 } else {
215 prop_assert_eq!(
217 response.status(),
218 StatusCode::OK,
219 "Expected 200 for body size {} <= limit {}",
220 body_size,
221 limit
222 );
223 }
224
225 Ok(())
226 })?;
227 }
228
229 #[test]
230 fn prop_body_limit_without_content_length_header(
231 limit in 1usize..10240usize,
232 body_size_factor in 0.5f64..2.0f64,
233 ) {
234 let rt = tokio::runtime::Runtime::new().unwrap();
235 rt.block_on(async {
236 let body_size = ((limit as f64) * body_size_factor) as usize;
237 let body = Bytes::from(vec![b'x'; body_size]);
238 let request = create_test_request_without_content_length(body.clone());
240
241 let layer = BodyLimitLayer::new(limit);
242 let handler = ok_handler();
243
244 let response = layer.call(request, handler).await;
245
246 if body_size > limit {
247 prop_assert_eq!(
249 response.status(),
250 StatusCode::PAYLOAD_TOO_LARGE,
251 "Expected 413 for body size {} > limit {} (no Content-Length)",
252 body_size,
253 limit
254 );
255 } else {
256 prop_assert_eq!(
258 response.status(),
259 StatusCode::OK,
260 "Expected 200 for body size {} <= limit {} (no Content-Length)",
261 body_size,
262 limit
263 );
264 }
265
266 Ok(())
267 })?;
268 }
269 }
270
271 #[tokio::test]
272 async fn test_body_at_exact_limit() {
273 let limit = 100;
274 let body = Bytes::from(vec![b'x'; limit]);
275 let request = create_test_request_with_body(body);
276
277 let layer = BodyLimitLayer::new(limit);
278 let handler = ok_handler();
279
280 let response = layer.call(request, handler).await;
281 assert_eq!(response.status(), StatusCode::OK);
282 }
283
284 #[tokio::test]
285 async fn test_body_one_byte_over_limit() {
286 let limit = 100;
287 let body = Bytes::from(vec![b'x'; limit + 1]);
288 let request = create_test_request_with_body(body);
289
290 let layer = BodyLimitLayer::new(limit);
291 let handler = ok_handler();
292
293 let response = layer.call(request, handler).await;
294 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
295 }
296
297 #[tokio::test]
298 async fn test_body_one_byte_under_limit() {
299 let limit = 100;
300 let body = Bytes::from(vec![b'x'; limit - 1]);
301 let request = create_test_request_with_body(body);
302
303 let layer = BodyLimitLayer::new(limit);
304 let handler = ok_handler();
305
306 let response = layer.call(request, handler).await;
307 assert_eq!(response.status(), StatusCode::OK);
308 }
309
310 #[tokio::test]
311 async fn test_empty_body() {
312 let limit = 100;
313 let body = Bytes::new();
314 let request = create_test_request_with_body(body);
315
316 let layer = BodyLimitLayer::new(limit);
317 let handler = ok_handler();
318
319 let response = layer.call(request, handler).await;
320 assert_eq!(response.status(), StatusCode::OK);
321 }
322
323 #[tokio::test]
324 async fn test_default_limit() {
325 let layer = BodyLimitLayer::default();
326 assert_eq!(layer.limit(), DEFAULT_BODY_LIMIT);
327 }
328
329 #[test]
330 fn test_clone() {
331 let layer = BodyLimitLayer::new(1024);
332 let cloned = layer.clone();
333 assert_eq!(layer.limit(), cloned.limit());
334 }
335}