serial_test/
file_lock.rs

1use fslock::LockFile;
2#[cfg(feature = "logging")]
3use log::debug;
4use std::{
5    env,
6    fs::{self, File},
7    io::{Read, Write},
8    path::Path,
9    thread,
10    time::Duration,
11};
12
13pub(crate) struct Lock {
14    lockfile: LockFile,
15    pub(crate) parallel_count: u32,
16    path: String,
17}
18
19impl Lock {
20    // Can't use the same file as fslock truncates it
21    fn gen_count_file(path: &str) -> String {
22        format!("{}-count", path)
23    }
24
25    fn read_parallel_count(path: &str) -> u32 {
26        let parallel_count = match File::open(Lock::gen_count_file(path)) {
27            Ok(mut file) => {
28                let mut count_buf = [0; 4];
29                match file.read_exact(&mut count_buf) {
30                    Ok(_) => u32::from_ne_bytes(count_buf),
31                    Err(_err) => {
32                        #[cfg(feature = "logging")]
33                        debug!("Error loading count file: {}", _err);
34                        0u32
35                    }
36                }
37            }
38            Err(_) => 0,
39        };
40
41        #[cfg(feature = "logging")]
42        debug!("Parallel count for {:?} is {}", path, parallel_count);
43        parallel_count
44    }
45
46    pub(crate) fn new(path: &str) -> Lock {
47        if !Path::new(path).exists() {
48            fs::write(path, "").unwrap_or_else(|_| panic!("Lock file path was {:?}", path))
49        }
50        let mut lockfile = LockFile::open(path).unwrap();
51
52        #[cfg(feature = "logging")]
53        debug!("Waiting on {:?}", path);
54
55        lockfile.lock().unwrap();
56
57        #[cfg(feature = "logging")]
58        debug!("Locked for {:?}", path);
59
60        Lock {
61            lockfile,
62            parallel_count: Lock::read_parallel_count(path),
63            path: String::from(path),
64        }
65    }
66
67    pub(crate) fn start_serial(self: &mut Lock) {
68        loop {
69            if self.parallel_count == 0 {
70                return;
71            }
72            #[cfg(feature = "logging")]
73            debug!("Waiting because parallel count is {}", self.parallel_count);
74            // unlock here is safe because we re-lock before returning
75            self.unlock();
76            thread::sleep(Duration::from_secs(1));
77            self.lockfile.lock().unwrap();
78            #[cfg(feature = "logging")]
79            debug!("Locked for {:?}", self.path);
80            self.parallel_count = Lock::read_parallel_count(&self.path)
81        }
82    }
83
84    fn unlock(self: &mut Lock) {
85        #[cfg(feature = "logging")]
86        debug!("Unlocking {}", self.path);
87        self.lockfile.unlock().unwrap();
88    }
89
90    pub(crate) fn end_serial(mut self: Lock) {
91        self.unlock();
92    }
93
94    fn write_parallel(self: &Lock) {
95        let mut file = File::create(&Lock::gen_count_file(&self.path)).unwrap();
96        file.write_all(&self.parallel_count.to_ne_bytes()).unwrap();
97    }
98
99    pub(crate) fn start_parallel(self: &mut Lock) {
100        self.parallel_count += 1;
101        self.write_parallel();
102        self.unlock();
103    }
104
105    pub(crate) fn end_parallel(mut self: Lock) {
106        assert!(self.parallel_count > 0);
107        self.parallel_count -= 1;
108        self.write_parallel();
109        self.unlock();
110    }
111}
112
113pub(crate) fn path_for_name(name: &str) -> String {
114    let mut pathbuf = env::temp_dir();
115    pathbuf.push(format!("serial-test-{}", name));
116    pathbuf.into_os_string().into_string().unwrap()
117}
118
119fn make_lock_for_name_and_path(name: &str, path: Option<&str>) -> Lock {
120    if let Some(opt_path) = path {
121        Lock::new(opt_path)
122    } else {
123        let default_path = path_for_name(name);
124        Lock::new(&default_path)
125    }
126}
127
128pub(crate) fn get_locks(names: &Vec<&str>, path: Option<&str>) -> Vec<Lock> {
129    if names.len() > 1 && path.is_some() {
130        panic!("Can't do file_parallel with both more than one name _and_ a specific path");
131    }
132    names
133        .iter()
134        .map(|name| make_lock_for_name_and_path(name, path))
135        .collect::<Vec<_>>()
136}