1use std::fs::{File, OpenOptions};
58use std::io::{Read, Write};
59use std::path::{Path, PathBuf};
60use std::time::{Duration, Instant};
61
62use sochdb_core::SochDBError;
63
64#[derive(Debug)]
70pub enum LockError {
71 DatabaseLocked {
73 holder_pid: Option<u32>,
75 lock_path: PathBuf,
77 },
78 Timeout {
80 elapsed: Duration,
82 timeout: Duration,
84 },
85 StaleLock {
87 stale_pid: u32,
89 },
90 Io(std::io::Error),
92}
93
94impl std::fmt::Display for LockError {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 LockError::DatabaseLocked { holder_pid, lock_path } => {
98 if let Some(pid) = holder_pid {
99 write!(f, "Database is locked by process {} (lock file: {})",
100 pid, lock_path.display())
101 } else {
102 write!(f, "Database is locked (lock file: {})", lock_path.display())
103 }
104 }
105 LockError::Timeout { elapsed, timeout } => {
106 write!(f, "Lock acquisition timed out after {:?} (timeout: {:?})",
107 elapsed, timeout)
108 }
109 LockError::StaleLock { stale_pid } => {
110 write!(f, "Stale lock detected from crashed process {}", stale_pid)
111 }
112 LockError::Io(e) => write!(f, "Lock I/O error: {}", e),
113 }
114 }
115}
116
117impl std::error::Error for LockError {
118 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
119 match self {
120 LockError::Io(e) => Some(e),
121 _ => None,
122 }
123 }
124}
125
126impl From<std::io::Error> for LockError {
127 fn from(e: std::io::Error) -> Self {
128 LockError::Io(e)
129 }
130}
131
132impl From<LockError> for SochDBError {
133 fn from(e: LockError) -> Self {
134 match e {
135 LockError::DatabaseLocked { holder_pid, lock_path } => {
136 SochDBError::LockError(format!(
137 "Database locked by PID {:?} (lock: {})",
138 holder_pid, lock_path.display()
139 ))
140 }
141 LockError::Timeout { elapsed, timeout } => {
142 SochDBError::LockError(format!(
143 "Lock timeout after {:?} (max: {:?})", elapsed, timeout
144 ))
145 }
146 LockError::StaleLock { stale_pid } => {
147 SochDBError::LockError(format!(
148 "Stale lock from crashed process {}", stale_pid
149 ))
150 }
151 LockError::Io(e) => SochDBError::Io(e),
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
162pub struct LockConfig {
163 pub timeout: Option<Duration>,
165 pub retry_interval: Duration,
167 pub detect_stale_locks: bool,
169 pub lock_file_name: String,
171}
172
173impl Default for LockConfig {
174 fn default() -> Self {
175 Self {
176 timeout: Some(Duration::from_secs(5)),
177 retry_interval: Duration::from_millis(100),
178 detect_stale_locks: true,
179 lock_file_name: ".lock".to_string(),
180 }
181 }
182}
183
184impl LockConfig {
185 pub fn no_wait() -> Self {
187 Self {
188 timeout: None,
189 ..Default::default()
190 }
191 }
192
193 pub fn with_timeout(timeout: Duration) -> Self {
195 Self {
196 timeout: Some(timeout),
197 ..Default::default()
198 }
199 }
200}
201
202pub struct DatabaseLock {
222 lock_file: File,
224 path: PathBuf,
226 our_pid: u32,
228}
229
230impl DatabaseLock {
231 pub fn acquire<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
249 Self::acquire_with_config(db_path, &LockConfig::no_wait())
250 }
251
252 pub fn acquire_with_timeout<P: AsRef<Path>>(
261 db_path: P,
262 timeout: Duration
263 ) -> std::result::Result<Self, LockError> {
264 Self::acquire_with_config(db_path, &LockConfig::with_timeout(timeout))
265 }
266
267 pub fn acquire_with_config<P: AsRef<Path>>(
269 db_path: P,
270 config: &LockConfig,
271 ) -> std::result::Result<Self, LockError> {
272 let db_path = db_path.as_ref();
273 let lock_path = db_path.join(&config.lock_file_name);
274
275 if !db_path.exists() {
277 std::fs::create_dir_all(db_path)?;
278 }
279
280 let deadline = config.timeout.map(|t| Instant::now() + t);
281 let our_pid = std::process::id();
282
283 loop {
284 let file = OpenOptions::new()
286 .create(true)
287 .read(true)
288 .write(true)
289 .open(&lock_path)?;
290
291 match Self::try_flock(&file, false) {
293 Ok(()) => {
294 Self::write_pid(&file, our_pid)?;
296
297 return Ok(Self {
298 lock_file: file,
299 path: lock_path,
300 our_pid,
301 });
302 }
303 Err(LockError::DatabaseLocked { .. }) => {
304 let mut should_retry = false;
308 if config.detect_stale_locks {
309 if let Some(holder_pid) = Self::read_pid(&file) {
310 if !Self::process_exists(holder_pid) {
311 drop(file);
314
315 if std::fs::remove_file(&lock_path).is_ok() {
317 should_retry = true;
318 }
319 }
320 }
321 }
322
323 if should_retry {
324 continue; }
326
327 if let Some(deadline) = deadline {
329 if Instant::now() >= deadline {
330 return Err(LockError::Timeout {
331 elapsed: config.timeout.unwrap_or_default(),
332 timeout: config.timeout.unwrap_or_default(),
333 });
334 }
335
336 std::thread::sleep(config.retry_interval);
338 continue;
339 } else {
340 return Err(LockError::DatabaseLocked {
343 holder_pid: None,
344 lock_path
345 });
346 }
347 }
348 Err(e) => return Err(e),
349 }
350 }
351 }
352
353 pub fn path(&self) -> &Path {
355 &self.path
356 }
357
358 pub fn pid(&self) -> u32 {
360 self.our_pid
361 }
362
363 pub fn get_lock_holder<P: AsRef<Path>>(db_path: P) -> Option<u32> {
367 let lock_path = db_path.as_ref().join(".lock");
368 let file = File::open(&lock_path).ok()?;
369 Self::read_pid(&file)
370 }
371
372 fn write_pid(file: &File, pid: u32) -> std::result::Result<(), LockError> {
374 use std::io::Seek;
375 let mut file = file;
376 file.seek(std::io::SeekFrom::Start(0))?;
377 file.set_len(0)?;
378 writeln!(file, "{}", pid)?;
379 file.sync_all()?;
380 Ok(())
381 }
382
383 fn read_pid(file: &File) -> Option<u32> {
385 use std::io::Seek;
386 let mut file = file;
387 let _ = file.seek(std::io::SeekFrom::Start(0));
388 let mut contents = String::new();
389 file.read_to_string(&mut contents).ok()?;
390 contents.trim().parse().ok()
391 }
392
393 #[cfg(unix)]
395 fn process_exists(pid: u32) -> bool {
396 let result = unsafe { libc::kill(pid as libc::pid_t, 0) };
399 if result == 0 {
400 true
401 } else {
402 let errno = std::io::Error::last_os_error().raw_os_error();
404 errno != Some(libc::ESRCH)
405 }
406 }
407
408 #[cfg(windows)]
409 fn process_exists(pid: u32) -> bool {
410 unsafe {
411 let handle = windows_sys::Win32::System::Threading::OpenProcess(
412 windows_sys::Win32::System::Threading::PROCESS_QUERY_LIMITED_INFORMATION,
413 0,
414 pid,
415 );
416 if handle == 0 || handle == -1 {
417 false
418 } else {
419 windows_sys::Win32::Foundation::CloseHandle(handle);
420 true
421 }
422 }
423 }
424
425 #[cfg(not(any(unix, windows)))]
426 fn process_exists(_pid: u32) -> bool {
427 true
429 }
430
431 #[cfg(unix)]
433 fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
434 use std::os::unix::io::AsRawFd;
435
436 let fd = file.as_raw_fd();
437 let operation = if blocking {
438 libc::LOCK_EX
439 } else {
440 libc::LOCK_EX | libc::LOCK_NB
441 };
442
443 let result = unsafe { libc::flock(fd, operation) };
444
445 if result == 0 {
446 Ok(())
447 } else {
448 let err = std::io::Error::last_os_error();
449 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
450 Err(LockError::DatabaseLocked {
451 holder_pid: None,
452 lock_path: PathBuf::new(),
453 })
454 } else {
455 Err(LockError::Io(err))
456 }
457 }
458 }
459
460 #[cfg(windows)]
461 fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
462 use std::os::windows::io::AsRawHandle;
463
464 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
465
466 let flags = windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
467 | if blocking { 0 } else { windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY };
468
469 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
470
471 let result = unsafe {
472 windows_sys::Win32::Storage::FileSystem::LockFileEx(
473 handle,
474 flags,
475 0,
476 1,
477 0,
478 &mut overlapped,
479 )
480 };
481
482 if result != 0 {
483 Ok(())
484 } else {
485 let err = std::io::Error::last_os_error();
486 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
487 Err(LockError::DatabaseLocked {
488 holder_pid: None,
489 lock_path: PathBuf::new(),
490 })
491 } else {
492 Err(LockError::Io(err))
493 }
494 }
495 }
496
497 #[cfg(not(any(unix, windows)))]
498 fn try_flock(_file: &File, _blocking: bool) -> std::result::Result<(), LockError> {
499 Ok(())
502 }
503
504 #[cfg(unix)]
506 fn release(&self) {
507 use std::os::unix::io::AsRawFd;
508 let fd = self.lock_file.as_raw_fd();
509 unsafe { libc::flock(fd, libc::LOCK_UN) };
510 }
511
512 #[cfg(windows)]
513 fn release(&self) {
514 use std::os::windows::io::AsRawHandle;
515 let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
516 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
517 unsafe {
518 windows_sys::Win32::Storage::FileSystem::UnlockFileEx(
519 handle,
520 0,
521 1,
522 0,
523 &mut overlapped,
524 );
525 }
526 }
527
528 #[cfg(not(any(unix, windows)))]
529 fn release(&self) {
530 }
532}
533
534impl Drop for DatabaseLock {
535 fn drop(&mut self) {
536 self.release();
537 let _ = std::fs::remove_file(&self.path);
540 }
541}
542
543#[repr(C)]
557#[derive(Debug, Clone, Copy, Default)]
558pub struct RwLockState {
559 pub reader_count: u32,
561 pub writer_intent: u32,
563 pub writer_active: u32,
565 pub _padding: u32,
567}
568
569#[derive(Debug, Clone, Copy, PartialEq, Eq)]
571pub enum ConnectionMode {
572 ReadOnly,
574 ReadWrite,
576}
577
578pub struct RwDatabaseLock {
585 lock_file: File,
587 path: PathBuf,
589 mode: ConnectionMode,
591 our_pid: u32,
593}
594
595impl RwDatabaseLock {
596 pub fn acquire_shared<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
601 Self::acquire_with_mode(db_path, ConnectionMode::ReadOnly, &LockConfig::default())
602 }
603
604 pub fn acquire_exclusive<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
609 Self::acquire_with_mode(db_path, ConnectionMode::ReadWrite, &LockConfig::default())
610 }
611
612 pub fn acquire_with_mode<P: AsRef<Path>>(
614 db_path: P,
615 mode: ConnectionMode,
616 config: &LockConfig,
617 ) -> std::result::Result<Self, LockError> {
618 let db_path = db_path.as_ref();
619 let lock_path = db_path.join(&config.lock_file_name);
620
621 if !db_path.exists() {
622 std::fs::create_dir_all(db_path)?;
623 }
624
625 let file = OpenOptions::new()
626 .create(true)
627 .read(true)
628 .write(true)
629 .open(&lock_path)?;
630
631 let our_pid = std::process::id();
632 let deadline = config.timeout.map(|t| Instant::now() + t);
633
634 loop {
635 match mode {
636 ConnectionMode::ReadOnly => {
637 if Self::try_shared_lock(&file)? {
639 return Ok(Self {
640 lock_file: file,
641 path: lock_path,
642 mode,
643 our_pid,
644 });
645 }
646 }
647 ConnectionMode::ReadWrite => {
648 if Self::try_exclusive_lock(&file)? {
650 return Ok(Self {
651 lock_file: file,
652 path: lock_path,
653 mode,
654 our_pid,
655 });
656 }
657 }
658 }
659
660 if let Some(deadline) = deadline {
662 if Instant::now() >= deadline {
663 return Err(LockError::Timeout {
664 elapsed: config.timeout.unwrap_or_default(),
665 timeout: config.timeout.unwrap_or_default(),
666 });
667 }
668 std::thread::sleep(config.retry_interval);
669 } else {
670 return Err(LockError::DatabaseLocked {
671 holder_pid: None,
672 lock_path,
673 });
674 }
675 }
676 }
677
678 pub fn mode(&self) -> ConnectionMode {
680 self.mode
681 }
682
683 pub fn is_readonly(&self) -> bool {
685 self.mode == ConnectionMode::ReadOnly
686 }
687
688 #[cfg(unix)]
689 fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
690 use std::os::unix::io::AsRawFd;
691 let fd = file.as_raw_fd();
692 let result = unsafe { libc::flock(fd, libc::LOCK_SH | libc::LOCK_NB) };
693 if result == 0 {
694 Ok(true)
695 } else {
696 let err = std::io::Error::last_os_error();
697 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
698 Ok(false)
699 } else {
700 Err(LockError::Io(err))
701 }
702 }
703 }
704
705 #[cfg(unix)]
706 fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
707 use std::os::unix::io::AsRawFd;
708 let fd = file.as_raw_fd();
709 let result = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) };
710 if result == 0 {
711 Ok(true)
712 } else {
713 let err = std::io::Error::last_os_error();
714 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
715 Ok(false)
716 } else {
717 Err(LockError::Io(err))
718 }
719 }
720 }
721
722 #[cfg(windows)]
723 fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
724 use std::os::windows::io::AsRawHandle;
725 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
726 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
727
728 let result = unsafe {
729 windows_sys::Win32::Storage::FileSystem::LockFileEx(
730 handle,
731 windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
732 0, 1, 0,
733 &mut overlapped,
734 )
735 };
736
737 if result != 0 {
738 Ok(true)
739 } else {
740 let err = std::io::Error::last_os_error();
741 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
742 Ok(false)
743 } else {
744 Err(LockError::Io(err))
745 }
746 }
747 }
748
749 #[cfg(windows)]
750 fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
751 use std::os::windows::io::AsRawHandle;
752 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
753 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
754
755 let result = unsafe {
756 windows_sys::Win32::Storage::FileSystem::LockFileEx(
757 handle,
758 windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
759 | windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
760 0, 1, 0,
761 &mut overlapped,
762 )
763 };
764
765 if result != 0 {
766 Ok(true)
767 } else {
768 let err = std::io::Error::last_os_error();
769 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
770 Ok(false)
771 } else {
772 Err(LockError::Io(err))
773 }
774 }
775 }
776
777 #[cfg(not(any(unix, windows)))]
778 fn try_shared_lock(_file: &File) -> std::result::Result<bool, LockError> {
779 Ok(true)
780 }
781
782 #[cfg(not(any(unix, windows)))]
783 fn try_exclusive_lock(_file: &File) -> std::result::Result<bool, LockError> {
784 Ok(true)
785 }
786
787 #[cfg(unix)]
788 fn release(&self) {
789 use std::os::unix::io::AsRawFd;
790 let fd = self.lock_file.as_raw_fd();
791 unsafe { libc::flock(fd, libc::LOCK_UN) };
792 }
793
794 #[cfg(windows)]
795 fn release(&self) {
796 use std::os::windows::io::AsRawHandle;
797 let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
798 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
799 unsafe {
800 windows_sys::Win32::Storage::FileSystem::UnlockFileEx(handle, 0, 1, 0, &mut overlapped);
801 }
802 }
803
804 #[cfg(not(any(unix, windows)))]
805 fn release(&self) {}
806}
807
808impl Drop for RwDatabaseLock {
809 fn drop(&mut self) {
810 self.release();
811 }
812}
813
814#[cfg(test)]
819mod tests {
820 use super::*;
821 use std::thread;
822 use tempfile::TempDir;
823
824 #[test]
825 fn test_exclusive_lock_basic() {
826 let dir = TempDir::new().unwrap();
827 let db_path = dir.path();
828
829 let lock1 = DatabaseLock::acquire(db_path);
831 assert!(lock1.is_ok());
832
833 let lock2 = DatabaseLock::acquire(db_path);
835 assert!(matches!(lock2, Err(LockError::DatabaseLocked { .. })));
836
837 drop(lock1);
839 let lock3 = DatabaseLock::acquire(db_path);
840 assert!(lock3.is_ok());
841 }
842
843 #[test]
844 fn test_lock_with_timeout() {
845 let dir = TempDir::new().unwrap();
846 let db_path = dir.path().to_path_buf();
847
848 let _lock = DatabaseLock::acquire(&db_path).unwrap();
850
851 let start = Instant::now();
853 let result = DatabaseLock::acquire_with_timeout(&db_path, Duration::from_millis(100));
854 let elapsed = start.elapsed();
855
856 assert!(matches!(result, Err(LockError::Timeout { .. })));
857 assert!(elapsed >= Duration::from_millis(100));
858 assert!(elapsed < Duration::from_millis(500)); }
860
861 #[test]
862 fn test_lock_pid_recorded() {
863 let dir = TempDir::new().unwrap();
864 let db_path = dir.path();
865
866 let lock = DatabaseLock::acquire(db_path).unwrap();
867 let our_pid = std::process::id();
868
869 assert_eq!(lock.pid(), our_pid);
870
871 let holder = DatabaseLock::get_lock_holder(db_path);
873 assert_eq!(holder, Some(our_pid));
874 }
875
876 #[test]
877 fn test_shared_lock_multiple_readers() {
878 let dir = TempDir::new().unwrap();
879 let db_path = dir.path();
880
881 let lock1 = RwDatabaseLock::acquire_shared(db_path);
883 let lock2 = RwDatabaseLock::acquire_shared(db_path);
884
885 assert!(lock1.is_ok());
886 assert!(lock2.is_ok());
887 }
888
889 #[test]
890 fn test_exclusive_blocks_shared() {
891 let dir = TempDir::new().unwrap();
892 let db_path = dir.path();
893
894 let _exclusive = RwDatabaseLock::acquire_exclusive(db_path).unwrap();
896
897 let shared = RwDatabaseLock::acquire_with_mode(
899 db_path,
900 ConnectionMode::ReadOnly,
901 &LockConfig::no_wait(),
902 );
903
904 assert!(matches!(shared, Err(LockError::DatabaseLocked { .. })));
905 }
906}