volo_http/server/utils/
multipart.rs

1//! Multipart implementation for server.
2//!
3//! This module provides utilities for extracting `multipart/form-data` formatted data from HTTP
4//! requests.
5//!
6//! # Example
7//!
8//! ```rust
9//! use http::StatusCode;
10//! use volo_http::{
11//!     Router,
12//!     response::Response,
13//!     server::{
14//!         route::post,
15//!         utils::multipart::{Multipart, MultipartRejectionError},
16//!     },
17//! };
18//!
19//! async fn upload(mut multipart: Multipart) -> Result<StatusCode, MultipartRejectionError> {
20//!     while let Some(field) = multipart.next_field().await? {
21//!         let name = field.name().unwrap().to_string();
22//!         let value = field.bytes().await?;
23//!
24//!         println!("The field {} has {} bytes", name, value.len());
25//!     }
26//!
27//!     Ok(StatusCode::OK)
28//! }
29//!
30//! let app: Router = Router::new().route("/upload", post(upload));
31//! ```
32//!
33//! See [`Multipart`] for more details.
34
35use std::{error::Error, fmt};
36
37use http::{StatusCode, request::Parts};
38use http_body_util::BodyExt;
39use multer::Field;
40
41use crate::{
42    context::ServerContext,
43    server::{IntoResponse, extract::FromRequest},
44};
45
46/// Extract a type from `multipart/form-data` HTTP requests.
47///
48/// [`Multipart`] can be passed as an argument to a handler, which can be used to extract each
49/// `multipart/form-data` field by calling [`Multipart::next_field`].
50///
51/// **Notice**
52///
53/// Extracting `multipart/form-data` data will consume the body, hence [`Multipart`] must be the
54/// last argument from the handler.
55///
56/// # Example
57///
58/// ```rust
59/// use http::StatusCode;
60/// use volo_http::{
61///     response::Response,
62///     server::utils::multipart::{Multipart, MultipartRejectionError},
63/// };
64///
65/// async fn upload(mut multipart: Multipart) -> Result<StatusCode, MultipartRejectionError> {
66///     while let Some(field) = multipart.next_field().await? {
67///         todo!()
68///     }
69///
70///     Ok(StatusCode::OK)
71/// }
72/// ```
73///
74/// # Body Limitation
75///
76/// Since the body is unlimited, so it is recommended to use
77/// [`BodyLimitLayer`](crate::server::layer::BodyLimitLayer) to limit the size of the body.
78///
79/// ```rust
80/// use http::StatusCode;
81/// use volo_http::{
82///     Router,
83///     server::{
84///         layer::BodyLimitLayer,
85///         route::post,
86///         utils::multipart::{Multipart, MultipartRejectionError},
87///     },
88/// };
89///
90/// async fn upload_handler(
91///     mut multipart: Multipart,
92/// ) -> Result<StatusCode, MultipartRejectionError> {
93///     Ok(StatusCode::OK)
94/// }
95///
96/// let app: Router<_> = Router::new()
97///     .route("/", post(upload_handler))
98///     .layer(BodyLimitLayer::new(1024));
99/// ```
100#[must_use]
101pub struct Multipart {
102    inner: multer::Multipart<'static>,
103}
104
105impl Multipart {
106    /// Iterate over all [`Field`] in [`Multipart`]
107    ///
108    /// # Example
109    ///
110    /// ```rust
111    /// # use volo_http::server::utils::multipart::Multipart;
112    /// # let mut multipart: Multipart;
113    /// // Extract each field from multipart by using while loop
114    /// # async fn upload(mut multipart: Multipart) {
115    /// while let Some(field) = multipart.next_field().await.unwrap() {
116    ///     let name = field.name().unwrap().to_string(); // Get field name
117    ///     let data = field.bytes().await.unwrap(); // Get field data
118    /// }
119    /// # }
120    /// ```
121    pub async fn next_field(&mut self) -> Result<Option<Field<'static>>, MultipartRejectionError> {
122        Ok(self.inner.next_field().await?)
123    }
124}
125
126impl FromRequest<crate::body::Body> for Multipart {
127    type Rejection = MultipartRejectionError;
128    async fn from_request(
129        _: &mut ServerContext,
130        parts: Parts,
131        body: crate::body::Body,
132    ) -> Result<Self, Self::Rejection> {
133        let boundary = multer::parse_boundary(
134            parts
135                .headers
136                .get(http::header::CONTENT_TYPE)
137                .ok_or(multer::Error::NoMultipart)?
138                .to_str()
139                .map_err(|_| multer::Error::NoBoundary)?,
140        )?;
141
142        let multipart = multer::Multipart::new(body.into_data_stream(), boundary);
143
144        Ok(Self { inner: multipart })
145    }
146}
147
148/// [`Error`]s while extracting [`Multipart`].
149///
150/// [`Error`]: Error
151#[derive(Debug)]
152pub struct MultipartRejectionError {
153    inner: multer::Error,
154}
155
156impl From<multer::Error> for MultipartRejectionError {
157    fn from(err: multer::Error) -> Self {
158        Self { inner: err }
159    }
160}
161
162fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
163    match err {
164        multer::Error::UnknownField { .. }
165        | multer::Error::IncompleteFieldData { .. }
166        | multer::Error::IncompleteHeaders
167        | multer::Error::ReadHeaderFailed(..)
168        | multer::Error::DecodeHeaderName { .. }
169        | multer::Error::DecodeContentType(..)
170        | multer::Error::NoBoundary
171        | multer::Error::DecodeHeaderValue { .. }
172        | multer::Error::NoMultipart
173        | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
174        multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
175            StatusCode::PAYLOAD_TOO_LARGE
176        }
177        multer::Error::StreamReadFailed(_) => StatusCode::INTERNAL_SERVER_ERROR,
178        _ => StatusCode::INTERNAL_SERVER_ERROR,
179    }
180}
181
182impl MultipartRejectionError {
183    /// Convert the [`MultipartRejectionError`] into a [`http::StatusCode`].
184    pub fn to_status_code(&self) -> http::StatusCode {
185        status_code_from_multer_error(&self.inner)
186    }
187}
188
189impl Error for MultipartRejectionError {
190    fn source(&self) -> Option<&(dyn Error + 'static)> {
191        Some(&self.inner)
192    }
193}
194
195impl fmt::Display for MultipartRejectionError {
196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197        std::fmt::Display::fmt(&self.inner, f)
198    }
199}
200
201impl IntoResponse for MultipartRejectionError {
202    fn into_response(self) -> http::Response<crate::body::Body> {
203        self.to_status_code().into_response()
204    }
205}
206
207#[cfg(test)]
208mod multipart_tests {
209    use std::{
210        convert::Infallible,
211        net::{IpAddr, Ipv4Addr, SocketAddr},
212    };
213
214    use motore::Service;
215    use reqwest::multipart::Form;
216    use volo::net::Address;
217
218    use crate::{
219        Server,
220        context::ServerContext,
221        request::Request,
222        response::Response,
223        server::{
224            IntoResponse, test_helpers,
225            utils::multipart::{Multipart, MultipartRejectionError},
226        },
227    };
228
229    fn _test_compile() {
230        async fn handler(_: Multipart) {}
231        let app = test_helpers::to_service(handler);
232        let addr = Address::Ip(SocketAddr::new(
233            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
234            25241,
235        ));
236        let _server = Server::new(app).run(addr);
237    }
238
239    async fn run_handler<S>(service: S, port: u16)
240    where
241        S: Service<ServerContext, Request, Response = Response, Error = Infallible>
242            + Send
243            + Sync
244            + 'static,
245    {
246        let addr = Address::Ip(SocketAddr::new(
247            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
248            port,
249        ));
250
251        tokio::spawn(Server::new(service).run(addr));
252
253        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
254    }
255
256    #[tokio::test]
257    async fn test_single_field_upload() {
258        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
259        const FILE_NAME: &str = "index.html";
260        const CONTENT_TYPE: &str = "text/html; charset=utf-8";
261
262        async fn handler(mut multipart: Multipart) -> impl IntoResponse {
263            let field = multipart.next_field().await.unwrap().unwrap();
264
265            assert_eq!(field.file_name().unwrap(), FILE_NAME);
266            assert_eq!(field.content_type().unwrap().as_ref(), CONTENT_TYPE);
267            assert_eq!(field.headers()["foo"], "bar");
268            assert_eq!(field.bytes().await.unwrap(), BYTES);
269
270            assert!(multipart.next_field().await.unwrap().is_none());
271        }
272
273        let form = Form::new().part(
274            "file",
275            reqwest::multipart::Part::bytes(BYTES)
276                .file_name(FILE_NAME)
277                .mime_str(CONTENT_TYPE)
278                .unwrap()
279                .headers(reqwest::header::HeaderMap::from_iter([(
280                    reqwest::header::HeaderName::from_static("foo"),
281                    reqwest::header::HeaderValue::from_static("bar"),
282                )])),
283        );
284
285        run_handler(test_helpers::to_service(handler), 25241).await;
286
287        let url_str = format!("http://127.0.0.1:{}", 25241);
288        let url = url::Url::parse(url_str.as_str()).unwrap();
289
290        reqwest::Client::new()
291            .post(url)
292            .multipart(form)
293            .send()
294            .await
295            .unwrap();
296    }
297
298    #[tokio::test]
299    async fn test_multiple_field_upload() {
300        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
301        const CONTENT_TYPE: &str = "text/html; charset=utf-8";
302
303        const FIELD_NAME1: &str = "file1";
304        const FIELD_NAME2: &str = "file2";
305        const FILE_NAME1: &str = "index1.html";
306        const FILE_NAME2: &str = "index2.html";
307
308        async fn handler(mut multipart: Multipart) -> Result<(), MultipartRejectionError> {
309            while let Some(field) = multipart.next_field().await? {
310                match field.name() {
311                    Some(FIELD_NAME1) => {
312                        assert_eq!(field.file_name().unwrap(), FILE_NAME1);
313                        assert_eq!(field.headers()["foo1"], "bar1");
314                    }
315                    Some(FIELD_NAME2) => {
316                        assert_eq!(field.file_name().unwrap(), FILE_NAME2);
317                        assert_eq!(field.headers()["foo2"], "bar2");
318                    }
319                    _ => unreachable!(),
320                }
321                assert_eq!(field.content_type().unwrap().as_ref(), CONTENT_TYPE);
322                assert_eq!(field.bytes().await?, BYTES);
323            }
324
325            Ok(())
326        }
327
328        let form = Form::new()
329            .part(
330                FIELD_NAME1,
331                reqwest::multipart::Part::bytes(BYTES)
332                    .file_name(FILE_NAME1)
333                    .mime_str(CONTENT_TYPE)
334                    .unwrap()
335                    .headers(reqwest::header::HeaderMap::from_iter([(
336                        reqwest::header::HeaderName::from_static("foo1"),
337                        reqwest::header::HeaderValue::from_static("bar1"),
338                    )])),
339            )
340            .part(
341                FIELD_NAME2,
342                reqwest::multipart::Part::bytes(BYTES)
343                    .file_name(FILE_NAME2)
344                    .mime_str(CONTENT_TYPE)
345                    .unwrap()
346                    .headers(reqwest::header::HeaderMap::from_iter([(
347                        reqwest::header::HeaderName::from_static("foo2"),
348                        reqwest::header::HeaderValue::from_static("bar2"),
349                    )])),
350            );
351
352        run_handler(test_helpers::to_service(handler), 25242).await;
353
354        let url_str = format!("http://127.0.0.1:{}", 25242);
355        let url = url::Url::parse(url_str.as_str()).unwrap();
356
357        reqwest::Client::new()
358            .post(url.clone())
359            .multipart(form)
360            .send()
361            .await
362            .unwrap();
363    }
364}