1use crate::{
2 Body, BodyState, Bytes, FromRequest, Future, Request, Result, header,
3 types::{PayloadError, RealIp},
4};
5use headers::HeaderMapExt;
6use http_body_util::{BodyExt, Collected};
7
8#[cfg(any(feature = "params", feature = "multipart"))]
9use std::sync::Arc;
10
11#[cfg(feature = "limits")]
12use crate::types::Limits;
13#[cfg(feature = "limits")]
14use http_body_util::{LengthLimitError, Limited};
15
16#[cfg(any(feature = "form", feature = "json", feature = "multipart"))]
17use crate::types::Payload;
18
19#[cfg(feature = "form")]
20use crate::types::Form;
21
22#[cfg(feature = "json")]
23use crate::types::Json;
24
25#[cfg(feature = "multipart")]
26use crate::types::Multipart;
27
28#[cfg(feature = "cookie")]
29use crate::types::{Cookie, Cookies, CookiesError};
30
31#[cfg(feature = "session")]
32use crate::types::Session;
33
34#[cfg(feature = "params")]
35use crate::types::{ParamsError, PathDeserializer, RouteInfo};
36
37pub trait RequestExt: private::Sealed + Sized {
39 fn schema(&self) -> Option<&http::uri::Scheme>;
41
42 fn path(&self) -> &str;
44
45 fn query_string(&self) -> Option<&str>;
47
48 #[cfg(feature = "query")]
54 fn query<T>(&self) -> Result<T, PayloadError>
55 where
56 T: serde::de::DeserializeOwned;
57
58 fn header<K, T>(&self, key: K) -> Option<T>
60 where
61 K: header::AsHeaderName,
62 T: std::str::FromStr;
63
64 fn header_typed<H>(&self) -> Option<H>
66 where
67 H: headers::Header;
68
69 fn content_length(&self) -> Option<u64>;
71
72 fn content_type(&self) -> Option<mime::Mime>;
74
75 fn extract<T>(&mut self) -> impl Future<Output = Result<T, T::Error>> + Send
77 where
78 T: FromRequest;
79
80 fn incoming(&mut self) -> Result<Body, PayloadError>;
87
88 fn bytes(&mut self) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
92
93 fn text(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
97
98 #[cfg(feature = "form")]
103 fn form<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
104 where
105 T: serde::de::DeserializeOwned;
106
107 #[cfg(feature = "json")]
111 fn json<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
112 where
113 T: serde::de::DeserializeOwned;
114
115 #[cfg(feature = "multipart")]
120 fn multipart(&mut self) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
121
122 #[cfg(feature = "state")]
124 fn state<T>(&self) -> Option<T>
125 where
126 T: Clone + Send + Sync + 'static;
127
128 #[cfg(feature = "state")]
130 fn set_state<T>(&mut self, t: T) -> Option<T>
131 where
132 T: Clone + Send + Sync + 'static;
133
134 #[cfg(feature = "cookie")]
140 fn cookies(&self) -> Result<Cookies, CookiesError>;
141
142 #[cfg(feature = "cookie")]
144 fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
145 where
146 S: AsRef<str>;
147
148 #[cfg(feature = "session")]
150 fn session(&self) -> &Session;
151
152 #[cfg(feature = "params")]
158 fn params<T>(&self) -> Result<T, ParamsError>
159 where
160 T: serde::de::DeserializeOwned;
161
162 #[cfg(feature = "params")]
168 fn param<T>(&self, name: &str) -> Result<T, ParamsError>
169 where
170 T: std::str::FromStr,
171 T::Err: std::fmt::Display;
172
173 #[cfg(feature = "params")]
175 fn route_info(&self) -> &Arc<RouteInfo>;
176
177 fn remote_addr(&self) -> Option<&std::net::SocketAddr>;
179
180 fn realip(&self) -> Option<RealIp>;
182}
183
184impl RequestExt for Request {
185 fn schema(&self) -> Option<&http::uri::Scheme> {
186 self.uri().scheme()
187 }
188
189 fn path(&self) -> &str {
190 self.uri().path()
191 }
192
193 fn query_string(&self) -> Option<&str> {
194 self.uri().query()
195 }
196
197 #[cfg(feature = "query")]
198 fn query<T>(&self) -> Result<T, PayloadError>
199 where
200 T: serde::de::DeserializeOwned,
201 {
202 serde_urlencoded::from_str(self.query_string().unwrap_or_default())
203 .map_err(PayloadError::UrlDecode)
204 }
205
206 fn header<K, T>(&self, key: K) -> Option<T>
207 where
208 K: header::AsHeaderName,
209 T: std::str::FromStr,
210 {
211 self.headers()
212 .get(key)
213 .map(header::HeaderValue::to_str)
214 .and_then(Result::ok)
215 .map(str::parse)
216 .and_then(Result::ok)
217 }
218
219 fn header_typed<H>(&self) -> Option<H>
220 where
221 H: headers::Header,
222 {
223 self.headers().typed_get()
224 }
225
226 fn content_length(&self) -> Option<u64> {
227 self.header(header::CONTENT_LENGTH)
228 }
229
230 fn content_type(&self) -> Option<mime::Mime> {
231 self.header(header::CONTENT_TYPE)
232 }
233
234 async fn extract<T>(&mut self) -> Result<T, T::Error>
235 where
236 T: FromRequest,
237 {
238 T::extract(self).await
239 }
240
241 fn incoming(&mut self) -> Result<Body, PayloadError> {
242 if let Some(state) = self.extensions().get::<BodyState>() {
243 match state {
244 BodyState::Empty => Err(PayloadError::Empty)?,
245 BodyState::Used => Err(PayloadError::Used)?,
246 BodyState::Normal => {}
247 }
248 }
249
250 let (state, result) = match std::mem::replace(self.body_mut(), Body::Empty) {
251 Body::Empty => (BodyState::Empty, Err(PayloadError::Empty)),
252 body => (BodyState::Used, Ok(body)),
253 };
254
255 self.extensions_mut().insert(state);
256 result
257 }
258
259 async fn bytes(&mut self) -> Result<Bytes, PayloadError> {
260 self.incoming()?
261 .collect()
262 .await
263 .map_err(|err| {
264 #[cfg(feature = "limits")]
265 if err.is::<LengthLimitError>() {
266 return PayloadError::TooLarge;
267 }
268 if let Ok(err) = err.downcast::<hyper::Error>() {
269 return PayloadError::Hyper(err);
270 }
271 PayloadError::Read
272 })
273 .map(Collected::to_bytes)
274 }
275
276 async fn text(&mut self) -> Result<String, PayloadError> {
277 let bytes = self.bytes().await?;
278 String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
279 }
280
281 #[cfg(feature = "form")]
282 async fn form<T>(&mut self) -> Result<T, PayloadError>
283 where
284 T: serde::de::DeserializeOwned,
285 {
286 <Form as Payload>::check_type(self.content_type())?;
287 let bytes = self.bytes().await?;
288 serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
289 }
290
291 #[cfg(feature = "json")]
292 async fn json<T>(&mut self) -> Result<T, PayloadError>
293 where
294 T: serde::de::DeserializeOwned,
295 {
296 <Json as Payload>::check_type(self.content_type())?;
297 let bytes = self.bytes().await?;
298 serde_json::from_slice(&bytes).map_err(PayloadError::Json)
299 }
300
301 #[cfg(feature = "multipart")]
302 async fn multipart(&mut self) -> Result<Multipart, PayloadError> {
303 let m = <Multipart as Payload>::check_type(self.content_type())?;
304
305 let boundary = m
306 .get_param(mime::BOUNDARY)
307 .ok_or(PayloadError::MissingBoundary)?
308 .as_str();
309
310 Ok(Multipart::new(self.incoming()?, boundary))
311 }
312
313 #[cfg(feature = "state")]
314 fn state<T>(&self) -> Option<T>
315 where
316 T: Clone + Send + Sync + 'static,
317 {
318 self.extensions().get().cloned()
319 }
320
321 #[cfg(feature = "state")]
322 fn set_state<T>(&mut self, t: T) -> Option<T>
323 where
324 T: Clone + Send + Sync + 'static,
325 {
326 self.extensions_mut().insert(t)
327 }
328
329 #[cfg(feature = "cookie")]
330 fn cookies(&self) -> Result<Cookies, CookiesError> {
331 self.extensions()
332 .get::<Cookies>()
333 .cloned()
334 .ok_or(CookiesError::Read)
335 }
336
337 #[cfg(feature = "cookie")]
338 fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
339 where
340 S: AsRef<str>,
341 {
342 self.extensions().get::<Cookies>()?.get(name.as_ref())
343 }
344
345 #[cfg(feature = "session")]
346 fn session(&self) -> &Session {
347 self.extensions().get().expect("should get a session")
348 }
349
350 #[cfg(feature = "params")]
351 fn params<T>(&self) -> Result<T, ParamsError>
352 where
353 T: serde::de::DeserializeOwned,
354 {
355 T::deserialize(PathDeserializer::new(&self.route_info().params)).map_err(ParamsError::Parse)
356 }
357
358 #[cfg(feature = "params")]
359 fn param<T>(&self, name: &str) -> Result<T, ParamsError>
360 where
361 T: std::str::FromStr,
362 T::Err: std::fmt::Display,
363 {
364 self.route_info().params.find(name)
365 }
366
367 fn remote_addr(&self) -> Option<&std::net::SocketAddr> {
368 self.extensions().get()
369 }
370
371 #[cfg(feature = "params")]
372 fn route_info(&self) -> &Arc<RouteInfo> {
373 self.extensions().get().expect("should get current route")
374 }
375
376 fn realip(&self) -> Option<RealIp> {
377 RealIp::parse(self)
378 }
379}
380
381#[cfg(feature = "limits")]
383pub trait RequestLimitsExt: private::Sealed + Sized {
384 fn limits(&self) -> &Limits;
386
387 fn bytes_with(
391 &mut self,
392 limit: Option<u64>,
393 max: u64,
394 ) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
395
396 fn text_with_limit(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
400
401 #[cfg(feature = "form")]
406 fn form_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
407 where
408 T: serde::de::DeserializeOwned;
409
410 #[cfg(feature = "json")]
414 fn json_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
415 where
416 T: serde::de::DeserializeOwned;
417
418 #[cfg(feature = "multipart")]
423 fn multipart_with_limit(
424 &mut self,
425 ) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
426}
427
428#[cfg(feature = "limits")]
429impl RequestLimitsExt for Request {
430 fn limits(&self) -> &Limits {
431 self.extensions()
432 .get::<Limits>()
433 .expect("Limits middleware is required")
434 }
435
436 async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
437 Limited::new(
438 self.incoming()?,
439 usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
440 )
441 .collect()
442 .await
443 .map_err(|err| {
444 if err.is::<LengthLimitError>() {
445 return PayloadError::TooLarge;
446 }
447 if let Ok(err) = err.downcast::<hyper::Error>() {
448 return PayloadError::Hyper(*err);
449 }
450 PayloadError::Read
451 })
452 .map(Collected::to_bytes)
453 }
454
455 async fn text_with_limit(&mut self) -> Result<String, PayloadError> {
456 let bytes = self
457 .bytes_with(self.limits().get("text"), Limits::NORMAL)
458 .await?;
459 String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
460 }
461
462 #[cfg(feature = "form")]
463 async fn form_with_limit<T>(&mut self) -> Result<T, PayloadError>
464 where
465 T: serde::de::DeserializeOwned,
466 {
467 let limit = self.limits().get(<Form as Payload>::NAME);
468 <Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
469 let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
470 serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
471 }
472
473 #[cfg(feature = "json")]
474 async fn json_with_limit<T>(&mut self) -> Result<T, PayloadError>
475 where
476 T: serde::de::DeserializeOwned,
477 {
478 let limit = self.limits().get(<Json as Payload>::NAME);
479 <Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
480 let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
481 serde_json::from_slice(&bytes).map_err(PayloadError::Json)
482 }
483
484 #[cfg(feature = "multipart")]
485 async fn multipart_with_limit(&mut self) -> Result<Multipart, PayloadError> {
486 let limit = self.limits().get(<Multipart as Payload>::NAME);
487 let m = <Multipart as Payload>::check_header(
488 self.content_type(),
489 self.content_length(),
490 limit,
491 )?;
492 let boundary = m
493 .get_param(mime::BOUNDARY)
494 .ok_or(PayloadError::MissingBoundary)?
495 .as_str();
496 Ok(Multipart::with_limits(
497 self.incoming()?,
498 boundary,
499 self.extensions()
500 .get::<std::sync::Arc<crate::types::MultipartLimits>>()
501 .map(AsRef::as_ref)
502 .cloned()
503 .unwrap_or_default(),
504 ))
505 }
506}
507
508mod private {
509 pub trait Sealed {}
510 impl Sealed for super::Request {}
511}