1use std::path::{Component, Path, PathBuf};
7use std::sync::LazyLock;
8use std::{collections::HashSet, ffi::OsStr};
9use tracing::warn;
10
11pub async fn file_exists(path: &Path) -> bool {
14 if !is_safe_path(path) {
16 warn!("Unsafe path access attempt: {:?}", path);
17 return false;
18 }
19
20 tokio::fs::metadata(path).await.is_ok()
21}
22
23pub fn get_file_extension(path: &Path) -> Option<&str> {
25 path.extension().and_then(|ext| ext.to_str())
26}
27
28pub fn is_markdown_file(path: &Path) -> bool {
30 match get_file_extension(path) {
31 Some(ext) => matches!(ext.to_lowercase().as_str(), "md" | "markdown"),
32 None => false,
33 }
34}
35
36pub fn is_image_file(path: &Path) -> bool {
38 match get_file_extension(path) {
39 Some(ext) => matches!(
40 ext.to_lowercase().as_str(),
41 "jpg" | "jpeg" | "png" | "gif" | "webp" | "bmp"
42 ),
43 None => false,
44 }
45}
46
47pub fn validate_app_credentials(app_id: &str, app_secret: &str) -> Result<(), String> {
49 if app_id.is_empty() {
50 return Err("App ID cannot be empty".to_string());
51 }
52
53 if app_secret.is_empty() {
54 return Err("App secret cannot be empty".to_string());
55 }
56
57 if !app_id.starts_with("wx") || app_id.len() != 18 {
59 return Err(
60 "Invalid app ID format (should start with 'wx' and be 18 characters)".to_string(),
61 );
62 }
63
64 if app_secret.len() != 32 {
66 return Err("Invalid app secret format (should be 32 characters)".to_string());
67 }
68
69 Ok(())
70}
71
72static DANGEROUS_EXTENSIONS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
74 let mut set = HashSet::new();
75 set.insert("exe");
76 set.insert("bat");
77 set.insert("cmd");
78 set.insert("com");
79 set.insert("scr");
80 set.insert("pif");
81 set.insert("vbs");
82 set.insert("js");
83 set.insert("jse");
84 set.insert("wsf");
85 set.insert("wsh");
86 set.insert("msi");
87 set.insert("dll");
88 set.insert("scf");
89 set.insert("lnk");
90 set.insert("inf");
91 set.insert("reg");
92 set
93});
94
95pub fn is_safe_path(path: &Path) -> bool {
97 if let Some(path_str) = path.to_str()
99 && (path_str.contains("/tmp/")
100 || path_str.contains("/var/folders/")
101 || path_str.contains("\\Temp\\"))
102 {
103 if let Some(extension) = path.extension().and_then(OsStr::to_str)
105 && DANGEROUS_EXTENSIONS.contains(&extension.to_lowercase().as_str())
106 {
107 return false;
108 }
109 return true;
110 }
111 if let Some(extension) = path.extension().and_then(OsStr::to_str)
113 && DANGEROUS_EXTENSIONS.contains(&extension.to_lowercase().as_str())
114 {
115 return false;
116 }
117
118 for component in path.components() {
120 match component {
121 Component::ParentDir => {
122 continue;
124 }
125 Component::Normal(name) => {
126 let name_str = name.to_string_lossy();
127
128 if name_str.starts_with('.') && name_str.len() > 1 {
130 if !matches!(name_str.as_ref(), ".gitignore" | ".env" | ".dockerignore")
132 && !name_str.starts_with(".tmp")
133 {
134 return false;
135 }
136 }
137
138 if name_str.contains('\0') || name_str.contains('\x01') {
140 return false;
141 }
142
143 if is_reserved_name(&name_str) {
145 return false;
146 }
147 }
148 Component::RootDir | Component::CurDir => {
149 continue;
151 }
152 Component::Prefix(_) => {
153 continue;
155 }
156 }
157 }
158
159 true
160}
161
162fn is_reserved_name(name: &str) -> bool {
164 let upper_name = name.to_uppercase();
165 let base_name = upper_name.split('.').next().unwrap_or("");
166
167 matches!(
168 base_name,
169 "CON"
170 | "PRN"
171 | "AUX"
172 | "NUL"
173 | "COM1"
174 | "COM2"
175 | "COM3"
176 | "COM4"
177 | "COM5"
178 | "COM6"
179 | "COM7"
180 | "COM8"
181 | "COM9"
182 | "LPT1"
183 | "LPT2"
184 | "LPT3"
185 | "LPT4"
186 | "LPT5"
187 | "LPT6"
188 | "LPT7"
189 | "LPT8"
190 | "LPT9"
191 )
192}
193
194pub fn has_path_traversal(path: &str) -> bool {
196 path.contains("../")
198 || path.contains("..\\")
199 || path.contains("/..")
200 || path.contains("\\..")
201 || path.contains("....")
202 || path == ".."
203}
204
205pub fn sanitize_filename(filename: &str) -> String {
207 let mut sanitized = filename
208 .chars()
209 .filter(|&c| !matches!(c, '<' | '>' | ':' | '"' | '|' | '?' | '*' | '\0'..='\x1F'))
210 .collect::<String>();
211
212 sanitized = sanitized.replace(['/', '\\'], "_");
214
215 if sanitized.starts_with('.') && sanitized.len() > 1 {
217 sanitized = format!("_{}", &sanitized[1..]);
218 }
219
220 if sanitized.is_empty() {
222 sanitized = "unnamed".to_string();
223 }
224
225 if sanitized.len() > 255 {
227 sanitized.truncate(252);
228 sanitized.push_str("...");
229 }
230
231 sanitized
232}
233
234pub fn validate_file_size(size: u64, max_size: u64, file_type: &str) -> Result<(), String> {
236 if size > max_size {
237 return Err(format!(
238 "{file_type} file too large: {size} bytes (max: {max_size} bytes)"
239 ));
240 }
241 Ok(())
242}
243
244pub fn get_base_directory(file_path: &Path) -> Option<&Path> {
246 file_path.parent()
247}
248
249pub fn resolve_path(base_dir: &Path, relative_path: &str) -> Result<PathBuf, String> {
252 let relative = Path::new(relative_path);
253
254 if relative.is_absolute() {
256 if !is_safe_path(relative) {
257 return Err("Absolute path contains unsafe components".to_string());
258 }
259 return Ok(PathBuf::from(relative_path));
260 }
261
262 let resolved = base_dir.join(relative_path);
264
265 if !is_safe_path(&resolved) {
267 return Err("Resolved path contains unsafe components".to_string());
268 }
269
270 match resolved.canonicalize() {
272 Ok(canonical_resolved) => {
273 match base_dir.canonicalize() {
274 Ok(canonical_base) => {
275 if canonical_resolved.starts_with(&canonical_base) {
276 Ok(resolved)
277 } else {
278 Err("Path traversal attempt detected".to_string())
279 }
280 }
281 Err(_) => {
282 if has_path_traversal(relative_path) {
285 Err("Path contains traversal sequences".to_string())
286 } else {
287 Ok(resolved)
288 }
289 }
290 }
291 }
292 Err(_) => {
293 if has_path_traversal(relative_path) {
295 Err("Path contains traversal sequences".to_string())
296 } else {
297 Ok(resolved)
298 }
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use std::path::PathBuf;
307
308 #[test]
309 fn test_get_file_extension() {
310 assert_eq!(get_file_extension(Path::new("test.md")), Some("md"));
311 assert_eq!(
312 get_file_extension(Path::new("test.markdown")),
313 Some("markdown")
314 );
315 assert_eq!(get_file_extension(Path::new("test.jpg")), Some("jpg"));
316 assert_eq!(get_file_extension(Path::new("test")), None);
317 assert_eq!(get_file_extension(Path::new(".gitignore")), None);
318 }
319
320 #[test]
321 fn test_is_markdown_file() {
322 assert!(is_markdown_file(Path::new("test.md")));
323 assert!(is_markdown_file(Path::new("test.markdown")));
324 assert!(is_markdown_file(Path::new("TEST.MD")));
325 assert!(!is_markdown_file(Path::new("test.txt")));
326 assert!(!is_markdown_file(Path::new("test")));
327 }
328
329 #[test]
330 fn test_is_image_file() {
331 assert!(is_image_file(Path::new("test.jpg")));
332 assert!(is_image_file(Path::new("test.PNG")));
333 assert!(is_image_file(Path::new("test.gif")));
334 assert!(!is_image_file(Path::new("test.txt")));
335 assert!(!is_image_file(Path::new("test")));
336 }
337
338 #[test]
339 fn test_validate_app_credentials() {
340 assert!(
341 validate_app_credentials("wx1234567890123456", "12345678901234567890123456789012")
342 .is_ok()
343 );
344
345 assert!(validate_app_credentials("", "12345678901234567890123456789012").is_err());
346 assert!(validate_app_credentials("invalid", "12345678901234567890123456789012").is_err());
347 assert!(validate_app_credentials("wx123", "12345678901234567890123456789012").is_err());
348
349 assert!(validate_app_credentials("wx1234567890123456", "").is_err());
350 assert!(validate_app_credentials("wx1234567890123456", "short").is_err());
351 }
352
353 #[test]
354 fn test_get_base_directory() {
355 assert_eq!(
356 get_base_directory(Path::new("/path/to/file.md")),
357 Some(Path::new("/path/to"))
358 );
359 assert_eq!(
360 get_base_directory(Path::new("file.md")),
361 Some(Path::new(""))
362 );
363 assert_eq!(get_base_directory(Path::new("/")), None);
364 }
365
366 #[test]
367 fn test_resolve_path() {
368 let base = Path::new("/base/dir");
369
370 assert_eq!(
371 resolve_path(base, "relative.md").unwrap(),
372 PathBuf::from("/base/dir/relative.md")
373 );
374
375 assert_eq!(
376 resolve_path(base, "/absolute.md").unwrap(),
377 PathBuf::from("/absolute.md")
378 );
379
380 assert_eq!(
381 resolve_path(base, "./relative.md").unwrap(),
382 PathBuf::from("/base/dir/./relative.md")
383 );
384
385 assert!(resolve_path(base, "../../../etc/passwd").is_err());
386 assert!(resolve_path(base, "..\\..\\windows\\system32").is_err());
387
388 assert!(resolve_path(base, "malware.exe").is_err());
389 assert!(resolve_path(base, "script.bat").is_err());
390 }
391
392 #[test]
393 fn test_is_safe_path() {
394 assert!(is_safe_path(Path::new("document.md")));
395 assert!(is_safe_path(Path::new("image.jpg")));
396 assert!(is_safe_path(Path::new("folder/file.txt")));
397
398 assert!(!is_safe_path(Path::new("malware.exe")));
399 assert!(!is_safe_path(Path::new("script.bat")));
400 assert!(!is_safe_path(Path::new("virus.scr")));
401
402 assert!(!is_safe_path(Path::new("CON")));
403 assert!(!is_safe_path(Path::new("PRN.txt")));
404 assert!(!is_safe_path(Path::new("COM1.dat")));
405
406 assert!(!is_safe_path(Path::new(".hidden")));
407 assert!(is_safe_path(Path::new(".gitignore")));
408 assert!(is_safe_path(Path::new(".env")));
409 }
410
411 #[test]
412 fn test_has_path_traversal() {
413 assert!(has_path_traversal("../etc/passwd"));
414 assert!(has_path_traversal("..\\windows\\system32"));
415 assert!(has_path_traversal("folder/../../../etc"));
416 assert!(has_path_traversal(".."));
417 assert!(has_path_traversal("...."));
418
419 assert!(!has_path_traversal("normal/path/file.txt"));
420 assert!(!has_path_traversal("file.md"));
421 assert!(!has_path_traversal("folder/subfolder/file"));
422 }
423
424 #[test]
425 fn test_sanitize_filename() {
426 assert_eq!(sanitize_filename("normal_file.txt"), "normal_file.txt");
427
428 assert_eq!(sanitize_filename("file<>:\"|?*.txt"), "file.txt");
429
430 assert_eq!(sanitize_filename("path/to/file.txt"), "path_to_file.txt");
431 assert_eq!(sanitize_filename("path\\to\\file.txt"), "path_to_file.txt");
432
433 assert_eq!(sanitize_filename(".hidden"), "_hidden");
434 assert_eq!(sanitize_filename(""), "unnamed");
435
436 let long_name = "a".repeat(300);
438 let sanitized = sanitize_filename(&long_name);
439 assert!(sanitized.len() <= 255);
440 assert!(sanitized.ends_with("..."));
441 }
442
443 #[test]
444 fn test_validate_file_size() {
445 assert!(validate_file_size(1000, 2000, "test").is_ok());
447
448 let result = validate_file_size(3000, 2000, "image");
450 assert!(result.is_err());
451 assert!(result.unwrap_err().contains("image file too large"));
452 }
453}