volo_http/server/utils/
multipart.rs1use std::{error::Error, fmt};
36
37use http::{StatusCode, request::Parts};
38use http_body_util::BodyExt;
39use multer::Field;
40
41use crate::{
42 context::ServerContext,
43 server::{IntoResponse, extract::FromRequest},
44};
45
46#[must_use]
101pub struct Multipart {
102 inner: multer::Multipart<'static>,
103}
104
105impl Multipart {
106 pub async fn next_field(&mut self) -> Result<Option<Field<'static>>, MultipartRejectionError> {
122 Ok(self.inner.next_field().await?)
123 }
124}
125
126impl FromRequest<crate::body::Body> for Multipart {
127 type Rejection = MultipartRejectionError;
128 async fn from_request(
129 _: &mut ServerContext,
130 parts: Parts,
131 body: crate::body::Body,
132 ) -> Result<Self, Self::Rejection> {
133 let boundary = multer::parse_boundary(
134 parts
135 .headers
136 .get(http::header::CONTENT_TYPE)
137 .ok_or(multer::Error::NoMultipart)?
138 .to_str()
139 .map_err(|_| multer::Error::NoBoundary)?,
140 )?;
141
142 let multipart = multer::Multipart::new(body.into_data_stream(), boundary);
143
144 Ok(Self { inner: multipart })
145 }
146}
147
148#[derive(Debug)]
152pub struct MultipartRejectionError {
153 inner: multer::Error,
154}
155
156impl From<multer::Error> for MultipartRejectionError {
157 fn from(err: multer::Error) -> Self {
158 Self { inner: err }
159 }
160}
161
162fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
163 match err {
164 multer::Error::UnknownField { .. }
165 | multer::Error::IncompleteFieldData { .. }
166 | multer::Error::IncompleteHeaders
167 | multer::Error::ReadHeaderFailed(..)
168 | multer::Error::DecodeHeaderName { .. }
169 | multer::Error::DecodeContentType(..)
170 | multer::Error::NoBoundary
171 | multer::Error::DecodeHeaderValue { .. }
172 | multer::Error::NoMultipart
173 | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
174 multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
175 StatusCode::PAYLOAD_TOO_LARGE
176 }
177 multer::Error::StreamReadFailed(_) => StatusCode::INTERNAL_SERVER_ERROR,
178 _ => StatusCode::INTERNAL_SERVER_ERROR,
179 }
180}
181
182impl MultipartRejectionError {
183 pub fn to_status_code(&self) -> http::StatusCode {
185 status_code_from_multer_error(&self.inner)
186 }
187}
188
189impl Error for MultipartRejectionError {
190 fn source(&self) -> Option<&(dyn Error + 'static)> {
191 Some(&self.inner)
192 }
193}
194
195impl fmt::Display for MultipartRejectionError {
196 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197 std::fmt::Display::fmt(&self.inner, f)
198 }
199}
200
201impl IntoResponse for MultipartRejectionError {
202 fn into_response(self) -> http::Response<crate::body::Body> {
203 self.to_status_code().into_response()
204 }
205}
206
207#[cfg(test)]
208mod multipart_tests {
209 use std::{
210 convert::Infallible,
211 net::{IpAddr, Ipv4Addr, SocketAddr},
212 };
213
214 use motore::Service;
215 use reqwest::multipart::Form;
216 use volo::net::Address;
217
218 use crate::{
219 Server,
220 context::ServerContext,
221 request::Request,
222 response::Response,
223 server::{
224 IntoResponse, test_helpers,
225 utils::multipart::{Multipart, MultipartRejectionError},
226 },
227 };
228
229 fn _test_compile() {
230 async fn handler(_: Multipart) {}
231 let app = test_helpers::to_service(handler);
232 let addr = Address::Ip(SocketAddr::new(
233 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
234 25241,
235 ));
236 let _server = Server::new(app).run(addr);
237 }
238
239 async fn run_handler<S>(service: S, port: u16)
240 where
241 S: Service<ServerContext, Request, Response = Response, Error = Infallible>
242 + Send
243 + Sync
244 + 'static,
245 {
246 let addr = Address::Ip(SocketAddr::new(
247 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
248 port,
249 ));
250
251 tokio::spawn(Server::new(service).run(addr));
252
253 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
254 }
255
256 #[tokio::test]
257 async fn test_single_field_upload() {
258 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
259 const FILE_NAME: &str = "index.html";
260 const CONTENT_TYPE: &str = "text/html; charset=utf-8";
261
262 async fn handler(mut multipart: Multipart) -> impl IntoResponse {
263 let field = multipart.next_field().await.unwrap().unwrap();
264
265 assert_eq!(field.file_name().unwrap(), FILE_NAME);
266 assert_eq!(field.content_type().unwrap().as_ref(), CONTENT_TYPE);
267 assert_eq!(field.headers()["foo"], "bar");
268 assert_eq!(field.bytes().await.unwrap(), BYTES);
269
270 assert!(multipart.next_field().await.unwrap().is_none());
271 }
272
273 let form = Form::new().part(
274 "file",
275 reqwest::multipart::Part::bytes(BYTES)
276 .file_name(FILE_NAME)
277 .mime_str(CONTENT_TYPE)
278 .unwrap()
279 .headers(reqwest::header::HeaderMap::from_iter([(
280 reqwest::header::HeaderName::from_static("foo"),
281 reqwest::header::HeaderValue::from_static("bar"),
282 )])),
283 );
284
285 run_handler(test_helpers::to_service(handler), 25241).await;
286
287 let url_str = format!("http://127.0.0.1:{}", 25241);
288 let url = url::Url::parse(url_str.as_str()).unwrap();
289
290 reqwest::Client::new()
291 .post(url)
292 .multipart(form)
293 .send()
294 .await
295 .unwrap();
296 }
297
298 #[tokio::test]
299 async fn test_multiple_field_upload() {
300 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
301 const CONTENT_TYPE: &str = "text/html; charset=utf-8";
302
303 const FIELD_NAME1: &str = "file1";
304 const FIELD_NAME2: &str = "file2";
305 const FILE_NAME1: &str = "index1.html";
306 const FILE_NAME2: &str = "index2.html";
307
308 async fn handler(mut multipart: Multipart) -> Result<(), MultipartRejectionError> {
309 while let Some(field) = multipart.next_field().await? {
310 match field.name() {
311 Some(FIELD_NAME1) => {
312 assert_eq!(field.file_name().unwrap(), FILE_NAME1);
313 assert_eq!(field.headers()["foo1"], "bar1");
314 }
315 Some(FIELD_NAME2) => {
316 assert_eq!(field.file_name().unwrap(), FILE_NAME2);
317 assert_eq!(field.headers()["foo2"], "bar2");
318 }
319 _ => unreachable!(),
320 }
321 assert_eq!(field.content_type().unwrap().as_ref(), CONTENT_TYPE);
322 assert_eq!(field.bytes().await?, BYTES);
323 }
324
325 Ok(())
326 }
327
328 let form = Form::new()
329 .part(
330 FIELD_NAME1,
331 reqwest::multipart::Part::bytes(BYTES)
332 .file_name(FILE_NAME1)
333 .mime_str(CONTENT_TYPE)
334 .unwrap()
335 .headers(reqwest::header::HeaderMap::from_iter([(
336 reqwest::header::HeaderName::from_static("foo1"),
337 reqwest::header::HeaderValue::from_static("bar1"),
338 )])),
339 )
340 .part(
341 FIELD_NAME2,
342 reqwest::multipart::Part::bytes(BYTES)
343 .file_name(FILE_NAME2)
344 .mime_str(CONTENT_TYPE)
345 .unwrap()
346 .headers(reqwest::header::HeaderMap::from_iter([(
347 reqwest::header::HeaderName::from_static("foo2"),
348 reqwest::header::HeaderValue::from_static("bar2"),
349 )])),
350 );
351
352 run_handler(test_helpers::to_service(handler), 25242).await;
353
354 let url_str = format!("http://127.0.0.1:{}", 25242);
355 let url = url::Url::parse(url_str.as_str()).unwrap();
356
357 reqwest::Client::new()
358 .post(url.clone())
359 .multipart(form)
360 .send()
361 .await
362 .unwrap();
363 }
364}