whiteout/storage/
atomic.rs1use anyhow::{Context, Result};
2use std::fs::{self, File, OpenOptions};
3use std::io::{Read, Write};
4use std::path::{Path, PathBuf};
5use std::time::Duration;
6
7#[cfg(unix)]
8use std::os::unix::fs::PermissionsExt;
9
10pub struct AtomicFile {
12 path: PathBuf,
13 temp_path: PathBuf,
14}
15
16impl AtomicFile {
17 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
18 let path = path.as_ref().to_path_buf();
19 let temp_path = Self::temp_path(&path)?;
20
21 Ok(Self { path, temp_path })
22 }
23
24 fn temp_path(path: &Path) -> Result<PathBuf> {
26 let file_name = path
27 .file_name()
28 .ok_or_else(|| anyhow::anyhow!("Invalid file path"))?;
29
30 let temp_name = format!(
31 ".{}.tmp.{}",
32 file_name.to_string_lossy(),
33 std::process::id()
34 );
35
36 Ok(path.with_file_name(temp_name))
37 }
38
39 pub fn write(&self, content: &[u8]) -> Result<()> {
41 let mut temp_file = OpenOptions::new()
43 .write(true)
44 .create(true)
45 .truncate(true)
46 .open(&self.temp_path)
47 .context("Failed to create temporary file")?;
48
49 temp_file
50 .write_all(content)
51 .context("Failed to write to temporary file")?;
52
53 temp_file
54 .sync_all()
55 .context("Failed to sync temporary file")?;
56
57 #[cfg(unix)]
59 {
60 let metadata = fs::metadata(&self.temp_path)?;
61 let mut permissions = metadata.permissions();
62 permissions.set_mode(0o644); fs::set_permissions(&self.temp_path, permissions)?;
64 }
65
66 fs::rename(&self.temp_path, &self.path)
68 .context("Failed to atomically rename file")?;
69
70 Ok(())
71 }
72
73 pub fn read(&self) -> Result<Vec<u8>> {
75 const MAX_RETRIES: u32 = 3;
76 const RETRY_DELAY: Duration = Duration::from_millis(10);
77
78 for attempt in 0..MAX_RETRIES {
79 match self.try_read() {
80 Ok(content) => return Ok(content),
81 Err(e) if attempt < MAX_RETRIES - 1 => {
82 if e.to_string().contains("temporarily unavailable") ||
84 e.to_string().contains("locked") {
85 std::thread::sleep(RETRY_DELAY);
86 continue;
87 }
88 return Err(e);
89 }
90 Err(e) => return Err(e),
91 }
92 }
93
94 anyhow::bail!("Failed to read file after {} attempts", MAX_RETRIES)
95 }
96
97 fn try_read(&self) -> Result<Vec<u8>> {
98 let mut file = File::open(&self.path)
99 .with_context(|| format!("Failed to open file: {}", self.path.display()))?;
100
101 let mut content = Vec::new();
102 file.read_to_end(&mut content)
103 .context("Failed to read file content")?;
104
105 Ok(content)
106 }
107
108 pub fn exists(&self) -> bool {
110 self.path.exists() && self.path.is_file()
111 }
112
113 pub fn delete(&self) -> Result<()> {
115 if self.temp_path.exists() {
116 fs::remove_file(&self.temp_path)
117 .context("Failed to remove temporary file")?;
118 }
119
120 if self.path.exists() {
121 fs::remove_file(&self.path)
122 .context("Failed to remove file")?;
123 }
124
125 Ok(())
126 }
127}
128
129#[cfg(unix)]
131pub mod lock {
132 use std::fs::File;
133 use std::os::unix::io::AsRawFd;
134 use anyhow::Result;
135
136 pub struct FileLock {
137 file: File,
138 }
139
140 impl FileLock {
141 pub fn acquire(file: File) -> Result<Self> {
142 use libc::{flock, LOCK_EX};
143
144 let fd = file.as_raw_fd();
145 let result = unsafe { flock(fd, LOCK_EX) };
146
147 if result != 0 {
148 anyhow::bail!("Failed to acquire file lock");
149 }
150
151 Ok(Self { file })
152 }
153
154 pub fn try_acquire(file: File) -> Result<Option<Self>> {
155 use libc::{flock, LOCK_EX, LOCK_NB};
156
157 let fd = file.as_raw_fd();
158 let result = unsafe { flock(fd, LOCK_EX | LOCK_NB) };
159
160 if result == 0 {
161 Ok(Some(Self { file }))
162 } else if std::io::Error::last_os_error().kind() == std::io::ErrorKind::WouldBlock {
163 Ok(None)
164 } else {
165 anyhow::bail!("Failed to try acquiring file lock");
166 }
167 }
168 }
169
170 impl Drop for FileLock {
171 fn drop(&mut self) {
172 use libc::{flock, LOCK_UN};
173
174 let fd = self.file.as_raw_fd();
175 unsafe { flock(fd, LOCK_UN) };
176 }
177 }
178}
179
180#[cfg(not(unix))]
181pub mod lock {
182 use std::fs::File;
183 use anyhow::Result;
184
185 pub struct FileLock {
186 _file: File,
187 }
188
189 impl FileLock {
190 pub fn acquire(file: File) -> Result<Self> {
191 Ok(Self { _file: file })
193 }
194
195 pub fn try_acquire(file: File) -> Result<Option<Self>> {
196 Ok(Some(Self { _file: file }))
197 }
198 }
199}
200
201pub fn validate_path<P: AsRef<Path>>(path: P, base_dir: &Path) -> Result<PathBuf> {
203 let path = path.as_ref();
204
205 let canonical = if path.exists() {
207 path.canonicalize()
208 .context("Failed to canonicalize path")?
209 } else {
210 let parent = path.parent()
212 .ok_or_else(|| anyhow::anyhow!("Invalid path: no parent directory"))?;
213
214 let parent_canonical = parent.canonicalize()
215 .context("Failed to canonicalize parent directory")?;
216
217 let file_name = path.file_name()
218 .ok_or_else(|| anyhow::anyhow!("Invalid path: no file name"))?;
219
220 parent_canonical.join(file_name)
221 };
222
223 let base_canonical = base_dir.canonicalize()
225 .context("Failed to canonicalize base directory")?;
226
227 if !canonical.starts_with(&base_canonical) {
228 anyhow::bail!(
229 "Path traversal detected: {} is outside of {}",
230 canonical.display(),
231 base_canonical.display()
232 );
233 }
234
235 let path_str = canonical.to_string_lossy();
237 if path_str.contains("..") || path_str.contains("~") {
238 anyhow::bail!("Suspicious path pattern detected");
239 }
240
241 Ok(canonical)
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use tempfile::TempDir;
248
249 #[test]
250 fn test_atomic_write_read() -> Result<()> {
251 let temp_dir = TempDir::new()?;
252 let file_path = temp_dir.path().join("test.txt");
253
254 let atomic_file = AtomicFile::new(&file_path)?;
255 let content = b"test content";
256
257 atomic_file.write(content)?;
258 let read_content = atomic_file.read()?;
259
260 assert_eq!(content, &read_content[..]);
261 Ok(())
262 }
263
264 #[test]
265 fn test_path_validation() -> Result<()> {
266 let temp_dir = TempDir::new()?;
267 let base = temp_dir.path();
268
269 let valid_path = base.join("subdir").join("file.txt");
271 std::fs::create_dir_all(valid_path.parent().unwrap())?;
272 let result = validate_path(&valid_path, base);
273 assert!(result.is_ok());
274
275 let invalid_path = base.join("..").join("outside.txt");
277 let result = validate_path(&invalid_path, base);
278 assert!(result.is_err());
279
280 Ok(())
281 }
282
283 #[test]
284 fn test_file_locking() -> Result<()> {
285 let temp_dir = TempDir::new()?;
286 let file_path = temp_dir.path().join("locked.txt");
287
288 std::fs::write(&file_path, "test")?;
289
290 let file1 = File::open(&file_path)?;
291 let lock1 = lock::FileLock::acquire(file1)?;
292
293 let file2 = File::open(&file_path)?;
295 let lock2 = lock::FileLock::try_acquire(file2)?;
296
297 assert!(lock2.is_none() || cfg!(not(unix)));
298
299 drop(lock1);
300 Ok(())
301 }
302}