Skip to main content

rustapi_core/
multipart.rs

1//! Multipart form data extractor for file uploads
2//!
3//! This module provides types for handling `multipart/form-data` requests,
4//! commonly used for file uploads.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use rustapi_core::multipart::{Multipart, FieldData};
10//!
11//! async fn upload(mut multipart: Multipart) -> Result<String, ApiError> {
12//!     while let Some(field) = multipart.next_field().await? {
13//!         let name = field.name().unwrap_or("unknown");
14//!         let filename = field.file_name().map(|s| s.to_string());
15//!         let data = field.bytes().await?;
16//!         
17//!         println!("Field: {}, File: {:?}, Size: {} bytes", name, filename, data.len());
18//!     }
19//!     Ok("Upload successful".to_string())
20//! }
21//! ```
22
23use crate::error::{ApiError, Result};
24use crate::extract::FromRequest;
25use crate::request::Request;
26use crate::stream::StreamingBody;
27use bytes::Bytes;
28use futures_util::stream;
29use http::StatusCode;
30use std::error::Error as _;
31use std::path::Path;
32use tokio::io::AsyncWriteExt;
33
34/// Maximum file size (default: 10MB)
35pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
36
37/// Maximum number of fields in multipart form (default: 100)
38pub const DEFAULT_MAX_FIELDS: usize = 100;
39
40/// Multipart form data extractor
41///
42/// Parses `multipart/form-data` requests, commonly used for file uploads.
43///
44/// # Example
45///
46/// ```rust,ignore
47/// use rustapi_core::multipart::Multipart;
48///
49/// async fn upload(mut multipart: Multipart) -> Result<String, ApiError> {
50///     while let Some(field) = multipart.next_field().await? {
51///         let name = field.name().unwrap_or("unknown").to_string();
52///         let data = field.bytes().await?;
53///         println!("Received field '{}' with {} bytes", name, data.len());
54///     }
55///     Ok("Upload complete".to_string())
56/// }
57/// ```
58pub struct Multipart {
59    fields: Vec<MultipartField>,
60    current_index: usize,
61}
62
63impl Multipart {
64    /// Create a new Multipart from raw data
65    fn new(fields: Vec<MultipartField>) -> Self {
66        Self {
67            fields,
68            current_index: 0,
69        }
70    }
71
72    /// Get the next field from the multipart form
73    pub async fn next_field(&mut self) -> Result<Option<MultipartField>> {
74        if self.current_index >= self.fields.len() {
75            return Ok(None);
76        }
77        let field = self.fields.get(self.current_index).cloned();
78        self.current_index += 1;
79        Ok(field)
80    }
81
82    /// Collect all fields into a vector
83    pub fn into_fields(self) -> Vec<MultipartField> {
84        self.fields
85    }
86
87    /// Get the number of fields
88    pub fn field_count(&self) -> usize {
89        self.fields.len()
90    }
91}
92
93/// Streaming multipart extractor for large file uploads.
94///
95/// Unlike [`Multipart`], this extractor does not buffer the entire request body in memory before
96/// parsing. It consumes the request body as a stream and yields one field at a time.
97///
98/// If a [`MultipartConfig`] is present in app state, its size and content-type limits are applied.
99pub struct StreamingMultipart {
100    inner: multer::Multipart<'static>,
101    config: MultipartConfig,
102    field_count: usize,
103}
104
105impl StreamingMultipart {
106    fn new(stream: StreamingBody, boundary: String, config: MultipartConfig) -> Self {
107        Self {
108            inner: multer::Multipart::new(stream, boundary),
109            config,
110            field_count: 0,
111        }
112    }
113
114    /// Get the next field from the multipart stream.
115    ///
116    /// Consume or drop the previously returned field before calling this again.
117    pub async fn next_field(&mut self) -> Result<Option<StreamingMultipartField<'static>>> {
118        let field = self.inner.next_field().await.map_err(map_multer_error)?;
119        let Some(field) = field else {
120            return Ok(None);
121        };
122
123        self.field_count += 1;
124        if self.field_count > self.config.max_fields {
125            return Err(ApiError::bad_request(format!(
126                "Multipart field count exceeded limit of {}",
127                self.config.max_fields
128            )));
129        }
130
131        validate_streaming_field(&field, &self.config)?;
132
133        Ok(Some(StreamingMultipartField::new(
134            field,
135            self.config.max_file_size,
136        )))
137    }
138
139    /// Number of fields yielded so far.
140    pub fn field_count(&self) -> usize {
141        self.field_count
142    }
143}
144
145impl FromRequest for StreamingMultipart {
146    async fn from_request(req: &mut Request) -> Result<Self> {
147        let content_type = req
148            .headers()
149            .get(http::header::CONTENT_TYPE)
150            .and_then(|v| v.to_str().ok())
151            .ok_or_else(|| ApiError::bad_request("Missing Content-Type header"))?;
152
153        if !content_type.starts_with("multipart/form-data") {
154            return Err(ApiError::bad_request(format!(
155                "Expected multipart/form-data, got: {}",
156                content_type
157            )));
158        }
159
160        let boundary = extract_boundary(content_type)
161            .ok_or_else(|| ApiError::bad_request("Missing boundary in Content-Type"))?;
162
163        let config = req
164            .state()
165            .get::<MultipartConfig>()
166            .cloned()
167            .unwrap_or_default();
168
169        let stream = request_body_stream(req, config.max_size)?;
170        Ok(Self::new(stream, boundary, config))
171    }
172}
173
174/// A single streaming field from a multipart form.
175///
176/// This field is one-shot: once you call [`chunk`](Self::chunk), [`bytes`](Self::bytes),
177/// [`text`](Self::text), or one of the save helpers, the underlying stream is consumed.
178pub struct StreamingMultipartField<'a> {
179    inner: multer::Field<'a>,
180    max_file_size: usize,
181    bytes_read: usize,
182}
183
184impl<'a> StreamingMultipartField<'a> {
185    fn new(inner: multer::Field<'a>, max_file_size: usize) -> Self {
186        Self {
187            inner,
188            max_file_size,
189            bytes_read: 0,
190        }
191    }
192
193    /// Get the field name.
194    pub fn name(&self) -> Option<&str> {
195        self.inner.name()
196    }
197
198    /// Get the original filename when this field is a file upload.
199    pub fn file_name(&self) -> Option<&str> {
200        self.inner.file_name()
201    }
202
203    /// Get the content type of the field.
204    pub fn content_type(&self) -> Option<&str> {
205        self.inner.content_type().map(|mime| mime.essence_str())
206    }
207
208    /// Check whether this field represents a file upload.
209    pub fn is_file(&self) -> bool {
210        self.file_name().is_some()
211    }
212
213    /// Number of bytes consumed from this field so far.
214    pub fn bytes_read(&self) -> usize {
215        self.bytes_read
216    }
217
218    /// Read the next chunk from the field stream.
219    pub async fn chunk(&mut self) -> Result<Option<Bytes>> {
220        let chunk = self.inner.chunk().await.map_err(map_multer_error)?;
221        let Some(chunk) = chunk else {
222            return Ok(None);
223        };
224
225        self.bytes_read += chunk.len();
226        if self.bytes_read > self.max_file_size {
227            return Err(file_size_limit_error(self.max_file_size));
228        }
229
230        Ok(Some(chunk))
231    }
232
233    /// Collect the full field into memory.
234    pub async fn bytes(&mut self) -> Result<Bytes> {
235        let mut buffer = bytes::BytesMut::new();
236        while let Some(chunk) = self.chunk().await? {
237            buffer.extend_from_slice(&chunk);
238        }
239        Ok(buffer.freeze())
240    }
241
242    /// Collect the field as UTF-8 text.
243    pub async fn text(&mut self) -> Result<String> {
244        String::from_utf8(self.bytes().await?.to_vec())
245            .map_err(|e| ApiError::bad_request(format!("Invalid UTF-8 in field: {}", e)))
246    }
247
248    /// Save the field to a directory using either the provided filename or the uploaded name.
249    pub async fn save_to(
250        &mut self,
251        dir: impl AsRef<Path>,
252        filename: Option<&str>,
253    ) -> Result<String> {
254        let dir = dir.as_ref();
255
256        tokio::fs::create_dir_all(dir)
257            .await
258            .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
259
260        let final_filename = filename
261            .map(|value| value.to_string())
262            .or_else(|| self.file_name().map(|value| value.to_string()))
263            .ok_or_else(|| {
264                ApiError::bad_request("No filename provided and field has no filename")
265            })?;
266
267        let safe_filename = sanitize_filename(&final_filename);
268        let file_path = dir.join(&safe_filename);
269        self.save_as(&file_path).await?;
270
271        Ok(file_path.to_string_lossy().to_string())
272    }
273
274    /// Save the field contents to an explicit file path without buffering the full file in memory.
275    pub async fn save_as(&mut self, path: impl AsRef<Path>) -> Result<()> {
276        let path = path.as_ref();
277
278        if let Some(parent) = path.parent() {
279            tokio::fs::create_dir_all(parent)
280                .await
281                .map_err(|e| ApiError::internal(format!("Failed to create directory: {}", e)))?;
282        }
283
284        let mut file = tokio::fs::File::create(path)
285            .await
286            .map_err(|e| ApiError::internal(format!("Failed to create file: {}", e)))?;
287
288        while let Some(chunk) = self.chunk().await? {
289            file.write_all(&chunk)
290                .await
291                .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
292        }
293
294        file.flush()
295            .await
296            .map_err(|e| ApiError::internal(format!("Failed to flush file: {}", e)))?;
297
298        Ok(())
299    }
300
301    /// Collect the field into an [`UploadedFile`] for APIs that still expect the buffered wrapper.
302    pub async fn into_uploaded_file(mut self) -> Result<UploadedFile> {
303        let filename = self
304            .file_name()
305            .ok_or_else(|| ApiError::bad_request("Field is not a file upload"))?
306            .to_string();
307        let content_type = self.content_type().map(|value| value.to_string());
308        let data = self.bytes().await?;
309
310        Ok(UploadedFile {
311            filename,
312            content_type,
313            data,
314        })
315    }
316}
317
318/// A single field from a multipart form
319#[derive(Clone)]
320pub struct MultipartField {
321    name: Option<String>,
322    file_name: Option<String>,
323    content_type: Option<String>,
324    data: Bytes,
325}
326
327impl MultipartField {
328    /// Create a new multipart field
329    pub fn new(
330        name: Option<String>,
331        file_name: Option<String>,
332        content_type: Option<String>,
333        data: Bytes,
334    ) -> Self {
335        Self {
336            name,
337            file_name,
338            content_type,
339            data,
340        }
341    }
342
343    /// Get the field name
344    pub fn name(&self) -> Option<&str> {
345        self.name.as_deref()
346    }
347
348    /// Get the original filename (if this is a file upload)
349    pub fn file_name(&self) -> Option<&str> {
350        self.file_name.as_deref()
351    }
352
353    /// Get the content type of the field
354    pub fn content_type(&self) -> Option<&str> {
355        self.content_type.as_deref()
356    }
357
358    /// Check if this field is a file upload
359    pub fn is_file(&self) -> bool {
360        self.file_name.is_some()
361    }
362
363    /// Get the field data as bytes
364    pub async fn bytes(&self) -> Result<Bytes> {
365        Ok(self.data.clone())
366    }
367
368    /// Get the field data as a string (UTF-8)
369    pub async fn text(&self) -> Result<String> {
370        String::from_utf8(self.data.to_vec())
371            .map_err(|e| ApiError::bad_request(format!("Invalid UTF-8 in field: {}", e)))
372    }
373
374    /// Get the size of the field data in bytes
375    pub fn size(&self) -> usize {
376        self.data.len()
377    }
378
379    /// Save the file to disk
380    ///
381    /// # Arguments
382    ///
383    /// * `path` - The directory to save the file to
384    /// * `filename` - Optional custom filename, uses original filename if None
385    ///
386    /// # Example
387    ///
388    /// ```rust,ignore
389    /// field.save_to("./uploads", None).await?;
390    /// // or with custom filename
391    /// field.save_to("./uploads", Some("custom_name.txt")).await?;
392    /// ```
393    pub async fn save_to(&self, dir: impl AsRef<Path>, filename: Option<&str>) -> Result<String> {
394        let dir = dir.as_ref();
395
396        // Ensure directory exists
397        tokio::fs::create_dir_all(dir)
398            .await
399            .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
400
401        // Determine filename
402        let final_filename = filename
403            .map(|s| s.to_string())
404            .or_else(|| self.file_name.clone())
405            .ok_or_else(|| {
406                ApiError::bad_request("No filename provided and field has no filename")
407            })?;
408
409        // Sanitize filename to prevent path traversal
410        let safe_filename = sanitize_filename(&final_filename);
411        let file_path = dir.join(&safe_filename);
412
413        // Write file
414        tokio::fs::write(&file_path, &self.data)
415            .await
416            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
417
418        Ok(file_path.to_string_lossy().to_string())
419    }
420}
421
422/// Sanitize a filename to prevent path traversal attacks
423fn sanitize_filename(filename: &str) -> String {
424    // Remove path separators and parent directory references
425    filename
426        .replace(['/', '\\'], "_")
427        .replace("..", "_")
428        .trim_start_matches('.')
429        .to_string()
430}
431
432impl FromRequest for Multipart {
433    async fn from_request(req: &mut Request) -> Result<Self> {
434        // Check content type
435        let content_type = req
436            .headers()
437            .get(http::header::CONTENT_TYPE)
438            .and_then(|v| v.to_str().ok())
439            .ok_or_else(|| ApiError::bad_request("Missing Content-Type header"))?;
440
441        if !content_type.starts_with("multipart/form-data") {
442            return Err(ApiError::bad_request(format!(
443                "Expected multipart/form-data, got: {}",
444                content_type
445            )));
446        }
447
448        // Extract boundary
449        let boundary = extract_boundary(content_type)
450            .ok_or_else(|| ApiError::bad_request("Missing boundary in Content-Type"))?;
451
452        // Get body
453        let body = req
454            .take_body()
455            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
456
457        // Parse multipart
458        let fields = parse_multipart(&body, &boundary)?;
459
460        Ok(Multipart::new(fields))
461    }
462}
463
464fn request_body_stream(req: &mut Request, limit: usize) -> Result<StreamingBody> {
465    if let Some(stream) = req.take_stream() {
466        return Ok(StreamingBody::new(stream, Some(limit)));
467    }
468
469    if let Some(body) = req.take_body() {
470        let stream = stream::once(async move { Ok::<Bytes, ApiError>(body) });
471        return Ok(StreamingBody::from_stream(stream, Some(limit)));
472    }
473
474    Err(ApiError::internal("Body already consumed"))
475}
476
477fn validate_streaming_field(field: &multer::Field<'_>, config: &MultipartConfig) -> Result<()> {
478    if field.file_name().is_none() || config.allowed_content_types.is_empty() {
479        return Ok(());
480    }
481
482    let content_type = field
483        .content_type()
484        .map(|mime| mime.essence_str().to_string())
485        .ok_or_else(|| ApiError::bad_request("Uploaded file is missing Content-Type"))?;
486
487    if config
488        .allowed_content_types
489        .iter()
490        .any(|allowed| allowed.eq_ignore_ascii_case(&content_type))
491    {
492        return Ok(());
493    }
494
495    Err(ApiError::bad_request(format!(
496        "Unsupported content type '{}'",
497        content_type
498    )))
499}
500
501fn file_size_limit_error(limit: usize) -> ApiError {
502    ApiError::new(
503        StatusCode::PAYLOAD_TOO_LARGE,
504        "payload_too_large",
505        format!("Multipart field exceeded limit of {} bytes", limit),
506    )
507}
508
509fn map_multer_error(error: multer::Error) -> ApiError {
510    if let Some(source) = error.source() {
511        if let Some(api_error) = source.downcast_ref::<ApiError>() {
512            return api_error.clone();
513        }
514    }
515
516    let message = error.to_string();
517    if message.to_ascii_lowercase().contains("size limit") {
518        return ApiError::new(StatusCode::PAYLOAD_TOO_LARGE, "payload_too_large", message);
519    }
520
521    ApiError::bad_request(format!("Invalid multipart body: {}", message))
522}
523
524/// Extract boundary from Content-Type header
525fn extract_boundary(content_type: &str) -> Option<String> {
526    content_type.split(';').find_map(|part| {
527        let part = part.trim();
528        if part.starts_with("boundary=") {
529            let boundary = part.trim_start_matches("boundary=").trim_matches('"');
530            Some(boundary.to_string())
531        } else {
532            None
533        }
534    })
535}
536
537/// Parse multipart form data
538fn parse_multipart(body: &Bytes, boundary: &str) -> Result<Vec<MultipartField>> {
539    let mut fields = Vec::new();
540    let delimiter = format!("--{}", boundary);
541    let end_delimiter = format!("--{}--", boundary);
542
543    // Convert body to string for easier parsing
544    // Note: This is a simplified parser. For production, consider using multer crate.
545    let body_str = String::from_utf8_lossy(body);
546
547    // Split by delimiter
548    let parts: Vec<&str> = body_str.split(&delimiter).collect();
549
550    for part in parts.iter().skip(1) {
551        // Skip empty parts and end delimiter
552        let part = part.trim_start_matches("\r\n").trim_start_matches('\n');
553        if part.is_empty() || part.starts_with("--") {
554            continue;
555        }
556
557        // Find header/body separator (blank line)
558        let header_body_split = if let Some(pos) = part.find("\r\n\r\n") {
559            pos
560        } else if let Some(pos) = part.find("\n\n") {
561            pos
562        } else {
563            continue;
564        };
565
566        let headers_section = &part[..header_body_split];
567        let body_section = &part[header_body_split..]
568            .trim_start_matches("\r\n\r\n")
569            .trim_start_matches("\n\n");
570
571        // Remove trailing boundary markers from body
572        let body_section = body_section
573            .trim_end_matches(&end_delimiter)
574            .trim_end_matches(&delimiter)
575            .trim_end_matches("\r\n")
576            .trim_end_matches('\n');
577
578        // Parse headers
579        let mut name = None;
580        let mut filename = None;
581        let mut content_type = None;
582
583        for header_line in headers_section.lines() {
584            let header_line = header_line.trim();
585            if header_line.is_empty() {
586                continue;
587            }
588
589            if let Some((key, value)) = header_line.split_once(':') {
590                let key = key.trim().to_lowercase();
591                let value = value.trim();
592
593                match key.as_str() {
594                    "content-disposition" => {
595                        // Parse name and filename from Content-Disposition
596                        for part in value.split(';') {
597                            let part = part.trim();
598                            if part.starts_with("name=") {
599                                name = Some(
600                                    part.trim_start_matches("name=")
601                                        .trim_matches('"')
602                                        .to_string(),
603                                );
604                            } else if part.starts_with("filename=") {
605                                filename = Some(
606                                    part.trim_start_matches("filename=")
607                                        .trim_matches('"')
608                                        .to_string(),
609                                );
610                            }
611                        }
612                    }
613                    "content-type" => {
614                        content_type = Some(value.to_string());
615                    }
616                    _ => {}
617                }
618            }
619        }
620
621        fields.push(MultipartField::new(
622            name,
623            filename,
624            content_type,
625            Bytes::copy_from_slice(body_section.as_bytes()),
626        ));
627    }
628
629    Ok(fields)
630}
631
632/// Configuration for multipart form handling
633#[derive(Clone)]
634pub struct MultipartConfig {
635    /// Maximum total size of the multipart form (default: 10MB)
636    pub max_size: usize,
637    /// Maximum number of fields (default: 100)
638    pub max_fields: usize,
639    /// Maximum size per file (default: 10MB)
640    pub max_file_size: usize,
641    /// Allowed content types for files (empty = all allowed)
642    pub allowed_content_types: Vec<String>,
643}
644
645impl Default for MultipartConfig {
646    fn default() -> Self {
647        Self {
648            max_size: DEFAULT_MAX_FILE_SIZE,
649            max_fields: DEFAULT_MAX_FIELDS,
650            max_file_size: DEFAULT_MAX_FILE_SIZE,
651            allowed_content_types: Vec::new(),
652        }
653    }
654}
655
656impl MultipartConfig {
657    /// Create a new multipart config with default values
658    pub fn new() -> Self {
659        Self::default()
660    }
661
662    /// Set the maximum total size
663    pub fn max_size(mut self, size: usize) -> Self {
664        self.max_size = size;
665        self
666    }
667
668    /// Set the maximum number of fields
669    pub fn max_fields(mut self, count: usize) -> Self {
670        self.max_fields = count;
671        self
672    }
673
674    /// Set the maximum file size
675    pub fn max_file_size(mut self, size: usize) -> Self {
676        self.max_file_size = size;
677        self
678    }
679
680    /// Set allowed content types for file uploads
681    pub fn allowed_content_types(mut self, types: Vec<String>) -> Self {
682        self.allowed_content_types = types;
683        self
684    }
685
686    /// Add an allowed content type
687    pub fn allow_content_type(mut self, content_type: impl Into<String>) -> Self {
688        self.allowed_content_types.push(content_type.into());
689        self
690    }
691}
692
693/// File data wrapper for convenient access to uploaded files
694#[derive(Clone)]
695pub struct UploadedFile {
696    /// Original filename
697    pub filename: String,
698    /// Content type (MIME type)
699    pub content_type: Option<String>,
700    /// File data
701    pub data: Bytes,
702}
703
704impl UploadedFile {
705    /// Create from a multipart field
706    pub fn from_field(field: &MultipartField) -> Option<Self> {
707        field.file_name().map(|filename| Self {
708            filename: filename.to_string(),
709            content_type: field.content_type().map(|s| s.to_string()),
710            data: field.data.clone(),
711        })
712    }
713
714    /// Get file size in bytes
715    pub fn size(&self) -> usize {
716        self.data.len()
717    }
718
719    /// Get file extension
720    pub fn extension(&self) -> Option<&str> {
721        self.filename.rsplit('.').next()
722    }
723
724    /// Save to disk with original filename
725    pub async fn save_to(&self, dir: impl AsRef<Path>) -> Result<String> {
726        let dir = dir.as_ref();
727
728        tokio::fs::create_dir_all(dir)
729            .await
730            .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
731
732        let safe_filename = sanitize_filename(&self.filename);
733        let file_path = dir.join(&safe_filename);
734
735        tokio::fs::write(&file_path, &self.data)
736            .await
737            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
738
739        Ok(file_path.to_string_lossy().to_string())
740    }
741
742    /// Save with a custom filename
743    pub async fn save_as(&self, path: impl AsRef<Path>) -> Result<()> {
744        let path = path.as_ref();
745
746        if let Some(parent) = path.parent() {
747            tokio::fs::create_dir_all(parent)
748                .await
749                .map_err(|e| ApiError::internal(format!("Failed to create directory: {}", e)))?;
750        }
751
752        tokio::fs::write(path, &self.data)
753            .await
754            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
755
756        Ok(())
757    }
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763    use futures_util::stream;
764
765    fn chunked_body_stream(
766        body: Bytes,
767        chunk_size: usize,
768    ) -> impl futures_util::Stream<Item = Result<Bytes>> + Send + 'static {
769        let chunks = body
770            .chunks(chunk_size)
771            .map(Bytes::copy_from_slice)
772            .map(Ok)
773            .collect::<Vec<_>>();
774        stream::iter(chunks)
775    }
776
777    fn streaming_multipart_from_body(
778        body: Bytes,
779        boundary: &str,
780        config: MultipartConfig,
781    ) -> StreamingMultipart {
782        let stream =
783            StreamingBody::from_stream(chunked_body_stream(body, 7), Some(config.max_size));
784        StreamingMultipart::new(stream, boundary.to_string(), config)
785    }
786
787    #[test]
788    fn test_extract_boundary() {
789        let ct = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW";
790        assert_eq!(
791            extract_boundary(ct),
792            Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string())
793        );
794
795        let ct_quoted = "multipart/form-data; boundary=\"----WebKitFormBoundary\"";
796        assert_eq!(
797            extract_boundary(ct_quoted),
798            Some("----WebKitFormBoundary".to_string())
799        );
800    }
801
802    #[test]
803    fn test_sanitize_filename() {
804        assert_eq!(sanitize_filename("test.txt"), "test.txt");
805        assert_eq!(sanitize_filename("../../../etc/passwd"), "______etc_passwd");
806        // ..\..\windows\system32 -> .._.._windows_system32 -> ____windows_system32
807        assert_eq!(
808            sanitize_filename("..\\..\\windows\\system32"),
809            "____windows_system32"
810        );
811        assert_eq!(sanitize_filename(".hidden"), "hidden");
812    }
813
814    #[test]
815    fn test_parse_simple_multipart() {
816        let boundary = "----WebKitFormBoundary";
817        let body = "------WebKitFormBoundary\r\n\
818             Content-Disposition: form-data; name=\"field1\"\r\n\
819             \r\n\
820             value1\r\n\
821             ------WebKitFormBoundary\r\n\
822             Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\
823             Content-Type: text/plain\r\n\
824             \r\n\
825             file content\r\n\
826             ------WebKitFormBoundary--\r\n"
827            .to_string();
828
829        let fields = parse_multipart(&Bytes::from(body), boundary).unwrap();
830        assert_eq!(fields.len(), 2);
831
832        assert_eq!(fields[0].name(), Some("field1"));
833        assert!(!fields[0].is_file());
834
835        assert_eq!(fields[1].name(), Some("file"));
836        assert_eq!(fields[1].file_name(), Some("test.txt"));
837        assert_eq!(fields[1].content_type(), Some("text/plain"));
838        assert!(fields[1].is_file());
839    }
840
841    #[test]
842    fn test_multipart_config() {
843        let config = MultipartConfig::new()
844            .max_size(20 * 1024 * 1024)
845            .max_fields(50)
846            .max_file_size(5 * 1024 * 1024)
847            .allow_content_type("image/png")
848            .allow_content_type("image/jpeg");
849
850        assert_eq!(config.max_size, 20 * 1024 * 1024);
851        assert_eq!(config.max_fields, 50);
852        assert_eq!(config.max_file_size, 5 * 1024 * 1024);
853        assert_eq!(config.allowed_content_types.len(), 2);
854    }
855
856    #[tokio::test]
857    async fn streaming_multipart_reads_chunked_body() {
858        let boundary = "----RustApiBoundary";
859        let body = format!(
860            "--{boundary}\r\n\
861             Content-Disposition: form-data; name=\"title\"\r\n\
862             \r\n\
863             hello\r\n\
864             --{boundary}\r\n\
865             Content-Disposition: form-data; name=\"file\"; filename=\"demo.txt\"\r\n\
866             Content-Type: text/plain\r\n\
867             \r\n\
868             streamed-content\r\n\
869             --{boundary}--\r\n"
870        );
871
872        let mut multipart = streaming_multipart_from_body(
873            Bytes::from(body),
874            boundary,
875            MultipartConfig::new().max_size(1024).max_file_size(1024),
876        );
877
878        let mut title = multipart.next_field().await.unwrap().unwrap();
879        assert_eq!(title.name(), Some("title"));
880        assert_eq!(title.text().await.unwrap(), "hello");
881        drop(title);
882
883        let mut file = multipart.next_field().await.unwrap().unwrap();
884        assert_eq!(file.file_name(), Some("demo.txt"));
885        assert_eq!(file.content_type(), Some("text/plain"));
886        assert_eq!(file.bytes().await.unwrap(), Bytes::from("streamed-content"));
887        drop(file);
888
889        assert!(multipart.next_field().await.unwrap().is_none());
890        assert_eq!(multipart.field_count(), 2);
891    }
892
893    #[tokio::test]
894    async fn streaming_multipart_enforces_per_file_limit() {
895        let boundary = "----RustApiBoundary";
896        let body = format!(
897            "--{boundary}\r\n\
898             Content-Disposition: form-data; name=\"file\"; filename=\"demo.txt\"\r\n\
899             Content-Type: text/plain\r\n\
900             \r\n\
901             way-too-large\r\n\
902             --{boundary}--\r\n"
903        );
904
905        let mut multipart = streaming_multipart_from_body(
906            Bytes::from(body),
907            boundary,
908            MultipartConfig::new().max_size(1024).max_file_size(4),
909        );
910
911        let mut file = multipart.next_field().await.unwrap().unwrap();
912        let error = file.bytes().await.unwrap_err();
913        assert_eq!(error.status, StatusCode::PAYLOAD_TOO_LARGE);
914        assert!(error.message.contains("4"));
915    }
916
917    #[tokio::test]
918    async fn streaming_multipart_enforces_field_count_limit() {
919        let boundary = "----RustApiBoundary";
920        let body = format!(
921            "--{boundary}\r\n\
922             Content-Disposition: form-data; name=\"first\"\r\n\
923             \r\n\
924             one\r\n\
925             --{boundary}\r\n\
926             Content-Disposition: form-data; name=\"second\"\r\n\
927             \r\n\
928             two\r\n\
929             --{boundary}--\r\n"
930        );
931
932        let mut multipart = streaming_multipart_from_body(
933            Bytes::from(body),
934            boundary,
935            MultipartConfig::new().max_size(1024).max_fields(1),
936        );
937
938        assert!(multipart.next_field().await.unwrap().is_some());
939        let next = multipart.next_field().await;
940        assert!(next.is_err());
941        let error = next.err().unwrap();
942        assert_eq!(error.status, StatusCode::BAD_REQUEST);
943        assert!(error.message.contains("field count exceeded"));
944    }
945
946    #[tokio::test]
947    async fn streaming_multipart_save_to_writes_incrementally() {
948        let boundary = "----RustApiBoundary";
949        let body = format!(
950            "--{boundary}\r\n\
951             Content-Disposition: form-data; name=\"file\"; filename=\"demo.txt\"\r\n\
952             Content-Type: text/plain\r\n\
953             \r\n\
954             persisted\r\n\
955             --{boundary}--\r\n"
956        );
957
958        let mut multipart = streaming_multipart_from_body(
959            Bytes::from(body),
960            boundary,
961            MultipartConfig::new().max_size(1024).max_file_size(1024),
962        );
963
964        let mut file = multipart.next_field().await.unwrap().unwrap();
965        let temp_dir =
966            std::env::temp_dir().join(format!("rustapi-streaming-upload-{}", uuid::Uuid::new_v4()));
967        let saved_path = file.save_to(&temp_dir, None).await.unwrap();
968        let saved = tokio::fs::read_to_string(&saved_path).await.unwrap();
969
970        assert_eq!(saved, "persisted");
971
972        tokio::fs::remove_dir_all(&temp_dir).await.unwrap();
973    }
974}