1use anyhow::{Context, Result};
2use std::fs::{File, OpenOptions};
3use std::io::{Read, Write};
4use std::path::{Path, PathBuf};
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum LockType {
11 Read,
12 Write,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LockInfo {
18 pub lock_type: String, pub process_id: u32,
20 pub acquired_at: u64,
21 pub expires_at: u64,
22 pub holder_info: String, }
24
25impl LockInfo {
26 pub fn new(lock_type: LockType, duration_secs: u64, holder_info: String) -> Self {
27 let now = SystemTime::now()
28 .duration_since(UNIX_EPOCH)
29 .unwrap()
30 .as_secs();
31
32 Self {
33 lock_type: match lock_type {
34 LockType::Read => "read".to_string(),
35 LockType::Write => "write".to_string(),
36 },
37 process_id: std::process::id(),
38 acquired_at: now,
39 expires_at: now + duration_secs,
40 holder_info,
41 }
42 }
43
44 pub fn is_expired(&self) -> bool {
45 let now = SystemTime::now()
46 .duration_since(UNIX_EPOCH)
47 .unwrap()
48 .as_secs();
49 now >= self.expires_at
50 }
51
52 pub fn is_write_lock(&self) -> bool {
53 self.lock_type == "write"
54 }
55
56 pub fn is_read_lock(&self) -> bool {
57 self.lock_type == "read"
58 }
59}
60
61pub struct SWMRLockManager {
63 collection_path: PathBuf,
64 locks_dir: PathBuf,
65 write_lock_path: PathBuf,
66}
67
68impl SWMRLockManager {
69 pub fn new(collection_path: &Path) -> Self {
70 let locks_dir = collection_path.join("locks");
71 let write_lock_path = locks_dir.join("write.lock");
72
73 Self {
74 collection_path: collection_path.to_path_buf(),
75 locks_dir,
76 write_lock_path,
77 }
78 }
79
80 pub fn init(&self) -> Result<()> {
82 std::fs::create_dir_all(&self.locks_dir)
83 .context("Failed to create locks directory")?;
84 Ok(())
85 }
86
87 pub fn acquire_read_lock(&self, timeout_secs: u64, holder_info: String) -> Result<ReadLock> {
89 self.init()?;
90
91 let start_time = SystemTime::now();
92 let timeout = Duration::from_secs(timeout_secs);
93
94 loop {
95 if let Ok(write_lock_info) = self.read_write_lock() {
97 if !write_lock_info.is_expired() {
98 if start_time.elapsed().unwrap() >= timeout {
99 anyhow::bail!("Timeout waiting for read lock - write lock held by process {}",
100 write_lock_info.process_id);
101 }
102 std::thread::sleep(Duration::from_millis(100));
103 continue;
104 }
105 }
106
107 let lock_info = LockInfo::new(LockType::Read, 3600, holder_info); let read_lock_path = self.locks_dir.join(format!("read_{}.lock", std::process::id()));
110
111 self.write_lock_file(&read_lock_path, &lock_info)?;
112
113 return Ok(ReadLock {
114 manager: self,
115 lock_path: read_lock_path,
116 lock_info,
117 });
118 }
119 }
120
121 pub fn acquire_write_lock(&self, timeout_secs: u64, holder_info: String) -> Result<WriteLock> {
123 self.init()?;
124
125 let start_time = SystemTime::now();
126 let timeout = Duration::from_secs(timeout_secs);
127
128 loop {
129 if let Ok(existing_write) = self.read_write_lock() {
131 if !existing_write.is_expired() {
132 if start_time.elapsed().unwrap() >= timeout {
133 anyhow::bail!("Timeout waiting for write lock - held by process {}",
134 existing_write.process_id);
135 }
136 std::thread::sleep(Duration::from_millis(100));
137 continue;
138 }
139 }
140
141 let active_read_locks = self.get_active_read_locks()?;
143 if !active_read_locks.is_empty() {
144 if start_time.elapsed().unwrap() >= timeout {
145 anyhow::bail!("Timeout waiting for write lock - {} read locks active",
146 active_read_locks.len());
147 }
148 std::thread::sleep(Duration::from_millis(100));
149 continue;
150 }
151
152 let lock_info = LockInfo::new(LockType::Write, 1800, holder_info); self.write_lock_file(&self.write_lock_path, &lock_info)?;
156
157 return Ok(WriteLock {
158 manager: self,
159 lock_info,
160 });
161 }
162 }
163
164 pub fn is_write_locked(&self) -> Result<bool> {
166 match self.read_write_lock() {
167 Ok(lock_info) => Ok(!lock_info.is_expired()),
168 Err(_) => Ok(false),
169 }
170 }
171
172 pub fn active_read_lock_count(&self) -> Result<usize> {
174 Ok(self.get_active_read_locks()?.len())
175 }
176
177 fn read_write_lock(&self) -> Result<LockInfo> {
178 let content = std::fs::read_to_string(&self.write_lock_path)
179 .context("Failed to read write lock file")?;
180 let lock_info: LockInfo = serde_json::from_str(&content)
181 .context("Failed to parse write lock info")?;
182 Ok(lock_info)
183 }
184
185 fn get_active_read_locks(&self) -> Result<Vec<LockInfo>> {
186 let mut active_locks = Vec::new();
187
188 if !self.locks_dir.exists() {
189 return Ok(active_locks);
190 }
191
192 for entry in std::fs::read_dir(&self.locks_dir)? {
193 let entry = entry?;
194 let path = entry.path();
195
196 if let Some(filename) = path.file_name() {
197 if let Some(filename_str) = filename.to_str() {
198 if filename_str.starts_with("read_") && filename_str.ends_with(".lock") {
199 if let Ok(content) = std::fs::read_to_string(&path) {
200 if let Ok(lock_info) = serde_json::from_str::<LockInfo>(&content) {
201 if !lock_info.is_expired() {
202 active_locks.push(lock_info);
203 } else {
204 let _ = std::fs::remove_file(&path);
206 }
207 }
208 }
209 }
210 }
211 }
212 }
213
214 Ok(active_locks)
215 }
216
217 fn write_lock_file(&self, path: &Path, lock_info: &LockInfo) -> Result<()> {
218 let json = serde_json::to_string_pretty(lock_info)
219 .context("Failed to serialize lock info")?;
220 std::fs::write(path, json)
221 .context("Failed to write lock file")?;
222 Ok(())
223 }
224
225 fn release_read_lock(&self, lock_path: &Path) -> Result<()> {
226 if lock_path.exists() {
227 std::fs::remove_file(lock_path)
228 .context("Failed to remove read lock file")?;
229 }
230 Ok(())
231 }
232
233 fn release_write_lock(&self) -> Result<()> {
234 if self.write_lock_path.exists() {
235 std::fs::remove_file(&self.write_lock_path)
236 .context("Failed to remove write lock file")?;
237 }
238 Ok(())
239 }
240}
241
242pub struct ReadLock<'a> {
244 manager: &'a SWMRLockManager,
245 lock_path: PathBuf,
246 lock_info: LockInfo,
247}
248
249impl<'a> ReadLock<'a> {
250 pub fn lock_info(&self) -> &LockInfo {
251 &self.lock_info
252 }
253
254 pub fn extend(&mut self, additional_secs: u64) -> Result<()> {
256 self.lock_info.expires_at += additional_secs;
257 self.manager.write_lock_file(&self.lock_path, &self.lock_info)?;
258 Ok(())
259 }
260}
261
262impl<'a> Drop for ReadLock<'a> {
263 fn drop(&mut self) {
264 let _ = self.manager.release_read_lock(&self.lock_path);
265 }
266}
267
268pub struct WriteLock<'a> {
270 manager: &'a SWMRLockManager,
271 lock_info: LockInfo,
272}
273
274impl<'a> WriteLock<'a> {
275 pub fn lock_info(&self) -> &LockInfo {
276 &self.lock_info
277 }
278
279 pub fn extend(&mut self, additional_secs: u64) -> Result<()> {
281 self.lock_info.expires_at += additional_secs;
282 self.manager.write_lock_file(&self.manager.write_lock_path, &self.lock_info)?;
283 Ok(())
284 }
285}
286
287impl<'a> Drop for WriteLock<'a> {
288 fn drop(&mut self) {
289 let _ = self.manager.release_write_lock();
290 }
291}
292
293pub trait LockAware {
295 fn with_read_lock<F, R>(&self, timeout_secs: u64, operation: F) -> Result<R>
296 where
297 F: FnOnce() -> Result<R>;
298
299 fn with_write_lock<F, R>(&self, timeout_secs: u64, operation: F) -> Result<R>
300 where
301 F: FnOnce() -> Result<R>;
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use tempfile::TempDir;
308
309 #[test]
310 fn test_read_lock_acquisition() {
311 let temp_dir = TempDir::new().unwrap();
312 let manager = SWMRLockManager::new(temp_dir.path());
313
314 let _lock = manager.acquire_read_lock(5, "test".to_string()).unwrap();
315 assert_eq!(manager.active_read_lock_count().unwrap(), 1);
316 }
317
318 #[test]
319 fn test_write_lock_exclusivity() {
320 let temp_dir = TempDir::new().unwrap();
321 let manager = SWMRLockManager::new(temp_dir.path());
322
323 let _write_lock = manager.acquire_write_lock(5, "test".to_string()).unwrap();
324
325 let result = manager.acquire_read_lock(1, "test2".to_string());
327 assert!(result.is_err());
328 }
329}