whiteout/storage/
atomic.rs

1use anyhow::{Context, Result};
2use std::fs::{self, File, OpenOptions};
3use std::io::{Read, Write};
4use std::path::{Path, PathBuf};
5use std::time::Duration;
6
7#[cfg(unix)]
8use std::os::unix::fs::PermissionsExt;
9
10/// Atomic file operations to prevent TOCTOU race conditions
11pub struct AtomicFile {
12    path: PathBuf,
13    temp_path: PathBuf,
14}
15
16impl AtomicFile {
17    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
18        let path = path.as_ref().to_path_buf();
19        let temp_path = Self::temp_path(&path)?;
20        
21        Ok(Self { path, temp_path })
22    }
23    
24    /// Generate a temporary file path for atomic operations
25    fn temp_path(path: &Path) -> Result<PathBuf> {
26        let file_name = path
27            .file_name()
28            .ok_or_else(|| anyhow::anyhow!("Invalid file path"))?;
29        
30        let temp_name = format!(
31            ".{}.tmp.{}",
32            file_name.to_string_lossy(),
33            std::process::id()
34        );
35        
36        Ok(path.with_file_name(temp_name))
37    }
38    
39    /// Atomically write content to file
40    pub fn write(&self, content: &[u8]) -> Result<()> {
41        // Write to temporary file first
42        let mut temp_file = OpenOptions::new()
43            .write(true)
44            .create(true)
45            .truncate(true)
46            .open(&self.temp_path)
47            .context("Failed to create temporary file")?;
48        
49        temp_file
50            .write_all(content)
51            .context("Failed to write to temporary file")?;
52        
53        temp_file
54            .sync_all()
55            .context("Failed to sync temporary file")?;
56        
57        // Set permissions on Unix systems
58        #[cfg(unix)]
59        {
60            let metadata = fs::metadata(&self.temp_path)?;
61            let mut permissions = metadata.permissions();
62            permissions.set_mode(0o644); // Read/write for owner, read for others
63            fs::set_permissions(&self.temp_path, permissions)?;
64        }
65        
66        // Atomically rename temp file to target
67        fs::rename(&self.temp_path, &self.path)
68            .context("Failed to atomically rename file")?;
69        
70        Ok(())
71    }
72    
73    /// Atomically read file with retry logic
74    pub fn read(&self) -> Result<Vec<u8>> {
75        const MAX_RETRIES: u32 = 3;
76        const RETRY_DELAY: Duration = Duration::from_millis(10);
77        
78        for attempt in 0..MAX_RETRIES {
79            match self.try_read() {
80                Ok(content) => return Ok(content),
81                Err(e) if attempt < MAX_RETRIES - 1 => {
82                    // Check if it's a temporary failure
83                    if e.to_string().contains("temporarily unavailable") ||
84                       e.to_string().contains("locked") {
85                        std::thread::sleep(RETRY_DELAY);
86                        continue;
87                    }
88                    return Err(e);
89                }
90                Err(e) => return Err(e),
91            }
92        }
93        
94        anyhow::bail!("Failed to read file after {} attempts", MAX_RETRIES)
95    }
96    
97    fn try_read(&self) -> Result<Vec<u8>> {
98        let mut file = File::open(&self.path)
99            .with_context(|| format!("Failed to open file: {}", self.path.display()))?;
100        
101        let mut content = Vec::new();
102        file.read_to_end(&mut content)
103            .context("Failed to read file content")?;
104        
105        Ok(content)
106    }
107    
108    /// Check if file exists with proper validation
109    pub fn exists(&self) -> bool {
110        self.path.exists() && self.path.is_file()
111    }
112    
113    /// Securely delete file
114    pub fn delete(&self) -> Result<()> {
115        if self.temp_path.exists() {
116            fs::remove_file(&self.temp_path)
117                .context("Failed to remove temporary file")?;
118        }
119        
120        if self.path.exists() {
121            fs::remove_file(&self.path)
122                .context("Failed to remove file")?;
123        }
124        
125        Ok(())
126    }
127}
128
129/// File locking mechanism to prevent concurrent access
130#[cfg(unix)]
131pub mod lock {
132    use std::fs::File;
133    use std::os::unix::io::AsRawFd;
134    use anyhow::Result;
135    
136    pub struct FileLock {
137        file: File,
138    }
139    
140    impl FileLock {
141        pub fn acquire(file: File) -> Result<Self> {
142            use libc::{flock, LOCK_EX};
143            
144            let fd = file.as_raw_fd();
145            let result = unsafe { flock(fd, LOCK_EX) };
146            
147            if result != 0 {
148                anyhow::bail!("Failed to acquire file lock");
149            }
150            
151            Ok(Self { file })
152        }
153        
154        pub fn try_acquire(file: File) -> Result<Option<Self>> {
155            use libc::{flock, LOCK_EX, LOCK_NB};
156            
157            let fd = file.as_raw_fd();
158            let result = unsafe { flock(fd, LOCK_EX | LOCK_NB) };
159            
160            if result == 0 {
161                Ok(Some(Self { file }))
162            } else if std::io::Error::last_os_error().kind() == std::io::ErrorKind::WouldBlock {
163                Ok(None)
164            } else {
165                anyhow::bail!("Failed to try acquiring file lock");
166            }
167        }
168    }
169    
170    impl Drop for FileLock {
171        fn drop(&mut self) {
172            use libc::{flock, LOCK_UN};
173            
174            let fd = self.file.as_raw_fd();
175            unsafe { flock(fd, LOCK_UN) };
176        }
177    }
178}
179
180#[cfg(not(unix))]
181pub mod lock {
182    use std::fs::File;
183    use anyhow::Result;
184    
185    pub struct FileLock {
186        _file: File,
187    }
188    
189    impl FileLock {
190        pub fn acquire(file: File) -> Result<Self> {
191            // Windows file locking would go here
192            Ok(Self { _file: file })
193        }
194        
195        pub fn try_acquire(file: File) -> Result<Option<Self>> {
196            Ok(Some(Self { _file: file }))
197        }
198    }
199}
200
201/// Validate file path to prevent directory traversal
202pub fn validate_path<P: AsRef<Path>>(path: P, base_dir: &Path) -> Result<PathBuf> {
203    let path = path.as_ref();
204    
205    // Resolve to canonical path
206    let canonical = if path.exists() {
207        path.canonicalize()
208            .context("Failed to canonicalize path")?
209    } else {
210        // For non-existent files, canonicalize the parent and append filename
211        let parent = path.parent()
212            .ok_or_else(|| anyhow::anyhow!("Invalid path: no parent directory"))?;
213        
214        let parent_canonical = parent.canonicalize()
215            .context("Failed to canonicalize parent directory")?;
216        
217        let file_name = path.file_name()
218            .ok_or_else(|| anyhow::anyhow!("Invalid path: no file name"))?;
219        
220        parent_canonical.join(file_name)
221    };
222    
223    // Ensure path is within base directory
224    let base_canonical = base_dir.canonicalize()
225        .context("Failed to canonicalize base directory")?;
226    
227    if !canonical.starts_with(&base_canonical) {
228        anyhow::bail!(
229            "Path traversal detected: {} is outside of {}",
230            canonical.display(),
231            base_canonical.display()
232        );
233    }
234    
235    // Check for suspicious patterns
236    let path_str = canonical.to_string_lossy();
237    if path_str.contains("..") || path_str.contains("~") {
238        anyhow::bail!("Suspicious path pattern detected");
239    }
240    
241    Ok(canonical)
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use tempfile::TempDir;
248    
249    #[test]
250    fn test_atomic_write_read() -> Result<()> {
251        let temp_dir = TempDir::new()?;
252        let file_path = temp_dir.path().join("test.txt");
253        
254        let atomic_file = AtomicFile::new(&file_path)?;
255        let content = b"test content";
256        
257        atomic_file.write(content)?;
258        let read_content = atomic_file.read()?;
259        
260        assert_eq!(content, &read_content[..]);
261        Ok(())
262    }
263    
264    #[test]
265    fn test_path_validation() -> Result<()> {
266        let temp_dir = TempDir::new()?;
267        let base = temp_dir.path();
268        
269        // Valid path
270        let valid_path = base.join("subdir").join("file.txt");
271        std::fs::create_dir_all(valid_path.parent().unwrap())?;
272        let result = validate_path(&valid_path, base);
273        assert!(result.is_ok());
274        
275        // Invalid path (traversal attempt)
276        let invalid_path = base.join("..").join("outside.txt");
277        let result = validate_path(&invalid_path, base);
278        assert!(result.is_err());
279        
280        Ok(())
281    }
282    
283    #[test]
284    fn test_file_locking() -> Result<()> {
285        let temp_dir = TempDir::new()?;
286        let file_path = temp_dir.path().join("locked.txt");
287        
288        std::fs::write(&file_path, "test")?;
289        
290        let file1 = File::open(&file_path)?;
291        let lock1 = lock::FileLock::acquire(file1)?;
292        
293        // Try to acquire another lock (should fail or block)
294        let file2 = File::open(&file_path)?;
295        let lock2 = lock::FileLock::try_acquire(file2)?;
296        
297        assert!(lock2.is_none() || cfg!(not(unix)));
298        
299        drop(lock1);
300        Ok(())
301    }
302}