Skip to main content

rustapi_core/
multipart.rs

1//! Multipart form data extractor for file uploads
2//!
3//! This module provides types for handling `multipart/form-data` requests,
4//! commonly used for file uploads.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use rustapi_core::multipart::{Multipart, FieldData};
10//!
11//! async fn upload(mut multipart: Multipart) -> Result<String, ApiError> {
12//!     while let Some(field) = multipart.next_field().await? {
13//!         let name = field.name().unwrap_or("unknown");
14//!         let filename = field.file_name().map(|s| s.to_string());
15//!         let data = field.bytes().await?;
16//!         
17//!         println!("Field: {}, File: {:?}, Size: {} bytes", name, filename, data.len());
18//!     }
19//!     Ok("Upload successful".to_string())
20//! }
21//! ```
22
23use crate::error::{ApiError, Result};
24use crate::extract::FromRequest;
25use crate::request::Request;
26use crate::stream::StreamingBody;
27use bytes::Bytes;
28use futures_util::stream;
29use http::StatusCode;
30use std::error::Error as _;
31use std::path::Path;
32use tokio::io::AsyncWriteExt;
33
34/// Maximum file size (default: 10MB)
35pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
36
37/// Maximum number of fields in multipart form (default: 100)
38pub const DEFAULT_MAX_FIELDS: usize = 100;
39
40/// Multipart form data extractor
41///
42/// Parses `multipart/form-data` requests, commonly used for file uploads.
43///
44/// # Example
45///
46/// ```rust,ignore
47/// use rustapi_core::multipart::Multipart;
48///
49/// async fn upload(mut multipart: Multipart) -> Result<String, ApiError> {
50///     while let Some(field) = multipart.next_field().await? {
51///         let name = field.name().unwrap_or("unknown").to_string();
52///         let data = field.bytes().await?;
53///         println!("Received field '{}' with {} bytes", name, data.len());
54///     }
55///     Ok("Upload complete".to_string())
56/// }
57/// ```
58pub struct Multipart {
59    fields: Vec<MultipartField>,
60    current_index: usize,
61}
62
63impl Multipart {
64    /// Create a new Multipart from raw data
65    fn new(fields: Vec<MultipartField>) -> Self {
66        Self {
67            fields,
68            current_index: 0,
69        }
70    }
71
72    /// Get the next field from the multipart form
73    pub async fn next_field(&mut self) -> Result<Option<MultipartField>> {
74        if self.current_index >= self.fields.len() {
75            return Ok(None);
76        }
77        let field = self.fields.get(self.current_index).cloned();
78        self.current_index += 1;
79        Ok(field)
80    }
81
82    /// Collect all fields into a vector
83    pub fn into_fields(self) -> Vec<MultipartField> {
84        self.fields
85    }
86
87    /// Get the number of fields
88    pub fn field_count(&self) -> usize {
89        self.fields.len()
90    }
91}
92
93/// Streaming multipart extractor for large file uploads.
94///
95/// Unlike [`Multipart`], this extractor does not buffer the entire request body in memory before
96/// parsing. It consumes the request body as a stream and yields one field at a time.
97///
98/// If a [`MultipartConfig`] is present in app state, its size and content-type limits are applied.
99pub struct StreamingMultipart {
100    inner: multer::Multipart<'static>,
101    config: MultipartConfig,
102    field_count: usize,
103}
104
105impl StreamingMultipart {
106    fn new(stream: StreamingBody, boundary: String, config: MultipartConfig) -> Self {
107        Self {
108            inner: multer::Multipart::new(stream, boundary),
109            config,
110            field_count: 0,
111        }
112    }
113
114    /// Get the next field from the multipart stream.
115    ///
116    /// Consume or drop the previously returned field before calling this again.
117    pub async fn next_field(&mut self) -> Result<Option<StreamingMultipartField<'static>>> {
118        let field = self.inner.next_field().await.map_err(map_multer_error)?;
119        let Some(field) = field else {
120            return Ok(None);
121        };
122
123        self.field_count += 1;
124        if self.field_count > self.config.max_fields {
125            return Err(ApiError::bad_request(format!(
126                "Multipart field count exceeded limit of {}",
127                self.config.max_fields
128            )));
129        }
130
131        validate_streaming_field(&field, &self.config)?;
132
133        Ok(Some(StreamingMultipartField::new(
134            field,
135            self.config.max_file_size,
136        )))
137    }
138
139    /// Number of fields yielded so far.
140    pub fn field_count(&self) -> usize {
141        self.field_count
142    }
143}
144
145impl FromRequest for StreamingMultipart {
146    async fn from_request(req: &mut Request) -> Result<Self> {
147        let content_type = req
148            .headers()
149            .get(http::header::CONTENT_TYPE)
150            .and_then(|v| v.to_str().ok())
151            .ok_or_else(|| ApiError::bad_request("Missing Content-Type header"))?;
152
153        if !content_type.starts_with("multipart/form-data") {
154            return Err(ApiError::bad_request(format!(
155                "Expected multipart/form-data, got: {}",
156                content_type
157            )));
158        }
159
160        let boundary = extract_boundary(content_type)
161            .ok_or_else(|| ApiError::bad_request("Missing boundary in Content-Type"))?;
162
163        let config = req
164            .state()
165            .get::<MultipartConfig>()
166            .cloned()
167            .unwrap_or_default();
168
169        let stream = request_body_stream(req, config.max_size)?;
170        Ok(Self::new(stream, boundary, config))
171    }
172}
173
174/// A single streaming field from a multipart form.
175///
176/// This field is one-shot: once you call [`chunk`](Self::chunk), [`bytes`](Self::bytes),
177/// [`text`](Self::text), or one of the save helpers, the underlying stream is consumed.
178pub struct StreamingMultipartField<'a> {
179    inner: multer::Field<'a>,
180    max_file_size: usize,
181    bytes_read: usize,
182}
183
184impl<'a> StreamingMultipartField<'a> {
185    fn new(inner: multer::Field<'a>, max_file_size: usize) -> Self {
186        Self {
187            inner,
188            max_file_size,
189            bytes_read: 0,
190        }
191    }
192
193    /// Get the field name.
194    pub fn name(&self) -> Option<&str> {
195        self.inner.name()
196    }
197
198    /// Get the original filename when this field is a file upload.
199    pub fn file_name(&self) -> Option<&str> {
200        self.inner.file_name()
201    }
202
203    /// Get the content type of the field.
204    pub fn content_type(&self) -> Option<&str> {
205        self.inner.content_type().map(|mime| mime.essence_str())
206    }
207
208    /// Check whether this field represents a file upload.
209    pub fn is_file(&self) -> bool {
210        self.file_name().is_some()
211    }
212
213    /// Number of bytes consumed from this field so far.
214    pub fn bytes_read(&self) -> usize {
215        self.bytes_read
216    }
217
218    /// Read the next chunk from the field stream.
219    pub async fn chunk(&mut self) -> Result<Option<Bytes>> {
220        let chunk = self.inner.chunk().await.map_err(map_multer_error)?;
221        let Some(chunk) = chunk else {
222            return Ok(None);
223        };
224
225        self.bytes_read += chunk.len();
226        if self.bytes_read > self.max_file_size {
227            return Err(file_size_limit_error(self.max_file_size));
228        }
229
230        Ok(Some(chunk))
231    }
232
233    /// Collect the full field into memory.
234    pub async fn bytes(&mut self) -> Result<Bytes> {
235        let mut buffer = bytes::BytesMut::new();
236        while let Some(chunk) = self.chunk().await? {
237            buffer.extend_from_slice(&chunk);
238        }
239        Ok(buffer.freeze())
240    }
241
242    /// Collect the field as UTF-8 text.
243    pub async fn text(&mut self) -> Result<String> {
244        String::from_utf8(self.bytes().await?.to_vec())
245            .map_err(|e| ApiError::bad_request(format!("Invalid UTF-8 in field: {}", e)))
246    }
247
248    /// Save the field to a directory using either the provided filename or the uploaded name.
249    pub async fn save_to(
250        &mut self,
251        dir: impl AsRef<Path>,
252        filename: Option<&str>,
253    ) -> Result<String> {
254        let dir = dir.as_ref();
255
256        tokio::fs::create_dir_all(dir)
257            .await
258            .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
259
260        let final_filename = filename
261            .map(|value| value.to_string())
262            .or_else(|| self.file_name().map(|value| value.to_string()))
263            .ok_or_else(|| {
264                ApiError::bad_request("No filename provided and field has no filename")
265            })?;
266
267        let safe_filename = sanitize_filename(&final_filename);
268        let file_path = dir.join(&safe_filename);
269        self.save_as(&file_path).await?;
270
271        Ok(file_path.to_string_lossy().to_string())
272    }
273
274    /// Save the field contents to an explicit file path without buffering the full file in memory.
275    pub async fn save_as(&mut self, path: impl AsRef<Path>) -> Result<()> {
276        let path = path.as_ref();
277
278        if let Some(parent) = path.parent() {
279            tokio::fs::create_dir_all(parent)
280                .await
281                .map_err(|e| ApiError::internal(format!("Failed to create directory: {}", e)))?;
282        }
283
284        let mut file = tokio::fs::File::create(path)
285            .await
286            .map_err(|e| ApiError::internal(format!("Failed to create file: {}", e)))?;
287
288        while let Some(chunk) = self.chunk().await? {
289            file.write_all(&chunk)
290                .await
291                .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
292        }
293
294        file.flush()
295            .await
296            .map_err(|e| ApiError::internal(format!("Failed to flush file: {}", e)))?;
297
298        Ok(())
299    }
300
301    /// Collect the field into an [`UploadedFile`] for APIs that still expect the buffered wrapper.
302    pub async fn into_uploaded_file(mut self) -> Result<UploadedFile> {
303        let filename = self
304            .file_name()
305            .ok_or_else(|| ApiError::bad_request("Field is not a file upload"))?
306            .to_string();
307        let content_type = self.content_type().map(|value| value.to_string());
308        let data = self.bytes().await?;
309
310        Ok(UploadedFile {
311            filename,
312            content_type,
313            data,
314        })
315    }
316}
317
318/// A single field from a multipart form
319#[derive(Clone)]
320pub struct MultipartField {
321    name: Option<String>,
322    file_name: Option<String>,
323    content_type: Option<String>,
324    data: Bytes,
325}
326
327impl MultipartField {
328    /// Create a new multipart field
329    pub fn new(
330        name: Option<String>,
331        file_name: Option<String>,
332        content_type: Option<String>,
333        data: Bytes,
334    ) -> Self {
335        Self {
336            name,
337            file_name,
338            content_type,
339            data,
340        }
341    }
342
343    /// Get the field name
344    pub fn name(&self) -> Option<&str> {
345        self.name.as_deref()
346    }
347
348    /// Get the original filename (if this is a file upload)
349    pub fn file_name(&self) -> Option<&str> {
350        self.file_name.as_deref()
351    }
352
353    /// Get the content type of the field
354    pub fn content_type(&self) -> Option<&str> {
355        self.content_type.as_deref()
356    }
357
358    /// Check if this field is a file upload
359    pub fn is_file(&self) -> bool {
360        self.file_name.is_some()
361    }
362
363    /// Get the field data as bytes
364    pub async fn bytes(&self) -> Result<Bytes> {
365        Ok(self.data.clone())
366    }
367
368    /// Get the field data as a string (UTF-8)
369    pub async fn text(&self) -> Result<String> {
370        String::from_utf8(self.data.to_vec())
371            .map_err(|e| ApiError::bad_request(format!("Invalid UTF-8 in field: {}", e)))
372    }
373
374    /// Get the size of the field data in bytes
375    pub fn size(&self) -> usize {
376        self.data.len()
377    }
378
379    /// Save the file to disk
380    ///
381    /// # Arguments
382    ///
383    /// * `path` - The directory to save the file to
384    /// * `filename` - Optional custom filename, uses original filename if None
385    ///
386    /// # Example
387    ///
388    /// ```rust,ignore
389    /// field.save_to("./uploads", None).await?;
390    /// // or with custom filename
391    /// field.save_to("./uploads", Some("custom_name.txt")).await?;
392    /// ```
393    pub async fn save_to(&self, dir: impl AsRef<Path>, filename: Option<&str>) -> Result<String> {
394        let dir = dir.as_ref();
395
396        // Ensure directory exists
397        tokio::fs::create_dir_all(dir)
398            .await
399            .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
400
401        // Determine filename
402        let final_filename = filename
403            .map(|s| s.to_string())
404            .or_else(|| self.file_name.clone())
405            .ok_or_else(|| {
406                ApiError::bad_request("No filename provided and field has no filename")
407            })?;
408
409        // Sanitize filename to prevent path traversal
410        let safe_filename = sanitize_filename(&final_filename);
411        let file_path = dir.join(&safe_filename);
412
413        // Write file
414        tokio::fs::write(&file_path, &self.data)
415            .await
416            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
417
418        Ok(file_path.to_string_lossy().to_string())
419    }
420}
421
422/// Sanitize a filename to prevent path traversal attacks
423fn sanitize_filename(filename: &str) -> String {
424    // Remove path separators and parent directory references
425    filename
426        .replace(['/', '\\'], "_")
427        .replace("..", "_")
428        .trim_start_matches('.')
429        .to_string()
430}
431
432impl FromRequest for Multipart {
433    async fn from_request(req: &mut Request) -> Result<Self> {
434        // Check content type
435        let content_type = req
436            .headers()
437            .get(http::header::CONTENT_TYPE)
438            .and_then(|v| v.to_str().ok())
439            .ok_or_else(|| ApiError::bad_request("Missing Content-Type header"))?;
440
441        if !content_type.starts_with("multipart/form-data") {
442            return Err(ApiError::bad_request(format!(
443                "Expected multipart/form-data, got: {}",
444                content_type
445            )));
446        }
447
448        // Extract boundary
449        let boundary = extract_boundary(content_type)
450            .ok_or_else(|| ApiError::bad_request("Missing boundary in Content-Type"))?;
451
452        // Buffer streaming bodies from live HTTP connections before parsing.
453        req.load_body().await?;
454
455        let body = req
456            .take_body()
457            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
458
459        // Parse multipart
460        let fields = parse_multipart(&body, &boundary)?;
461
462        Ok(Multipart::new(fields))
463    }
464}
465
466impl rustapi_openapi::OperationModifier for Multipart {
467    fn update_operation(op: &mut rustapi_openapi::Operation) {
468        use rustapi_openapi::{MediaType, RequestBody, SchemaRef};
469        use std::collections::BTreeMap;
470
471        let mut content = BTreeMap::new();
472        content.insert(
473            "multipart/form-data".to_string(),
474            MediaType {
475                schema: Some(SchemaRef::Inline(serde_json::json!({ "type": "object" }))),
476                example: None,
477            },
478        );
479
480        op.request_body = Some(RequestBody {
481            description: None,
482            required: Some(true),
483            content,
484        });
485    }
486}
487
488impl rustapi_openapi::OperationModifier for StreamingMultipart {
489    fn update_operation(op: &mut rustapi_openapi::Operation) {
490        Multipart::update_operation(op);
491    }
492}
493
494fn request_body_stream(req: &mut Request, limit: usize) -> Result<StreamingBody> {
495    if let Some(stream) = req.take_stream() {
496        return Ok(StreamingBody::new(stream, Some(limit)));
497    }
498
499    if let Some(body) = req.take_body() {
500        let stream = stream::once(async move { Ok::<Bytes, ApiError>(body) });
501        return Ok(StreamingBody::from_stream(stream, Some(limit)));
502    }
503
504    Err(ApiError::internal("Body already consumed"))
505}
506
507fn validate_streaming_field(field: &multer::Field<'_>, config: &MultipartConfig) -> Result<()> {
508    if field.file_name().is_none() || config.allowed_content_types.is_empty() {
509        return Ok(());
510    }
511
512    let content_type = field
513        .content_type()
514        .map(|mime| mime.essence_str().to_string())
515        .ok_or_else(|| ApiError::bad_request("Uploaded file is missing Content-Type"))?;
516
517    if config
518        .allowed_content_types
519        .iter()
520        .any(|allowed| allowed.eq_ignore_ascii_case(&content_type))
521    {
522        return Ok(());
523    }
524
525    Err(ApiError::bad_request(format!(
526        "Unsupported content type '{}'",
527        content_type
528    )))
529}
530
531fn file_size_limit_error(limit: usize) -> ApiError {
532    ApiError::new(
533        StatusCode::PAYLOAD_TOO_LARGE,
534        "payload_too_large",
535        format!("Multipart field exceeded limit of {} bytes", limit),
536    )
537}
538
539fn map_multer_error(error: multer::Error) -> ApiError {
540    if let Some(source) = error.source() {
541        if let Some(api_error) = source.downcast_ref::<ApiError>() {
542            return api_error.clone();
543        }
544    }
545
546    let message = error.to_string();
547    if message.to_ascii_lowercase().contains("size limit") {
548        return ApiError::new(StatusCode::PAYLOAD_TOO_LARGE, "payload_too_large", message);
549    }
550
551    ApiError::bad_request(format!("Invalid multipart body: {}", message))
552}
553
554/// Extract boundary from Content-Type header
555fn extract_boundary(content_type: &str) -> Option<String> {
556    content_type.split(';').find_map(|part| {
557        let part = part.trim();
558        if part.starts_with("boundary=") {
559            let boundary = part.trim_start_matches("boundary=").trim_matches('"');
560            Some(boundary.to_string())
561        } else {
562            None
563        }
564    })
565}
566
567fn find_subsequence(haystack: &[u8], needle: &[u8], from: usize) -> Option<usize> {
568    haystack[from..]
569        .windows(needle.len())
570        .position(|window| window == needle)
571        .map(|pos| from + pos)
572}
573
574fn trim_trailing_crlf(mut data: Vec<u8>) -> Vec<u8> {
575    while data.ends_with(b"\r\n") {
576        data.truncate(data.len().saturating_sub(2));
577    }
578    while data.ends_with(b"\n") {
579        data.truncate(data.len().saturating_sub(1));
580    }
581    data
582}
583
584fn parse_multipart_part(part: &[u8]) -> Option<MultipartField> {
585    let (header_end, body_start) = if let Some(pos) = find_subsequence(part, b"\r\n\r\n", 0) {
586        (pos, pos + 4)
587    } else if let Some(pos) = find_subsequence(part, b"\n\n", 0) {
588        (pos, pos + 2)
589    } else {
590        return None;
591    };
592
593    let headers_section = String::from_utf8_lossy(&part[..header_end]);
594    let body_section = trim_trailing_crlf(part[body_start..].to_vec());
595
596    let mut name = None;
597    let mut filename = None;
598    let mut content_type = None;
599
600    for header_line in headers_section.lines() {
601        let header_line = header_line.trim();
602        if header_line.is_empty() {
603            continue;
604        }
605
606        if let Some((key, value)) = header_line.split_once(':') {
607            let key = key.trim().to_lowercase();
608            let value = value.trim();
609
610            match key.as_str() {
611                "content-disposition" => {
612                    for segment in value.split(';') {
613                        let segment = segment.trim();
614                        if segment.starts_with("name=") {
615                            name = Some(
616                                segment
617                                    .trim_start_matches("name=")
618                                    .trim_matches('"')
619                                    .to_string(),
620                            );
621                        } else if segment.starts_with("filename=") {
622                            filename = Some(
623                                segment
624                                    .trim_start_matches("filename=")
625                                    .trim_matches('"')
626                                    .to_string(),
627                            );
628                        }
629                    }
630                }
631                "content-type" => {
632                    content_type = Some(value.to_string());
633                }
634                _ => {}
635            }
636        }
637    }
638
639    Some(MultipartField::new(
640        name,
641        filename,
642        content_type,
643        Bytes::from(body_section),
644    ))
645}
646
647/// Parse multipart form data from raw bytes (binary-safe).
648fn parse_multipart(body: &Bytes, boundary: &str) -> Result<Vec<MultipartField>> {
649    let delimiter = format!("--{}", boundary);
650    let delim = delimiter.as_bytes();
651
652    let first = find_subsequence(body, delim, 0)
653        .ok_or_else(|| ApiError::bad_request("No multipart boundary found"))?;
654
655    let mut fields = Vec::new();
656    let mut cursor = first + delim.len();
657
658    if body[cursor..].starts_with(b"--") {
659        return Ok(fields);
660    }
661
662    if body[cursor..].starts_with(b"\r\n") {
663        cursor += 2;
664    } else if body.get(cursor) == Some(&b'\n') {
665        cursor += 1;
666    }
667
668    while cursor < body.len() {
669        let next = find_subsequence(body, delim, cursor);
670        let part_end = next.unwrap_or(body.len());
671        let part = &body[cursor..part_end];
672
673        if !part.is_empty() {
674            if let Some(field) = parse_multipart_part(part) {
675                fields.push(field);
676            }
677        }
678
679        let Some(next_pos) = next else {
680            break;
681        };
682
683        cursor = next_pos + delim.len();
684        if body[cursor..].starts_with(b"--") {
685            break;
686        }
687        if body[cursor..].starts_with(b"\r\n") {
688            cursor += 2;
689        } else if body.get(cursor) == Some(&b'\n') {
690            cursor += 1;
691        }
692    }
693
694    Ok(fields)
695}
696
697/// Configuration for multipart form handling
698#[derive(Clone)]
699pub struct MultipartConfig {
700    /// Maximum total size of the multipart form (default: 10MB)
701    pub max_size: usize,
702    /// Maximum number of fields (default: 100)
703    pub max_fields: usize,
704    /// Maximum size per file (default: 10MB)
705    pub max_file_size: usize,
706    /// Allowed content types for files (empty = all allowed)
707    pub allowed_content_types: Vec<String>,
708}
709
710impl Default for MultipartConfig {
711    fn default() -> Self {
712        Self {
713            max_size: DEFAULT_MAX_FILE_SIZE,
714            max_fields: DEFAULT_MAX_FIELDS,
715            max_file_size: DEFAULT_MAX_FILE_SIZE,
716            allowed_content_types: Vec::new(),
717        }
718    }
719}
720
721impl MultipartConfig {
722    /// Create a new multipart config with default values
723    pub fn new() -> Self {
724        Self::default()
725    }
726
727    /// Set the maximum total size
728    pub fn max_size(mut self, size: usize) -> Self {
729        self.max_size = size;
730        self
731    }
732
733    /// Set the maximum number of fields
734    pub fn max_fields(mut self, count: usize) -> Self {
735        self.max_fields = count;
736        self
737    }
738
739    /// Set the maximum file size
740    pub fn max_file_size(mut self, size: usize) -> Self {
741        self.max_file_size = size;
742        self
743    }
744
745    /// Set allowed content types for file uploads
746    pub fn allowed_content_types(mut self, types: Vec<String>) -> Self {
747        self.allowed_content_types = types;
748        self
749    }
750
751    /// Add an allowed content type
752    pub fn allow_content_type(mut self, content_type: impl Into<String>) -> Self {
753        self.allowed_content_types.push(content_type.into());
754        self
755    }
756}
757
758/// File data wrapper for convenient access to uploaded files
759#[derive(Clone)]
760pub struct UploadedFile {
761    /// Original filename
762    pub filename: String,
763    /// Content type (MIME type)
764    pub content_type: Option<String>,
765    /// File data
766    pub data: Bytes,
767}
768
769impl UploadedFile {
770    /// Create from a multipart field
771    pub fn from_field(field: &MultipartField) -> Option<Self> {
772        field.file_name().map(|filename| Self {
773            filename: filename.to_string(),
774            content_type: field.content_type().map(|s| s.to_string()),
775            data: field.data.clone(),
776        })
777    }
778
779    /// Get file size in bytes
780    pub fn size(&self) -> usize {
781        self.data.len()
782    }
783
784    /// Get file extension
785    pub fn extension(&self) -> Option<&str> {
786        self.filename.rsplit('.').next()
787    }
788
789    /// Save to disk with original filename
790    pub async fn save_to(&self, dir: impl AsRef<Path>) -> Result<String> {
791        let dir = dir.as_ref();
792
793        tokio::fs::create_dir_all(dir)
794            .await
795            .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
796
797        let safe_filename = sanitize_filename(&self.filename);
798        let file_path = dir.join(&safe_filename);
799
800        tokio::fs::write(&file_path, &self.data)
801            .await
802            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
803
804        Ok(file_path.to_string_lossy().to_string())
805    }
806
807    /// Save with a custom filename
808    pub async fn save_as(&self, path: impl AsRef<Path>) -> Result<()> {
809        let path = path.as_ref();
810
811        if let Some(parent) = path.parent() {
812            tokio::fs::create_dir_all(parent)
813                .await
814                .map_err(|e| ApiError::internal(format!("Failed to create directory: {}", e)))?;
815        }
816
817        tokio::fs::write(path, &self.data)
818            .await
819            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
820
821        Ok(())
822    }
823}
824
825#[cfg(test)]
826mod tests {
827    use super::*;
828    use futures_util::stream;
829
830    fn chunked_body_stream(
831        body: Bytes,
832        chunk_size: usize,
833    ) -> impl futures_util::Stream<Item = Result<Bytes>> + Send + 'static {
834        let chunks = body
835            .chunks(chunk_size)
836            .map(Bytes::copy_from_slice)
837            .map(Ok)
838            .collect::<Vec<_>>();
839        stream::iter(chunks)
840    }
841
842    fn streaming_multipart_from_body(
843        body: Bytes,
844        boundary: &str,
845        config: MultipartConfig,
846    ) -> StreamingMultipart {
847        let stream =
848            StreamingBody::from_stream(chunked_body_stream(body, 7), Some(config.max_size));
849        StreamingMultipart::new(stream, boundary.to_string(), config)
850    }
851
852    #[test]
853    fn test_extract_boundary() {
854        let ct = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW";
855        assert_eq!(
856            extract_boundary(ct),
857            Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string())
858        );
859
860        let ct_quoted = "multipart/form-data; boundary=\"----WebKitFormBoundary\"";
861        assert_eq!(
862            extract_boundary(ct_quoted),
863            Some("----WebKitFormBoundary".to_string())
864        );
865    }
866
867    #[test]
868    fn test_sanitize_filename() {
869        assert_eq!(sanitize_filename("test.txt"), "test.txt");
870        assert_eq!(sanitize_filename("../../../etc/passwd"), "______etc_passwd");
871        // ..\..\windows\system32 -> .._.._windows_system32 -> ____windows_system32
872        assert_eq!(
873            sanitize_filename("..\\..\\windows\\system32"),
874            "____windows_system32"
875        );
876        assert_eq!(sanitize_filename(".hidden"), "hidden");
877    }
878
879    #[test]
880    fn test_parse_simple_multipart() {
881        let boundary = "----WebKitFormBoundary";
882        let body = "------WebKitFormBoundary\r\n\
883             Content-Disposition: form-data; name=\"field1\"\r\n\
884             \r\n\
885             value1\r\n\
886             ------WebKitFormBoundary\r\n\
887             Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\
888             Content-Type: text/plain\r\n\
889             \r\n\
890             file content\r\n\
891             ------WebKitFormBoundary--\r\n"
892            .to_string();
893
894        let fields = parse_multipart(&Bytes::from(body), boundary).unwrap();
895        assert_eq!(fields.len(), 2);
896
897        assert_eq!(fields[0].name(), Some("field1"));
898        assert!(!fields[0].is_file());
899
900        assert_eq!(fields[1].name(), Some("file"));
901        assert_eq!(fields[1].file_name(), Some("test.txt"));
902        assert_eq!(fields[1].content_type(), Some("text/plain"));
903        assert!(fields[1].is_file());
904    }
905
906    #[tokio::test]
907    async fn test_parse_multipart_preserves_binary_payload() {
908        let boundary = "test-boundary";
909        let binary = vec![0x4D, 0x5A, 0x90, 0x00, 0xFF, 0xFE, 0x00, 0x00];
910        let mut body = format!(
911            "--{boundary}\r\nContent-Disposition: form-data; name=\"project_name\"\r\n\r\napp\r\n\
912             --{boundary}\r\nContent-Disposition: form-data; name=\"binary\"; filename=\"app.bin\"\r\n\
913             Content-Type: application/octet-stream\r\n\r\n"
914        )
915        .into_bytes();
916        body.extend_from_slice(&binary);
917        body.extend_from_slice(format!("\r\n--{boundary}--\r\n").as_bytes());
918
919        let fields = parse_multipart(&Bytes::from(body), boundary).unwrap();
920        assert_eq!(fields.len(), 2);
921        assert_eq!(
922            fields[1].bytes().await.expect("binary field"),
923            binary.as_slice()
924        );
925    }
926
927    #[test]
928    fn test_multipart_config() {
929        let config = MultipartConfig::new()
930            .max_size(20 * 1024 * 1024)
931            .max_fields(50)
932            .max_file_size(5 * 1024 * 1024)
933            .allow_content_type("image/png")
934            .allow_content_type("image/jpeg");
935
936        assert_eq!(config.max_size, 20 * 1024 * 1024);
937        assert_eq!(config.max_fields, 50);
938        assert_eq!(config.max_file_size, 5 * 1024 * 1024);
939        assert_eq!(config.allowed_content_types.len(), 2);
940    }
941
942    #[tokio::test]
943    async fn streaming_multipart_reads_chunked_body() {
944        let boundary = "----RustApiBoundary";
945        let body = format!(
946            "--{boundary}\r\n\
947             Content-Disposition: form-data; name=\"title\"\r\n\
948             \r\n\
949             hello\r\n\
950             --{boundary}\r\n\
951             Content-Disposition: form-data; name=\"file\"; filename=\"demo.txt\"\r\n\
952             Content-Type: text/plain\r\n\
953             \r\n\
954             streamed-content\r\n\
955             --{boundary}--\r\n"
956        );
957
958        let mut multipart = streaming_multipart_from_body(
959            Bytes::from(body),
960            boundary,
961            MultipartConfig::new().max_size(1024).max_file_size(1024),
962        );
963
964        let mut title = multipart.next_field().await.unwrap().unwrap();
965        assert_eq!(title.name(), Some("title"));
966        assert_eq!(title.text().await.unwrap(), "hello");
967        drop(title);
968
969        let mut file = multipart.next_field().await.unwrap().unwrap();
970        assert_eq!(file.file_name(), Some("demo.txt"));
971        assert_eq!(file.content_type(), Some("text/plain"));
972        assert_eq!(file.bytes().await.unwrap(), Bytes::from("streamed-content"));
973        drop(file);
974
975        assert!(multipart.next_field().await.unwrap().is_none());
976        assert_eq!(multipart.field_count(), 2);
977    }
978
979    #[tokio::test]
980    async fn streaming_multipart_enforces_per_file_limit() {
981        let boundary = "----RustApiBoundary";
982        let body = format!(
983            "--{boundary}\r\n\
984             Content-Disposition: form-data; name=\"file\"; filename=\"demo.txt\"\r\n\
985             Content-Type: text/plain\r\n\
986             \r\n\
987             way-too-large\r\n\
988             --{boundary}--\r\n"
989        );
990
991        let mut multipart = streaming_multipart_from_body(
992            Bytes::from(body),
993            boundary,
994            MultipartConfig::new().max_size(1024).max_file_size(4),
995        );
996
997        let mut file = multipart.next_field().await.unwrap().unwrap();
998        let error = file.bytes().await.unwrap_err();
999        assert_eq!(error.status, StatusCode::PAYLOAD_TOO_LARGE);
1000        assert!(error.message.contains("4"));
1001    }
1002
1003    #[tokio::test]
1004    async fn streaming_multipart_enforces_field_count_limit() {
1005        let boundary = "----RustApiBoundary";
1006        let body = format!(
1007            "--{boundary}\r\n\
1008             Content-Disposition: form-data; name=\"first\"\r\n\
1009             \r\n\
1010             one\r\n\
1011             --{boundary}\r\n\
1012             Content-Disposition: form-data; name=\"second\"\r\n\
1013             \r\n\
1014             two\r\n\
1015             --{boundary}--\r\n"
1016        );
1017
1018        let mut multipart = streaming_multipart_from_body(
1019            Bytes::from(body),
1020            boundary,
1021            MultipartConfig::new().max_size(1024).max_fields(1),
1022        );
1023
1024        assert!(multipart.next_field().await.unwrap().is_some());
1025        let next = multipart.next_field().await;
1026        assert!(next.is_err());
1027        let error = next.err().unwrap();
1028        assert_eq!(error.status, StatusCode::BAD_REQUEST);
1029        assert!(error.message.contains("field count exceeded"));
1030    }
1031
1032    #[tokio::test]
1033    async fn streaming_multipart_save_to_writes_incrementally() {
1034        let boundary = "----RustApiBoundary";
1035        let body = format!(
1036            "--{boundary}\r\n\
1037             Content-Disposition: form-data; name=\"file\"; filename=\"demo.txt\"\r\n\
1038             Content-Type: text/plain\r\n\
1039             \r\n\
1040             persisted\r\n\
1041             --{boundary}--\r\n"
1042        );
1043
1044        let mut multipart = streaming_multipart_from_body(
1045            Bytes::from(body),
1046            boundary,
1047            MultipartConfig::new().max_size(1024).max_file_size(1024),
1048        );
1049
1050        let mut file = multipart.next_field().await.unwrap().unwrap();
1051        let temp_dir =
1052            std::env::temp_dir().join(format!("rustapi-streaming-upload-{}", uuid::Uuid::new_v4()));
1053        let saved_path = file.save_to(&temp_dir, None).await.unwrap();
1054        let saved = tokio::fs::read_to_string(&saved_path).await.unwrap();
1055
1056        assert_eq!(saved, "persisted");
1057
1058        tokio::fs::remove_dir_all(&temp_dir).await.unwrap();
1059    }
1060}