viz_core/
request.rs

1use crate::{
2    header,
3    types::{PayloadError, RealIp},
4    Body, BodyState, Bytes, FromRequest, Future, Request, Result,
5};
6use headers::HeaderMapExt;
7use http_body_util::{BodyExt, Collected};
8
9#[cfg(any(feature = "params", feature = "multipart"))]
10use std::sync::Arc;
11
12#[cfg(feature = "limits")]
13use crate::types::Limits;
14#[cfg(feature = "limits")]
15use http_body_util::{LengthLimitError, Limited};
16
17#[cfg(any(feature = "form", feature = "json", feature = "multipart"))]
18use crate::types::Payload;
19
20#[cfg(feature = "form")]
21use crate::types::Form;
22
23#[cfg(feature = "json")]
24use crate::types::Json;
25
26#[cfg(feature = "multipart")]
27use crate::types::Multipart;
28
29#[cfg(feature = "cookie")]
30use crate::types::{Cookie, Cookies, CookiesError};
31
32#[cfg(feature = "session")]
33use crate::types::Session;
34
35#[cfg(feature = "params")]
36use crate::types::{ParamsError, PathDeserializer, RouteInfo};
37
38/// The [`Request`] Extension.
39pub trait RequestExt: private::Sealed + Sized {
40    /// Get URL's schema of this request.
41    fn schema(&self) -> Option<&http::uri::Scheme>;
42
43    /// Get URL's path of this request.
44    fn path(&self) -> &str;
45
46    /// Get URL's query string of this request.
47    fn query_string(&self) -> Option<&str>;
48
49    /// Get query data by type.
50    ///
51    /// # Errors
52    ///
53    /// Will return [`PayloadError::UrlDecode`] if decoding the query string fails.
54    #[cfg(feature = "query")]
55    fn query<T>(&self) -> Result<T, PayloadError>
56    where
57        T: serde::de::DeserializeOwned;
58
59    /// Get a header with the key.
60    fn header<K, T>(&self, key: K) -> Option<T>
61    where
62        K: header::AsHeaderName,
63        T: std::str::FromStr;
64
65    /// Get a header with the specified type.
66    fn header_typed<H>(&self) -> Option<H>
67    where
68        H: headers::Header;
69
70    /// Get the size of this request's body.
71    fn content_length(&self) -> Option<u64>;
72
73    /// Get the media type of this request.
74    fn content_type(&self) -> Option<mime::Mime>;
75
76    /// Extract the data from this request by the specified type.
77    fn extract<T>(&mut self) -> impl Future<Output = Result<T, T::Error>> + Send
78    where
79        T: FromRequest;
80
81    /// Get an incoming body.
82    ///
83    /// # Errors
84    ///
85    /// Will return [`PayloadError::Empty`] or [`PayloadError::Used`] if the incoming does not
86    /// exist or be used.
87    fn incoming(&mut self) -> Result<Body, PayloadError>;
88
89    /// Return with a [Bytes][mdn] representation of the request body.
90    ///
91    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
92    fn bytes(&mut self) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
93
94    /// Return with a [Text][mdn] representation of the request body.
95    ///
96    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
97    fn text(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
98
99    /// Return with a `application/x-www-form-urlencoded` [FormData][mdn] by the specified type
100    /// representation of the request body.
101    ///
102    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
103    #[cfg(feature = "form")]
104    fn form<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
105    where
106        T: serde::de::DeserializeOwned;
107
108    /// Return with a [JSON][mdn] by the specified type representation of the request body.
109    ///
110    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/json>
111    #[cfg(feature = "json")]
112    fn json<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
113    where
114        T: serde::de::DeserializeOwned;
115
116    /// Return with a `multipart/form-data` [FormData][mdn] by the specified type
117    /// representation of the request body.
118    ///
119    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
120    #[cfg(feature = "multipart")]
121    fn multipart(&mut self) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
122
123    /// Return a shared state by the specified type.
124    #[cfg(feature = "state")]
125    fn state<T>(&self) -> Option<T>
126    where
127        T: Clone + Send + Sync + 'static;
128
129    /// Store a shared state.
130    #[cfg(feature = "state")]
131    fn set_state<T>(&mut self, t: T) -> Option<T>
132    where
133        T: Clone + Send + Sync + 'static;
134
135    /// Get a wrapper of `cookie-jar` for managing cookies.
136    ///
137    /// # Errors
138    ///
139    /// Will return [`CookiesError`] if getting cookies fails.
140    #[cfg(feature = "cookie")]
141    fn cookies(&self) -> Result<Cookies, CookiesError>;
142
143    /// Get a cookie by the specified name.
144    #[cfg(feature = "cookie")]
145    fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
146    where
147        S: AsRef<str>;
148
149    /// Get current session.
150    #[cfg(feature = "session")]
151    fn session(&self) -> &Session;
152
153    /// Get all parameters.
154    ///
155    /// # Errors
156    ///
157    /// Will return [`ParamsError`] if deserializer the parameters fails.
158    #[cfg(feature = "params")]
159    fn params<T>(&self) -> Result<T, ParamsError>
160    where
161        T: serde::de::DeserializeOwned;
162
163    /// Get single parameter by name.
164    ///
165    /// # Errors
166    ///
167    /// Will return [`ParamsError`] if deserializer the single parameter fails.
168    #[cfg(feature = "params")]
169    fn param<T>(&self, name: &str) -> Result<T, ParamsError>
170    where
171        T: std::str::FromStr,
172        T::Err: std::fmt::Display;
173
174    /// Get current route.
175    #[cfg(feature = "params")]
176    fn route_info(&self) -> &Arc<RouteInfo>;
177
178    /// Get remote addr.
179    fn remote_addr(&self) -> Option<&std::net::SocketAddr>;
180
181    /// Get realip.
182    fn realip(&self) -> Option<RealIp>;
183}
184
185impl RequestExt for Request {
186    fn schema(&self) -> Option<&http::uri::Scheme> {
187        self.uri().scheme()
188    }
189
190    fn path(&self) -> &str {
191        self.uri().path()
192    }
193
194    fn query_string(&self) -> Option<&str> {
195        self.uri().query()
196    }
197
198    #[cfg(feature = "query")]
199    fn query<T>(&self) -> Result<T, PayloadError>
200    where
201        T: serde::de::DeserializeOwned,
202    {
203        serde_urlencoded::from_str(self.query_string().unwrap_or_default())
204            .map_err(PayloadError::UrlDecode)
205    }
206
207    fn header<K, T>(&self, key: K) -> Option<T>
208    where
209        K: header::AsHeaderName,
210        T: std::str::FromStr,
211    {
212        self.headers()
213            .get(key)
214            .map(header::HeaderValue::to_str)
215            .and_then(Result::ok)
216            .map(str::parse)
217            .and_then(Result::ok)
218    }
219
220    fn header_typed<H>(&self) -> Option<H>
221    where
222        H: headers::Header,
223    {
224        self.headers().typed_get()
225    }
226
227    fn content_length(&self) -> Option<u64> {
228        self.header(header::CONTENT_LENGTH)
229    }
230
231    fn content_type(&self) -> Option<mime::Mime> {
232        self.header(header::CONTENT_TYPE)
233    }
234
235    async fn extract<T>(&mut self) -> Result<T, T::Error>
236    where
237        T: FromRequest,
238    {
239        T::extract(self).await
240    }
241
242    fn incoming(&mut self) -> Result<Body, PayloadError> {
243        if let Some(state) = self.extensions().get::<BodyState>() {
244            match state {
245                BodyState::Empty => Err(PayloadError::Empty)?,
246                BodyState::Used => Err(PayloadError::Used)?,
247                BodyState::Normal => {}
248            }
249        }
250
251        let (state, result) = match std::mem::replace(self.body_mut(), Body::Empty) {
252            Body::Empty => (BodyState::Empty, Err(PayloadError::Empty)),
253            body => (BodyState::Used, Ok(body)),
254        };
255
256        self.extensions_mut().insert(state);
257        result
258    }
259
260    async fn bytes(&mut self) -> Result<Bytes, PayloadError> {
261        self.incoming()?
262            .collect()
263            .await
264            .map_err(|err| {
265                #[cfg(feature = "limits")]
266                if err.is::<LengthLimitError>() {
267                    return PayloadError::TooLarge;
268                }
269                if let Ok(err) = err.downcast::<hyper::Error>() {
270                    return PayloadError::Hyper(err);
271                }
272                PayloadError::Read
273            })
274            .map(Collected::to_bytes)
275    }
276
277    async fn text(&mut self) -> Result<String, PayloadError> {
278        let bytes = self.bytes().await?;
279        String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
280    }
281
282    #[cfg(feature = "form")]
283    async fn form<T>(&mut self) -> Result<T, PayloadError>
284    where
285        T: serde::de::DeserializeOwned,
286    {
287        <Form as Payload>::check_type(self.content_type())?;
288        let bytes = self.bytes().await?;
289        serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
290    }
291
292    #[cfg(feature = "json")]
293    async fn json<T>(&mut self) -> Result<T, PayloadError>
294    where
295        T: serde::de::DeserializeOwned,
296    {
297        <Json as Payload>::check_type(self.content_type())?;
298        let bytes = self.bytes().await?;
299        serde_json::from_slice(&bytes).map_err(PayloadError::Json)
300    }
301
302    #[cfg(feature = "multipart")]
303    async fn multipart(&mut self) -> Result<Multipart, PayloadError> {
304        let m = <Multipart as Payload>::check_type(self.content_type())?;
305
306        let boundary = m
307            .get_param(mime::BOUNDARY)
308            .ok_or(PayloadError::MissingBoundary)?
309            .as_str();
310
311        Ok(Multipart::new(self.incoming()?, boundary))
312    }
313
314    #[cfg(feature = "state")]
315    fn state<T>(&self) -> Option<T>
316    where
317        T: Clone + Send + Sync + 'static,
318    {
319        self.extensions().get().cloned()
320    }
321
322    #[cfg(feature = "state")]
323    fn set_state<T>(&mut self, t: T) -> Option<T>
324    where
325        T: Clone + Send + Sync + 'static,
326    {
327        self.extensions_mut().insert(t)
328    }
329
330    #[cfg(feature = "cookie")]
331    fn cookies(&self) -> Result<Cookies, CookiesError> {
332        self.extensions()
333            .get::<Cookies>()
334            .cloned()
335            .ok_or(CookiesError::Read)
336    }
337
338    #[cfg(feature = "cookie")]
339    fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
340    where
341        S: AsRef<str>,
342    {
343        self.extensions().get::<Cookies>()?.get(name.as_ref())
344    }
345
346    #[cfg(feature = "session")]
347    fn session(&self) -> &Session {
348        self.extensions().get().expect("should get a session")
349    }
350
351    #[cfg(feature = "params")]
352    fn params<T>(&self) -> Result<T, ParamsError>
353    where
354        T: serde::de::DeserializeOwned,
355    {
356        T::deserialize(PathDeserializer::new(&self.route_info().params)).map_err(ParamsError::Parse)
357    }
358
359    #[cfg(feature = "params")]
360    fn param<T>(&self, name: &str) -> Result<T, ParamsError>
361    where
362        T: std::str::FromStr,
363        T::Err: std::fmt::Display,
364    {
365        self.route_info().params.find(name)
366    }
367
368    fn remote_addr(&self) -> Option<&std::net::SocketAddr> {
369        self.extensions().get()
370    }
371
372    #[cfg(feature = "params")]
373    fn route_info(&self) -> &Arc<RouteInfo> {
374        self.extensions().get().expect("should get current route")
375    }
376
377    fn realip(&self) -> Option<RealIp> {
378        RealIp::parse(self)
379    }
380}
381
382/// The [`Request`] Extension with a limited body.
383#[cfg(feature = "limits")]
384pub trait RequestLimitsExt: private::Sealed + Sized {
385    /// Get limits settings.
386    fn limits(&self) -> &Limits;
387
388    /// Return with a [Bytes][mdn] by a limit representation of the request body.
389    ///
390    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
391    fn bytes_with(
392        &mut self,
393        limit: Option<u64>,
394        max: u64,
395    ) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
396
397    /// Return with a limited [Text][mdn] representation of the request body.
398    ///
399    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
400    fn text_with_limit(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
401
402    /// Return with a limited `application/x-www-form-urlencoded` [FormData][mdn] by the specified type
403    /// representation of the request body.
404    ///
405    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
406    #[cfg(feature = "form")]
407    fn form_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
408    where
409        T: serde::de::DeserializeOwned;
410
411    /// Return with a limited [JSON][mdn] by the specified type representation of the request body.
412    ///
413    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/json>
414    #[cfg(feature = "json")]
415    fn json_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
416    where
417        T: serde::de::DeserializeOwned;
418
419    /// Return with a limited `multipart/form-data` [FormData][mdn] by the specified type
420    /// representation of the request body.
421    ///
422    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
423    #[cfg(feature = "multipart")]
424    fn multipart_with_limit(
425        &mut self,
426    ) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
427}
428
429#[cfg(feature = "limits")]
430impl RequestLimitsExt for Request {
431    fn limits(&self) -> &Limits {
432        self.extensions()
433            .get::<Limits>()
434            .expect("Limits middleware is required")
435    }
436
437    async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
438        Limited::new(
439            self.incoming()?,
440            usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
441        )
442        .collect()
443        .await
444        .map_err(|err| {
445            if err.is::<LengthLimitError>() {
446                return PayloadError::TooLarge;
447            }
448            if let Ok(err) = err.downcast::<hyper::Error>() {
449                return PayloadError::Hyper(*err);
450            }
451            PayloadError::Read
452        })
453        .map(Collected::to_bytes)
454    }
455
456    async fn text_with_limit(&mut self) -> Result<String, PayloadError> {
457        let bytes = self
458            .bytes_with(self.limits().get("text"), Limits::NORMAL)
459            .await?;
460        String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
461    }
462
463    #[cfg(feature = "form")]
464    async fn form_with_limit<T>(&mut self) -> Result<T, PayloadError>
465    where
466        T: serde::de::DeserializeOwned,
467    {
468        let limit = self.limits().get(<Form as Payload>::NAME);
469        <Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
470        let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
471        serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
472    }
473
474    #[cfg(feature = "json")]
475    async fn json_with_limit<T>(&mut self) -> Result<T, PayloadError>
476    where
477        T: serde::de::DeserializeOwned,
478    {
479        let limit = self.limits().get(<Json as Payload>::NAME);
480        <Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
481        let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
482        serde_json::from_slice(&bytes).map_err(PayloadError::Json)
483    }
484
485    #[cfg(feature = "multipart")]
486    async fn multipart_with_limit(&mut self) -> Result<Multipart, PayloadError> {
487        let limit = self.limits().get(<Multipart as Payload>::NAME);
488        let m = <Multipart as Payload>::check_header(
489            self.content_type(),
490            self.content_length(),
491            limit,
492        )?;
493        let boundary = m
494            .get_param(mime::BOUNDARY)
495            .ok_or(PayloadError::MissingBoundary)?
496            .as_str();
497        Ok(Multipart::with_limits(
498            self.incoming()?,
499            boundary,
500            self.extensions()
501                .get::<std::sync::Arc<crate::types::MultipartLimits>>()
502                .map(AsRef::as_ref)
503                .cloned()
504                .unwrap_or_default(),
505        ))
506    }
507}
508
509mod private {
510    pub trait Sealed {}
511    impl Sealed for super::Request {}
512}