1use crate::{
2 header,
3 types::{PayloadError, RealIp},
4 Body, BodyState, Bytes, FromRequest, Future, Request, Result,
5};
6use headers::HeaderMapExt;
7use http_body_util::{BodyExt, Collected};
8
9#[cfg(any(feature = "params", feature = "multipart"))]
10use std::sync::Arc;
11
12#[cfg(feature = "limits")]
13use crate::types::Limits;
14#[cfg(feature = "limits")]
15use http_body_util::{LengthLimitError, Limited};
16
17#[cfg(any(feature = "form", feature = "json", feature = "multipart"))]
18use crate::types::Payload;
19
20#[cfg(feature = "form")]
21use crate::types::Form;
22
23#[cfg(feature = "json")]
24use crate::types::Json;
25
26#[cfg(feature = "multipart")]
27use crate::types::Multipart;
28
29#[cfg(feature = "cookie")]
30use crate::types::{Cookie, Cookies, CookiesError};
31
32#[cfg(feature = "session")]
33use crate::types::Session;
34
35#[cfg(feature = "params")]
36use crate::types::{ParamsError, PathDeserializer, RouteInfo};
37
38pub trait RequestExt: private::Sealed + Sized {
40 fn schema(&self) -> Option<&http::uri::Scheme>;
42
43 fn path(&self) -> &str;
45
46 fn query_string(&self) -> Option<&str>;
48
49 #[cfg(feature = "query")]
55 fn query<T>(&self) -> Result<T, PayloadError>
56 where
57 T: serde::de::DeserializeOwned;
58
59 fn header<K, T>(&self, key: K) -> Option<T>
61 where
62 K: header::AsHeaderName,
63 T: std::str::FromStr;
64
65 fn header_typed<H>(&self) -> Option<H>
67 where
68 H: headers::Header;
69
70 fn content_length(&self) -> Option<u64>;
72
73 fn content_type(&self) -> Option<mime::Mime>;
75
76 fn extract<T>(&mut self) -> impl Future<Output = Result<T, T::Error>> + Send
78 where
79 T: FromRequest;
80
81 fn incoming(&mut self) -> Result<Body, PayloadError>;
88
89 fn bytes(&mut self) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
93
94 fn text(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
98
99 #[cfg(feature = "form")]
104 fn form<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
105 where
106 T: serde::de::DeserializeOwned;
107
108 #[cfg(feature = "json")]
112 fn json<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
113 where
114 T: serde::de::DeserializeOwned;
115
116 #[cfg(feature = "multipart")]
121 fn multipart(&mut self) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
122
123 #[cfg(feature = "state")]
125 fn state<T>(&self) -> Option<T>
126 where
127 T: Clone + Send + Sync + 'static;
128
129 #[cfg(feature = "state")]
131 fn set_state<T>(&mut self, t: T) -> Option<T>
132 where
133 T: Clone + Send + Sync + 'static;
134
135 #[cfg(feature = "cookie")]
141 fn cookies(&self) -> Result<Cookies, CookiesError>;
142
143 #[cfg(feature = "cookie")]
145 fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
146 where
147 S: AsRef<str>;
148
149 #[cfg(feature = "session")]
151 fn session(&self) -> &Session;
152
153 #[cfg(feature = "params")]
159 fn params<T>(&self) -> Result<T, ParamsError>
160 where
161 T: serde::de::DeserializeOwned;
162
163 #[cfg(feature = "params")]
169 fn param<T>(&self, name: &str) -> Result<T, ParamsError>
170 where
171 T: std::str::FromStr,
172 T::Err: std::fmt::Display;
173
174 #[cfg(feature = "params")]
176 fn route_info(&self) -> &Arc<RouteInfo>;
177
178 fn remote_addr(&self) -> Option<&std::net::SocketAddr>;
180
181 fn realip(&self) -> Option<RealIp>;
183}
184
185impl RequestExt for Request {
186 fn schema(&self) -> Option<&http::uri::Scheme> {
187 self.uri().scheme()
188 }
189
190 fn path(&self) -> &str {
191 self.uri().path()
192 }
193
194 fn query_string(&self) -> Option<&str> {
195 self.uri().query()
196 }
197
198 #[cfg(feature = "query")]
199 fn query<T>(&self) -> Result<T, PayloadError>
200 where
201 T: serde::de::DeserializeOwned,
202 {
203 serde_urlencoded::from_str(self.query_string().unwrap_or_default())
204 .map_err(PayloadError::UrlDecode)
205 }
206
207 fn header<K, T>(&self, key: K) -> Option<T>
208 where
209 K: header::AsHeaderName,
210 T: std::str::FromStr,
211 {
212 self.headers()
213 .get(key)
214 .map(header::HeaderValue::to_str)
215 .and_then(Result::ok)
216 .map(str::parse)
217 .and_then(Result::ok)
218 }
219
220 fn header_typed<H>(&self) -> Option<H>
221 where
222 H: headers::Header,
223 {
224 self.headers().typed_get()
225 }
226
227 fn content_length(&self) -> Option<u64> {
228 self.header(header::CONTENT_LENGTH)
229 }
230
231 fn content_type(&self) -> Option<mime::Mime> {
232 self.header(header::CONTENT_TYPE)
233 }
234
235 async fn extract<T>(&mut self) -> Result<T, T::Error>
236 where
237 T: FromRequest,
238 {
239 T::extract(self).await
240 }
241
242 fn incoming(&mut self) -> Result<Body, PayloadError> {
243 if let Some(state) = self.extensions().get::<BodyState>() {
244 match state {
245 BodyState::Empty => Err(PayloadError::Empty)?,
246 BodyState::Used => Err(PayloadError::Used)?,
247 BodyState::Normal => {}
248 }
249 }
250
251 let (state, result) = match std::mem::replace(self.body_mut(), Body::Empty) {
252 Body::Empty => (BodyState::Empty, Err(PayloadError::Empty)),
253 body => (BodyState::Used, Ok(body)),
254 };
255
256 self.extensions_mut().insert(state);
257 result
258 }
259
260 async fn bytes(&mut self) -> Result<Bytes, PayloadError> {
261 self.incoming()?
262 .collect()
263 .await
264 .map_err(|err| {
265 #[cfg(feature = "limits")]
266 if err.is::<LengthLimitError>() {
267 return PayloadError::TooLarge;
268 }
269 if let Ok(err) = err.downcast::<hyper::Error>() {
270 return PayloadError::Hyper(err);
271 }
272 PayloadError::Read
273 })
274 .map(Collected::to_bytes)
275 }
276
277 async fn text(&mut self) -> Result<String, PayloadError> {
278 let bytes = self.bytes().await?;
279 String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
280 }
281
282 #[cfg(feature = "form")]
283 async fn form<T>(&mut self) -> Result<T, PayloadError>
284 where
285 T: serde::de::DeserializeOwned,
286 {
287 <Form as Payload>::check_type(self.content_type())?;
288 let bytes = self.bytes().await?;
289 serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
290 }
291
292 #[cfg(feature = "json")]
293 async fn json<T>(&mut self) -> Result<T, PayloadError>
294 where
295 T: serde::de::DeserializeOwned,
296 {
297 <Json as Payload>::check_type(self.content_type())?;
298 let bytes = self.bytes().await?;
299 serde_json::from_slice(&bytes).map_err(PayloadError::Json)
300 }
301
302 #[cfg(feature = "multipart")]
303 async fn multipart(&mut self) -> Result<Multipart, PayloadError> {
304 let m = <Multipart as Payload>::check_type(self.content_type())?;
305
306 let boundary = m
307 .get_param(mime::BOUNDARY)
308 .ok_or(PayloadError::MissingBoundary)?
309 .as_str();
310
311 Ok(Multipart::new(self.incoming()?, boundary))
312 }
313
314 #[cfg(feature = "state")]
315 fn state<T>(&self) -> Option<T>
316 where
317 T: Clone + Send + Sync + 'static,
318 {
319 self.extensions().get().cloned()
320 }
321
322 #[cfg(feature = "state")]
323 fn set_state<T>(&mut self, t: T) -> Option<T>
324 where
325 T: Clone + Send + Sync + 'static,
326 {
327 self.extensions_mut().insert(t)
328 }
329
330 #[cfg(feature = "cookie")]
331 fn cookies(&self) -> Result<Cookies, CookiesError> {
332 self.extensions()
333 .get::<Cookies>()
334 .cloned()
335 .ok_or(CookiesError::Read)
336 }
337
338 #[cfg(feature = "cookie")]
339 fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
340 where
341 S: AsRef<str>,
342 {
343 self.extensions().get::<Cookies>()?.get(name.as_ref())
344 }
345
346 #[cfg(feature = "session")]
347 fn session(&self) -> &Session {
348 self.extensions().get().expect("should get a session")
349 }
350
351 #[cfg(feature = "params")]
352 fn params<T>(&self) -> Result<T, ParamsError>
353 where
354 T: serde::de::DeserializeOwned,
355 {
356 T::deserialize(PathDeserializer::new(&self.route_info().params)).map_err(ParamsError::Parse)
357 }
358
359 #[cfg(feature = "params")]
360 fn param<T>(&self, name: &str) -> Result<T, ParamsError>
361 where
362 T: std::str::FromStr,
363 T::Err: std::fmt::Display,
364 {
365 self.route_info().params.find(name)
366 }
367
368 fn remote_addr(&self) -> Option<&std::net::SocketAddr> {
369 self.extensions().get()
370 }
371
372 #[cfg(feature = "params")]
373 fn route_info(&self) -> &Arc<RouteInfo> {
374 self.extensions().get().expect("should get current route")
375 }
376
377 fn realip(&self) -> Option<RealIp> {
378 RealIp::parse(self)
379 }
380}
381
382#[cfg(feature = "limits")]
384pub trait RequestLimitsExt: private::Sealed + Sized {
385 fn limits(&self) -> &Limits;
387
388 fn bytes_with(
392 &mut self,
393 limit: Option<u64>,
394 max: u64,
395 ) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
396
397 fn text_with_limit(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
401
402 #[cfg(feature = "form")]
407 fn form_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
408 where
409 T: serde::de::DeserializeOwned;
410
411 #[cfg(feature = "json")]
415 fn json_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
416 where
417 T: serde::de::DeserializeOwned;
418
419 #[cfg(feature = "multipart")]
424 fn multipart_with_limit(
425 &mut self,
426 ) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
427}
428
429#[cfg(feature = "limits")]
430impl RequestLimitsExt for Request {
431 fn limits(&self) -> &Limits {
432 self.extensions()
433 .get::<Limits>()
434 .expect("Limits middleware is required")
435 }
436
437 async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
438 Limited::new(
439 self.incoming()?,
440 usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
441 )
442 .collect()
443 .await
444 .map_err(|err| {
445 if err.is::<LengthLimitError>() {
446 return PayloadError::TooLarge;
447 }
448 if let Ok(err) = err.downcast::<hyper::Error>() {
449 return PayloadError::Hyper(*err);
450 }
451 PayloadError::Read
452 })
453 .map(Collected::to_bytes)
454 }
455
456 async fn text_with_limit(&mut self) -> Result<String, PayloadError> {
457 let bytes = self
458 .bytes_with(self.limits().get("text"), Limits::NORMAL)
459 .await?;
460 String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
461 }
462
463 #[cfg(feature = "form")]
464 async fn form_with_limit<T>(&mut self) -> Result<T, PayloadError>
465 where
466 T: serde::de::DeserializeOwned,
467 {
468 let limit = self.limits().get(<Form as Payload>::NAME);
469 <Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
470 let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
471 serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
472 }
473
474 #[cfg(feature = "json")]
475 async fn json_with_limit<T>(&mut self) -> Result<T, PayloadError>
476 where
477 T: serde::de::DeserializeOwned,
478 {
479 let limit = self.limits().get(<Json as Payload>::NAME);
480 <Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
481 let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
482 serde_json::from_slice(&bytes).map_err(PayloadError::Json)
483 }
484
485 #[cfg(feature = "multipart")]
486 async fn multipart_with_limit(&mut self) -> Result<Multipart, PayloadError> {
487 let limit = self.limits().get(<Multipart as Payload>::NAME);
488 let m = <Multipart as Payload>::check_header(
489 self.content_type(),
490 self.content_length(),
491 limit,
492 )?;
493 let boundary = m
494 .get_param(mime::BOUNDARY)
495 .ok_or(PayloadError::MissingBoundary)?
496 .as_str();
497 Ok(Multipart::with_limits(
498 self.incoming()?,
499 boundary,
500 self.extensions()
501 .get::<std::sync::Arc<crate::types::MultipartLimits>>()
502 .map(AsRef::as_ref)
503 .cloned()
504 .unwrap_or_default(),
505 ))
506 }
507}
508
509mod private {
510 pub trait Sealed {}
511 impl Sealed for super::Request {}
512}