util/
lock.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::fs::File;
9use std::io;
10use std::path::Path;
11
12use fs2::FileExt;
13
14use crate::errors::IOContext;
15use crate::file::open;
16
17/// RAII lock on a filesystem path.
18#[derive(Debug)]
19pub struct PathLock {
20    file: File,
21}
22
23impl PathLock {
24    /// Take an exclusive lock on `path`. The lock file will be created on
25    /// demand.
26    pub fn exclusive<P: AsRef<Path>>(path: P) -> io::Result<Self> {
27        let file = open(path.as_ref(), "wc").io_context("lock file")?;
28        file.lock_exclusive()
29            .path_context("error locking file", path.as_ref())?;
30        Ok(PathLock { file })
31    }
32
33    pub fn as_file(&self) -> &File {
34        &self.file
35    }
36}
37
38impl Drop for PathLock {
39    fn drop(&mut self) {
40        self.file.unlock().expect("unlock");
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use std::sync::mpsc::channel;
47    use std::thread;
48
49    use super::*;
50
51    #[test]
52    fn test_path_lock() -> anyhow::Result<()> {
53        let dir = tempfile::tempdir()?;
54        let path = dir.path().join("a");
55        let (tx, rx) = channel();
56        const N: usize = 50;
57        let threads: Vec<_> = (0..N)
58            .map(|i| {
59                let path = path.clone();
60                let tx = tx.clone();
61                thread::spawn(move || {
62                    // Write 2 values that are the same, protected by the lock.
63                    let _locked = PathLock::exclusive(&path);
64                    tx.send(i).unwrap();
65                    tx.send(i).unwrap();
66                })
67            })
68            .collect();
69
70        for thread in threads {
71            thread.join().expect("joined");
72        }
73
74        for _ in 0..N {
75            // Read 2 values. They should be the same.
76            let v1 = rx.recv().unwrap();
77            let v2 = rx.recv().unwrap();
78            assert_eq!(v1, v2);
79        }
80
81        Ok(())
82    }
83}