viz_core/
request.rs

1use crate::{
2    Body, BodyState, Bytes, FromRequest, Future, Request, Result, header,
3    types::{PayloadError, RealIp},
4};
5use headers::HeaderMapExt;
6use http_body_util::{BodyExt, Collected};
7
8#[cfg(any(feature = "params", feature = "multipart"))]
9use std::sync::Arc;
10
11#[cfg(feature = "limits")]
12use crate::types::Limits;
13#[cfg(feature = "limits")]
14use http_body_util::{LengthLimitError, Limited};
15
16#[cfg(any(feature = "form", feature = "json", feature = "multipart"))]
17use crate::types::Payload;
18
19#[cfg(feature = "form")]
20use crate::types::Form;
21
22#[cfg(feature = "json")]
23use crate::types::Json;
24
25#[cfg(feature = "multipart")]
26use crate::types::Multipart;
27
28#[cfg(feature = "cookie")]
29use crate::types::{Cookie, Cookies, CookiesError};
30
31#[cfg(feature = "session")]
32use crate::types::Session;
33
34#[cfg(feature = "params")]
35use crate::types::{ParamsError, PathDeserializer, RouteInfo};
36
37/// The [`Request`] Extension.
38pub trait RequestExt: private::Sealed + Sized {
39    /// Get URL's schema of this request.
40    fn schema(&self) -> Option<&http::uri::Scheme>;
41
42    /// Get URL's path of this request.
43    fn path(&self) -> &str;
44
45    /// Get URL's query string of this request.
46    fn query_string(&self) -> Option<&str>;
47
48    /// Get query data by type.
49    ///
50    /// # Errors
51    ///
52    /// Will return [`PayloadError::UrlDecode`] if decoding the query string fails.
53    #[cfg(feature = "query")]
54    fn query<T>(&self) -> Result<T, PayloadError>
55    where
56        T: serde::de::DeserializeOwned;
57
58    /// Get a header with the key.
59    fn header<K, T>(&self, key: K) -> Option<T>
60    where
61        K: header::AsHeaderName,
62        T: std::str::FromStr;
63
64    /// Get a header with the specified type.
65    fn header_typed<H>(&self) -> Option<H>
66    where
67        H: headers::Header;
68
69    /// Get the size of this request's body.
70    fn content_length(&self) -> Option<u64>;
71
72    /// Get the media type of this request.
73    fn content_type(&self) -> Option<mime::Mime>;
74
75    /// Extract the data from this request by the specified type.
76    fn extract<T>(&mut self) -> impl Future<Output = Result<T, T::Error>> + Send
77    where
78        T: FromRequest;
79
80    /// Get an incoming body.
81    ///
82    /// # Errors
83    ///
84    /// Will return [`PayloadError::Empty`] or [`PayloadError::Used`] if the incoming does not
85    /// exist or be used.
86    fn incoming(&mut self) -> Result<Body, PayloadError>;
87
88    /// Return with a [Bytes][mdn] representation of the request body.
89    ///
90    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
91    fn bytes(&mut self) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
92
93    /// Return with a [Text][mdn] representation of the request body.
94    ///
95    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
96    fn text(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
97
98    /// Return with a `application/x-www-form-urlencoded` [FormData][mdn] by the specified type
99    /// representation of the request body.
100    ///
101    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
102    #[cfg(feature = "form")]
103    fn form<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
104    where
105        T: serde::de::DeserializeOwned;
106
107    /// Return with a [JSON][mdn] by the specified type representation of the request body.
108    ///
109    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/json>
110    #[cfg(feature = "json")]
111    fn json<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
112    where
113        T: serde::de::DeserializeOwned;
114
115    /// Return with a `multipart/form-data` [FormData][mdn] by the specified type
116    /// representation of the request body.
117    ///
118    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
119    #[cfg(feature = "multipart")]
120    fn multipart(&mut self) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
121
122    /// Return a shared state by the specified type.
123    #[cfg(feature = "state")]
124    fn state<T>(&self) -> Option<T>
125    where
126        T: Clone + Send + Sync + 'static;
127
128    /// Store a shared state.
129    #[cfg(feature = "state")]
130    fn set_state<T>(&mut self, t: T) -> Option<T>
131    where
132        T: Clone + Send + Sync + 'static;
133
134    /// Get a wrapper of `cookie-jar` for managing cookies.
135    ///
136    /// # Errors
137    ///
138    /// Will return [`CookiesError`] if getting cookies fails.
139    #[cfg(feature = "cookie")]
140    fn cookies(&self) -> Result<Cookies, CookiesError>;
141
142    /// Get a cookie by the specified name.
143    #[cfg(feature = "cookie")]
144    fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
145    where
146        S: AsRef<str>;
147
148    /// Get current session.
149    #[cfg(feature = "session")]
150    fn session(&self) -> &Session;
151
152    /// Get all parameters.
153    ///
154    /// # Errors
155    ///
156    /// Will return [`ParamsError`] if deserializer the parameters fails.
157    #[cfg(feature = "params")]
158    fn params<T>(&self) -> Result<T, ParamsError>
159    where
160        T: serde::de::DeserializeOwned;
161
162    /// Get single parameter by name.
163    ///
164    /// # Errors
165    ///
166    /// Will return [`ParamsError`] if deserializer the single parameter fails.
167    #[cfg(feature = "params")]
168    fn param<T>(&self, name: &str) -> Result<T, ParamsError>
169    where
170        T: std::str::FromStr,
171        T::Err: std::fmt::Display;
172
173    /// Get current route.
174    #[cfg(feature = "params")]
175    fn route_info(&self) -> &Arc<RouteInfo>;
176
177    /// Get remote addr.
178    fn remote_addr(&self) -> Option<&std::net::SocketAddr>;
179
180    /// Get realip.
181    fn realip(&self) -> Option<RealIp>;
182}
183
184impl RequestExt for Request {
185    fn schema(&self) -> Option<&http::uri::Scheme> {
186        self.uri().scheme()
187    }
188
189    fn path(&self) -> &str {
190        self.uri().path()
191    }
192
193    fn query_string(&self) -> Option<&str> {
194        self.uri().query()
195    }
196
197    #[cfg(feature = "query")]
198    fn query<T>(&self) -> Result<T, PayloadError>
199    where
200        T: serde::de::DeserializeOwned,
201    {
202        serde_urlencoded::from_str(self.query_string().unwrap_or_default())
203            .map_err(PayloadError::UrlDecode)
204    }
205
206    fn header<K, T>(&self, key: K) -> Option<T>
207    where
208        K: header::AsHeaderName,
209        T: std::str::FromStr,
210    {
211        self.headers()
212            .get(key)
213            .map(header::HeaderValue::to_str)
214            .and_then(Result::ok)
215            .map(str::parse)
216            .and_then(Result::ok)
217    }
218
219    fn header_typed<H>(&self) -> Option<H>
220    where
221        H: headers::Header,
222    {
223        self.headers().typed_get()
224    }
225
226    fn content_length(&self) -> Option<u64> {
227        self.header(header::CONTENT_LENGTH)
228    }
229
230    fn content_type(&self) -> Option<mime::Mime> {
231        self.header(header::CONTENT_TYPE)
232    }
233
234    async fn extract<T>(&mut self) -> Result<T, T::Error>
235    where
236        T: FromRequest,
237    {
238        T::extract(self).await
239    }
240
241    fn incoming(&mut self) -> Result<Body, PayloadError> {
242        if let Some(state) = self.extensions().get::<BodyState>() {
243            match state {
244                BodyState::Empty => Err(PayloadError::Empty)?,
245                BodyState::Used => Err(PayloadError::Used)?,
246                BodyState::Normal => {}
247            }
248        }
249
250        let (state, result) = match std::mem::replace(self.body_mut(), Body::Empty) {
251            Body::Empty => (BodyState::Empty, Err(PayloadError::Empty)),
252            body => (BodyState::Used, Ok(body)),
253        };
254
255        self.extensions_mut().insert(state);
256        result
257    }
258
259    async fn bytes(&mut self) -> Result<Bytes, PayloadError> {
260        self.incoming()?
261            .collect()
262            .await
263            .map_err(|err| {
264                #[cfg(feature = "limits")]
265                if err.is::<LengthLimitError>() {
266                    return PayloadError::TooLarge;
267                }
268                if let Ok(err) = err.downcast::<hyper::Error>() {
269                    return PayloadError::Hyper(err);
270                }
271                PayloadError::Read
272            })
273            .map(Collected::to_bytes)
274    }
275
276    async fn text(&mut self) -> Result<String, PayloadError> {
277        let bytes = self.bytes().await?;
278        String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
279    }
280
281    #[cfg(feature = "form")]
282    async fn form<T>(&mut self) -> Result<T, PayloadError>
283    where
284        T: serde::de::DeserializeOwned,
285    {
286        <Form as Payload>::check_type(self.content_type())?;
287        let bytes = self.bytes().await?;
288        serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
289    }
290
291    #[cfg(feature = "json")]
292    async fn json<T>(&mut self) -> Result<T, PayloadError>
293    where
294        T: serde::de::DeserializeOwned,
295    {
296        <Json as Payload>::check_type(self.content_type())?;
297        let bytes = self.bytes().await?;
298        serde_json::from_slice(&bytes).map_err(PayloadError::Json)
299    }
300
301    #[cfg(feature = "multipart")]
302    async fn multipart(&mut self) -> Result<Multipart, PayloadError> {
303        let m = <Multipart as Payload>::check_type(self.content_type())?;
304
305        let boundary = m
306            .get_param(mime::BOUNDARY)
307            .ok_or(PayloadError::MissingBoundary)?
308            .as_str();
309
310        Ok(Multipart::new(self.incoming()?, boundary))
311    }
312
313    #[cfg(feature = "state")]
314    fn state<T>(&self) -> Option<T>
315    where
316        T: Clone + Send + Sync + 'static,
317    {
318        self.extensions().get().cloned()
319    }
320
321    #[cfg(feature = "state")]
322    fn set_state<T>(&mut self, t: T) -> Option<T>
323    where
324        T: Clone + Send + Sync + 'static,
325    {
326        self.extensions_mut().insert(t)
327    }
328
329    #[cfg(feature = "cookie")]
330    fn cookies(&self) -> Result<Cookies, CookiesError> {
331        self.extensions()
332            .get::<Cookies>()
333            .cloned()
334            .ok_or(CookiesError::Read)
335    }
336
337    #[cfg(feature = "cookie")]
338    fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
339    where
340        S: AsRef<str>,
341    {
342        self.extensions().get::<Cookies>()?.get(name.as_ref())
343    }
344
345    #[cfg(feature = "session")]
346    fn session(&self) -> &Session {
347        self.extensions().get().expect("should get a session")
348    }
349
350    #[cfg(feature = "params")]
351    fn params<T>(&self) -> Result<T, ParamsError>
352    where
353        T: serde::de::DeserializeOwned,
354    {
355        T::deserialize(PathDeserializer::new(&self.route_info().params)).map_err(ParamsError::Parse)
356    }
357
358    #[cfg(feature = "params")]
359    fn param<T>(&self, name: &str) -> Result<T, ParamsError>
360    where
361        T: std::str::FromStr,
362        T::Err: std::fmt::Display,
363    {
364        self.route_info().params.find(name)
365    }
366
367    fn remote_addr(&self) -> Option<&std::net::SocketAddr> {
368        self.extensions().get()
369    }
370
371    #[cfg(feature = "params")]
372    fn route_info(&self) -> &Arc<RouteInfo> {
373        self.extensions().get().expect("should get current route")
374    }
375
376    fn realip(&self) -> Option<RealIp> {
377        RealIp::parse(self)
378    }
379}
380
381/// The [`Request`] Extension with a limited body.
382#[cfg(feature = "limits")]
383pub trait RequestLimitsExt: private::Sealed + Sized {
384    /// Get limits settings.
385    fn limits(&self) -> &Limits;
386
387    /// Return with a [Bytes][mdn] by a limit representation of the request body.
388    ///
389    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
390    fn bytes_with(
391        &mut self,
392        limit: Option<u64>,
393        max: u64,
394    ) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
395
396    /// Return with a limited [Text][mdn] representation of the request body.
397    ///
398    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
399    fn text_with_limit(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
400
401    /// Return with a limited `application/x-www-form-urlencoded` [FormData][mdn] by the specified type
402    /// representation of the request body.
403    ///
404    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
405    #[cfg(feature = "form")]
406    fn form_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
407    where
408        T: serde::de::DeserializeOwned;
409
410    /// Return with a limited [JSON][mdn] by the specified type representation of the request body.
411    ///
412    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/json>
413    #[cfg(feature = "json")]
414    fn json_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
415    where
416        T: serde::de::DeserializeOwned;
417
418    /// Return with a limited `multipart/form-data` [FormData][mdn] by the specified type
419    /// representation of the request body.
420    ///
421    /// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
422    #[cfg(feature = "multipart")]
423    fn multipart_with_limit(
424        &mut self,
425    ) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
426}
427
428#[cfg(feature = "limits")]
429impl RequestLimitsExt for Request {
430    fn limits(&self) -> &Limits {
431        self.extensions()
432            .get::<Limits>()
433            .expect("Limits middleware is required")
434    }
435
436    async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
437        Limited::new(
438            self.incoming()?,
439            usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
440        )
441        .collect()
442        .await
443        .map_err(|err| {
444            if err.is::<LengthLimitError>() {
445                return PayloadError::TooLarge;
446            }
447            if let Ok(err) = err.downcast::<hyper::Error>() {
448                return PayloadError::Hyper(*err);
449            }
450            PayloadError::Read
451        })
452        .map(Collected::to_bytes)
453    }
454
455    async fn text_with_limit(&mut self) -> Result<String, PayloadError> {
456        let bytes = self
457            .bytes_with(self.limits().get("text"), Limits::NORMAL)
458            .await?;
459        String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
460    }
461
462    #[cfg(feature = "form")]
463    async fn form_with_limit<T>(&mut self) -> Result<T, PayloadError>
464    where
465        T: serde::de::DeserializeOwned,
466    {
467        let limit = self.limits().get(<Form as Payload>::NAME);
468        <Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
469        let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
470        serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
471    }
472
473    #[cfg(feature = "json")]
474    async fn json_with_limit<T>(&mut self) -> Result<T, PayloadError>
475    where
476        T: serde::de::DeserializeOwned,
477    {
478        let limit = self.limits().get(<Json as Payload>::NAME);
479        <Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
480        let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
481        serde_json::from_slice(&bytes).map_err(PayloadError::Json)
482    }
483
484    #[cfg(feature = "multipart")]
485    async fn multipart_with_limit(&mut self) -> Result<Multipart, PayloadError> {
486        let limit = self.limits().get(<Multipart as Payload>::NAME);
487        let m = <Multipart as Payload>::check_header(
488            self.content_type(),
489            self.content_length(),
490            limit,
491        )?;
492        let boundary = m
493            .get_param(mime::BOUNDARY)
494            .ok_or(PayloadError::MissingBoundary)?
495            .as_str();
496        Ok(Multipart::with_limits(
497            self.incoming()?,
498            boundary,
499            self.extensions()
500                .get::<std::sync::Arc<crate::types::MultipartLimits>>()
501                .map(AsRef::as_ref)
502                .cloned()
503                .unwrap_or_default(),
504        ))
505    }
506}
507
508mod private {
509    pub trait Sealed {}
510    impl Sealed for super::Request {}
511}