Skip to main content

xet_runtime/file_utils/
safe_file_creator.rs

1use std::fs::{self, File, Metadata};
2use std::io::{self, BufWriter, Seek, SeekFrom, Write};
3use std::path::{Path, PathBuf};
4
5use rand::distr::Alphanumeric;
6use rand::{Rng, rng};
7
8use super::create_file;
9use super::file_metadata::set_file_metadata;
10
11pub struct SafeFileCreator {
12    dest_path: Option<PathBuf>,
13    temp_path: PathBuf,
14    original_metadata: Option<Metadata>,
15    writer: Option<BufWriter<File>>,
16}
17
18impl SafeFileCreator {
19    /// Safely creates a new file at a specific location.  Ensures the file is not created with elevated privileges,
20    /// and a temporary file is created then renamed on close.
21    pub fn new<P: AsRef<Path>>(dest_path: P) -> io::Result<Self> {
22        let dest_path = dest_path.as_ref().to_path_buf();
23
24        let parent = dest_path
25            .parent()
26            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path doesn't have a valid parent directory"))?;
27        let file_name = parent
28            .file_name()
29            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path doesn't have a valid file name"))?
30            .to_str();
31
32        let temp_path = Self::temp_file_path(parent, file_name);
33
34        // This matches the permissions and ownership of the parent directory
35        let file = create_file(&temp_path)?;
36        let writer = BufWriter::new(file);
37
38        Ok(SafeFileCreator {
39            dest_path: Some(dest_path),
40            temp_path,
41            original_metadata: None,
42            writer: Some(writer),
43        })
44    }
45
46    /// Safely creates a new file while a destination name can't be decided now. Users need to call
47    /// ```ignore
48    /// pub fn set_dest_path<P: AsRef<Path>>(dest_path: P)
49    /// ```
50    /// to set the destination before closing the file.
51    pub fn new_unnamed(temp_root: impl AsRef<Path>) -> io::Result<Self> {
52        let temp_path = Self::temp_file_path(temp_root, None);
53
54        // This matches the permissions and ownership of the parent directory
55        let file = create_file(&temp_path)?;
56        let writer = BufWriter::new(file);
57
58        Ok(SafeFileCreator {
59            dest_path: None,
60            temp_path,
61            original_metadata: None,
62            writer: Some(writer),
63        })
64    }
65
66    /// Safely replaces a new file at a specific location.  Ensures the file is not created with elevated privileges,
67    /// and additionally the metadata of the old one will match the new metadata.
68    pub fn replace_existing<P: AsRef<Path>>(dest_path: P) -> io::Result<Self> {
69        let mut s = Self::new(&dest_path)?;
70        s.original_metadata = fs::metadata(dest_path).ok();
71        Ok(s)
72    }
73
74    /// Generates a temporary file path in the same directory as the destination file
75    fn temp_file_path(dest_dir: impl AsRef<Path>, file: Option<&str>) -> PathBuf {
76        let mut rng = rng();
77        let random_hash: String = (0..10).map(|_| rng.sample(Alphanumeric)).map(char::from).collect();
78        let temp_file_name = if let Some(filename) = file {
79            format!(".{filename}.{random_hash}.tmp")
80        } else {
81            format!(".{random_hash}.tmp")
82        };
83        dest_dir.as_ref().join(temp_file_name)
84    }
85
86    pub fn set_dest_path<P: AsRef<Path>>(&mut self, dest_path: P) {
87        let dest_path = dest_path.as_ref().to_path_buf();
88        self.dest_path = Some(dest_path);
89    }
90
91    // abort the writing process and delete the temporary file
92    pub fn abort(&mut self) -> io::Result<()> {
93        if self.writer.is_none() {
94            return Ok(());
95        }
96        self.writer = None;
97        if self.temp_path.exists() {
98            fs::remove_file(&self.temp_path)?;
99        }
100        Ok(())
101    }
102
103    /// Closes the writer and replaces the original file with the temporary file
104    pub fn close(&mut self) -> io::Result<()> {
105        let Some(dest_path) = &self.dest_path else {
106            return Err(io::Error::new(io::ErrorKind::InvalidInput, "destination file name not set"));
107        };
108
109        let Some(mut writer) = self.writer.take() else {
110            return Ok(());
111        };
112
113        writer.flush()?;
114        drop(writer);
115
116        // Replace the original file with the new file
117        fs::rename(&self.temp_path, dest_path)?;
118
119        if let Some(metadata) = self.original_metadata.as_ref() {
120            set_file_metadata(dest_path, metadata, false)?;
121        }
122        let original_permissions = if dest_path.exists() {
123            Some(fs::metadata(dest_path)?.permissions())
124        } else {
125            None
126        };
127
128        // Set the original file's permissions to the new file if they exist
129        if let Some(permissions) = original_permissions {
130            fs::set_permissions(dest_path, permissions.clone())?;
131        }
132
133        Ok(())
134    }
135
136    fn writer(&mut self) -> io::Result<&mut BufWriter<File>> {
137        match &mut self.writer {
138            Some(wr) => Ok(wr),
139            None => Err(io::Error::new(
140                io::ErrorKind::BrokenPipe,
141                format!("Writing to {:?} already completed.", &self.dest_path),
142            )),
143        }
144    }
145}
146
147impl Write for SafeFileCreator {
148    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
149        self.writer()?.write(buf)
150    }
151
152    fn flush(&mut self) -> io::Result<()> {
153        self.writer()?.flush()
154    }
155}
156
157impl Seek for SafeFileCreator {
158    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
159        self.writer()?.seek(pos)
160    }
161}
162
163impl Drop for SafeFileCreator {
164    fn drop(&mut self) {
165        if let Err(e) = self.close() {
166            eprintln!("Error: Failed to close writer for {:?}: {}", &self.dest_path, e);
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use std::fs::File;
174    use std::io::Read;
175    #[cfg(unix)]
176    use std::os::unix::fs::PermissionsExt;
177
178    use tempfile::tempdir;
179
180    use super::*;
181
182    #[test]
183    fn test_safe_file_creator_new() {
184        let dir = tempdir().unwrap();
185        let dest_path = dir.path().join("new_file.txt");
186
187        let mut safe_file_creator = SafeFileCreator::new(&dest_path).unwrap();
188        writeln!(safe_file_creator, "Hello, world!").unwrap();
189        safe_file_creator.close().unwrap();
190
191        // Verify file contents
192        let mut contents = String::new();
193        File::open(&dest_path).unwrap().read_to_string(&mut contents).unwrap();
194        assert_eq!(contents.trim(), "Hello, world!");
195
196        // Verify file permissions: owner should have read/write.
197        // Group/other bits depend on host umask and may vary (e.g. 0o600 vs 0o644).
198        #[cfg(unix)]
199        {
200            let metadata = std::fs::metadata(&dest_path).unwrap();
201            let permissions = metadata.permissions();
202            let mode = permissions.mode() & 0o777;
203            // Default creation mode is 0o666 masked by umask.
204            assert!(mode & 0o600 == 0o600, "Owner should have rw permissions, got {mode:#o}");
205        }
206    }
207
208    #[test]
209    fn test_safe_file_creator_new_unnamed() {
210        let _dir = tempdir().unwrap();
211        let mut safe_file_creator = SafeFileCreator::new_unnamed(_dir.path()).unwrap();
212        writeln!(safe_file_creator, "Hello, world!").unwrap();
213
214        // Test error checking
215        let ret = safe_file_creator.close();
216        assert!(ret.is_err());
217
218        let dir = tempdir().unwrap();
219        let dest_path = dir.path().join("new_file.txt");
220        safe_file_creator.set_dest_path(&dest_path);
221        safe_file_creator.close().unwrap();
222
223        // Verify file contents
224        let mut contents = String::new();
225        File::open(&dest_path).unwrap().read_to_string(&mut contents).unwrap();
226        assert_eq!(contents.trim(), "Hello, world!");
227
228        // Verify file permissions: owner should have read/write.
229        // Group/other bits depend on host umask and may vary (e.g. 0o600 vs 0o644).
230        #[cfg(unix)]
231        {
232            let metadata = std::fs::metadata(&dest_path).unwrap();
233            let permissions = metadata.permissions();
234            let mode = permissions.mode() & 0o777;
235            // Default creation mode is 0o666 masked by umask.
236            assert!(mode & 0o600 == 0o600, "Owner should have rw permissions, got {mode:#o}");
237        }
238    }
239
240    #[test]
241    fn test_safe_file_creator_replace_existing() {
242        let dir = tempdir().unwrap();
243        let dest_path = dir.path().join("existing_file.txt");
244
245        // Create the existing file
246        {
247            let mut file = File::create(&dest_path).unwrap();
248            file.write_all(b"Old content").unwrap();
249            #[cfg(unix)]
250            {
251                let mut perms = file.metadata().unwrap().permissions();
252                perms.set_mode(0o600);
253                std::fs::set_permissions(&dest_path, perms).unwrap();
254            }
255        }
256
257        let mut safe_file_creator = SafeFileCreator::replace_existing(&dest_path).unwrap();
258        writeln!(safe_file_creator, "New content").unwrap();
259        safe_file_creator.close().unwrap();
260
261        // Verify file contents
262        let mut contents = String::new();
263        File::open(&dest_path).unwrap().read_to_string(&mut contents).unwrap();
264        assert_eq!(contents.trim(), "New content");
265
266        // Verify file permissions
267        #[cfg(unix)]
268        {
269            let metadata = std::fs::metadata(&dest_path).unwrap();
270            let permissions = metadata.permissions();
271            assert_eq!(permissions.mode() & 0o777, 0o600); // Original file mode
272        }
273    }
274
275    #[test]
276    fn test_safe_file_creator_drop() {
277        let dir = tempdir().unwrap();
278        let dest_path = dir.path().join("drop_file.txt");
279
280        {
281            let mut safe_file_creator = SafeFileCreator::new(&dest_path).unwrap();
282            writeln!(safe_file_creator, "Hello, world!").unwrap();
283            // safe_file_creator is dropped here
284        }
285
286        // Verify file contents
287        let mut contents = String::new();
288        File::open(&dest_path).unwrap().read_to_string(&mut contents).unwrap();
289        assert_eq!(contents.trim(), "Hello, world!");
290    }
291
292    #[test]
293    fn test_safe_file_creator_double_close() {
294        let dir = tempdir().unwrap();
295        let dest_path = dir.path().join("double_close_file.txt");
296
297        let mut safe_file_creator = SafeFileCreator::new(&dest_path).unwrap();
298        writeln!(safe_file_creator, "Hello, world!").unwrap();
299        safe_file_creator.close().unwrap();
300        safe_file_creator.close().unwrap(); // Should be a no-op
301
302        // Verify file contents
303        let mut contents = String::new();
304        File::open(&dest_path).unwrap().read_to_string(&mut contents).unwrap();
305        assert_eq!(contents.trim(), "Hello, world!");
306    }
307
308    #[test]
309    #[cfg(unix)]
310    fn test_safe_file_creator_set_metadata() {
311        let dir = tempdir().unwrap();
312        let dest_path = dir.path().join("metadata_file.txt");
313
314        // Create the existing file
315        {
316            let mut file = File::create(&dest_path).unwrap();
317            file.write_all(b"Old content").unwrap();
318            let mut perms = file.metadata().unwrap().permissions();
319            perms.set_mode(0o600);
320            fs::set_permissions(&dest_path, perms).unwrap();
321        }
322
323        let mut safe_file_creator = SafeFileCreator::replace_existing(&dest_path).unwrap();
324        writeln!(safe_file_creator, "New content").unwrap();
325        safe_file_creator.close().unwrap();
326
327        // Verify file contents
328        let mut contents = String::new();
329        File::open(&dest_path).unwrap().read_to_string(&mut contents).unwrap();
330        assert_eq!(contents.trim(), "New content");
331
332        // Verify file permissions
333        let metadata = fs::metadata(&dest_path).unwrap();
334        let permissions = metadata.permissions();
335        #[cfg(unix)]
336        assert_eq!(permissions.mode() & 0o777, 0o600); // Original file mode
337    }
338}