1use anyhow::{Result, anyhow};
7use std::path::{Path, PathBuf};
8
9fn expand_path(path: &str) -> Result<String> {
18 let trimmed = path.trim();
19 if trimmed == "~" {
20 return home_dir().map(|h| h.to_string_lossy().to_string());
21 }
22 if let Some(rest) = trimmed.strip_prefix("~/") {
23 let home = home_dir()?;
24 return Ok(format!("{}/{}", home.display(), rest));
25 }
26 Ok(trimmed.to_string())
27}
28
29fn canonicalize_existing(path: &Path) -> Result<PathBuf> {
31 path.canonicalize()
32 .map_err(|e| anyhow!("Cannot canonicalize path '{}': {}", path.display(), e))
33}
34
35fn contains_traversal(path: &str) -> bool {
37 let path_lower = path.to_lowercase();
38 path_lower.contains("..")
39 || path_lower.contains("./")
40 || path.contains('\0')
41 || path.contains('\n')
42 || path.contains('\r')
43}
44
45fn home_dir() -> Result<PathBuf> {
47 std::env::var("HOME")
48 .map(PathBuf::from)
49 .map_err(|_| anyhow!("Cannot determine home directory from $HOME"))
50}
51
52fn is_under_allowed_base(path: &Path) -> Result<bool> {
54 let home = home_dir()?;
55
56 if path.starts_with(&home) {
58 return Ok(true);
59 }
60
61 #[cfg(target_os = "macos")]
63 if path.starts_with("/Users") {
64 let components: Vec<_> = path.components().collect();
66 if components.len() >= 3 {
67 return Ok(true);
69 }
70 }
71
72 if path.starts_with("/tmp")
75 || path.starts_with("/var/folders")
76 || path.starts_with("/private/tmp")
77 || path.starts_with("/private/var/folders")
78 {
79 return Ok(true);
80 }
81
82 Ok(false)
83}
84
85pub fn sanitize_existing_path(path: &str) -> Result<PathBuf> {
95 if contains_traversal(path) {
97 return Err(anyhow!(
98 "Path contains invalid traversal sequence: {}",
99 path
100 ));
101 }
102
103 let expanded = expand_path(path)?;
104
105 if contains_traversal(&expanded) {
107 return Err(anyhow!(
108 "Expanded path contains invalid sequence: {}",
109 expanded
110 ));
111 }
112
113 let path_buf = PathBuf::from(&expanded);
119
120 let canonical = canonicalize_existing(&path_buf)?;
122
123 if !is_under_allowed_base(&canonical)? {
125 return Err(anyhow!(
126 "Path '{}' is not under an allowed directory",
127 canonical.display()
128 ));
129 }
130
131 Ok(canonical)
132}
133
134pub fn sanitize_new_path(path: &str) -> Result<PathBuf> {
139 if contains_traversal(path) {
141 return Err(anyhow!(
142 "Path contains invalid traversal sequence: {}",
143 path
144 ));
145 }
146
147 let expanded = expand_path(path)?;
148
149 if contains_traversal(&expanded) {
151 return Err(anyhow!(
152 "Expanded path contains invalid sequence: {}",
153 expanded
154 ));
155 }
156
157 let path_buf = PathBuf::from(&expanded);
163
164 if let Some(parent) = path_buf.parent() {
166 if parent.exists() {
167 let canonical_parent = canonicalize_existing(parent)?;
168 if !is_under_allowed_base(&canonical_parent)? {
169 return Err(anyhow!(
170 "Parent directory '{}' is not under an allowed directory",
171 canonical_parent.display()
172 ));
173 }
174 } else if let Some(grandparent) = parent.parent()
175 && grandparent.exists()
176 {
177 let canonical_gp = canonicalize_existing(grandparent)?;
179 if !is_under_allowed_base(&canonical_gp)? {
180 return Err(anyhow!(
181 "Path '{}' would be created outside allowed directories",
182 path_buf.display()
183 ));
184 }
185 }
186 }
187
188 Ok(path_buf)
189}
190
191pub fn validate_read_path(path: &Path) -> Result<PathBuf> {
193 if !path.exists() {
194 return Err(anyhow!("Path does not exist: {}", path.display()));
195 }
196
197 let canonical = canonicalize_existing(path)?;
198
199 if !is_under_allowed_base(&canonical)? {
200 return Err(anyhow!(
201 "Cannot read from path outside allowed directories: {}",
202 canonical.display()
203 ));
204 }
205
206 Ok(canonical)
207}
208
209pub fn validate_write_path(path: &Path) -> Result<PathBuf> {
211 let path_str = path.to_string_lossy();
213 if contains_traversal(&path_str) {
214 return Err(anyhow!("Path contains invalid traversal sequence"));
215 }
216
217 if path.exists() {
218 let canonical = canonicalize_existing(path)?;
220 if !is_under_allowed_base(&canonical)? {
221 return Err(anyhow!(
222 "Cannot write to path outside allowed directories: {}",
223 canonical.display()
224 ));
225 }
226 Ok(canonical)
227 } else {
228 sanitize_new_path(&path_str)
230 }
231}
232
233pub fn safe_read_to_string(path: &str) -> Result<(PathBuf, String)> {
244 let validated = sanitize_existing_path(path)?;
245 let contents = std::fs::read_to_string(&validated)
249 .map_err(|e| anyhow!("Failed to read '{}': {}", validated.display(), e))?;
250 Ok((validated, contents))
251}
252
253pub async fn safe_read_to_string_async(path: &Path) -> Result<(PathBuf, String)> {
255 let validated = validate_read_path(path)?;
256 let contents = tokio::fs::read_to_string(&validated)
260 .await
261 .map_err(|e| anyhow!("Failed to read '{}': {}", validated.display(), e))?;
262 Ok((validated, contents))
263}
264
265pub async fn safe_open_file_async(path: &Path) -> Result<(PathBuf, tokio::fs::File)> {
267 let validated = validate_read_path(path)?;
268 let file = tokio::fs::File::open(&validated)
272 .await
273 .map_err(|e| anyhow!("Failed to open '{}': {}", validated.display(), e))?;
274 Ok((validated, file))
275}
276
277pub async fn safe_read_dir(path: &Path) -> Result<(PathBuf, tokio::fs::ReadDir)> {
279 let validated = validate_read_path(path)?;
280 let entries = tokio::fs::read_dir(&validated)
284 .await
285 .map_err(|e| anyhow!("Failed to read directory '{}': {}", validated.display(), e))?;
286 Ok((validated, entries))
287}
288
289pub fn safe_copy(src: &Path, dst: &Path) -> Result<PathBuf> {
291 let safe_src = validate_read_path(src)?;
292 let safe_dst = validate_write_path(dst)?;
293 std::fs::copy(&safe_src, &safe_dst).map_err(|e| {
296 anyhow!(
297 "Failed to copy '{}' → '{}': {}",
298 safe_src.display(),
299 safe_dst.display(),
300 e
301 )
302 })?;
303 Ok(safe_dst)
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use std::fs;
310 use tempfile::tempdir;
311
312 #[test]
313 fn test_traversal_detection() {
314 assert!(contains_traversal("../etc/passwd"));
315 assert!(contains_traversal("foo/../bar"));
316 assert!(contains_traversal("./hidden"));
317 assert!(contains_traversal("path\0with\0nulls"));
318 assert!(!contains_traversal("/normal/path"));
319 assert!(!contains_traversal("~/Documents"));
320 }
321
322 #[test]
323 fn test_sanitize_existing_path() {
324 let tmp = tempdir().unwrap();
326 let test_file = tmp.path().join("test.txt");
327 fs::write(&test_file, "test").unwrap();
328
329 let result = sanitize_existing_path(test_file.to_str().unwrap());
331 assert!(
332 result.is_ok(),
333 "Failed for path: {:?}, error: {:?}",
334 test_file,
335 result
336 );
337
338 let traversal = format!("{}/../../../etc/passwd", tmp.path().display());
340 let result = sanitize_existing_path(&traversal);
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn test_validate_read_path() {
346 let tmp = tempdir().unwrap();
347 let test_file = tmp.path().join("readable.txt");
348 fs::write(&test_file, "content").unwrap();
349
350 let result = validate_read_path(&test_file);
351 assert!(result.is_ok());
352
353 let missing = tmp.path().join("missing.txt");
355 let result = validate_read_path(&missing);
356 assert!(result.is_err());
357 }
358
359 #[test]
360 fn test_validate_write_path() {
361 let tmp = tempdir().unwrap();
362
363 let new_file = tmp.path().join("new.txt");
365 let result = validate_write_path(&new_file);
366 assert!(result.is_ok());
367
368 let existing = tmp.path().join("existing.txt");
370 fs::write(&existing, "data").unwrap();
371 let result = validate_write_path(&existing);
372 assert!(result.is_ok());
373 }
374}