wechat_pub_rs/
utils.rs

1//! Utility functions and helpers.
2//!
3//! This module provides security-focused utilities with input validation
4//! and safe path handling to prevent common vulnerabilities.
5
6use std::path::{Component, Path, PathBuf};
7use std::sync::LazyLock;
8use std::{collections::HashSet, ffi::OsStr};
9use tracing::warn;
10
11/// Checks if a file exists and is readable with path validation.
12/// Returns false for invalid or potentially dangerous paths.
13pub async fn file_exists(path: &Path) -> bool {
14    // Validate path for security
15    if !is_safe_path(path) {
16        warn!("Unsafe path access attempt: {:?}", path);
17        return false;
18    }
19
20    tokio::fs::metadata(path).await.is_ok()
21}
22
23/// Gets the file extension from a path.
24pub fn get_file_extension(path: &Path) -> Option<&str> {
25    path.extension().and_then(|ext| ext.to_str())
26}
27
28/// Validates that a path points to a markdown file.
29pub fn is_markdown_file(path: &Path) -> bool {
30    match get_file_extension(path) {
31        Some(ext) => matches!(ext.to_lowercase().as_str(), "md" | "markdown"),
32        None => false,
33    }
34}
35
36/// Validates that a path points to an image file.
37pub fn is_image_file(path: &Path) -> bool {
38    match get_file_extension(path) {
39        Some(ext) => matches!(
40            ext.to_lowercase().as_str(),
41            "jpg" | "jpeg" | "png" | "gif" | "webp" | "bmp"
42        ),
43        None => false,
44    }
45}
46
47/// Validates WeChat app credentials format.
48pub fn validate_app_credentials(app_id: &str, app_secret: &str) -> Result<(), String> {
49    if app_id.is_empty() {
50        return Err("App ID cannot be empty".to_string());
51    }
52
53    if app_secret.is_empty() {
54        return Err("App secret cannot be empty".to_string());
55    }
56
57    // WeChat app IDs typically start with "wx" and are 18 characters long
58    if !app_id.starts_with("wx") || app_id.len() != 18 {
59        return Err(
60            "Invalid app ID format (should start with 'wx' and be 18 characters)".to_string(),
61        );
62    }
63
64    // WeChat app secrets are typically 32 characters long
65    if app_secret.len() != 32 {
66        return Err("Invalid app secret format (should be 32 characters)".to_string());
67    }
68
69    Ok(())
70}
71
72/// Set of dangerous file extensions that should be blocked.
73static DANGEROUS_EXTENSIONS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
74    let mut set = HashSet::new();
75    set.insert("exe");
76    set.insert("bat");
77    set.insert("cmd");
78    set.insert("com");
79    set.insert("scr");
80    set.insert("pif");
81    set.insert("vbs");
82    set.insert("js");
83    set.insert("jse");
84    set.insert("wsf");
85    set.insert("wsh");
86    set.insert("msi");
87    set.insert("dll");
88    set.insert("scf");
89    set.insert("lnk");
90    set.insert("inf");
91    set.insert("reg");
92    set
93});
94
95/// Validates that a path is safe to access (prevents path traversal and dangerous files).
96pub fn is_safe_path(path: &Path) -> bool {
97    // Allow files in system temp directories
98    if let Some(path_str) = path.to_str()
99        && (path_str.contains("/tmp/")
100            || path_str.contains("/var/folders/")
101            || path_str.contains("\\Temp\\"))
102    {
103        // Still check for dangerous extensions in temp files
104        if let Some(extension) = path.extension().and_then(OsStr::to_str)
105            && DANGEROUS_EXTENSIONS.contains(&extension.to_lowercase().as_str())
106        {
107            return false;
108        }
109        return true;
110    }
111    // Check for dangerous file extensions
112    if let Some(extension) = path.extension().and_then(OsStr::to_str)
113        && DANGEROUS_EXTENSIONS.contains(&extension.to_lowercase().as_str())
114    {
115        return false;
116    }
117
118    // Check each component of the path
119    for component in path.components() {
120        match component {
121            Component::ParentDir => {
122                // Allow parent dir components, but they will be validated during resolution
123                continue;
124            }
125            Component::Normal(name) => {
126                let name_str = name.to_string_lossy();
127
128                // Check for hidden files (starting with .)
129                if name_str.starts_with('.') && name_str.len() > 1 {
130                    // Allow common hidden files and temp file patterns
131                    if !matches!(name_str.as_ref(), ".gitignore" | ".env" | ".dockerignore")
132                        && !name_str.starts_with(".tmp")
133                    {
134                        return false;
135                    }
136                }
137
138                // Check for null bytes and other dangerous characters
139                if name_str.contains('\0') || name_str.contains('\x01') {
140                    return false;
141                }
142
143                // Check for reserved names on Windows
144                if is_reserved_name(&name_str) {
145                    return false;
146                }
147            }
148            Component::RootDir | Component::CurDir => {
149                // These are generally safe
150                continue;
151            }
152            Component::Prefix(_) => {
153                // Windows drive prefixes are generally safe
154                continue;
155            }
156        }
157    }
158
159    true
160}
161
162/// Checks if a filename is a Windows reserved name.
163fn is_reserved_name(name: &str) -> bool {
164    let upper_name = name.to_uppercase();
165    let base_name = upper_name.split('.').next().unwrap_or("");
166
167    matches!(
168        base_name,
169        "CON"
170            | "PRN"
171            | "AUX"
172            | "NUL"
173            | "COM1"
174            | "COM2"
175            | "COM3"
176            | "COM4"
177            | "COM5"
178            | "COM6"
179            | "COM7"
180            | "COM8"
181            | "COM9"
182            | "LPT1"
183            | "LPT2"
184            | "LPT3"
185            | "LPT4"
186            | "LPT5"
187            | "LPT6"
188            | "LPT7"
189            | "LPT8"
190            | "LPT9"
191    )
192}
193
194/// Checks if a path contains potential traversal sequences.
195pub fn has_path_traversal(path: &str) -> bool {
196    // Check for common traversal patterns
197    path.contains("../")
198        || path.contains("..\\")
199        || path.contains("/..")
200        || path.contains("\\..")
201        || path.contains("....")
202        || path == ".."
203}
204
205/// Sanitizes a filename by removing or replacing dangerous characters.
206pub fn sanitize_filename(filename: &str) -> String {
207    let mut sanitized = filename
208        .chars()
209        .filter(|&c| !matches!(c, '<' | '>' | ':' | '"' | '|' | '?' | '*' | '\0'..='\x1F'))
210        .collect::<String>();
211
212    // Replace path separators with underscores
213    sanitized = sanitized.replace(['/', '\\'], "_");
214
215    // Ensure it doesn't start with a dot (hidden file)
216    if sanitized.starts_with('.') && sanitized.len() > 1 {
217        sanitized = format!("_{}", &sanitized[1..]);
218    }
219
220    // Ensure it's not empty
221    if sanitized.is_empty() {
222        sanitized = "unnamed".to_string();
223    }
224
225    // Truncate if too long
226    if sanitized.len() > 255 {
227        sanitized.truncate(252);
228        sanitized.push_str("...");
229    }
230
231    sanitized
232}
233
234/// Validates file size limits to prevent DoS attacks.
235pub fn validate_file_size(size: u64, max_size: u64, file_type: &str) -> Result<(), String> {
236    if size > max_size {
237        return Err(format!(
238            "{file_type} file too large: {size} bytes (max: {max_size} bytes)"
239        ));
240    }
241    Ok(())
242}
243
244/// Extracts the base directory from a file path.
245pub fn get_base_directory(file_path: &Path) -> Option<&Path> {
246    file_path.parent()
247}
248
249/// Resolves relative paths against a base directory with security validation.
250/// Prevents path traversal attacks by validating the resolved path.
251pub fn resolve_path(base_dir: &Path, relative_path: &str) -> Result<PathBuf, String> {
252    let relative = Path::new(relative_path);
253
254    // Check for absolute paths
255    if relative.is_absolute() {
256        if !is_safe_path(relative) {
257            return Err("Absolute path contains unsafe components".to_string());
258        }
259        return Ok(PathBuf::from(relative_path));
260    }
261
262    // Resolve relative path
263    let resolved = base_dir.join(relative_path);
264
265    // Validate the resolved path
266    if !is_safe_path(&resolved) {
267        return Err("Resolved path contains unsafe components".to_string());
268    }
269
270    // Ensure the resolved path is still under the base directory
271    match resolved.canonicalize() {
272        Ok(canonical_resolved) => {
273            match base_dir.canonicalize() {
274                Ok(canonical_base) => {
275                    if canonical_resolved.starts_with(&canonical_base) {
276                        Ok(resolved)
277                    } else {
278                        Err("Path traversal attempt detected".to_string())
279                    }
280                }
281                Err(_) => {
282                    // Base directory doesn't exist or can't be canonicalized
283                    // Fall back to basic validation
284                    if has_path_traversal(relative_path) {
285                        Err("Path contains traversal sequences".to_string())
286                    } else {
287                        Ok(resolved)
288                    }
289                }
290            }
291        }
292        Err(_) => {
293            // File doesn't exist yet, validate the path structure
294            if has_path_traversal(relative_path) {
295                Err("Path contains traversal sequences".to_string())
296            } else {
297                Ok(resolved)
298            }
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use std::path::PathBuf;
307
308    #[test]
309    fn test_get_file_extension() {
310        assert_eq!(get_file_extension(Path::new("test.md")), Some("md"));
311        assert_eq!(
312            get_file_extension(Path::new("test.markdown")),
313            Some("markdown")
314        );
315        assert_eq!(get_file_extension(Path::new("test.jpg")), Some("jpg"));
316        assert_eq!(get_file_extension(Path::new("test")), None);
317        assert_eq!(get_file_extension(Path::new(".gitignore")), None);
318    }
319
320    #[test]
321    fn test_is_markdown_file() {
322        assert!(is_markdown_file(Path::new("test.md")));
323        assert!(is_markdown_file(Path::new("test.markdown")));
324        assert!(is_markdown_file(Path::new("TEST.MD")));
325        assert!(!is_markdown_file(Path::new("test.txt")));
326        assert!(!is_markdown_file(Path::new("test")));
327    }
328
329    #[test]
330    fn test_is_image_file() {
331        assert!(is_image_file(Path::new("test.jpg")));
332        assert!(is_image_file(Path::new("test.PNG")));
333        assert!(is_image_file(Path::new("test.gif")));
334        assert!(!is_image_file(Path::new("test.txt")));
335        assert!(!is_image_file(Path::new("test")));
336    }
337
338    #[test]
339    fn test_validate_app_credentials() {
340        assert!(
341            validate_app_credentials("wx1234567890123456", "12345678901234567890123456789012")
342                .is_ok()
343        );
344
345        assert!(validate_app_credentials("", "12345678901234567890123456789012").is_err());
346        assert!(validate_app_credentials("invalid", "12345678901234567890123456789012").is_err());
347        assert!(validate_app_credentials("wx123", "12345678901234567890123456789012").is_err());
348
349        assert!(validate_app_credentials("wx1234567890123456", "").is_err());
350        assert!(validate_app_credentials("wx1234567890123456", "short").is_err());
351    }
352
353    #[test]
354    fn test_get_base_directory() {
355        assert_eq!(
356            get_base_directory(Path::new("/path/to/file.md")),
357            Some(Path::new("/path/to"))
358        );
359        assert_eq!(
360            get_base_directory(Path::new("file.md")),
361            Some(Path::new(""))
362        );
363        assert_eq!(get_base_directory(Path::new("/")), None);
364    }
365
366    #[test]
367    fn test_resolve_path() {
368        let base = Path::new("/base/dir");
369
370        assert_eq!(
371            resolve_path(base, "relative.md").unwrap(),
372            PathBuf::from("/base/dir/relative.md")
373        );
374
375        assert_eq!(
376            resolve_path(base, "/absolute.md").unwrap(),
377            PathBuf::from("/absolute.md")
378        );
379
380        assert_eq!(
381            resolve_path(base, "./relative.md").unwrap(),
382            PathBuf::from("/base/dir/./relative.md")
383        );
384
385        assert!(resolve_path(base, "../../../etc/passwd").is_err());
386        assert!(resolve_path(base, "..\\..\\windows\\system32").is_err());
387
388        assert!(resolve_path(base, "malware.exe").is_err());
389        assert!(resolve_path(base, "script.bat").is_err());
390    }
391
392    #[test]
393    fn test_is_safe_path() {
394        assert!(is_safe_path(Path::new("document.md")));
395        assert!(is_safe_path(Path::new("image.jpg")));
396        assert!(is_safe_path(Path::new("folder/file.txt")));
397
398        assert!(!is_safe_path(Path::new("malware.exe")));
399        assert!(!is_safe_path(Path::new("script.bat")));
400        assert!(!is_safe_path(Path::new("virus.scr")));
401
402        assert!(!is_safe_path(Path::new("CON")));
403        assert!(!is_safe_path(Path::new("PRN.txt")));
404        assert!(!is_safe_path(Path::new("COM1.dat")));
405
406        assert!(!is_safe_path(Path::new(".hidden")));
407        assert!(is_safe_path(Path::new(".gitignore")));
408        assert!(is_safe_path(Path::new(".env")));
409    }
410
411    #[test]
412    fn test_has_path_traversal() {
413        assert!(has_path_traversal("../etc/passwd"));
414        assert!(has_path_traversal("..\\windows\\system32"));
415        assert!(has_path_traversal("folder/../../../etc"));
416        assert!(has_path_traversal(".."));
417        assert!(has_path_traversal("...."));
418
419        assert!(!has_path_traversal("normal/path/file.txt"));
420        assert!(!has_path_traversal("file.md"));
421        assert!(!has_path_traversal("folder/subfolder/file"));
422    }
423
424    #[test]
425    fn test_sanitize_filename() {
426        assert_eq!(sanitize_filename("normal_file.txt"), "normal_file.txt");
427
428        assert_eq!(sanitize_filename("file<>:\"|?*.txt"), "file.txt");
429
430        assert_eq!(sanitize_filename("path/to/file.txt"), "path_to_file.txt");
431        assert_eq!(sanitize_filename("path\\to\\file.txt"), "path_to_file.txt");
432
433        assert_eq!(sanitize_filename(".hidden"), "_hidden");
434        assert_eq!(sanitize_filename(""), "unnamed");
435
436        // Test very long filename
437        let long_name = "a".repeat(300);
438        let sanitized = sanitize_filename(&long_name);
439        assert!(sanitized.len() <= 255);
440        assert!(sanitized.ends_with("..."));
441    }
442
443    #[test]
444    fn test_validate_file_size() {
445        // Test valid size
446        assert!(validate_file_size(1000, 2000, "test").is_ok());
447
448        // Test oversized file
449        let result = validate_file_size(3000, 2000, "image");
450        assert!(result.is_err());
451        assert!(result.unwrap_err().contains("image file too large"));
452    }
453}