Skip to main content

walletkit_core/storage/
lock.rs

1//! File-based storage lock for serializing writes.
2
3use std::fs::{self, File, OpenOptions};
4use std::path::Path;
5use std::sync::Arc;
6
7use super::error::{StorageError, StorageResult};
8
9/// A file-backed lock that serializes storage mutations across processes.
10#[derive(Debug, Clone)]
11pub struct StorageLock {
12    file: Arc<File>,
13}
14
15impl StorageLock {
16    /// Opens or creates the lock file at `path`.
17    ///
18    /// # Errors
19    ///
20    /// Returns an error if the file cannot be opened or created.
21    pub fn open(path: &Path) -> StorageResult<Self> {
22        if let Some(parent) = path.parent() {
23            fs::create_dir_all(parent).map_err(|err| map_io_err(&err))?;
24        }
25        let file = OpenOptions::new()
26            .read(true)
27            .write(true)
28            .create(true)
29            .truncate(false)
30            .open(path)
31            .map_err(|err| map_io_err(&err))?;
32        Ok(Self {
33            file: Arc::new(file),
34        })
35    }
36
37    /// Acquires the exclusive lock.
38    ///
39    /// # Errors
40    ///
41    /// Returns an error if the lock cannot be acquired.
42    pub fn lock(&self) -> StorageResult<StorageLockGuard> {
43        lock_exclusive(&self.file).map_err(|err| map_io_err(&err))?;
44        Ok(StorageLockGuard {
45            file: Arc::clone(&self.file),
46        })
47    }
48
49    /// Attempts to acquire the exclusive lock without blocking.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if the lock attempt fails for reasons other than
54    /// the lock being held by another process.
55    pub fn try_lock(&self) -> StorageResult<Option<StorageLockGuard>> {
56        if try_lock_exclusive(&self.file).map_err(|err| map_io_err(&err))? {
57            Ok(Some(StorageLockGuard {
58                file: Arc::clone(&self.file),
59            }))
60        } else {
61            Ok(None)
62        }
63    }
64}
65
66/// Guard that holds an exclusive lock for its lifetime.
67#[derive(Debug)]
68pub struct StorageLockGuard {
69    file: Arc<File>,
70}
71
72impl Drop for StorageLockGuard {
73    fn drop(&mut self) {
74        let _ = unlock(&self.file);
75    }
76}
77
78fn map_io_err(err: &std::io::Error) -> StorageError {
79    StorageError::Lock(err.to_string())
80}
81
82#[cfg(unix)]
83fn lock_exclusive(file: &File) -> std::io::Result<()> {
84    let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
85    let result = unsafe { flock(fd, LOCK_EX) };
86    if result == 0 {
87        Ok(())
88    } else {
89        Err(std::io::Error::last_os_error())
90    }
91}
92
93#[cfg(unix)]
94fn try_lock_exclusive(file: &File) -> std::io::Result<bool> {
95    let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
96    let result = unsafe { flock(fd, LOCK_EX | LOCK_NB) };
97    if result == 0 {
98        Ok(true)
99    } else {
100        let err = std::io::Error::last_os_error();
101        if err.kind() == std::io::ErrorKind::WouldBlock {
102            Ok(false)
103        } else {
104            Err(err)
105        }
106    }
107}
108
109#[cfg(unix)]
110fn unlock(file: &File) -> std::io::Result<()> {
111    let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
112    let result = unsafe { flock(fd, LOCK_UN) };
113    if result == 0 {
114        Ok(())
115    } else {
116        Err(std::io::Error::last_os_error())
117    }
118}
119
120#[cfg(unix)]
121use std::os::raw::c_int;
122
123#[cfg(unix)]
124const LOCK_EX: c_int = 2;
125#[cfg(unix)]
126const LOCK_NB: c_int = 4;
127#[cfg(unix)]
128const LOCK_UN: c_int = 8;
129
130#[cfg(unix)]
131extern "C" {
132    fn flock(fd: c_int, operation: c_int) -> c_int;
133}
134
135#[cfg(windows)]
136fn lock_exclusive(file: &File) -> std::io::Result<()> {
137    lock_file(file, 0)
138}
139
140#[cfg(windows)]
141fn try_lock_exclusive(file: &File) -> std::io::Result<bool> {
142    match lock_file(file, LOCKFILE_FAIL_IMMEDIATELY) {
143        Ok(()) => Ok(true),
144        Err(err) => {
145            if err.raw_os_error() == Some(ERROR_LOCK_VIOLATION) {
146                Ok(false)
147            } else {
148                Err(err)
149            }
150        }
151    }
152}
153
154#[cfg(windows)]
155fn unlock(file: &File) -> std::io::Result<()> {
156    let handle = std::os::windows::io::AsRawHandle::as_raw_handle(file) as HANDLE;
157    let mut overlapped: OVERLAPPED = unsafe { std::mem::zeroed() };
158    let result = unsafe { UnlockFileEx(handle, 0, 1, 0, &mut overlapped) };
159    if result != 0 {
160        Ok(())
161    } else {
162        Err(std::io::Error::last_os_error())
163    }
164}
165
166#[cfg(windows)]
167fn lock_file(file: &File, flags: u32) -> std::io::Result<()> {
168    let handle = std::os::windows::io::AsRawHandle::as_raw_handle(file) as HANDLE;
169    let mut overlapped: OVERLAPPED = unsafe { std::mem::zeroed() };
170    let result = unsafe {
171        LockFileEx(
172            handle,
173            LOCKFILE_EXCLUSIVE_LOCK | flags,
174            0,
175            1,
176            0,
177            &mut overlapped,
178        )
179    };
180    if result != 0 {
181        Ok(())
182    } else {
183        Err(std::io::Error::last_os_error())
184    }
185}
186
187#[cfg(windows)]
188type HANDLE = *mut std::ffi::c_void;
189
190#[cfg(windows)]
191#[repr(C)]
192struct OVERLAPPED {
193    internal: usize,
194    internal_high: usize,
195    offset: u32,
196    offset_high: u32,
197    h_event: HANDLE,
198}
199
200#[cfg(windows)]
201const LOCKFILE_EXCLUSIVE_LOCK: u32 = 0x2;
202#[cfg(windows)]
203const LOCKFILE_FAIL_IMMEDIATELY: u32 = 0x1;
204#[cfg(windows)]
205const ERROR_LOCK_VIOLATION: i32 = 33;
206
207#[cfg(windows)]
208extern "system" {
209    fn LockFileEx(
210        h_file: HANDLE,
211        flags: u32,
212        reserved: u32,
213        bytes_to_lock_low: u32,
214        bytes_to_lock_high: u32,
215        overlapped: *mut OVERLAPPED,
216    ) -> i32;
217    fn UnlockFileEx(
218        h_file: HANDLE,
219        reserved: u32,
220        bytes_to_unlock_low: u32,
221        bytes_to_unlock_high: u32,
222        overlapped: *mut OVERLAPPED,
223    ) -> i32;
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use uuid::Uuid;
230
231    fn temp_lock_path() -> std::path::PathBuf {
232        let mut path = std::env::temp_dir();
233        path.push(format!("walletkit-lock-{}.lock", Uuid::new_v4()));
234        path
235    }
236
237    #[test]
238    fn test_lock_is_exclusive() {
239        let path = temp_lock_path();
240        let lock_a = StorageLock::open(&path).expect("open lock");
241        let guard = lock_a.lock().expect("acquire lock");
242
243        let lock_b = StorageLock::open(&path).expect("open lock");
244        let blocked = lock_b.try_lock().expect("try lock");
245        assert!(blocked.is_none());
246
247        drop(guard);
248        let guard = lock_b.try_lock().expect("try lock");
249        assert!(guard.is_some());
250
251        let _ = std::fs::remove_file(path);
252    }
253
254    #[test]
255    fn test_lock_serializes_across_threads() {
256        let path = temp_lock_path();
257        let lock = StorageLock::open(&path).expect("open lock");
258
259        let (locked_tx, locked_rx) = std::sync::mpsc::channel();
260        let (release_tx, release_rx) = std::sync::mpsc::channel();
261        let (released_tx, released_rx) = std::sync::mpsc::channel();
262
263        let path_clone = path.clone();
264        let thread_a = std::thread::spawn(move || {
265            let guard = lock.lock().expect("lock in thread");
266            locked_tx.send(()).expect("signal locked");
267            release_rx.recv().expect("wait release");
268            drop(guard);
269            released_tx.send(()).expect("signal released");
270            let _ = std::fs::remove_file(path_clone);
271        });
272
273        locked_rx.recv().expect("wait locked");
274        let lock_b = StorageLock::open(&path).expect("open lock");
275        let blocked = lock_b.try_lock().expect("try lock");
276        assert!(blocked.is_none());
277
278        release_tx.send(()).expect("release");
279        released_rx.recv().expect("wait released");
280
281        let guard = lock_b.try_lock().expect("try lock");
282        assert!(guard.is_some());
283
284        thread_a.join().expect("thread join");
285    }
286}