rustapi_core/
multipart.rs1use crate::error::{ApiError, Result};
24use crate::extract::FromRequest;
25use crate::request::Request;
26use bytes::Bytes;
27use std::path::Path;
28
29pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
31
32pub const DEFAULT_MAX_FIELDS: usize = 100;
34
35pub struct Multipart {
54 fields: Vec<MultipartField>,
55 current_index: usize,
56}
57
58impl Multipart {
59 fn new(fields: Vec<MultipartField>) -> Self {
61 Self {
62 fields,
63 current_index: 0,
64 }
65 }
66
67 pub async fn next_field(&mut self) -> Result<Option<MultipartField>> {
69 if self.current_index >= self.fields.len() {
70 return Ok(None);
71 }
72 let field = self.fields.get(self.current_index).cloned();
73 self.current_index += 1;
74 Ok(field)
75 }
76
77 pub fn into_fields(self) -> Vec<MultipartField> {
79 self.fields
80 }
81
82 pub fn field_count(&self) -> usize {
84 self.fields.len()
85 }
86}
87
88#[derive(Clone)]
90pub struct MultipartField {
91 name: Option<String>,
92 file_name: Option<String>,
93 content_type: Option<String>,
94 data: Bytes,
95}
96
97impl MultipartField {
98 pub fn new(
100 name: Option<String>,
101 file_name: Option<String>,
102 content_type: Option<String>,
103 data: Bytes,
104 ) -> Self {
105 Self {
106 name,
107 file_name,
108 content_type,
109 data,
110 }
111 }
112
113 pub fn name(&self) -> Option<&str> {
115 self.name.as_deref()
116 }
117
118 pub fn file_name(&self) -> Option<&str> {
120 self.file_name.as_deref()
121 }
122
123 pub fn content_type(&self) -> Option<&str> {
125 self.content_type.as_deref()
126 }
127
128 pub fn is_file(&self) -> bool {
130 self.file_name.is_some()
131 }
132
133 pub async fn bytes(&self) -> Result<Bytes> {
135 Ok(self.data.clone())
136 }
137
138 pub async fn text(&self) -> Result<String> {
140 String::from_utf8(self.data.to_vec())
141 .map_err(|e| ApiError::bad_request(format!("Invalid UTF-8 in field: {}", e)))
142 }
143
144 pub fn size(&self) -> usize {
146 self.data.len()
147 }
148
149 pub async fn save_to(&self, dir: impl AsRef<Path>, filename: Option<&str>) -> Result<String> {
164 let dir = dir.as_ref();
165
166 tokio::fs::create_dir_all(dir)
168 .await
169 .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
170
171 let final_filename = filename
173 .map(|s| s.to_string())
174 .or_else(|| self.file_name.clone())
175 .ok_or_else(|| {
176 ApiError::bad_request("No filename provided and field has no filename")
177 })?;
178
179 let safe_filename = sanitize_filename(&final_filename);
181 let file_path = dir.join(&safe_filename);
182
183 tokio::fs::write(&file_path, &self.data)
185 .await
186 .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
187
188 Ok(file_path.to_string_lossy().to_string())
189 }
190}
191
192fn sanitize_filename(filename: &str) -> String {
194 filename
196 .replace(['/', '\\'], "_")
197 .replace("..", "_")
198 .trim_start_matches('.')
199 .to_string()
200}
201
202impl FromRequest for Multipart {
203 async fn from_request(req: &mut Request) -> Result<Self> {
204 let content_type = req
206 .headers()
207 .get(http::header::CONTENT_TYPE)
208 .and_then(|v| v.to_str().ok())
209 .ok_or_else(|| ApiError::bad_request("Missing Content-Type header"))?;
210
211 if !content_type.starts_with("multipart/form-data") {
212 return Err(ApiError::bad_request(format!(
213 "Expected multipart/form-data, got: {}",
214 content_type
215 )));
216 }
217
218 let boundary = extract_boundary(content_type)
220 .ok_or_else(|| ApiError::bad_request("Missing boundary in Content-Type"))?;
221
222 let body = req
224 .take_body()
225 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
226
227 let fields = parse_multipart(&body, &boundary)?;
229
230 Ok(Multipart::new(fields))
231 }
232}
233
234fn extract_boundary(content_type: &str) -> Option<String> {
236 content_type.split(';').find_map(|part| {
237 let part = part.trim();
238 if part.starts_with("boundary=") {
239 let boundary = part.trim_start_matches("boundary=").trim_matches('"');
240 Some(boundary.to_string())
241 } else {
242 None
243 }
244 })
245}
246
247fn parse_multipart(body: &Bytes, boundary: &str) -> Result<Vec<MultipartField>> {
249 let mut fields = Vec::new();
250 let delimiter = format!("--{}", boundary);
251 let end_delimiter = format!("--{}--", boundary);
252
253 let body_str = String::from_utf8_lossy(body);
256
257 let parts: Vec<&str> = body_str.split(&delimiter).collect();
259
260 for part in parts.iter().skip(1) {
261 let part = part.trim_start_matches("\r\n").trim_start_matches('\n');
263 if part.is_empty() || part.starts_with("--") {
264 continue;
265 }
266
267 let header_body_split = if let Some(pos) = part.find("\r\n\r\n") {
269 pos
270 } else if let Some(pos) = part.find("\n\n") {
271 pos
272 } else {
273 continue;
274 };
275
276 let headers_section = &part[..header_body_split];
277 let body_section = &part[header_body_split..]
278 .trim_start_matches("\r\n\r\n")
279 .trim_start_matches("\n\n");
280
281 let body_section = body_section
283 .trim_end_matches(&end_delimiter)
284 .trim_end_matches(&delimiter)
285 .trim_end_matches("\r\n")
286 .trim_end_matches('\n');
287
288 let mut name = None;
290 let mut filename = None;
291 let mut content_type = None;
292
293 for header_line in headers_section.lines() {
294 let header_line = header_line.trim();
295 if header_line.is_empty() {
296 continue;
297 }
298
299 if let Some((key, value)) = header_line.split_once(':') {
300 let key = key.trim().to_lowercase();
301 let value = value.trim();
302
303 match key.as_str() {
304 "content-disposition" => {
305 for part in value.split(';') {
307 let part = part.trim();
308 if part.starts_with("name=") {
309 name = Some(
310 part.trim_start_matches("name=")
311 .trim_matches('"')
312 .to_string(),
313 );
314 } else if part.starts_with("filename=") {
315 filename = Some(
316 part.trim_start_matches("filename=")
317 .trim_matches('"')
318 .to_string(),
319 );
320 }
321 }
322 }
323 "content-type" => {
324 content_type = Some(value.to_string());
325 }
326 _ => {}
327 }
328 }
329 }
330
331 fields.push(MultipartField::new(
332 name,
333 filename,
334 content_type,
335 Bytes::copy_from_slice(body_section.as_bytes()),
336 ));
337 }
338
339 Ok(fields)
340}
341
342#[derive(Clone)]
344pub struct MultipartConfig {
345 pub max_size: usize,
347 pub max_fields: usize,
349 pub max_file_size: usize,
351 pub allowed_content_types: Vec<String>,
353}
354
355impl Default for MultipartConfig {
356 fn default() -> Self {
357 Self {
358 max_size: DEFAULT_MAX_FILE_SIZE,
359 max_fields: DEFAULT_MAX_FIELDS,
360 max_file_size: DEFAULT_MAX_FILE_SIZE,
361 allowed_content_types: Vec::new(),
362 }
363 }
364}
365
366impl MultipartConfig {
367 pub fn new() -> Self {
369 Self::default()
370 }
371
372 pub fn max_size(mut self, size: usize) -> Self {
374 self.max_size = size;
375 self
376 }
377
378 pub fn max_fields(mut self, count: usize) -> Self {
380 self.max_fields = count;
381 self
382 }
383
384 pub fn max_file_size(mut self, size: usize) -> Self {
386 self.max_file_size = size;
387 self
388 }
389
390 pub fn allowed_content_types(mut self, types: Vec<String>) -> Self {
392 self.allowed_content_types = types;
393 self
394 }
395
396 pub fn allow_content_type(mut self, content_type: impl Into<String>) -> Self {
398 self.allowed_content_types.push(content_type.into());
399 self
400 }
401}
402
403#[derive(Clone)]
405pub struct UploadedFile {
406 pub filename: String,
408 pub content_type: Option<String>,
410 pub data: Bytes,
412}
413
414impl UploadedFile {
415 pub fn from_field(field: &MultipartField) -> Option<Self> {
417 field.file_name().map(|filename| Self {
418 filename: filename.to_string(),
419 content_type: field.content_type().map(|s| s.to_string()),
420 data: field.data.clone(),
421 })
422 }
423
424 pub fn size(&self) -> usize {
426 self.data.len()
427 }
428
429 pub fn extension(&self) -> Option<&str> {
431 self.filename.rsplit('.').next()
432 }
433
434 pub async fn save_to(&self, dir: impl AsRef<Path>) -> Result<String> {
436 let dir = dir.as_ref();
437
438 tokio::fs::create_dir_all(dir)
439 .await
440 .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?;
441
442 let safe_filename = sanitize_filename(&self.filename);
443 let file_path = dir.join(&safe_filename);
444
445 tokio::fs::write(&file_path, &self.data)
446 .await
447 .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
448
449 Ok(file_path.to_string_lossy().to_string())
450 }
451
452 pub async fn save_as(&self, path: impl AsRef<Path>) -> Result<()> {
454 let path = path.as_ref();
455
456 if let Some(parent) = path.parent() {
457 tokio::fs::create_dir_all(parent)
458 .await
459 .map_err(|e| ApiError::internal(format!("Failed to create directory: {}", e)))?;
460 }
461
462 tokio::fs::write(path, &self.data)
463 .await
464 .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?;
465
466 Ok(())
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_extract_boundary() {
476 let ct = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW";
477 assert_eq!(
478 extract_boundary(ct),
479 Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string())
480 );
481
482 let ct_quoted = "multipart/form-data; boundary=\"----WebKitFormBoundary\"";
483 assert_eq!(
484 extract_boundary(ct_quoted),
485 Some("----WebKitFormBoundary".to_string())
486 );
487 }
488
489 #[test]
490 fn test_sanitize_filename() {
491 assert_eq!(sanitize_filename("test.txt"), "test.txt");
492 assert_eq!(sanitize_filename("../../../etc/passwd"), "______etc_passwd");
493 assert_eq!(
495 sanitize_filename("..\\..\\windows\\system32"),
496 "____windows_system32"
497 );
498 assert_eq!(sanitize_filename(".hidden"), "hidden");
499 }
500
501 #[test]
502 fn test_parse_simple_multipart() {
503 let boundary = "----WebKitFormBoundary";
504 let body = format!(
505 "------WebKitFormBoundary\r\n\
506 Content-Disposition: form-data; name=\"field1\"\r\n\
507 \r\n\
508 value1\r\n\
509 ------WebKitFormBoundary\r\n\
510 Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\
511 Content-Type: text/plain\r\n\
512 \r\n\
513 file content\r\n\
514 ------WebKitFormBoundary--\r\n"
515 );
516
517 let fields = parse_multipart(&Bytes::from(body), boundary).unwrap();
518 assert_eq!(fields.len(), 2);
519
520 assert_eq!(fields[0].name(), Some("field1"));
521 assert!(!fields[0].is_file());
522
523 assert_eq!(fields[1].name(), Some("file"));
524 assert_eq!(fields[1].file_name(), Some("test.txt"));
525 assert_eq!(fields[1].content_type(), Some("text/plain"));
526 assert!(fields[1].is_file());
527 }
528
529 #[test]
530 fn test_multipart_config() {
531 let config = MultipartConfig::new()
532 .max_size(20 * 1024 * 1024)
533 .max_fields(50)
534 .max_file_size(5 * 1024 * 1024)
535 .allow_content_type("image/png")
536 .allow_content_type("image/jpeg");
537
538 assert_eq!(config.max_size, 20 * 1024 * 1024);
539 assert_eq!(config.max_fields, 50);
540 assert_eq!(config.max_file_size, 5 * 1024 * 1024);
541 assert_eq!(config.allowed_content_types.len(), 2);
542 }
543}