rustapi_core/
multipart.rs

1//! Multipart form data extractor for file uploads
2//!
3//! This module provides types for handling `multipart/form-data` requests,
4//! commonly used for file uploads.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use rustapi_core::multipart::{Multipart, FieldData};
10//!
11//! async fn upload(mut multipart: Multipart) -> Result<String, ApiError> {
12//!     while let Some(field) = multipart.next_field().await? {
13//!         let name = field.name().unwrap_or("unknown");
14//!         let filename = field.file_name().map(|s| s.to_string());
15//!         let data = field.bytes().await?;
16//!         
17//!         println!("Field: {}, File: {:?}, Size: {} bytes", name, filename, data.len());
18//!     }
19//!     Ok("Upload successful".to_string())
20//! }
21//! ```
22
23use crate::error::{ApiError, Result};
24use crate::extract::FromRequest;
25use crate::request::Request;
26use bytes::Bytes;
27use std::path::Path;
28
29/// Maximum file size (default: 10MB)
30pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
31
32/// Maximum number of fields in multipart form (default: 100)
33pub const DEFAULT_MAX_FIELDS: usize = 100;
34
35/// Multipart form data extractor
36///
37/// Parses `multipart/form-data` requests, commonly used for file uploads.
38///
39/// # Example
40///
41/// ```rust,ignore
42/// use rustapi_core::multipart::Multipart;
43///
44/// async fn upload(mut multipart: Multipart) -> Result<String, ApiError> {
45///     while let Some(field) = multipart.next_field().await? {
46///         let name = field.name().unwrap_or("unknown").to_string();
47///         let data = field.bytes().await?;
48///         println!("Received field '{}' with {} bytes", name, data.len());
49///     }
50///     Ok("Upload complete".to_string())
51/// }
52/// ```
53pub struct Multipart {
54    fields: Vec<MultipartField>,
55    current_index: usize,
56}
57
58impl Multipart {
59    /// Create a new Multipart from raw data
60    fn new(fields: Vec<MultipartField>) -> Self {
61        Self {
62            fields,
63            current_index: 0,
64        }
65    }
66
67    /// Get the next field from the multipart form
68    pub async fn next_field(&mut self) -> Result<Option<MultipartField>> {
69        if self.current_index >= self.fields.len() {
70            return Ok(None);
71        }
72        let field = self.fields.get(self.current_index).cloned();
73        self.current_index += 1;
74        Ok(field)
75    }
76
77    /// Collect all fields into a vector
78    pub fn into_fields(self) -> Vec<MultipartField> {
79        self.fields
80    }
81
82    /// Get the number of fields
83    pub fn field_count(&self) -> usize {
84        self.fields.len()
85    }
86}
87
88/// A single field from a multipart form
89#[derive(Clone)]
90pub struct MultipartField {
91    name: Option<String>,
92    file_name: Option<String>,
93    content_type: Option<String>,
94    data: Bytes,
95}
96
97impl MultipartField {
98    /// Create a new multipart field
99    pub fn new(
100        name: Option<String>,
101        file_name: Option<String>,
102        content_type: Option<String>,
103        data: Bytes,
104    ) -> Self {
105        Self {
106            name,
107            file_name,
108            content_type,
109            data,
110        }
111    }
112
113    /// Get the field name
114    pub fn name(&self) -> Option<&str> {
115        self.name.as_deref()
116    }
117
118    /// Get the original filename (if this is a file upload)
119    pub fn file_name(&self) -> Option<&str> {
120        self.file_name.as_deref()
121    }
122
123    /// Get the content type of the field
124    pub fn content_type(&self) -> Option<&str> {
125        self.content_type.as_deref()
126    }
127
128    /// Check if this field is a file upload
129    pub fn is_file(&self) -> bool {
130        self.file_name.is_some()
131    }
132
133    /// Get the field data as bytes
134    pub async fn bytes(&self) -> Result<Bytes> {
135        Ok(self.data.clone())
136    }
137
138    /// Get the field data as a string (UTF-8)
139    pub async fn text(&self) -> Result<String> {
140        String::from_utf8(self.data.to_vec())
141            .map_err(|e| ApiError::bad_request(format!("Invalid UTF-8 in field: {}", e)))
142    }
143
144    /// Get the size of the field data in bytes
145    pub fn size(&self) -> usize {
146        self.data.len()
147    }
148
149    /// Save the file to disk
150    ///
151    /// # Arguments
152    ///
153    /// * `path` - The directory to save the file to
154    /// * `filename` - Optional custom filename, uses original filename if None
155    ///
156    /// # Example
157    ///
158    /// ```rust,ignore
159    /// field.save_to("./uploads", None).await?;
160    /// // or with custom filename
161    /// field.save_to("./uploads", Some("custom_name.txt")).await?;
162    /// ```
163    pub async fn save_to(&self, dir: impl AsRef<Path>, filename: Option<&str>) -> Result<String> {
164        let dir = dir.as_ref();
165
166        // Ensure directory exists
167        tokio::fs::create_dir_all(dir)
168            .await
169            .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
170
171        // Determine filename
172        let final_filename = filename
173            .map(|s| s.to_string())
174            .or_else(|| self.file_name.clone())
175            .ok_or_else(|| {
176                ApiError::bad_request("No filename provided and field has no filename")
177            })?;
178
179        // Sanitize filename to prevent path traversal
180        let safe_filename = sanitize_filename(&final_filename);
181        let file_path = dir.join(&safe_filename);
182
183        // Write file
184        tokio::fs::write(&file_path, &self.data)
185            .await
186            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
187
188        Ok(file_path.to_string_lossy().to_string())
189    }
190}
191
192/// Sanitize a filename to prevent path traversal attacks
193fn sanitize_filename(filename: &str) -> String {
194    // Remove path separators and parent directory references
195    filename
196        .replace(['/', '\\'], "_")
197        .replace("..", "_")
198        .trim_start_matches('.')
199        .to_string()
200}
201
202impl FromRequest for Multipart {
203    async fn from_request(req: &mut Request) -> Result<Self> {
204        // Check content type
205        let content_type = req
206            .headers()
207            .get(http::header::CONTENT_TYPE)
208            .and_then(|v| v.to_str().ok())
209            .ok_or_else(|| ApiError::bad_request("Missing Content-Type header"))?;
210
211        if !content_type.starts_with("multipart/form-data") {
212            return Err(ApiError::bad_request(format!(
213                "Expected multipart/form-data, got: {}",
214                content_type
215            )));
216        }
217
218        // Extract boundary
219        let boundary = extract_boundary(content_type)
220            .ok_or_else(|| ApiError::bad_request("Missing boundary in Content-Type"))?;
221
222        // Get body
223        let body = req
224            .take_body()
225            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
226
227        // Parse multipart
228        let fields = parse_multipart(&body, &boundary)?;
229
230        Ok(Multipart::new(fields))
231    }
232}
233
234/// Extract boundary from Content-Type header
235fn extract_boundary(content_type: &str) -> Option<String> {
236    content_type.split(';').find_map(|part| {
237        let part = part.trim();
238        if part.starts_with("boundary=") {
239            let boundary = part.trim_start_matches("boundary=").trim_matches('"');
240            Some(boundary.to_string())
241        } else {
242            None
243        }
244    })
245}
246
247/// Parse multipart form data
248fn parse_multipart(body: &Bytes, boundary: &str) -> Result<Vec<MultipartField>> {
249    let mut fields = Vec::new();
250    let delimiter = format!("--{}", boundary);
251    let end_delimiter = format!("--{}--", boundary);
252
253    // Convert body to string for easier parsing
254    // Note: This is a simplified parser. For production, consider using multer crate.
255    let body_str = String::from_utf8_lossy(body);
256
257    // Split by delimiter
258    let parts: Vec<&str> = body_str.split(&delimiter).collect();
259
260    for part in parts.iter().skip(1) {
261        // Skip empty parts and end delimiter
262        let part = part.trim_start_matches("\r\n").trim_start_matches('\n');
263        if part.is_empty() || part.starts_with("--") {
264            continue;
265        }
266
267        // Find header/body separator (blank line)
268        let header_body_split = if let Some(pos) = part.find("\r\n\r\n") {
269            pos
270        } else if let Some(pos) = part.find("\n\n") {
271            pos
272        } else {
273            continue;
274        };
275
276        let headers_section = &part[..header_body_split];
277        let body_section = &part[header_body_split..]
278            .trim_start_matches("\r\n\r\n")
279            .trim_start_matches("\n\n");
280
281        // Remove trailing boundary markers from body
282        let body_section = body_section
283            .trim_end_matches(&end_delimiter)
284            .trim_end_matches(&delimiter)
285            .trim_end_matches("\r\n")
286            .trim_end_matches('\n');
287
288        // Parse headers
289        let mut name = None;
290        let mut filename = None;
291        let mut content_type = None;
292
293        for header_line in headers_section.lines() {
294            let header_line = header_line.trim();
295            if header_line.is_empty() {
296                continue;
297            }
298
299            if let Some((key, value)) = header_line.split_once(':') {
300                let key = key.trim().to_lowercase();
301                let value = value.trim();
302
303                match key.as_str() {
304                    "content-disposition" => {
305                        // Parse name and filename from Content-Disposition
306                        for part in value.split(';') {
307                            let part = part.trim();
308                            if part.starts_with("name=") {
309                                name = Some(
310                                    part.trim_start_matches("name=")
311                                        .trim_matches('"')
312                                        .to_string(),
313                                );
314                            } else if part.starts_with("filename=") {
315                                filename = Some(
316                                    part.trim_start_matches("filename=")
317                                        .trim_matches('"')
318                                        .to_string(),
319                                );
320                            }
321                        }
322                    }
323                    "content-type" => {
324                        content_type = Some(value.to_string());
325                    }
326                    _ => {}
327                }
328            }
329        }
330
331        fields.push(MultipartField::new(
332            name,
333            filename,
334            content_type,
335            Bytes::copy_from_slice(body_section.as_bytes()),
336        ));
337    }
338
339    Ok(fields)
340}
341
342/// Configuration for multipart form handling
343#[derive(Clone)]
344pub struct MultipartConfig {
345    /// Maximum total size of the multipart form (default: 10MB)
346    pub max_size: usize,
347    /// Maximum number of fields (default: 100)
348    pub max_fields: usize,
349    /// Maximum size per file (default: 10MB)
350    pub max_file_size: usize,
351    /// Allowed content types for files (empty = all allowed)
352    pub allowed_content_types: Vec<String>,
353}
354
355impl Default for MultipartConfig {
356    fn default() -> Self {
357        Self {
358            max_size: DEFAULT_MAX_FILE_SIZE,
359            max_fields: DEFAULT_MAX_FIELDS,
360            max_file_size: DEFAULT_MAX_FILE_SIZE,
361            allowed_content_types: Vec::new(),
362        }
363    }
364}
365
366impl MultipartConfig {
367    /// Create a new multipart config with default values
368    pub fn new() -> Self {
369        Self::default()
370    }
371
372    /// Set the maximum total size
373    pub fn max_size(mut self, size: usize) -> Self {
374        self.max_size = size;
375        self
376    }
377
378    /// Set the maximum number of fields
379    pub fn max_fields(mut self, count: usize) -> Self {
380        self.max_fields = count;
381        self
382    }
383
384    /// Set the maximum file size
385    pub fn max_file_size(mut self, size: usize) -> Self {
386        self.max_file_size = size;
387        self
388    }
389
390    /// Set allowed content types for file uploads
391    pub fn allowed_content_types(mut self, types: Vec<String>) -> Self {
392        self.allowed_content_types = types;
393        self
394    }
395
396    /// Add an allowed content type
397    pub fn allow_content_type(mut self, content_type: impl Into<String>) -> Self {
398        self.allowed_content_types.push(content_type.into());
399        self
400    }
401}
402
403/// File data wrapper for convenient access to uploaded files
404#[derive(Clone)]
405pub struct UploadedFile {
406    /// Original filename
407    pub filename: String,
408    /// Content type (MIME type)
409    pub content_type: Option<String>,
410    /// File data
411    pub data: Bytes,
412}
413
414impl UploadedFile {
415    /// Create from a multipart field
416    pub fn from_field(field: &MultipartField) -> Option<Self> {
417        field.file_name().map(|filename| Self {
418            filename: filename.to_string(),
419            content_type: field.content_type().map(|s| s.to_string()),
420            data: field.data.clone(),
421        })
422    }
423
424    /// Get file size in bytes
425    pub fn size(&self) -> usize {
426        self.data.len()
427    }
428
429    /// Get file extension
430    pub fn extension(&self) -> Option<&str> {
431        self.filename.rsplit('.').next()
432    }
433
434    /// Save to disk with original filename
435    pub async fn save_to(&self, dir: impl AsRef<Path>) -> Result<String> {
436        let dir = dir.as_ref();
437
438        tokio::fs::create_dir_all(dir)
439            .await
440            .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
441
442        let safe_filename = sanitize_filename(&self.filename);
443        let file_path = dir.join(&safe_filename);
444
445        tokio::fs::write(&file_path, &self.data)
446            .await
447            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
448
449        Ok(file_path.to_string_lossy().to_string())
450    }
451
452    /// Save with a custom filename
453    pub async fn save_as(&self, path: impl AsRef<Path>) -> Result<()> {
454        let path = path.as_ref();
455
456        if let Some(parent) = path.parent() {
457            tokio::fs::create_dir_all(parent)
458                .await
459                .map_err(|e| ApiError::internal(format!("Failed to create directory: {}", e)))?;
460        }
461
462        tokio::fs::write(path, &self.data)
463            .await
464            .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
465
466        Ok(())
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_extract_boundary() {
476        let ct = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW";
477        assert_eq!(
478            extract_boundary(ct),
479            Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string())
480        );
481
482        let ct_quoted = "multipart/form-data; boundary=\"----WebKitFormBoundary\"";
483        assert_eq!(
484            extract_boundary(ct_quoted),
485            Some("----WebKitFormBoundary".to_string())
486        );
487    }
488
489    #[test]
490    fn test_sanitize_filename() {
491        assert_eq!(sanitize_filename("test.txt"), "test.txt");
492        assert_eq!(sanitize_filename("../../../etc/passwd"), "______etc_passwd");
493        // ..\..\windows\system32 -> .._.._windows_system32 -> ____windows_system32
494        assert_eq!(
495            sanitize_filename("..\\..\\windows\\system32"),
496            "____windows_system32"
497        );
498        assert_eq!(sanitize_filename(".hidden"), "hidden");
499    }
500
501    #[test]
502    fn test_parse_simple_multipart() {
503        let boundary = "----WebKitFormBoundary";
504        let body = format!(
505            "------WebKitFormBoundary\r\n\
506             Content-Disposition: form-data; name=\"field1\"\r\n\
507             \r\n\
508             value1\r\n\
509             ------WebKitFormBoundary\r\n\
510             Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\
511             Content-Type: text/plain\r\n\
512             \r\n\
513             file content\r\n\
514             ------WebKitFormBoundary--\r\n"
515        );
516
517        let fields = parse_multipart(&Bytes::from(body), boundary).unwrap();
518        assert_eq!(fields.len(), 2);
519
520        assert_eq!(fields[0].name(), Some("field1"));
521        assert!(!fields[0].is_file());
522
523        assert_eq!(fields[1].name(), Some("file"));
524        assert_eq!(fields[1].file_name(), Some("test.txt"));
525        assert_eq!(fields[1].content_type(), Some("text/plain"));
526        assert!(fields[1].is_file());
527    }
528
529    #[test]
530    fn test_multipart_config() {
531        let config = MultipartConfig::new()
532            .max_size(20 * 1024 * 1024)
533            .max_fields(50)
534            .max_file_size(5 * 1024 * 1024)
535            .allow_content_type("image/png")
536            .allow_content_type("image/jpeg");
537
538        assert_eq!(config.max_size, 20 * 1024 * 1024);
539        assert_eq!(config.max_fields, 50);
540        assert_eq!(config.max_file_size, 5 * 1024 * 1024);
541        assert_eq!(config.allowed_content_types.len(), 2);
542    }
543}