walletkit_core/storage/
lock.rs1use std::fs::{self, File, OpenOptions};
4use std::path::Path;
5use std::sync::Arc;
6
7use super::error::{StorageError, StorageResult};
8
9#[derive(Debug, Clone)]
11pub struct StorageLock {
12 file: Arc<File>,
13}
14
15impl StorageLock {
16 pub fn open(path: &Path) -> StorageResult<Self> {
22 if let Some(parent) = path.parent() {
23 fs::create_dir_all(parent).map_err(|err| map_io_err(&err))?;
24 }
25 let file = OpenOptions::new()
26 .read(true)
27 .write(true)
28 .create(true)
29 .truncate(false)
30 .open(path)
31 .map_err(|err| map_io_err(&err))?;
32 Ok(Self {
33 file: Arc::new(file),
34 })
35 }
36
37 pub fn lock(&self) -> StorageResult<StorageLockGuard> {
43 lock_exclusive(&self.file).map_err(|err| map_io_err(&err))?;
44 Ok(StorageLockGuard {
45 file: Arc::clone(&self.file),
46 })
47 }
48
49 pub fn try_lock(&self) -> StorageResult<Option<StorageLockGuard>> {
56 if try_lock_exclusive(&self.file).map_err(|err| map_io_err(&err))? {
57 Ok(Some(StorageLockGuard {
58 file: Arc::clone(&self.file),
59 }))
60 } else {
61 Ok(None)
62 }
63 }
64}
65
66#[derive(Debug)]
68pub struct StorageLockGuard {
69 file: Arc<File>,
70}
71
72impl Drop for StorageLockGuard {
73 fn drop(&mut self) {
74 let _ = unlock(&self.file);
75 }
76}
77
78fn map_io_err(err: &std::io::Error) -> StorageError {
79 StorageError::Lock(err.to_string())
80}
81
82#[cfg(unix)]
83fn lock_exclusive(file: &File) -> std::io::Result<()> {
84 let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
85 let result = unsafe { flock(fd, LOCK_EX) };
86 if result == 0 {
87 Ok(())
88 } else {
89 Err(std::io::Error::last_os_error())
90 }
91}
92
93#[cfg(unix)]
94fn try_lock_exclusive(file: &File) -> std::io::Result<bool> {
95 let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
96 let result = unsafe { flock(fd, LOCK_EX | LOCK_NB) };
97 if result == 0 {
98 Ok(true)
99 } else {
100 let err = std::io::Error::last_os_error();
101 if err.kind() == std::io::ErrorKind::WouldBlock {
102 Ok(false)
103 } else {
104 Err(err)
105 }
106 }
107}
108
109#[cfg(unix)]
110fn unlock(file: &File) -> std::io::Result<()> {
111 let fd = std::os::unix::io::AsRawFd::as_raw_fd(file);
112 let result = unsafe { flock(fd, LOCK_UN) };
113 if result == 0 {
114 Ok(())
115 } else {
116 Err(std::io::Error::last_os_error())
117 }
118}
119
120#[cfg(unix)]
121use std::os::raw::c_int;
122
123#[cfg(unix)]
124const LOCK_EX: c_int = 2;
125#[cfg(unix)]
126const LOCK_NB: c_int = 4;
127#[cfg(unix)]
128const LOCK_UN: c_int = 8;
129
130#[cfg(unix)]
131extern "C" {
132 fn flock(fd: c_int, operation: c_int) -> c_int;
133}
134
135#[cfg(windows)]
136fn lock_exclusive(file: &File) -> std::io::Result<()> {
137 lock_file(file, 0)
138}
139
140#[cfg(windows)]
141fn try_lock_exclusive(file: &File) -> std::io::Result<bool> {
142 match lock_file(file, LOCKFILE_FAIL_IMMEDIATELY) {
143 Ok(()) => Ok(true),
144 Err(err) => {
145 if err.raw_os_error() == Some(ERROR_LOCK_VIOLATION) {
146 Ok(false)
147 } else {
148 Err(err)
149 }
150 }
151 }
152}
153
154#[cfg(windows)]
155fn unlock(file: &File) -> std::io::Result<()> {
156 let handle = std::os::windows::io::AsRawHandle::as_raw_handle(file) as HANDLE;
157 let mut overlapped: OVERLAPPED = unsafe { std::mem::zeroed() };
158 let result = unsafe { UnlockFileEx(handle, 0, 1, 0, &mut overlapped) };
159 if result != 0 {
160 Ok(())
161 } else {
162 Err(std::io::Error::last_os_error())
163 }
164}
165
166#[cfg(windows)]
167fn lock_file(file: &File, flags: u32) -> std::io::Result<()> {
168 let handle = std::os::windows::io::AsRawHandle::as_raw_handle(file) as HANDLE;
169 let mut overlapped: OVERLAPPED = unsafe { std::mem::zeroed() };
170 let result = unsafe {
171 LockFileEx(
172 handle,
173 LOCKFILE_EXCLUSIVE_LOCK | flags,
174 0,
175 1,
176 0,
177 &mut overlapped,
178 )
179 };
180 if result != 0 {
181 Ok(())
182 } else {
183 Err(std::io::Error::last_os_error())
184 }
185}
186
187#[cfg(windows)]
188type HANDLE = *mut std::ffi::c_void;
189
190#[cfg(windows)]
191#[repr(C)]
192struct OVERLAPPED {
193 internal: usize,
194 internal_high: usize,
195 offset: u32,
196 offset_high: u32,
197 h_event: HANDLE,
198}
199
200#[cfg(windows)]
201const LOCKFILE_EXCLUSIVE_LOCK: u32 = 0x2;
202#[cfg(windows)]
203const LOCKFILE_FAIL_IMMEDIATELY: u32 = 0x1;
204#[cfg(windows)]
205const ERROR_LOCK_VIOLATION: i32 = 33;
206
207#[cfg(windows)]
208extern "system" {
209 fn LockFileEx(
210 h_file: HANDLE,
211 flags: u32,
212 reserved: u32,
213 bytes_to_lock_low: u32,
214 bytes_to_lock_high: u32,
215 overlapped: *mut OVERLAPPED,
216 ) -> i32;
217 fn UnlockFileEx(
218 h_file: HANDLE,
219 reserved: u32,
220 bytes_to_unlock_low: u32,
221 bytes_to_unlock_high: u32,
222 overlapped: *mut OVERLAPPED,
223 ) -> i32;
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use uuid::Uuid;
230
231 fn temp_lock_path() -> std::path::PathBuf {
232 let mut path = std::env::temp_dir();
233 path.push(format!("walletkit-lock-{}.lock", Uuid::new_v4()));
234 path
235 }
236
237 #[test]
238 fn test_lock_is_exclusive() {
239 let path = temp_lock_path();
240 let lock_a = StorageLock::open(&path).expect("open lock");
241 let guard = lock_a.lock().expect("acquire lock");
242
243 let lock_b = StorageLock::open(&path).expect("open lock");
244 let blocked = lock_b.try_lock().expect("try lock");
245 assert!(blocked.is_none());
246
247 drop(guard);
248 let guard = lock_b.try_lock().expect("try lock");
249 assert!(guard.is_some());
250
251 let _ = std::fs::remove_file(path);
252 }
253
254 #[test]
255 fn test_lock_serializes_across_threads() {
256 let path = temp_lock_path();
257 let lock = StorageLock::open(&path).expect("open lock");
258
259 let (locked_tx, locked_rx) = std::sync::mpsc::channel();
260 let (release_tx, release_rx) = std::sync::mpsc::channel();
261 let (released_tx, released_rx) = std::sync::mpsc::channel();
262
263 let path_clone = path.clone();
264 let thread_a = std::thread::spawn(move || {
265 let guard = lock.lock().expect("lock in thread");
266 locked_tx.send(()).expect("signal locked");
267 release_rx.recv().expect("wait release");
268 drop(guard);
269 released_tx.send(()).expect("signal released");
270 let _ = std::fs::remove_file(path_clone);
271 });
272
273 locked_rx.recv().expect("wait locked");
274 let lock_b = StorageLock::open(&path).expect("open lock");
275 let blocked = lock_b.try_lock().expect("try lock");
276 assert!(blocked.is_none());
277
278 release_tx.send(()).expect("release");
279 released_rx.recv().expect("wait released");
280
281 let guard = lock_b.try_lock().expect("try lock");
282 assert!(guard.is_some());
283
284 thread_a.join().expect("thread join");
285 }
286}